VulkanBenchmarks: refactor TriangleBenchmark into a more reusable base class

* Renamed TriangleBenchmark to DrawBenchmark
* Added a set of hooks (virtual functions) that are invoked at certain
points during execution of DrawBenchmark::init and renderFrame. These
functions are prefixed with "do" to distinguish them.
* Added some resource management functions in DrawBenchmark, such as
addImage and addSampler for child classes to create these resources. We
want these owned by the base class so that resources can be properly
disposed of in the right order.
* Removed enum class FragShadeType, and replaced with three classes
derived from DrawBenchmark that implement the necessary hooks.
* DrawBenchmark tracks the number of vertices provided via
addVertexBuffer so that it can pass it to vk::CommandBuffer::draw().
Derived types can therefore provide as many (triangle) vertices as they
like from doCreateVertexBuffers.

Bug: b/176981107
Change-Id: I687f3d5ca09f7f93a3d6d7f68871a95bd083bf89
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/52108
Tested-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/tests/VulkanBenchmarks/VulkanBenchmarks.cpp b/tests/VulkanBenchmarks/VulkanBenchmarks.cpp
index 91e747e..2647640 100644
--- a/tests/VulkanBenchmarks/VulkanBenchmarks.cpp
+++ b/tests/VulkanBenchmarks/VulkanBenchmarks.cpp
@@ -32,6 +32,18 @@
 public:
 	VulkanBenchmark()
 	{
+	}
+
+	virtual ~VulkanBenchmark()
+	{
+		device.waitIdle();
+		device.destroy(nullptr);
+		instance.destroy(nullptr);
+	}
+
+	// Call once after construction so that virtual functions may be called during init
+	void initialize()
+	{
 		// TODO(b/158231104): Other platforms
 #if defined(_WIN32)
 		dl = std::make_unique<vk::DynamicLoader>("./vk_swiftshader.dll");
@@ -67,13 +79,6 @@
 		queue = device.getQueue(queueFamilyIndex, 0);
 	}
 
-	virtual ~VulkanBenchmark()
-	{
-		device.waitIdle();
-		device.destroy(nullptr);
-		instance.destroy(nullptr);
-	}
-
 private:
 	std::unique_ptr<vk::DynamicLoader> dl;
 
@@ -89,8 +94,10 @@
 class ClearImageBenchmark : public VulkanBenchmark
 {
 public:
-	ClearImageBenchmark(vk::Format clearFormat, vk::ImageAspectFlagBits clearAspect)
+	void initialize(vk::Format clearFormat, vk::ImageAspectFlagBits clearAspect)
 	{
+		VulkanBenchmark::initialize();
+
 		vk::ImageCreateInfo imageInfo;
 		imageInfo.imageType = vk::ImageType::e2D;
 		imageInfo.format = clearFormat;
@@ -188,7 +195,8 @@
 
 static void ClearImage(benchmark::State &state, vk::Format clearFormat, vk::ImageAspectFlagBits clearAspect)
 {
-	ClearImageBenchmark benchmark(clearFormat, clearAspect);
+	ClearImageBenchmark benchmark;
+	benchmark.initialize(clearFormat, clearAspect);
 
 	// Execute once to have the Reactor routine generated.
 	benchmark.clear();
@@ -199,26 +207,24 @@
 	}
 }
 
-enum class FragShadeType
-{
-	Solid,
-	Interpolate,
-	Sample,
-};
-
 enum class Multisample
 {
 	False,
 	True
 };
 
-class TriangleBenchmark : public VulkanBenchmark
+class DrawBenchmark : public VulkanBenchmark
 {
 public:
-	TriangleBenchmark(FragShadeType fragShadeType, Multisample multisample)
-	    : fragShadeType(fragShadeType)
-	    , multisample(multisample == Multisample::True)
+	DrawBenchmark(Multisample multisample)
+	    : multisample(multisample == Multisample::True)
 	{
+	}
+
+	void initialize()
+	{
+		VulkanBenchmark::initialize();
+
 		window.reset(new Window(instance, windowSize));
 		swapchain.reset(new Swapchain(physicalDevice, device, *window));
 
@@ -234,13 +240,16 @@
 		createCommandBuffers(renderPass);
 	}
 
-	~TriangleBenchmark()
+	~DrawBenchmark()
 	{
 		device.freeCommandBuffers(commandPool, commandBuffers);
 
 		device.destroyDescriptorPool(descriptorPool);
-		device.destroySampler(sampler, nullptr);
-		texture.reset();
+		for(auto &sampler : samplers)
+		{
+			device.destroySampler(sampler, nullptr);
+		}
+		images.clear();
 		device.destroyCommandPool(commandPool, nullptr);
 
 		for(auto &fence : waitFences)
@@ -297,7 +306,7 @@
 		window->show();
 	}
 
-protected:
+private:
 	void createSynchronizationPrimitives()
 	{
 		vk::SemaphoreCreateInfo semaphoreCreateInfo;
@@ -321,36 +330,8 @@
 		commandPool = device.createCommandPool(commandPoolCreateInfo);
 
 		std::vector<vk::DescriptorSet> descriptorSets;
-		if(fragShadeType == FragShadeType::Sample)
+		if(descriptorSetLayout)
 		{
-			texture.reset(new Image(device, 16, 16, vk::Format::eR8G8B8A8Unorm));
-
-			// Fill texture with white
-			vk::DeviceSize bufferSize = 16 * 16 * 4;
-			Buffer buffer(device, bufferSize, vk::BufferUsageFlagBits::eTransferSrc);
-			void *data = buffer.mapMemory();
-			memset(data, 255, bufferSize);
-			buffer.unmapMemory();
-
-			Util::transitionImageLayout(device, commandPool, queue, texture->getImage(), vk::Format::eR8G8B8A8Unorm, vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal);
-			Util::copyBufferToImage(device, commandPool, queue, buffer.getBuffer(), texture->getImage(), 16, 16);
-			Util::transitionImageLayout(device, commandPool, queue, texture->getImage(), vk::Format::eR8G8B8A8Unorm, vk::ImageLayout::eTransferDstOptimal, vk::ImageLayout::eShaderReadOnlyOptimal);
-
-			vk::SamplerCreateInfo samplerInfo;
-			samplerInfo.magFilter = vk::Filter::eLinear;
-			samplerInfo.minFilter = vk::Filter::eLinear;
-			samplerInfo.addressModeU = vk::SamplerAddressMode::eRepeat;
-			samplerInfo.addressModeV = vk::SamplerAddressMode::eRepeat;
-			samplerInfo.addressModeW = vk::SamplerAddressMode::eRepeat;
-			samplerInfo.anisotropyEnable = VK_FALSE;
-			samplerInfo.unnormalizedCoordinates = VK_FALSE;
-			samplerInfo.mipmapMode = vk::SamplerMipmapMode::eLinear;
-			samplerInfo.mipLodBias = 0.0f;
-			samplerInfo.minLod = 0.0f;
-			samplerInfo.maxLod = 0.0f;
-
-			sampler = device.createSampler(samplerInfo);
-
 			std::array<vk::DescriptorPoolSize, 1> poolSizes = {};
 			poolSizes[0].type = vk::DescriptorType::eCombinedImageSampler;
 			poolSizes[0].descriptorCount = 1;
@@ -370,21 +351,7 @@
 
 			descriptorSets = device.allocateDescriptorSets(allocInfo);
 
-			vk::DescriptorImageInfo imageInfo;
-			imageInfo.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
-			imageInfo.imageView = texture->getImageView();
-			imageInfo.sampler = sampler;
-
-			std::array<vk::WriteDescriptorSet, 1> descriptorWrites = {};
-
-			descriptorWrites[0].dstSet = descriptorSets[0];
-			descriptorWrites[0].dstBinding = 1;
-			descriptorWrites[0].dstArrayElement = 0;
-			descriptorWrites[0].descriptorType = vk::DescriptorType::eCombinedImageSampler;
-			descriptorWrites[0].descriptorCount = 1;
-			descriptorWrites[0].pImageInfo = &imageInfo;
-
-			device.updateDescriptorSets(static_cast<uint32_t>(descriptorWrites.size()), descriptorWrites.data(), 0, nullptr);
+			doUpdateDescriptorSet(commandPool, descriptorSets[0]);
 		}
 
 		vk::CommandBufferAllocateInfo commandBufferAllocateInfo;
@@ -424,11 +391,11 @@
 				commandBuffers[i].bindDescriptorSets(vk::PipelineBindPoint::eGraphics, pipelineLayout, 0, 1, &descriptorSets[0], 0, nullptr);
 			}
 
-			// Draw a triangle
+			// Draw
 			commandBuffers[i].bindPipeline(vk::PipelineBindPoint::eGraphics, pipeline);
 			VULKAN_HPP_NAMESPACE::DeviceSize offset = 0;
 			commandBuffers[i].bindVertexBuffers(0, 1, &vertices.buffer, &offset);
-			commandBuffers[i].draw(3, 1, 0, 0);
+			commandBuffers[i].draw(vertices.numVertices, 1, 0, 0);
 
 			commandBuffers[i].endRenderPass();
 			commandBuffers[i].end();
@@ -437,64 +404,7 @@
 
 	void prepareVertices()
 	{
-		struct Vertex
-		{
-			float position[3];
-			float color[3];
-			float texCoord[2];
-		};
-
-		Vertex vertexBufferData[] = {
-			{ { 1.0f, 1.0f, 0.05f }, { 1.0f, 0.0f, 0.0f }, { 1.0f, 0.0f } },
-			{ { -1.0f, 1.0f, 0.5f }, { 0.0f, 1.0f, 0.0f }, { 0.0f, 1.0f } },
-			{ { 0.0f, -1.0f, 0.5f }, { 0.0f, 0.0f, 1.0f }, { 0.0f, 0.0f } }
-		};
-
-		vk::BufferCreateInfo vertexBufferInfo;
-		vertexBufferInfo.size = sizeof(vertexBufferData);
-		vertexBufferInfo.usage = vk::BufferUsageFlagBits::eVertexBuffer;
-		vertices.buffer = device.createBuffer(vertexBufferInfo);
-
-		vk::MemoryAllocateInfo memoryAllocateInfo;
-		vk::MemoryRequirements memoryRequirements = device.getBufferMemoryRequirements(vertices.buffer);
-		memoryAllocateInfo.allocationSize = memoryRequirements.size;
-		memoryAllocateInfo.memoryTypeIndex = Util::getMemoryTypeIndex(physicalDevice, memoryRequirements.memoryTypeBits, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-		vertices.memory = device.allocateMemory(memoryAllocateInfo);
-
-		void *data = device.mapMemory(vertices.memory, 0, VK_WHOLE_SIZE);
-		memcpy(data, vertexBufferData, sizeof(vertexBufferData));
-		device.unmapMemory(vertices.memory);
-		device.bindBufferMemory(vertices.buffer, vertices.memory, 0);
-
-		vertices.inputBinding.binding = 0;
-		vertices.inputBinding.stride = sizeof(Vertex);
-		vertices.inputBinding.inputRate = vk::VertexInputRate::eVertex;
-
-		switch(fragShadeType)
-		{
-			case FragShadeType::Solid:
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
-				break;
-
-			case FragShadeType::Interpolate:
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
-				break;
-
-			case FragShadeType::Sample:
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
-				vertices.inputAttributes.push_back(vk::VertexInputAttributeDescription(2, 0, vk::Format::eR32G32Sfloat, offsetof(Vertex, texCoord)));
-				break;
-
-			default:
-				assert(false && "Unhandled fragShadeType");
-		}
-
-		vertices.inputState.vertexBindingDescriptionCount = 1;
-		vertices.inputState.pVertexBindingDescriptions = &vertices.inputBinding;
-		vertices.inputState.vertexAttributeDescriptionCount = static_cast<uint32_t>(vertices.inputAttributes.size());
-		vertices.inputState.pVertexAttributeDescriptions = vertices.inputAttributes.data();
+		doCreateVertexBuffers();
 	}
 
 	void createFramebuffers(vk::RenderPass renderPass)
@@ -588,34 +498,16 @@
 		return device.createRenderPass(renderPassInfo);
 	}
 
-	vk::ShaderModule createShaderModule(const char *glslSource, EShLanguage glslLanguage)
-	{
-		auto spirv = Util::compileGLSLtoSPIRV(glslSource, glslLanguage);
-
-		vk::ShaderModuleCreateInfo moduleCreateInfo;
-		moduleCreateInfo.codeSize = spirv.size() * sizeof(uint32_t);
-		moduleCreateInfo.pCode = (uint32_t *)spirv.data();
-
-		return device.createShaderModule(moduleCreateInfo);
-	}
-
 	vk::Pipeline createGraphicsPipeline(vk::RenderPass renderPass)
 	{
+		auto setLayoutBindings = doCreateDescriptorSetLayouts();
+
 		std::vector<vk::DescriptorSetLayout> setLayouts;
-		if(fragShadeType == FragShadeType::Sample)
+		if(!setLayoutBindings.empty())
 		{
-			vk::DescriptorSetLayoutBinding samplerLayoutBinding;
-			samplerLayoutBinding.binding = 1;
-			samplerLayoutBinding.descriptorCount = 1;
-			samplerLayoutBinding.descriptorType = vk::DescriptorType::eCombinedImageSampler;
-			samplerLayoutBinding.pImmutableSamplers = nullptr;
-			samplerLayoutBinding.stageFlags = vk::ShaderStageFlagBits::eFragment;
-
-			std::array<vk::DescriptorSetLayoutBinding, 1> bindings = { samplerLayoutBinding };
 			vk::DescriptorSetLayoutCreateInfo layoutInfo;
-			layoutInfo.bindingCount = static_cast<uint32_t>(bindings.size());
-			layoutInfo.pBindings = bindings.data();
-
+			layoutInfo.bindingCount = static_cast<uint32_t>(setLayoutBindings.size());
+			layoutInfo.pBindings = setLayoutBindings.data();
 			descriptorSetLayout = device.createDescriptorSetLayout(layoutInfo);
 
 			setLayouts.push_back(descriptorSetLayout);
@@ -675,101 +567,11 @@
 		multisampleState.rasterizationSamples = multisample ? vk::SampleCountFlagBits::e4 : vk::SampleCountFlagBits::e1;
 		multisampleState.pSampleMask = nullptr;
 
-		const char *vertexShader = nullptr;
-		const char *fragmentShader = nullptr;
+		vk::ShaderModule vertexModule = doCreateVertexShader();
+		vk::ShaderModule fragmentModule = doCreateFragmentShader();
 
-		switch(fragShadeType)
-		{
-			case FragShadeType::Solid:
-			{
-				vertexShader = R"(#version 310 es
-					layout(location = 0) in vec3 inPos;
-
-					void main()
-					{
-						gl_Position = vec4(inPos.xyz, 1.0);
-					})";
-
-				fragmentShader = R"(#version 310 es
-					precision highp float;
-
-					layout(location = 0) out vec4 outColor;
-
-					void main()
-					{
-						outColor = vec4(1.0, 1.0, 1.0, 1.0);
-					})";
-			}
-			break;
-
-			case FragShadeType::Interpolate:
-			{
-				vertexShader = R"(#version 310 es
-					layout(location = 0) in vec3 inPos;
-					layout(location = 1) in vec3 inColor;
-
-					layout(location = 0) out vec3 outColor;
-
-					void main()
-					{
-						outColor = inColor;
-						gl_Position = vec4(inPos.xyz, 1.0);
-					})";
-
-				fragmentShader = R"(#version 310 es
-					precision highp float;
-
-					layout(location = 0) in vec3 inColor;
-
-					layout(location = 0) out vec4 outColor;
-
-					void main()
-					{
-						outColor = vec4(inColor, 1.0);
-					})";
-			}
-			break;
-
-			case FragShadeType::Sample:
-			{
-				vertexShader = R"(#version 310 es
-					layout(location = 0) in vec3 inPos;
-					layout(location = 1) in vec3 inColor;
-
-					layout(location = 0) out vec3 outColor;
-					layout(location = 1) out vec2 fragTexCoord;
-
-					void main()
-					{
-						outColor = inColor;
-						gl_Position = vec4(inPos.xyz, 1.0);
-						fragTexCoord = inPos.xy;
-					})";
-
-				fragmentShader = R"(#version 310 es
-					precision highp float;
-
-					layout(location = 0) in vec3 inColor;
-					layout(location = 1) in vec2 fragTexCoord;
-
-					layout(location = 0) out vec4 outColor;
-
-					layout(binding = 0) uniform sampler2D texSampler;
-
-					void main()
-					{
-						outColor = texture(texSampler, fragTexCoord) * vec4(inColor, 1.0);
-					})";
-			}
-			break;
-
-			default:
-				assert(false && "Unhandled fragShadeType");
-				break;
-		}
-
-		vk::ShaderModule vertexModule = createShaderModule(vertexShader, EShLanguage::EShLangVertex);
-		vk::ShaderModule fragmentModule = createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+		assert(vertexModule);    // TODO: if nullptr, use a default
+		assert(fragmentModule);  // TODO: if nullptr, use a default
 
 		std::array<vk::PipelineShaderStageCreateInfo, 2> shaderStages;
 
@@ -801,8 +603,130 @@
 		return pipeline;
 	}
 
+protected:
+	/////////////////////////
+	// Hooks
+	/////////////////////////
+
+	// Called from prepareVertices.
+	// Child type may call addVertexBuffer() from this function.
+	virtual void doCreateVertexBuffers() {}
+
+	// Called from createGraphicsPipeline.
+	// Child type may return vector of DescriptorSetLayoutBindings for which a DescriptorSetLayout
+	// will be created and stored in this->descriptorSetLayout.
+	virtual std::vector<vk::DescriptorSetLayoutBinding> doCreateDescriptorSetLayouts()
+	{
+		return {};
+	}
+
+	// Called from createGraphicsPipeline.
+	// Child type may call createShaderModule() and return the result.
+	virtual vk::ShaderModule doCreateVertexShader()
+	{
+		return nullptr;
+	}
+
+	// Called from createGraphicsPipeline.
+	// Child type may call createShaderModule() and return the result.
+	virtual vk::ShaderModule doCreateFragmentShader()
+	{
+		return nullptr;
+	}
+
+	// Called from createCommandBuffers.
+	// Child type may create resources (addImage, addSampler, etc.), and make sure to
+	// call device.updateDescriptorSets.
+	virtual void doUpdateDescriptorSet(vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet)
+	{
+	}
+
+	/////////////////////////
+	// Resource Management
+	/////////////////////////
+
+	// Call from doCreateFragmentShader()
+	vk::ShaderModule createShaderModule(const char *glslSource, EShLanguage glslLanguage)
+	{
+		auto spirv = Util::compileGLSLtoSPIRV(glslSource, glslLanguage);
+
+		vk::ShaderModuleCreateInfo moduleCreateInfo;
+		moduleCreateInfo.codeSize = spirv.size() * sizeof(uint32_t);
+		moduleCreateInfo.pCode = (uint32_t *)spirv.data();
+
+		return device.createShaderModule(moduleCreateInfo);
+	}
+
+	// Call from doCreateVertexBuffers()
+	template<typename VertexType>
+	void addVertexBuffer(VertexType *vertexBufferData, size_t vertexBufferDataSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes)
+	{
+		assert(!vertices.buffer);  // For now, only support adding once
+
+		vk::BufferCreateInfo vertexBufferInfo;
+		vertexBufferInfo.size = vertexBufferDataSize;
+		vertexBufferInfo.usage = vk::BufferUsageFlagBits::eVertexBuffer;
+		vertices.buffer = device.createBuffer(vertexBufferInfo);
+
+		vk::MemoryAllocateInfo memoryAllocateInfo;
+		vk::MemoryRequirements memoryRequirements = device.getBufferMemoryRequirements(vertices.buffer);
+		memoryAllocateInfo.allocationSize = memoryRequirements.size;
+		memoryAllocateInfo.memoryTypeIndex = Util::getMemoryTypeIndex(physicalDevice, memoryRequirements.memoryTypeBits, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+		vertices.memory = device.allocateMemory(memoryAllocateInfo);
+
+		void *data = device.mapMemory(vertices.memory, 0, VK_WHOLE_SIZE);
+		memcpy(data, vertexBufferData, vertexBufferDataSize);
+		device.unmapMemory(vertices.memory);
+		device.bindBufferMemory(vertices.buffer, vertices.memory, 0);
+
+		vertices.inputBinding.binding = 0;
+		vertices.inputBinding.stride = sizeof(VertexType);
+		vertices.inputBinding.inputRate = vk::VertexInputRate::eVertex;
+
+		vertices.inputAttributes = std::move(inputAttributes);
+
+		vertices.inputState.vertexBindingDescriptionCount = 1;
+		vertices.inputState.pVertexBindingDescriptions = &vertices.inputBinding;
+		vertices.inputState.vertexAttributeDescriptionCount = static_cast<uint32_t>(vertices.inputAttributes.size());
+		vertices.inputState.pVertexAttributeDescriptions = vertices.inputAttributes.data();
+
+		// Note that we assume data is tightly packed
+		vertices.numVertices = static_cast<uint32_t>(vertexBufferDataSize / sizeof(VertexType));
+	}
+
+	template<typename T>
+	struct Resource
+	{
+		size_t id;
+		T &obj;
+	};
+
+	template<typename... Args>
+	Resource<Image> addImage(Args &&... args)
+	{
+		images.emplace_back(std::make_unique<Image>(std::forward<Args>(args)...));
+		return { images.size() - 1, *images.back() };
+	}
+
+	Image &getImageById(size_t id)
+	{
+		return *images[id].get();
+	}
+
+	Resource<vk::Sampler> addSampler(const vk::SamplerCreateInfo &samplerCreateInfo)
+	{
+		auto sampler = device.createSampler(samplerCreateInfo);
+		samplers.push_back(sampler);
+		return { samplers.size() - 1, sampler };
+	}
+
+	vk::Sampler &getSamplerById(size_t id)
+	{
+		return samplers[id];
+	}
+
+private:
 	const vk::Extent2D windowSize = { 1280, 720 };
-	const FragShadeType fragShadeType;
 	const bool multisample;
 
 	std::unique_ptr<Window> window;
@@ -820,6 +744,8 @@
 		vk::VertexInputBindingDescription inputBinding;
 		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
 		vk::PipelineVertexInputStateCreateInfo inputState;
+
+		uint32_t numVertices;
 	} vertices;
 
 	vk::DescriptorSetLayout descriptorSetLayout;  // Owning handle
@@ -830,16 +756,270 @@
 	vk::Semaphore renderCompleteSemaphore;   // Owning handle
 	std::vector<vk::Fence> waitFences;       // Owning handles
 
-	vk::CommandPool commandPool;  // Owning handle
-	std::unique_ptr<Image> texture;
-	vk::Sampler sampler;                            // Owning handle
-	vk::DescriptorPool descriptorPool;              // Owning handle
+	vk::CommandPool commandPool;        // Owning handle
+	vk::DescriptorPool descriptorPool;  // Owning handle
+
+	// Resources
+	std::vector<std::unique_ptr<Image>> images;
+	std::vector<vk::Sampler> samplers;  // Owning handles
+
 	std::vector<vk::CommandBuffer> commandBuffers;  // Owning handles
 };
 
-static void Triangle(benchmark::State &state, FragShadeType fragShadeType, Multisample multisample)
+class TriangleSolidColorBenchmark : public DrawBenchmark
 {
-	TriangleBenchmark benchmark(fragShadeType, multisample);
+public:
+	TriangleSolidColorBenchmark(Multisample multisample)
+	    : DrawBenchmark(multisample)
+	{}
+
+protected:
+	void doCreateVertexBuffers() override
+	{
+		struct Vertex
+		{
+			float position[3];
+		};
+
+		Vertex vertexBufferData[] = {
+			{ { 1.0f, 1.0f, 0.5f } },
+			{ { -1.0f, 1.0f, 0.5f } },
+			{ { 0.0f, -1.0f, 0.5f } }
+		};
+
+		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
+
+		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	}
+
+	vk::ShaderModule doCreateVertexShader() override
+	{
+		const char *vertexShader = R"(#version 310 es
+			layout(location = 0) in vec3 inPos;
+
+			void main()
+			{
+				gl_Position = vec4(inPos.xyz, 1.0);
+			})";
+
+		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	}
+
+	vk::ShaderModule doCreateFragmentShader() override
+	{
+		const char *fragmentShader = R"(#version 310 es
+			precision highp float;
+
+			layout(location = 0) out vec4 outColor;
+
+			void main()
+			{
+				outColor = vec4(1.0, 1.0, 1.0, 1.0);
+			})";
+
+		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	}
+};
+
+class TriangleInterpolateColorBenchmark : public DrawBenchmark
+{
+public:
+	TriangleInterpolateColorBenchmark(Multisample multisample)
+	    : DrawBenchmark(multisample)
+	{}
+
+protected:
+	void doCreateVertexBuffers() override
+	{
+		struct Vertex
+		{
+			float position[3];
+			float color[3];
+		};
+
+		Vertex vertexBufferData[] = {
+			{ { 1.0f, 1.0f, 0.05f }, { 1.0f, 0.0f, 0.0f } },
+			{ { -1.0f, 1.0f, 0.5f }, { 0.0f, 1.0f, 0.0f } },
+			{ { 0.0f, -1.0f, 0.5f }, { 0.0f, 0.0f, 1.0f } }
+		};
+
+		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
+
+		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	}
+
+	vk::ShaderModule doCreateVertexShader() override
+	{
+		const char *vertexShader = R"(#version 310 es
+			layout(location = 0) in vec3 inPos;
+			layout(location = 1) in vec3 inColor;
+
+			layout(location = 0) out vec3 outColor;
+
+			void main()
+			{
+				outColor = inColor;
+				gl_Position = vec4(inPos.xyz, 1.0);
+			})";
+
+		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	}
+
+	vk::ShaderModule doCreateFragmentShader() override
+	{
+		const char *fragmentShader = R"(#version 310 es
+			precision highp float;
+
+			layout(location = 0) in vec3 inColor;
+
+			layout(location = 0) out vec4 outColor;
+
+			void main()
+			{
+				outColor = vec4(inColor, 1.0);
+			})";
+
+		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	}
+};
+
+class TriangleSampleTextureBenchmark : public DrawBenchmark
+{
+public:
+	TriangleSampleTextureBenchmark(Multisample multisample)
+	    : DrawBenchmark(multisample)
+	{}
+
+protected:
+	void doCreateVertexBuffers() override
+	{
+		struct Vertex
+		{
+			float position[3];
+			float color[3];
+			float texCoord[2];
+		};
+
+		Vertex vertexBufferData[] = {
+			{ { 1.0f, 1.0f, 0.5f }, { 1.0f, 0.0f, 0.0f }, { 1.0f, 0.0f } },
+			{ { -1.0f, 1.0f, 0.5f }, { 0.0f, 1.0f, 0.0f }, { 0.0f, 1.0f } },
+			{ { 0.0f, -1.0f, 0.5f }, { 0.0f, 0.0f, 1.0f }, { 0.0f, 0.0f } }
+		};
+
+		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(2, 0, vk::Format::eR32G32Sfloat, offsetof(Vertex, texCoord)));
+
+		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	}
+
+	vk::ShaderModule doCreateVertexShader() override
+	{
+		const char *vertexShader = R"(#version 310 es
+			layout(location = 0) in vec3 inPos;
+			layout(location = 1) in vec3 inColor;
+
+			layout(location = 0) out vec3 outColor;
+			layout(location = 1) out vec2 fragTexCoord;
+
+			void main()
+			{
+				outColor = inColor;
+				gl_Position = vec4(inPos.xyz, 1.0);
+				fragTexCoord = inPos.xy;
+			})";
+
+		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	}
+
+	vk::ShaderModule doCreateFragmentShader() override
+	{
+		const char *fragmentShader = R"(#version 310 es
+			precision highp float;
+
+			layout(location = 0) in vec3 inColor;
+			layout(location = 1) in vec2 fragTexCoord;
+
+			layout(location = 0) out vec4 outColor;
+
+			layout(binding = 0) uniform sampler2D texSampler;
+
+			void main()
+			{
+				outColor = texture(texSampler, fragTexCoord) * vec4(inColor, 1.0);
+			})";
+
+		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	}
+
+	std::vector<vk::DescriptorSetLayoutBinding> doCreateDescriptorSetLayouts() override
+	{
+		vk::DescriptorSetLayoutBinding samplerLayoutBinding;
+		samplerLayoutBinding.binding = 1;
+		samplerLayoutBinding.descriptorCount = 1;
+		samplerLayoutBinding.descriptorType = vk::DescriptorType::eCombinedImageSampler;
+		samplerLayoutBinding.pImmutableSamplers = nullptr;
+		samplerLayoutBinding.stageFlags = vk::ShaderStageFlagBits::eFragment;
+
+		return { samplerLayoutBinding };
+	}
+
+	void doUpdateDescriptorSet(vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet) override
+	{
+		auto &texture = addImage(device, 16, 16, vk::Format::eR8G8B8A8Unorm).obj;
+
+		// Fill texture with white
+		vk::DeviceSize bufferSize = 16 * 16 * 4;
+		Buffer buffer(device, bufferSize, vk::BufferUsageFlagBits::eTransferSrc);
+		void *data = buffer.mapMemory();
+		memset(data, 255, bufferSize);
+		buffer.unmapMemory();
+
+		Util::transitionImageLayout(device, commandPool, queue, texture.getImage(), vk::Format::eR8G8B8A8Unorm, vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal);
+		Util::copyBufferToImage(device, commandPool, queue, buffer.getBuffer(), texture.getImage(), 16, 16);
+		Util::transitionImageLayout(device, commandPool, queue, texture.getImage(), vk::Format::eR8G8B8A8Unorm, vk::ImageLayout::eTransferDstOptimal, vk::ImageLayout::eShaderReadOnlyOptimal);
+
+		vk::SamplerCreateInfo samplerInfo;
+		samplerInfo.magFilter = vk::Filter::eLinear;
+		samplerInfo.minFilter = vk::Filter::eLinear;
+		samplerInfo.addressModeU = vk::SamplerAddressMode::eRepeat;
+		samplerInfo.addressModeV = vk::SamplerAddressMode::eRepeat;
+		samplerInfo.addressModeW = vk::SamplerAddressMode::eRepeat;
+		samplerInfo.anisotropyEnable = VK_FALSE;
+		samplerInfo.unnormalizedCoordinates = VK_FALSE;
+		samplerInfo.mipmapMode = vk::SamplerMipmapMode::eLinear;
+		samplerInfo.mipLodBias = 0.0f;
+		samplerInfo.minLod = 0.0f;
+		samplerInfo.maxLod = 0.0f;
+
+		auto sampler = addSampler(samplerInfo);
+
+		vk::DescriptorImageInfo imageInfo;
+		imageInfo.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
+		imageInfo.imageView = texture.getImageView();
+		imageInfo.sampler = sampler.obj;
+
+		std::array<vk::WriteDescriptorSet, 1> descriptorWrites = {};
+
+		descriptorWrites[0].dstSet = descriptorSet;
+		descriptorWrites[0].dstBinding = 1;
+		descriptorWrites[0].dstArrayElement = 0;
+		descriptorWrites[0].descriptorType = vk::DescriptorType::eCombinedImageSampler;
+		descriptorWrites[0].descriptorCount = 1;
+		descriptorWrites[0].pImageInfo = &imageInfo;
+
+		device.updateDescriptorSets(static_cast<uint32_t>(descriptorWrites.size()), descriptorWrites.data(), 0, nullptr);
+	}
+};
+
+template<typename T>
+static void RunBenchmark(benchmark::State &state, T &benchmark)
+{
+	benchmark.initialize();
 
 	if(false) benchmark.show();  // Enable for visual verification.
 
@@ -852,13 +1032,31 @@
 	}
 }
 
+static void TriangleSolidColor(benchmark::State &state, Multisample multisample)
+{
+	TriangleSolidColorBenchmark benchmark(multisample);
+	RunBenchmark(state, benchmark);
+}
+
+static void TriangleInterpolateColor(benchmark::State &state, Multisample multisample)
+{
+	TriangleInterpolateColorBenchmark benchmark(multisample);
+	RunBenchmark(state, benchmark);
+}
+
+static void TriangleSampleTexture(benchmark::State &state, Multisample multisample)
+{
+	TriangleSampleTextureBenchmark benchmark(multisample);
+	RunBenchmark(state, benchmark);
+}
+
 BENCHMARK_CAPTURE(ClearImage, VK_FORMAT_R8G8B8A8_UNORM, vk::Format::eR8G8B8A8Unorm, vk::ImageAspectFlagBits::eColor)->Unit(benchmark::kMillisecond);
 BENCHMARK_CAPTURE(ClearImage, VK_FORMAT_R32_SFLOAT, vk::Format::eR32Sfloat, vk::ImageAspectFlagBits::eColor)->Unit(benchmark::kMillisecond);
 BENCHMARK_CAPTURE(ClearImage, VK_FORMAT_D32_SFLOAT, vk::Format::eD32Sfloat, vk::ImageAspectFlagBits::eDepth)->Unit(benchmark::kMillisecond);
 
-BENCHMARK_CAPTURE(Triangle, Solid, FragShadeType::Solid, Multisample::False)->Unit(benchmark::kMillisecond);
-BENCHMARK_CAPTURE(Triangle, Interpolate, FragShadeType::Interpolate, Multisample::False)->Unit(benchmark::kMillisecond);
-BENCHMARK_CAPTURE(Triangle, Sample, FragShadeType::Sample, Multisample::False)->Unit(benchmark::kMillisecond);
-BENCHMARK_CAPTURE(Triangle, Solid_Multisample, FragShadeType::Solid, Multisample::True)->Unit(benchmark::kMillisecond);
-BENCHMARK_CAPTURE(Triangle, Interpolate_Multisample, FragShadeType::Interpolate, Multisample::True)->Unit(benchmark::kMillisecond);
-BENCHMARK_CAPTURE(Triangle, Sample_Multisample, FragShadeType::Sample, Multisample::True)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleSolidColor, TriangleSolidColor, Multisample::False)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleInterpolateColor, TriangleInterpolateColor, Multisample::False)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleSampleTexture, TriangleSampleTexture, Multisample::False)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleSolidColor, TriangleSolidColor_Multisample, Multisample::True)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleInterpolateColor, TriangleInterpolateColor_Multisample, Multisample::True)->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(TriangleSampleTexture, TriangleSampleTexture_Multisample, Multisample::True)->Unit(benchmark::kMillisecond);