Support VK_EXT_vertex_input_dynamic_state

Test: dEQP-VK.pipeline.*dyn*vertex_input*
Bug: angleproject:7162
Change-Id: I0d7ad8ec3d2c5ff47a2e013280f430222da3697a
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/72908
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Commit-Queue: Shahbaz Youssefi <syoussefi@google.com>
Tested-by: Shahbaz Youssefi <syoussefi@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Presubmit-Ready: Shahbaz Youssefi <syoussefi@google.com>
diff --git a/src/Device/Context.cpp b/src/Device/Context.cpp
index cb8131f..1cecc19 100644
--- a/src/Device/Context.cpp
+++ b/src/Device/Context.cpp
@@ -91,6 +91,37 @@
 	}
 }
 
+vk::InputsDynamicStateFlags ParseInputsDynamicStateFlags(const VkPipelineDynamicStateCreateInfo *dynamicStateCreateInfo)
+{
+	vk::InputsDynamicStateFlags dynamicStateFlags = {};
+
+	if(dynamicStateCreateInfo == nullptr)
+	{
+		return dynamicStateFlags;
+	}
+
+	for(uint32_t i = 0; i < dynamicStateCreateInfo->dynamicStateCount; i++)
+	{
+		VkDynamicState dynamicState = dynamicStateCreateInfo->pDynamicStates[i];
+		switch(dynamicState)
+		{
+		case VK_DYNAMIC_STATE_VERTEX_INPUT_BINDING_STRIDE:
+			dynamicStateFlags.dynamicVertexInputBindingStride = true;
+			break;
+		case VK_DYNAMIC_STATE_VERTEX_INPUT_EXT:
+			dynamicStateFlags.dynamicVertexInput = true;
+			dynamicStateFlags.dynamicVertexInputBindingStride = true;
+			break;
+
+		default:
+			// The rest of the dynamic state is handled by ParseDynamicStateFlags.
+			break;
+		}
+	}
+
+	return dynamicStateFlags;
+}
+
 vk::DynamicStateFlags ParseDynamicStateFlags(const VkPipelineDynamicStateCreateInfo *dynamicStateCreateInfo)
 {
 	vk::DynamicStateFlags dynamicStateFlags = {};
@@ -119,7 +150,8 @@
 			dynamicStateFlags.vertexInputInterface.dynamicPrimitiveTopology = true;
 			break;
 		case VK_DYNAMIC_STATE_VERTEX_INPUT_BINDING_STRIDE:
-			dynamicStateFlags.vertexInputInterface.dynamicVertexInputBindingStride = true;
+		case VK_DYNAMIC_STATE_VERTEX_INPUT_EXT:
+			// Handled by ParseInputsDynamicStateFlags
 			break;
 
 		// Pre-rasterization:
@@ -281,7 +313,7 @@
 	}
 }
 
-void Inputs::initialize(const VkPipelineVertexInputStateCreateInfo *vertexInputState)
+void Inputs::initialize(const VkPipelineVertexInputStateCreateInfo *vertexInputState, const VkPipelineDynamicStateCreateInfo *dynamicStateCreateInfo)
 {
 	if(vertexInputState->flags != 0)
 	{
@@ -289,6 +321,13 @@
 		UNSUPPORTED("vertexInputState->flags");
 	}
 
+	dynamicStateFlags = ParseInputsDynamicStateFlags(dynamicStateCreateInfo);
+
+	if(dynamicStateFlags.dynamicVertexInput)
+	{
+		return;
+	}
+
 	// Temporary in-binding-order representation of buffer strides, to be consumed below
 	// when considering attributes. TODO: unfuse buffers from attributes in backend, is old GL model.
 	uint32_t vertexStrides[MAX_VERTEX_INPUT_BINDINGS];
@@ -310,8 +349,14 @@
 		input.offset = desc.offset;
 		input.binding = desc.binding;
 		input.inputRate = inputRates[desc.binding];
-		input.vertexStride = vertexStrides[desc.binding];
-		input.instanceStride = instanceStrides[desc.binding];
+		if(!dynamicStateFlags.dynamicVertexInputBindingStride)
+		{
+			// The following gets overriden with dynamic state anyway and setting it is
+			// harmless.  But it is not done to be able to catch bugs with this dynamic
+			// state easier.
+			input.vertexStride = vertexStrides[desc.binding];
+			input.instanceStride = instanceStrides[desc.binding];
+		}
 	}
 }
 
@@ -324,7 +369,7 @@
 	descriptorDynamicOffsets = ddo;
 }
 
-void Inputs::bindVertexInputs(int firstInstance, bool dynamicInstanceStride)
+void Inputs::bindVertexInputs(int firstInstance)
 {
 	for(uint32_t i = 0; i < MAX_VERTEX_INPUT_BINDINGS; i++)
 	{
@@ -333,7 +378,7 @@
 		{
 			const auto &vertexInput = vertexInputBindings[attrib.binding];
 			VkDeviceSize offset = attrib.offset + vertexInput.offset +
-			                      getInstanceStride(i, dynamicInstanceStride) * firstInstance;
+			                      getInstanceStride(i) * firstInstance;
 			attrib.buffer = vertexInput.buffer ? vertexInput.buffer->getOffsetPointer(offset) : nullptr;
 
 			VkDeviceSize size = vertexInput.buffer ? vertexInput.buffer->getSize() : 0;
@@ -342,21 +387,50 @@
 	}
 }
 
-void Inputs::setVertexInputBinding(const VertexInputBinding bindings[])
+void Inputs::setVertexInputBinding(const VertexInputBinding bindings[], const DynamicState &dynamicState)
 {
 	for(uint32_t i = 0; i < MAX_VERTEX_INPUT_BINDINGS; ++i)
 	{
 		vertexInputBindings[i] = bindings[i];
 	}
+
+	if(dynamicStateFlags.dynamicVertexInput)
+	{
+		// If the entire vertex input state is dynamic, recalculate the contents of `stream`.
+		// This is similar to Inputs::initialize.
+		for(uint32_t i = 0; i < sw::MAX_INTERFACE_COMPONENTS / 4; i++)
+		{
+			const auto &desc = dynamicState.vertexInputAttributes[i];
+			const auto &bindingDesc = dynamicState.vertexInputBindings[desc.binding];
+			sw::Stream &input = stream[i];
+			input.format = desc.format;
+			input.offset = desc.offset;
+			input.binding = desc.binding;
+			input.inputRate = bindingDesc.inputRate;
+		}
+	}
+
+	// Stride may come from two different dynamic states
+	if(dynamicStateFlags.dynamicVertexInput || dynamicStateFlags.dynamicVertexInputBindingStride)
+	{
+		for(uint32_t i = 0; i < sw::MAX_INTERFACE_COMPONENTS / 4; i++)
+		{
+			sw::Stream &input = stream[i];
+			const VkDeviceSize stride = dynamicState.vertexInputBindings[input.binding].stride;
+
+			input.vertexStride = input.inputRate == VK_VERTEX_INPUT_RATE_VERTEX ? stride : 0;
+			input.instanceStride = input.inputRate == VK_VERTEX_INPUT_RATE_INSTANCE ? stride : 0;
+		}
+	}
 }
 
-void Inputs::advanceInstanceAttributes(bool dynamicInstanceStride)
+void Inputs::advanceInstanceAttributes()
 {
 	for(uint32_t i = 0; i < vk::MAX_VERTEX_INPUT_BINDINGS; i++)
 	{
 		auto &attrib = stream[i];
 
-		VkDeviceSize instanceStride = getInstanceStride(i, dynamicInstanceStride);
+		VkDeviceSize instanceStride = getInstanceStride(i);
 		if((attrib.format != VK_FORMAT_UNDEFINED) && instanceStride && (instanceStride < attrib.robustnessSize))
 		{
 			// Under the casts: attrib.buffer += instanceStride
@@ -366,37 +440,23 @@
 	}
 }
 
-VkDeviceSize Inputs::getVertexStride(uint32_t i, bool dynamicVertexStride) const
+VkDeviceSize Inputs::getVertexStride(uint32_t i) const
 {
 	auto &attrib = stream[i];
-	if(attrib.format != VK_FORMAT_UNDEFINED && attrib.inputRate == VK_VERTEX_INPUT_RATE_VERTEX)
+	if(attrib.format != VK_FORMAT_UNDEFINED)
 	{
-		if(dynamicVertexStride)
-		{
-			return vertexInputBindings[attrib.binding].stride;
-		}
-		else
-		{
-			return attrib.vertexStride;
-		}
+		return attrib.vertexStride;
 	}
 
 	return 0;
 }
 
-VkDeviceSize Inputs::getInstanceStride(uint32_t i, bool dynamicInstanceStride) const
+VkDeviceSize Inputs::getInstanceStride(uint32_t i) const
 {
 	auto &attrib = stream[i];
-	if(attrib.format != VK_FORMAT_UNDEFINED && attrib.inputRate == VK_VERTEX_INPUT_RATE_INSTANCE)
+	if(attrib.format != VK_FORMAT_UNDEFINED)
 	{
-		if(dynamicInstanceStride)
-		{
-			return vertexInputBindings[attrib.binding].stride;
-		}
-		else
-		{
-			return attrib.instanceStride;
-		}
+		return attrib.instanceStride;
 	}
 
 	return 0;
diff --git a/src/Device/Context.hpp b/src/Device/Context.hpp
index fc0f911..e859aff 100644
--- a/src/Device/Context.hpp
+++ b/src/Device/Context.hpp
@@ -32,12 +32,68 @@
 class PipelineLayout;
 class RenderPass;
 
+struct InputsDynamicStateFlags
+{
+	bool dynamicVertexInputBindingStride : 1;
+        bool dynamicVertexInput : 1;
+};
+
+// Note: The split between Inputs and VertexInputInterfaceState is mostly superficial.  The state
+// (be it dynamic or static) in Inputs should have been mostly a part of VertexInputInterfaceState.
+// Changing that requires some surgery.
+struct VertexInputInterfaceDynamicStateFlags
+{
+	bool dynamicPrimitiveRestartEnable : 1;
+	bool dynamicPrimitiveTopology : 1;
+};
+
+struct PreRasterizationDynamicStateFlags
+{
+	bool dynamicLineWidth : 1;
+	bool dynamicDepthBias : 1;
+	bool dynamicDepthBiasEnable : 1;
+	bool dynamicCullMode : 1;
+	bool dynamicFrontFace : 1;
+	bool dynamicViewport : 1;
+	bool dynamicScissor : 1;
+	bool dynamicViewportWithCount : 1;
+	bool dynamicScissorWithCount : 1;
+	bool dynamicRasterizerDiscardEnable : 1;
+};
+
+struct FragmentDynamicStateFlags
+{
+	bool dynamicDepthTestEnable : 1;
+	bool dynamicDepthWriteEnable : 1;
+	bool dynamicDepthBoundsTestEnable : 1;
+	bool dynamicDepthBounds : 1;
+	bool dynamicDepthCompareOp : 1;
+	bool dynamicStencilTestEnable : 1;
+	bool dynamicStencilOp : 1;
+	bool dynamicStencilCompareMask : 1;
+	bool dynamicStencilWriteMask : 1;
+	bool dynamicStencilReference : 1;
+};
+
+struct FragmentOutputInterfaceDynamicStateFlags
+{
+	bool dynamicBlendConstants : 1;
+};
+
+struct DynamicStateFlags
+{
+    // Note: InputsDynamicStateFlags is kept local to Inputs
+	VertexInputInterfaceDynamicStateFlags vertexInputInterface;
+	PreRasterizationDynamicStateFlags preRasterization;
+	FragmentDynamicStateFlags fragment;
+	FragmentOutputInterfaceDynamicStateFlags fragmentOutputInterface;
+};
+
 struct VertexInputBinding
 {
 	Buffer *buffer = nullptr;
 	VkDeviceSize offset = 0;
 	VkDeviceSize size = 0;
-	VkDeviceSize stride = 0;
 };
 
 struct IndexBuffer
@@ -63,9 +119,10 @@
 	VkFormat depthFormat() const;
 };
 
+struct DynamicState;
 struct Inputs
 {
-	void initialize(const VkPipelineVertexInputStateCreateInfo *vertexInputState);
+	void initialize(const VkPipelineVertexInputStateCreateInfo *vertexInputState, const VkPipelineDynamicStateCreateInfo *dynamicStateCreateInfo);
 
 	void updateDescriptorSets(const DescriptorSet::Array &dso,
 	                          const DescriptorSet::Bindings &ds,
@@ -75,13 +132,14 @@
 	inline const DescriptorSet::DynamicOffsets &getDescriptorDynamicOffsets() const { return descriptorDynamicOffsets; }
 	inline const sw::Stream &getStream(uint32_t i) const { return stream[i]; }
 
-	void bindVertexInputs(int firstInstance, bool dynamicInstanceStride);
-	void setVertexInputBinding(const VertexInputBinding vertexInputBindings[]);
-	void advanceInstanceAttributes(bool dynamicInstanceStride);
-	VkDeviceSize getVertexStride(uint32_t i, bool dynamicVertexStride) const;
-	VkDeviceSize getInstanceStride(uint32_t i, bool dynamicVertexStride) const;
+	void bindVertexInputs(int firstInstance);
+	void setVertexInputBinding(const VertexInputBinding vertexInputBindings[], const DynamicState &dynamicState);
+	void advanceInstanceAttributes();
+	VkDeviceSize getVertexStride(uint32_t i) const;
+	VkDeviceSize getInstanceStride(uint32_t i) const;
 
 private:
+	InputsDynamicStateFlags dynamicStateFlags = {};
 	VertexInputBinding vertexInputBindings[MAX_VERTEX_INPUT_BINDINGS] = {};
 	DescriptorSet::Array descriptorSetObjects = {};
 	DescriptorSet::Bindings descriptorSets = {};
@@ -133,6 +191,20 @@
 	VkBlendOp blendOperationAlpha;
 };
 
+struct DynamicVertexInputBindingState
+{
+	VkVertexInputRate inputRate = VK_VERTEX_INPUT_RATE_VERTEX;
+	VkDeviceSize stride = 0;
+	unsigned int divisor = 0;
+};
+
+struct DynamicVertexInputAttributeState
+{
+	VkFormat format = VK_FORMAT_UNDEFINED;
+	unsigned int offset = 0;
+	unsigned int binding = 0;
+};
+
 struct DynamicState
 {
 	VkViewport viewport = {};
@@ -163,54 +235,8 @@
 	VkBool32 rasterizerDiscardEnable = VK_FALSE;
 	VkBool32 depthBiasEnable = VK_FALSE;
 	VkBool32 primitiveRestartEnable = VK_FALSE;
-};
-
-struct VertexInputInterfaceDynamicStateFlags
-{
-	bool dynamicPrimitiveRestartEnable : 1;
-	bool dynamicPrimitiveTopology : 1;
-	bool dynamicVertexInputBindingStride : 1;
-};
-
-struct PreRasterizationDynamicStateFlags
-{
-	bool dynamicLineWidth : 1;
-	bool dynamicDepthBias : 1;
-	bool dynamicDepthBiasEnable : 1;
-	bool dynamicCullMode : 1;
-	bool dynamicFrontFace : 1;
-	bool dynamicViewport : 1;
-	bool dynamicScissor : 1;
-	bool dynamicViewportWithCount : 1;
-	bool dynamicScissorWithCount : 1;
-	bool dynamicRasterizerDiscardEnable : 1;
-};
-
-struct FragmentDynamicStateFlags
-{
-	bool dynamicDepthTestEnable : 1;
-	bool dynamicDepthWriteEnable : 1;
-	bool dynamicDepthBoundsTestEnable : 1;
-	bool dynamicDepthBounds : 1;
-	bool dynamicDepthCompareOp : 1;
-	bool dynamicStencilTestEnable : 1;
-	bool dynamicStencilOp : 1;
-	bool dynamicStencilCompareMask : 1;
-	bool dynamicStencilWriteMask : 1;
-	bool dynamicStencilReference : 1;
-};
-
-struct FragmentOutputInterfaceDynamicStateFlags
-{
-	bool dynamicBlendConstants : 1;
-};
-
-struct DynamicStateFlags
-{
-	VertexInputInterfaceDynamicStateFlags vertexInputInterface;
-	PreRasterizationDynamicStateFlags preRasterization;
-	FragmentDynamicStateFlags fragment;
-	FragmentOutputInterfaceDynamicStateFlags fragmentOutputInterface;
+	DynamicVertexInputBindingState vertexInputBindings[MAX_VERTEX_INPUT_BINDINGS];
+	DynamicVertexInputAttributeState vertexInputAttributes[sw::MAX_INTERFACE_COMPONENTS / 4];
 };
 
 struct VertexInputInterfaceState
@@ -224,7 +250,6 @@
 	inline VkPrimitiveTopology getTopology() const { return topology; }
 	inline bool hasPrimitiveRestartEnable() const { return primitiveRestartEnable; }
 
-	inline bool hasDynamicVertexStride() const { return dynamicStateFlags.dynamicVertexInputBindingStride; }
 	inline bool hasDynamicTopology() const { return dynamicStateFlags.dynamicPrimitiveTopology; }
 	inline bool hasDynamicPrimitiveRestartEnable() const { return dynamicStateFlags.dynamicPrimitiveRestartEnable; }
 
diff --git a/src/Device/Renderer.cpp b/src/Device/Renderer.cpp
index b79c740..b80ee46 100644
--- a/src/Device/Renderer.cpp
+++ b/src/Device/Renderer.cpp
@@ -278,7 +278,7 @@
 		const sw::Stream &stream = inputs.getStream(i);
 		data->input[i] = stream.buffer;
 		data->robustnessSize[i] = stream.robustnessSize;
-		data->stride[i] = inputs.getVertexStride(i, vertexInputInterfaceState.hasDynamicVertexStride());
+		data->stride[i] = inputs.getVertexStride(i);
 	}
 
 	data->indices = indexBuffer;
diff --git a/src/Vulkan/VkCommandBuffer.cpp b/src/Vulkan/VkCommandBuffer.cpp
index 5649c12..7699145 100644
--- a/src/Vulkan/VkCommandBuffer.cpp
+++ b/src/Vulkan/VkCommandBuffer.cpp
@@ -361,18 +361,23 @@
 class CmdVertexBufferBind : public vk::CommandBuffer::Command
 {
 public:
-	CmdVertexBufferBind(uint32_t binding, vk::Buffer *buffer, const VkDeviceSize offset, const VkDeviceSize size, const VkDeviceSize stride)
+	CmdVertexBufferBind(uint32_t binding, vk::Buffer *buffer, const VkDeviceSize offset, const VkDeviceSize size, const VkDeviceSize stride, bool hasStride)
 	    : binding(binding)
 	    , buffer(buffer)
 	    , offset(offset)
 	    , size(size)
 	    , stride(stride)
+	    , hasStride(hasStride)
 	{
 	}
 
 	void execute(vk::CommandBuffer::ExecutionState &executionState) override
 	{
-		executionState.vertexInputBindings[binding] = { buffer, offset, size, stride };
+		executionState.vertexInputBindings[binding] = { buffer, offset, size };
+		if(hasStride)
+		{
+			executionState.dynamicState.vertexInputBindings[binding].stride = stride;
+		}
 	}
 
 	std::string description() override { return "vkCmdVertexBufferBind()"; }
@@ -383,6 +388,7 @@
 	const VkDeviceSize offset;
 	const VkDeviceSize size;
 	const VkDeviceSize stride;
+	const bool hasStride;
 };
 
 class CmdIndexBufferBind : public vk::CommandBuffer::Command
@@ -397,7 +403,7 @@
 
 	void execute(vk::CommandBuffer::ExecutionState &executionState) override
 	{
-		executionState.indexBufferBinding = { buffer, offset, 0, 0 };
+		executionState.indexBufferBinding = { buffer, offset, 0 };
 		executionState.indexType = indexType;
 	}
 
@@ -911,6 +917,44 @@
 	const VkBool32 primitiveRestartEnable;
 };
 
+class CmdSetVertexInput : public vk::CommandBuffer::Command
+{
+public:
+	CmdSetVertexInput(uint32_t vertexBindingDescriptionCount,
+	                  const VkVertexInputBindingDescription2EXT *pVertexBindingDescriptions,
+	                  uint32_t vertexAttributeDescriptionCount,
+	                  const VkVertexInputAttributeDescription2EXT *pVertexAttributeDescriptions)
+	    :  // Note: the pNext values are unused, so this copy is currently safe.
+	    vertexBindingDescriptions(pVertexBindingDescriptions, pVertexBindingDescriptions + vertexBindingDescriptionCount)
+	    , vertexAttributeDescriptions(pVertexAttributeDescriptions, pVertexAttributeDescriptions + vertexAttributeDescriptionCount)
+	{}
+
+	void execute(vk::CommandBuffer::ExecutionState &executionState) override
+	{
+		for(const auto &desc : vertexBindingDescriptions)
+		{
+			vk::DynamicVertexInputBindingState &state = executionState.dynamicState.vertexInputBindings[desc.binding];
+			state.inputRate = desc.inputRate;
+			state.stride = desc.stride;
+			state.divisor = desc.divisor;
+		}
+
+		for(const auto &desc : vertexAttributeDescriptions)
+		{
+			vk::DynamicVertexInputAttributeState &state = executionState.dynamicState.vertexInputAttributes[desc.location];
+			state.format = desc.format;
+			state.offset = desc.offset;
+			state.binding = desc.binding;
+		}
+	}
+
+	std::string description() override { return "vkCmdSetVertexInputEXT()"; }
+
+private:
+	const std::vector<VkVertexInputBindingDescription2EXT> vertexBindingDescriptions;
+	const std::vector<VkVertexInputAttributeDescription2EXT> vertexAttributeDescriptions;
+};
+
 class CmdDrawBase : public vk::CommandBuffer::Command
 {
 public:
@@ -920,7 +964,6 @@
 		const auto &pipelineState = executionState.pipelineState[VK_PIPELINE_BIND_POINT_GRAPHICS];
 
 		auto *pipeline = static_cast<vk::GraphicsPipeline *>(pipelineState.pipeline);
-		bool hasDynamicVertexStride = pipeline->hasDynamicVertexStride();
 
 		vk::Attachments &attachments = pipeline->getAttachments();
 		executionState.bindAttachments(&attachments);
@@ -929,8 +972,8 @@
 		inputs.updateDescriptorSets(pipelineState.descriptorSetObjects,
 		                            pipelineState.descriptorSets,
 		                            pipelineState.descriptorDynamicOffsets);
-		inputs.setVertexInputBinding(executionState.vertexInputBindings);
-		inputs.bindVertexInputs(firstInstance, hasDynamicVertexStride);
+		inputs.setVertexInputBinding(executionState.vertexInputBindings, executionState.dynamicState);
+		inputs.bindVertexInputs(firstInstance);
 
 		if(indexed)
 		{
@@ -963,7 +1006,7 @@
 			if(instanceCount > 1)
 			{
 				UNOPTIMIZED("Optimize instancing to use a single draw call.");  // TODO(b/137740918)
-				inputs.advanceInstanceAttributes(hasDynamicVertexStride);
+				inputs.advanceInstanceAttributes();
 			}
 		}
 	}
@@ -1909,7 +1952,8 @@
 	{
 		addCommand<::CmdVertexBufferBind>(i + firstBinding, vk::Cast(pBuffers[i]), pOffsets[i],
 		                                  pSizes ? pSizes[i] : 0,
-		                                  pStrides ? pStrides[i] : 0);
+		                                  pStrides ? pStrides[i] : 0,
+		                                  pStrides);
 	}
 }
 
@@ -2085,6 +2129,15 @@
 	addCommand<::CmdSetPrimitiveRestartEnable>(primitiveRestartEnable);
 }
 
+void CommandBuffer::setVertexInput(uint32_t vertexBindingDescriptionCount,
+                                   const VkVertexInputBindingDescription2EXT *pVertexBindingDescriptions,
+                                   uint32_t vertexAttributeDescriptionCount,
+                                   const VkVertexInputAttributeDescription2EXT *pVertexAttributeDescriptions)
+{
+	addCommand<::CmdSetVertexInput>(vertexBindingDescriptionCount, pVertexBindingDescriptions,
+	                                vertexAttributeDescriptionCount, pVertexAttributeDescriptions);
+}
+
 void CommandBuffer::bindDescriptorSets(VkPipelineBindPoint pipelineBindPoint, const PipelineLayout *pipelineLayout,
                                        uint32_t firstSet, uint32_t descriptorSetCount, const VkDescriptorSet *pDescriptorSets,
                                        uint32_t dynamicOffsetCount, const uint32_t *pDynamicOffsets)
diff --git a/src/Vulkan/VkCommandBuffer.hpp b/src/Vulkan/VkCommandBuffer.hpp
index 37b3c1d..f7345c4 100644
--- a/src/Vulkan/VkCommandBuffer.hpp
+++ b/src/Vulkan/VkCommandBuffer.hpp
@@ -143,6 +143,10 @@
 	void setRasterizerDiscardEnable(VkBool32 rasterizerDiscardEnable);
 	void setDepthBiasEnable(VkBool32 depthBiasEnable);
 	void setPrimitiveRestartEnable(VkBool32 primitiveRestartEnable);
+	void setVertexInput(uint32_t vertexBindingDescriptionCount,
+			const VkVertexInputBindingDescription2EXT*  pVertexBindingDescriptions,
+			uint32_t vertexAttributeDescriptionCount,
+			const VkVertexInputAttributeDescription2EXT* pVertexAttributeDescriptions);
 	void bindDescriptorSets(VkPipelineBindPoint pipelineBindPoint, const PipelineLayout *layout,
 	                        uint32_t firstSet, uint32_t descriptorSetCount, const VkDescriptorSet *pDescriptorSets,
 	                        uint32_t dynamicOffsetCount, const uint32_t *pDynamicOffsets);
diff --git a/src/Vulkan/VkGetProcAddress.cpp b/src/Vulkan/VkGetProcAddress.cpp
index 8a8b759..aed7999 100644
--- a/src/Vulkan/VkGetProcAddress.cpp
+++ b/src/Vulkan/VkGetProcAddress.cpp
@@ -545,6 +545,12 @@
 	        MAKE_VULKAN_DEVICE_ENTRY(vkCmdSetStencilTestEnableEXT),
 	        MAKE_VULKAN_DEVICE_ENTRY(vkCmdSetViewportWithCountEXT),
 	    } },
+	// VK_EXT_vertex_input_dynamic_state
+	{
+	    VK_EXT_VERTEX_INPUT_DYNAMIC_STATE_EXTENSION_NAME,
+	    {
+	        MAKE_VULKAN_DEVICE_ENTRY(vkCmdSetVertexInputEXT),
+	    } },
 	// VK_EXT_line_rasterization
 	{
 	    VK_EXT_LINE_RASTERIZATION_EXTENSION_NAME,
diff --git a/src/Vulkan/VkPhysicalDevice.cpp b/src/Vulkan/VkPhysicalDevice.cpp
index 3866d91..3b447db 100644
--- a/src/Vulkan/VkPhysicalDevice.cpp
+++ b/src/Vulkan/VkPhysicalDevice.cpp
@@ -471,6 +471,11 @@
 	features->extendedDynamicState = VK_TRUE;
 }
 
+static void getPhysicalDeviceVertexInputDynamicStateFeaturesEXT(VkPhysicalDeviceVertexInputDynamicStateFeaturesEXT *features)
+{
+	features->vertexInputDynamicState = VK_TRUE;
+}
+
 static void getPhysicalDevice4444FormatsFeaturesEXT(VkPhysicalDevice4444FormatsFeaturesEXT *features)
 {
 	features->formatA4R4G4B4 = VK_TRUE;
@@ -598,6 +603,9 @@
 		case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTENDED_DYNAMIC_STATE_FEATURES_EXT:
 			getPhysicalDeviceExtendedDynamicStateFeaturesEXT(reinterpret_cast<VkPhysicalDeviceExtendedDynamicStateFeaturesEXT *>(curExtension));
 			break;
+		case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VERTEX_INPUT_DYNAMIC_STATE_FEATURES_EXT:
+			getPhysicalDeviceVertexInputDynamicStateFeaturesEXT(reinterpret_cast<VkPhysicalDeviceVertexInputDynamicStateFeaturesEXT *>(curExtension));
+			break;
 		case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PRIVATE_DATA_FEATURES:
 			getPhysicalDevicePrivateDataFeatures(reinterpret_cast<VkPhysicalDevicePrivateDataFeatures *>(curExtension));
 			break;
@@ -1670,6 +1678,13 @@
 	return CheckFeature(requested, supported, extendedDynamicState);
 }
 
+bool PhysicalDevice::hasExtendedFeatures(const VkPhysicalDeviceVertexInputDynamicStateFeaturesEXT *requested) const
+{
+	auto supported = getSupportedFeatures(requested);
+
+	return CheckFeature(requested, supported, vertexInputDynamicState);
+}
+
 bool PhysicalDevice::hasExtendedFeatures(const VkPhysicalDevicePrivateDataFeatures *requested) const
 {
 	auto supported = getSupportedFeatures(requested);
diff --git a/src/Vulkan/VkPhysicalDevice.hpp b/src/Vulkan/VkPhysicalDevice.hpp
index 60454e3..591e128 100644
--- a/src/Vulkan/VkPhysicalDevice.hpp
+++ b/src/Vulkan/VkPhysicalDevice.hpp
@@ -46,6 +46,7 @@
 	bool hasExtendedFeatures(const VkPhysicalDeviceDepthClipEnableFeaturesEXT *features) const;
 	bool hasExtendedFeatures(const VkPhysicalDeviceBlendOperationAdvancedFeaturesEXT *features) const;
 	bool hasExtendedFeatures(const VkPhysicalDeviceExtendedDynamicStateFeaturesEXT *features) const;
+	bool hasExtendedFeatures(const VkPhysicalDeviceVertexInputDynamicStateFeaturesEXT *features) const;
 	bool hasExtendedFeatures(const VkPhysicalDevicePrivateDataFeatures *features) const;
 	bool hasExtendedFeatures(const VkPhysicalDeviceTextureCompressionASTCHDRFeatures *features) const;
 	bool hasExtendedFeatures(const VkPhysicalDeviceShaderDemoteToHelperInvocationFeatures *features) const;
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index e292044..6d93e24 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -359,7 +359,7 @@
 	}
 	if(state.hasVertexInputInterfaceState() && !vertexInputInterfaceInLibraries)
 	{
-		inputs.initialize(pCreateInfo->pVertexInputState);
+		inputs.initialize(pCreateInfo->pVertexInputState, pCreateInfo->pDynamicState);
 	}
 }
 
diff --git a/src/Vulkan/VkPipeline.hpp b/src/Vulkan/VkPipeline.hpp
index 650563f..8c0c3c2 100644
--- a/src/Vulkan/VkPipeline.hpp
+++ b/src/Vulkan/VkPipeline.hpp
@@ -96,7 +96,6 @@
 	const GraphicsState &getState() const { return state; }
 
 	void getIndexBuffers(const vk::DynamicState &dynamicState, uint32_t count, uint32_t first, bool indexed, std::vector<std::pair<uint32_t, void *>> *indexBuffers) const;
-	bool hasDynamicVertexStride() const { return state.getVertexInputInterfaceState().hasDynamicVertexStride(); }
 
 	IndexBuffer &getIndexBuffer() { return indexBuffer; }
 	const IndexBuffer &getIndexBuffer() const { return indexBuffer; }
diff --git a/src/Vulkan/libVulkan.cpp b/src/Vulkan/libVulkan.cpp
index 30b34b6..d846234 100644
--- a/src/Vulkan/libVulkan.cpp
+++ b/src/Vulkan/libVulkan.cpp
@@ -471,6 +471,7 @@
 	{ { VK_EXT_PIPELINE_ROBUSTNESS_EXTENSION_NAME, VK_EXT_PIPELINE_ROBUSTNESS_SPEC_VERSION } },
 	{ { VK_EXT_RASTERIZATION_ORDER_ATTACHMENT_ACCESS_EXTENSION_NAME, VK_EXT_RASTERIZATION_ORDER_ATTACHMENT_ACCESS_SPEC_VERSION } },
 	{ { VK_EXT_HOST_IMAGE_COPY_EXTENSION_NAME, VK_EXT_HOST_IMAGE_COPY_SPEC_VERSION } },
+	{ { VK_EXT_VERTEX_INPUT_DYNAMIC_STATE_EXTENSION_NAME, VK_EXT_VERTEX_INPUT_DYNAMIC_STATE_SPEC_VERSION } },
 };
 
 static uint32_t numSupportedExtensions(const ExtensionProperties *extensionProperties, uint32_t extensionPropertiesCount)
@@ -1041,6 +1042,16 @@
 				}
 			}
 			break;
+		case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VERTEX_INPUT_DYNAMIC_STATE_FEATURES_EXT:
+			{
+				const auto *dynamicStateFeatures = reinterpret_cast<const VkPhysicalDeviceVertexInputDynamicStateFeaturesEXT *>(extensionCreateInfo);
+				bool hasFeatures = vk::Cast(physicalDevice)->hasExtendedFeatures(dynamicStateFeatures);
+				if(!hasFeatures)
+				{
+					return VK_ERROR_FEATURE_NOT_PRESENT;
+				}
+			}
+			break;
 		case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PRIVATE_DATA_FEATURES:
 			{
 				const auto *privateDataFeatures = reinterpret_cast<const VkPhysicalDevicePrivateDataFeatures *>(extensionCreateInfo);
@@ -3052,6 +3063,17 @@
 	vk::Cast(commandBuffer)->setPrimitiveRestartEnable(primitiveRestartEnable);
 }
 
+VKAPI_ATTR void VKAPI_CALL vkCmdSetVertexInputEXT(VkCommandBuffer commandBuffer, uint32_t vertexBindingDescriptionCount,
+                                                  const VkVertexInputBindingDescription2EXT *pVertexBindingDescriptions,
+                                                  uint32_t vertexAttributeDescriptionCount,
+                                                  const VkVertexInputAttributeDescription2EXT *pVertexAttributeDescriptions)
+{
+	TRACE("(VkCommandBuffer commandBuffer = %p, uint32_t vertexBindingDescriptionCount = %d, const VkVertexInputBindingDescription2EXT *pVertexBindingDescriptions = %p, uint32_t vertexAttributeDescriptionCount = %d, const VkVertexInputAttributeDescription2EXT *pVertexAttributeDescriptions = %p)",
+	      commandBuffer, vertexBindingDescriptionCount, pVertexBindingDescriptions, vertexAttributeDescriptionCount, pVertexAttributeDescriptions);
+
+	vk::Cast(commandBuffer)->setVertexInput(vertexBindingDescriptionCount, pVertexBindingDescriptions, vertexAttributeDescriptionCount, pVertexAttributeDescriptions);
+}
+
 VKAPI_ATTR void VKAPI_CALL vkCmdDraw(VkCommandBuffer commandBuffer, uint32_t vertexCount, uint32_t instanceCount, uint32_t firstVertex, uint32_t firstInstance)
 {
 	TRACE("(VkCommandBuffer commandBuffer = %p, uint32_t vertexCount = %d, uint32_t instanceCount = %d, uint32_t firstVertex = %d, uint32_t firstInstance = %d)",