Make vk-unittests use VulkanWrapper

This change moves the VulkanBenchmark and DrawBenchmark classes to
VulkanWrapper so that they can be used from other unit tests -- namely,
vk-unittests. In doing so, it became clear that using these as base
classes wasn't great for writing googletests, as text fixtures are
classes themselves, and this resulted in messy multiple inheritance. So
I modified the two classes to use callback registration instead of
virtual functions.

Apart from reworking existing tests (e.g. see TriangleBenchmark.cpp), I
also added a new DrawTests.cpp to vk-unittests with a unit test to make
sure we don't crash when leaving out "gl_Position", a bug that sugoi@
fixed in swiftshader-cl/51808. This is a good example of how easy it can
be to write such unit tests now.

List of changes:

* Moved VulkanBenchmark and DrawBenchmark to VulkanWrapper, and renamed
  VulkanTester and DrawTester respectively.
* ClearImageBenchmark refactored to aggregate a VulkanTester. This is an
  example where using a class is fine as we can still use the testers
  via aggregation.
* TriangleBenchmark tests refactored to use DrawTester and register
  callbacks.
* Moved compute tests to a ComputeTests.cpp.
* Moved the other tests to BasicTests.cpp.
* Added DrawTests.cpp with new DrawTests.VertexShaderNoPositionOutput
  test.
* CMake: add VulkanWrapper target for unittests as well as benchmarks.
* CMake: change FOLDER to better organize the tests and benchmarks for
  VS.

Bug: b/176981107
Change-Id: Ib1a0b85b3df787d2e39da08930414f9a14954a73
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/52348
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Tested-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ea63657..00cf78b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -243,13 +243,20 @@
     endif()
 endfunction()
 
+if (SWIFTSHADER_BUILD_TESTS OR SWIFTSHADER_BUILD_BENCHMARKS)
+    set(BUILD_VULKAN_WRAPPER TRUE)
+endif()
+
+if (BUILD_VULKAN_WRAPPER)
+    InitSubmodule(glslang ${THIRD_PARTY_DIR}/glslang)
+endif()
+
 if (SWIFTSHADER_BUILD_TESTS)
     InitSubmodule(gtest ${THIRD_PARTY_DIR}/googletest)
 endif()
 
 if(SWIFTSHADER_BUILD_BENCHMARKS)
     InitSubmodule(benchmark::benchmark ${THIRD_PARTY_DIR}/benchmark)
-    InitSubmodule(glslang ${THIRD_PARTY_DIR}/glslang)
 endif()
 
 if(REACTOR_EMIT_DEBUG_INFO)
@@ -992,6 +999,13 @@
     endif()
 endif()
 
+if(BUILD_VULKAN_WRAPPER)
+    if (NOT TARGET glslang)
+        add_subdirectory(${THIRD_PARTY_DIR}/glslang)
+    endif()
+    add_subdirectory(${TESTS_DIR}/VulkanWrapper) # Add VulkanWrapper target
+endif()
+
 if(SWIFTSHADER_BUILD_TESTS)
     add_subdirectory(${TESTS_DIR}/ReactorUnitTests) # Add ReactorUnitTests target
     add_subdirectory(${TESTS_DIR}/GLESUnitTests) # Add gles-unittests target
@@ -1000,11 +1014,6 @@
 endif()
 
 if(SWIFTSHADER_BUILD_BENCHMARKS)
-    if (NOT TARGET glslang)
-        add_subdirectory(${THIRD_PARTY_DIR}/glslang)
-    endif()
-    add_subdirectory(${TESTS_DIR}/VulkanWrapper) # Add VulkanWrapper target
-
     if (NOT TARGET benchmark::benchmark)
         set(BENCHMARK_ENABLE_TESTING FALSE CACHE BOOL FALSE FORCE)
         add_subdirectory(${THIRD_PARTY_DIR}/benchmark)
diff --git a/tests/ReactorBenchmarks/CMakeLists.txt b/tests/ReactorBenchmarks/CMakeLists.txt
index 5f25670..5f85e3f 100644
--- a/tests/ReactorBenchmarks/CMakeLists.txt
+++ b/tests/ReactorBenchmarks/CMakeLists.txt
@@ -31,7 +31,7 @@
 )
 
 set_target_properties(ReactorBenchmarks PROPERTIES
-    FOLDER "Benchmarks"
+    FOLDER "Tests/Benchmarks"
     RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
 )
 
diff --git a/tests/SystemBenchmarks/CMakeLists.txt b/tests/SystemBenchmarks/CMakeLists.txt
index d411303..39a5a57 100644
--- a/tests/SystemBenchmarks/CMakeLists.txt
+++ b/tests/SystemBenchmarks/CMakeLists.txt
@@ -32,7 +32,7 @@
 )
 
 set_target_properties(system-benchmarks PROPERTIES
-    FOLDER "Benchmarks"
+    FOLDER "Tests/Benchmarks"
     RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
 )
 
diff --git a/tests/VulkanBenchmarks/CMakeLists.txt b/tests/VulkanBenchmarks/CMakeLists.txt
index cbcdaed..62e699f 100644
--- a/tests/VulkanBenchmarks/CMakeLists.txt
+++ b/tests/VulkanBenchmarks/CMakeLists.txt
@@ -23,12 +23,8 @@
 
 set(VULKAN_BENCHMARKS_SRC_FILES
     ClearImageBenchmarks.cpp
-    DrawBenchmark.cpp
-    DrawBenchmark.hpp
     main.cpp
     TriangleBenchmarks.cpp
-    VulkanBenchmark.cpp
-    VulkanBenchmark.hpp
 )
 
 add_executable(VulkanBenchmarks
@@ -48,7 +44,7 @@
 )
 
 set_target_properties(VulkanBenchmarks PROPERTIES
-    FOLDER "Benchmarks"
+    FOLDER "Tests/Benchmarks"
     RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
 )
 
diff --git a/tests/VulkanBenchmarks/ClearImageBenchmarks.cpp b/tests/VulkanBenchmarks/ClearImageBenchmarks.cpp
index 0cc64d5..36d0449 100644
--- a/tests/VulkanBenchmarks/ClearImageBenchmarks.cpp
+++ b/tests/VulkanBenchmarks/ClearImageBenchmarks.cpp
@@ -12,17 +12,18 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "VulkanBenchmark.hpp"
+#include "VulkanTester.hpp"
 #include "benchmark/benchmark.h"
 
 #include <cassert>
 
-class ClearImageBenchmark : public VulkanBenchmark
+class ClearImageBenchmark
 {
 public:
 	void initialize(vk::Format clearFormat, vk::ImageAspectFlagBits clearAspect)
 	{
-		VulkanBenchmark::initialize();
+		tester.initialize();
+		auto &device = tester.getDevice();
 
 		vk::ImageCreateInfo imageInfo;
 		imageInfo.imageType = vk::ImageType::e2D;
@@ -48,7 +49,7 @@
 		device.bindImageMemory(image, memory, 0);
 
 		vk::CommandPoolCreateInfo commandPoolCreateInfo;
-		commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
+		commandPoolCreateInfo.queueFamilyIndex = tester.getQueueFamilyIndex();
 
 		commandPool = device.createCommandPool(commandPoolCreateInfo);
 
@@ -96,6 +97,7 @@
 
 	~ClearImageBenchmark()
 	{
+		auto &device = tester.getDevice();
 		device.freeCommandBuffers(commandPool, 1, &commandBuffer);
 		device.destroyCommandPool(commandPool, nullptr);
 		device.freeMemory(memory, nullptr);
@@ -104,6 +106,8 @@
 
 	void clear()
 	{
+		auto &queue = tester.getQueue();
+
 		vk::SubmitInfo submitInfo;
 		submitInfo.commandBufferCount = 1;
 		submitInfo.pCommandBuffers = &commandBuffer;
@@ -113,6 +117,7 @@
 	}
 
 private:
+	VulkanTester tester;
 	vk::Image image;                  // Owning handle
 	vk::DeviceMemory memory;          // Owning handle
 	vk::CommandPool commandPool;      // Owning handle
diff --git a/tests/VulkanBenchmarks/DrawBenchmark.hpp b/tests/VulkanBenchmarks/DrawBenchmark.hpp
deleted file mode 100644
index 2c30609..0000000
--- a/tests/VulkanBenchmarks/DrawBenchmark.hpp
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//    http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DRAW_BENCHMARK_HPP_
-#define DRAW_BENCHMARK_HPP_
-
-#include "Framebuffer.hpp"
-#include "Image.hpp"
-#include "Swapchain.hpp"
-#include "Util.hpp"
-#include "VulkanBenchmark.hpp"
-#include "Window.hpp"
-
-enum class Multisample
-{
-	False,
-	True
-};
-
-class DrawBenchmark : public VulkanBenchmark
-{
-public:
-	DrawBenchmark(Multisample multisample);
-	~DrawBenchmark();
-
-	void initialize();
-	void renderFrame();
-	void show();
-
-private:
-	void createSynchronizationPrimitives();
-	void createCommandBuffers(vk::RenderPass renderPass);
-	void prepareVertices();
-	void createFramebuffers(vk::RenderPass renderPass);
-	vk::RenderPass createRenderPass(vk::Format colorFormat);
-	vk::Pipeline createGraphicsPipeline(vk::RenderPass renderPass);
-	void addVertexBuffer(void *vertexBufferData, size_t vertexBufferDataSize, size_t vertexSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes);
-
-protected:
-	/////////////////////////
-	// Hooks
-	/////////////////////////
-
-	// Called from prepareVertices.
-	// Child type may call addVertexBuffer() from this function.
-	virtual void doCreateVertexBuffers() {}
-
-	// Called from createGraphicsPipeline.
-	// Child type may return vector of DescriptorSetLayoutBindings for which a DescriptorSetLayout
-	// will be created and stored in this->descriptorSetLayout.
-	virtual std::vector<vk::DescriptorSetLayoutBinding> doCreateDescriptorSetLayouts()
-	{
-		return {};
-	}
-
-	// Called from createGraphicsPipeline.
-	// Child type may call createShaderModule() and return the result.
-	virtual vk::ShaderModule doCreateVertexShader()
-	{
-		return nullptr;
-	}
-
-	// Called from createGraphicsPipeline.
-	// Child type may call createShaderModule() and return the result.
-	virtual vk::ShaderModule doCreateFragmentShader()
-	{
-		return nullptr;
-	}
-
-	// Called from createCommandBuffers.
-	// Child type may create resources (addImage, addSampler, etc.), and make sure to
-	// call device.updateDescriptorSets.
-	virtual void doUpdateDescriptorSet(vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet)
-	{
-	}
-
-	/////////////////////////
-	// Resource Management
-	/////////////////////////
-
-	// Call from doCreateFragmentShader()
-	vk::ShaderModule createShaderModule(const char *glslSource, EShLanguage glslLanguage);
-
-	// Call from doCreateVertexBuffers()
-	template<typename VertexType>
-	void addVertexBuffer(VertexType *vertexBufferData, size_t vertexBufferDataSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes)
-	{
-		addVertexBuffer(vertexBufferData, vertexBufferDataSize, sizeof(VertexType), std::move(inputAttributes));
-	}
-
-	template<typename T>
-	struct Resource
-	{
-		size_t id;
-		T &obj;
-	};
-
-	template<typename... Args>
-	Resource<Image> addImage(Args &&... args)
-	{
-		images.emplace_back(std::make_unique<Image>(std::forward<Args>(args)...));
-		return { images.size() - 1, *images.back() };
-	}
-
-	Image &getImageById(size_t id)
-	{
-		return *images[id].get();
-	}
-
-	Resource<vk::Sampler> addSampler(const vk::SamplerCreateInfo &samplerCreateInfo)
-	{
-		auto sampler = device.createSampler(samplerCreateInfo);
-		samplers.push_back(sampler);
-		return { samplers.size() - 1, sampler };
-	}
-
-	vk::Sampler &getSamplerById(size_t id)
-	{
-		return samplers[id];
-	}
-
-private:
-	const vk::Extent2D windowSize = { 1280, 720 };
-	const bool multisample;
-
-	std::unique_ptr<Window> window;
-	std::unique_ptr<Swapchain> swapchain;
-
-	vk::RenderPass renderPass;  // Owning handle
-	std::vector<std::unique_ptr<Framebuffer>> framebuffers;
-	uint32_t currentFrameBuffer = 0;
-
-	struct VertexBuffer
-	{
-		vk::Buffer buffer;        // Owning handle
-		vk::DeviceMemory memory;  // Owning handle
-
-		vk::VertexInputBindingDescription inputBinding;
-		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
-		vk::PipelineVertexInputStateCreateInfo inputState;
-
-		uint32_t numVertices;
-	} vertices;
-
-	vk::DescriptorSetLayout descriptorSetLayout;  // Owning handle
-	vk::PipelineLayout pipelineLayout;            // Owning handle
-	vk::Pipeline pipeline;                        // Owning handle
-
-	vk::Semaphore presentCompleteSemaphore;  // Owning handle
-	vk::Semaphore renderCompleteSemaphore;   // Owning handle
-	std::vector<vk::Fence> waitFences;       // Owning handles
-
-	vk::CommandPool commandPool;        // Owning handle
-	vk::DescriptorPool descriptorPool;  // Owning handle
-
-	// Resources
-	std::vector<std::unique_ptr<Image>> images;
-	std::vector<vk::Sampler> samplers;  // Owning handles
-
-	std::vector<vk::CommandBuffer> commandBuffers;  // Owning handles
-};
-
-#endif  // DRAW_BENCHMARK_HPP_
diff --git a/tests/VulkanBenchmarks/TriangleBenchmarks.cpp b/tests/VulkanBenchmarks/TriangleBenchmarks.cpp
index 832aa69..fc342a9 100644
--- a/tests/VulkanBenchmarks/TriangleBenchmarks.cpp
+++ b/tests/VulkanBenchmarks/TriangleBenchmarks.cpp
@@ -13,22 +13,33 @@
 // limitations under the License.
 
 #include "Buffer.hpp"
-#include "DrawBenchmark.hpp"
+#include "DrawTester.hpp"
 #include "benchmark/benchmark.h"
 
 #include <cassert>
 #include <vector>
 
-class TriangleSolidColorBenchmark : public DrawBenchmark
+template<typename T>
+static void RunBenchmark(benchmark::State &state, T &tester)
 {
-public:
-	TriangleSolidColorBenchmark(Multisample multisample)
-	    : DrawBenchmark(multisample)
-	{}
+	tester.initialize();
 
-protected:
-	void doCreateVertexBuffers() override
+	if(false) tester.show();  // Enable for visual verification.
+
+	// Warmup
+	tester.renderFrame();
+
+	for(auto _ : state)
 	{
+		tester.renderFrame();
+	}
+}
+
+static void TriangleSolidColor(benchmark::State &state, Multisample multisample)
+{
+	DrawTester tester(multisample);
+
+	tester.onCreateVertexBuffers([](DrawTester &tester) {
 		struct Vertex
 		{
 			float position[3];
@@ -43,11 +54,10 @@
 		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
 		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
 
-		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
-	}
+		tester.addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	});
 
-	vk::ShaderModule doCreateVertexShader() override
-	{
+	tester.onCreateVertexShader([](DrawTester &tester) {
 		const char *vertexShader = R"(#version 310 es
 			layout(location = 0) in vec3 inPos;
 
@@ -56,11 +66,10 @@
 				gl_Position = vec4(inPos.xyz, 1.0);
 			})";
 
-		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
-	}
+		return tester.createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	});
 
-	vk::ShaderModule doCreateFragmentShader() override
-	{
+	tester.onCreateFragmentShader([](DrawTester &tester) {
 		const char *fragmentShader = R"(#version 310 es
 			precision highp float;
 
@@ -71,20 +80,17 @@
 				outColor = vec4(1.0, 1.0, 1.0, 1.0);
 			})";
 
-		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
-	}
-};
+		return tester.createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	});
 
-class TriangleInterpolateColorBenchmark : public DrawBenchmark
+	RunBenchmark(state, tester);
+}
+
+static void TriangleInterpolateColor(benchmark::State &state, Multisample multisample)
 {
-public:
-	TriangleInterpolateColorBenchmark(Multisample multisample)
-	    : DrawBenchmark(multisample)
-	{}
+	DrawTester tester(multisample);
 
-protected:
-	void doCreateVertexBuffers() override
-	{
+	tester.onCreateVertexBuffers([](DrawTester &tester) {
 		struct Vertex
 		{
 			float position[3];
@@ -101,11 +107,10 @@
 		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
 		inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
 
-		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
-	}
+		tester.addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	});
 
-	vk::ShaderModule doCreateVertexShader() override
-	{
+	tester.onCreateVertexShader([](DrawTester &tester) {
 		const char *vertexShader = R"(#version 310 es
 			layout(location = 0) in vec3 inPos;
 			layout(location = 1) in vec3 inColor;
@@ -118,11 +123,10 @@
 				gl_Position = vec4(inPos.xyz, 1.0);
 			})";
 
-		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
-	}
+		return tester.createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	});
 
-	vk::ShaderModule doCreateFragmentShader() override
-	{
+	tester.onCreateFragmentShader([](DrawTester &tester) {
 		const char *fragmentShader = R"(#version 310 es
 			precision highp float;
 
@@ -135,20 +139,17 @@
 				outColor = vec4(inColor, 1.0);
 			})";
 
-		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
-	}
-};
+		return tester.createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	});
 
-class TriangleSampleTextureBenchmark : public DrawBenchmark
+	RunBenchmark(state, tester);
+}
+
+static void TriangleSampleTexture(benchmark::State &state, Multisample multisample)
 {
-public:
-	TriangleSampleTextureBenchmark(Multisample multisample)
-	    : DrawBenchmark(multisample)
-	{}
+	DrawTester tester(multisample);
 
-protected:
-	void doCreateVertexBuffers() override
-	{
+	tester.onCreateVertexBuffers([](DrawTester &tester) {
 		struct Vertex
 		{
 			float position[3];
@@ -167,11 +168,10 @@
 		inputAttributes.push_back(vk::VertexInputAttributeDescription(1, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, color)));
 		inputAttributes.push_back(vk::VertexInputAttributeDescription(2, 0, vk::Format::eR32G32Sfloat, offsetof(Vertex, texCoord)));
 
-		addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
-	}
+		tester.addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	});
 
-	vk::ShaderModule doCreateVertexShader() override
-	{
+	tester.onCreateVertexShader([](DrawTester &tester) {
 		const char *vertexShader = R"(#version 310 es
 			layout(location = 0) in vec3 inPos;
 			layout(location = 1) in vec3 inColor;
@@ -186,11 +186,10 @@
 				fragTexCoord = inPos.xy;
 			})";
 
-		return createShaderModule(vertexShader, EShLanguage::EShLangVertex);
-	}
+		return tester.createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	});
 
-	vk::ShaderModule doCreateFragmentShader() override
-	{
+	tester.onCreateFragmentShader([](DrawTester &tester) {
 		const char *fragmentShader = R"(#version 310 es
 			precision highp float;
 
@@ -206,11 +205,10 @@
 				outColor = texture(texSampler, fragTexCoord) * vec4(inColor, 1.0);
 			})";
 
-		return createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
-	}
+		return tester.createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	});
 
-	std::vector<vk::DescriptorSetLayoutBinding> doCreateDescriptorSetLayouts() override
-	{
+	tester.onCreateDescriptorSetLayouts([](DrawTester &tester) -> std::vector<vk::DescriptorSetLayoutBinding> {
 		vk::DescriptorSetLayoutBinding samplerLayoutBinding;
 		samplerLayoutBinding.binding = 1;
 		samplerLayoutBinding.descriptorCount = 1;
@@ -219,11 +217,13 @@
 		samplerLayoutBinding.stageFlags = vk::ShaderStageFlagBits::eFragment;
 
 		return { samplerLayoutBinding };
-	}
+	});
 
-	void doUpdateDescriptorSet(vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet) override
-	{
-		auto &texture = addImage(device, 16, 16, vk::Format::eR8G8B8A8Unorm).obj;
+	tester.onUpdateDescriptorSet([](DrawTester &tester, vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet) {
+		auto &device = tester.getDevice();
+		auto &queue = tester.getQueue();
+
+		auto &texture = tester.addImage(device, 16, 16, vk::Format::eR8G8B8A8Unorm).obj;
 
 		// Fill texture with white
 		vk::DeviceSize bufferSize = 16 * 16 * 4;
@@ -249,7 +249,7 @@
 		samplerInfo.minLod = 0.0f;
 		samplerInfo.maxLod = 0.0f;
 
-		auto sampler = addSampler(samplerInfo);
+		auto sampler = tester.addSampler(samplerInfo);
 
 		vk::DescriptorImageInfo imageInfo;
 		imageInfo.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
@@ -266,41 +266,9 @@
 		descriptorWrites[0].pImageInfo = &imageInfo;
 
 		device.updateDescriptorSets(static_cast<uint32_t>(descriptorWrites.size()), descriptorWrites.data(), 0, nullptr);
-	}
-};
+	});
 
-template<typename T>
-static void RunBenchmark(benchmark::State &state, T &benchmark)
-{
-	benchmark.initialize();
-
-	if(false) benchmark.show();  // Enable for visual verification.
-
-	// Warmup
-	benchmark.renderFrame();
-
-	for(auto _ : state)
-	{
-		benchmark.renderFrame();
-	}
-}
-
-static void TriangleSolidColor(benchmark::State &state, Multisample multisample)
-{
-	TriangleSolidColorBenchmark benchmark(multisample);
-	RunBenchmark(state, benchmark);
-}
-
-static void TriangleInterpolateColor(benchmark::State &state, Multisample multisample)
-{
-	TriangleInterpolateColorBenchmark benchmark(multisample);
-	RunBenchmark(state, benchmark);
-}
-
-static void TriangleSampleTexture(benchmark::State &state, Multisample multisample)
-{
-	TriangleSampleTextureBenchmark benchmark(multisample);
-	RunBenchmark(state, benchmark);
+	RunBenchmark(state, tester);
 }
 
 BENCHMARK_CAPTURE(TriangleSolidColor, TriangleSolidColor, Multisample::False)->Unit(benchmark::kMillisecond);
diff --git a/tests/VulkanBenchmarks/VulkanBenchmark.cpp b/tests/VulkanBenchmarks/VulkanBenchmark.cpp
deleted file mode 100644
index 41dda63..0000000
--- a/tests/VulkanBenchmarks/VulkanBenchmark.cpp
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//    http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "VulkanBenchmark.hpp"
-
-VulkanBenchmark::~VulkanBenchmark()
-{
-	device.waitIdle();
-	device.destroy(nullptr);
-	instance.destroy(nullptr);
-}
-
-void VulkanBenchmark::initialize()
-{
-	// TODO(b/158231104): Other platforms
-#if defined(_WIN32)
-	dl = std::make_unique<vk::DynamicLoader>("./vk_swiftshader.dll");
-#elif defined(__linux__)
-	dl = std::make_unique<vk::DynamicLoader>("./libvk_swiftshader.so");
-#else
-#	error Unimplemented platform
-#endif
-	assert(dl->success());
-
-	PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr = dl->getProcAddress<PFN_vkGetInstanceProcAddr>("vkGetInstanceProcAddr");
-	VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr);
-
-	instance = vk::createInstance({}, nullptr);
-	VULKAN_HPP_DEFAULT_DISPATCHER.init(instance);
-
-	std::vector<vk::PhysicalDevice> physicalDevices = instance.enumeratePhysicalDevices();
-	assert(!physicalDevices.empty());
-	physicalDevice = physicalDevices[0];
-
-	const float defaultQueuePriority = 0.0f;
-	vk::DeviceQueueCreateInfo queueCreatInfo;
-	queueCreatInfo.queueFamilyIndex = queueFamilyIndex;
-	queueCreatInfo.queueCount = 1;
-	queueCreatInfo.pQueuePriorities = &defaultQueuePriority;
-
-	vk::DeviceCreateInfo deviceCreateInfo;
-	deviceCreateInfo.queueCreateInfoCount = 1;
-	deviceCreateInfo.pQueueCreateInfos = &queueCreatInfo;
-
-	device = physicalDevice.createDevice(deviceCreateInfo, nullptr);
-
-	queue = device.getQueue(queueFamilyIndex, 0);
-}
diff --git a/tests/VulkanUnitTests/BUILD.gn b/tests/VulkanUnitTests/BUILD.gn
index f33f59e..1adc6f0 100644
--- a/tests/VulkanUnitTests/BUILD.gn
+++ b/tests/VulkanUnitTests/BUILD.gn
@@ -26,9 +26,12 @@
 

   sources = [

     "//gpu/swiftshader_tests_main.cc",

-    "Device.cpp",

-    "Driver.cpp",

-    "unittests.cpp",

+    "BasicTests.cpp"

+    "ComputeTests.cpp"

+    "Device.cpp"

+    "DrawTests.cpp"

+    "Driver.cpp"

+    "main.cpp"

   ]

 

   include_dirs = [

diff --git a/tests/VulkanUnitTests/BasicTests.cpp b/tests/VulkanUnitTests/BasicTests.cpp
new file mode 100644
index 0000000..b71348d
--- /dev/null
+++ b/tests/VulkanUnitTests/BasicTests.cpp
@@ -0,0 +1,184 @@
+// Copyright 2018 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Vulkan unit tests that provide coverage for functionality not tested by
+// the dEQP test suite. Also used as a smoke test.
+
+#include "Device.hpp"
+#include "Driver.hpp"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+class BasicTest : public testing::Test
+{
+protected:
+	static Driver driver;
+
+	void SetUp() override
+	{
+		ASSERT_TRUE(driver.loadSwiftShader());
+	}
+
+	void TearDown() override
+	{
+		driver.unload();
+	}
+};
+
+Driver BasicTest::driver;
+
+TEST_F(BasicTest, ICD_Check)
+{
+	auto createInstance = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkCreateInstance");
+	EXPECT_NE(createInstance, nullptr);
+
+	auto enumerateInstanceExtensionProperties =
+	    driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceExtensionProperties");
+	EXPECT_NE(enumerateInstanceExtensionProperties, nullptr);
+
+	auto enumerateInstanceLayerProperties =
+	    driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceLayerProperties");
+	EXPECT_NE(enumerateInstanceLayerProperties, nullptr);
+
+	auto enumerateInstanceVersion = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceVersion");
+	EXPECT_NE(enumerateInstanceVersion, nullptr);
+
+	auto bad_function = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "bad_function");
+	EXPECT_EQ(bad_function, nullptr);
+}
+
+TEST_F(BasicTest, Version)
+{
+	uint32_t apiVersion = 0;
+	VkResult result = driver.vkEnumerateInstanceVersion(&apiVersion);
+	EXPECT_EQ(apiVersion, (uint32_t)VK_API_VERSION_1_1);
+
+	const VkInstanceCreateInfo createInfo = {
+		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType
+		nullptr,                                 // pNext
+		0,                                       // flags
+		nullptr,                                 // pApplicationInfo
+		0,                                       // enabledLayerCount
+		nullptr,                                 // ppEnabledLayerNames
+		0,                                       // enabledExtensionCount
+		nullptr,                                 // ppEnabledExtensionNames
+	};
+	VkInstance instance = VK_NULL_HANDLE;
+	result = driver.vkCreateInstance(&createInfo, nullptr, &instance);
+	EXPECT_EQ(result, VK_SUCCESS);
+
+	ASSERT_TRUE(driver.resolve(instance));
+
+	uint32_t pPhysicalDeviceCount = 0;
+	result = driver.vkEnumeratePhysicalDevices(instance, &pPhysicalDeviceCount, nullptr);
+	EXPECT_EQ(result, VK_SUCCESS);
+	EXPECT_EQ(pPhysicalDeviceCount, 1U);
+
+	VkPhysicalDevice pPhysicalDevice = VK_NULL_HANDLE;
+	result = driver.vkEnumeratePhysicalDevices(instance, &pPhysicalDeviceCount, &pPhysicalDevice);
+	EXPECT_EQ(result, VK_SUCCESS);
+	EXPECT_NE(pPhysicalDevice, (VkPhysicalDevice)VK_NULL_HANDLE);
+
+	VkPhysicalDeviceProperties physicalDeviceProperties;
+	driver.vkGetPhysicalDeviceProperties(pPhysicalDevice, &physicalDeviceProperties);
+	EXPECT_EQ(physicalDeviceProperties.apiVersion, (uint32_t)VK_API_VERSION_1_1);
+	EXPECT_EQ(physicalDeviceProperties.deviceID, 0xC0DEU);
+	EXPECT_EQ(physicalDeviceProperties.deviceType, VK_PHYSICAL_DEVICE_TYPE_CPU);
+
+	EXPECT_NE(strstr(physicalDeviceProperties.deviceName, "SwiftShader Device"), nullptr);
+
+	VkPhysicalDeviceProperties2 physicalDeviceProperties2;
+	VkPhysicalDeviceDriverPropertiesKHR physicalDeviceDriverProperties;
+	physicalDeviceProperties2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
+	physicalDeviceProperties2.pNext = &physicalDeviceDriverProperties;
+	physicalDeviceDriverProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES_KHR;
+	physicalDeviceDriverProperties.pNext = nullptr;
+	physicalDeviceDriverProperties.driverID = (VkDriverIdKHR)0;
+	driver.vkGetPhysicalDeviceProperties2(pPhysicalDevice, &physicalDeviceProperties2);
+	EXPECT_EQ(physicalDeviceDriverProperties.driverID, VK_DRIVER_ID_GOOGLE_SWIFTSHADER_KHR);
+
+	driver.vkDestroyInstance(instance, nullptr);
+}
+/*
+TEST_F(BasicTest, UnsupportedDeviceExtension_DISABLED)
+{
+	uint32_t apiVersion = 0;
+	VkResult result = driver.vkEnumerateInstanceVersion(&apiVersion);
+	EXPECT_EQ(apiVersion, (uint32_t)VK_API_VERSION_1_1);
+
+	const VkInstanceCreateInfo createInfo = {
+		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType
+		nullptr,                                 // pNext
+		0,                                       // flags
+		nullptr,                                 // pApplicationInfo
+		0,                                       // enabledLayerCount
+		nullptr,                                 // ppEnabledLayerNames
+		0,                                       // enabledExtensionCount
+		nullptr,                                 // ppEnabledExtensionNames
+	};
+	VkInstance instance = VK_NULL_HANDLE;
+	result = driver.vkCreateInstance(&createInfo, nullptr, &instance);
+	EXPECT_EQ(result, VK_SUCCESS);
+
+	ASSERT_TRUE(driver.resolve(instance));
+
+	VkBaseInStructure unsupportedExt = { VK_STRUCTURE_TYPE_SHADER_MODULE_VALIDATION_CACHE_CREATE_INFO_EXT, nullptr };
+
+	// Gather all physical devices
+	std::vector<VkPhysicalDevice> physicalDevices;
+	result = Device::GetPhysicalDevices(&driver, instance, physicalDevices);
+	EXPECT_EQ(result, VK_SUCCESS);
+
+	// Inspect each physical device's queue families for compute support.
+	for(auto physicalDevice : physicalDevices)
+	{
+		int queueFamilyIndex = Device::GetComputeQueueFamilyIndex(&driver, physicalDevice);
+		if(queueFamilyIndex < 0)
+		{
+			continue;
+		}
+
+		const float queuePrioritory = 1.0f;
+		const VkDeviceQueueCreateInfo deviceQueueCreateInfo = {
+			VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO,  // sType
+			nullptr,                                     // pNext
+			0,                                           // flags
+			(uint32_t)queueFamilyIndex,                  // queueFamilyIndex
+			1,                                           // queueCount
+			&queuePrioritory,                            // pQueuePriorities
+		};
+
+		const VkDeviceCreateInfo deviceCreateInfo = {
+			VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,  // sType
+			&unsupportedExt,                       // pNext
+			0,                                     // flags
+			1,                                     // queueCreateInfoCount
+			&deviceQueueCreateInfo,                // pQueueCreateInfos
+			0,                                     // enabledLayerCount
+			nullptr,                               // ppEnabledLayerNames
+			0,                                     // enabledExtensionCount
+			nullptr,                               // ppEnabledExtensionNames
+			nullptr,                               // pEnabledFeatures
+		};
+
+		VkDevice device;
+		result = driver.vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device);
+		EXPECT_EQ(result, VK_SUCCESS);
+		driver.vkDestroyDevice(device, nullptr);
+	}
+
+	driver.vkDestroyInstance(instance, nullptr);
+}
+*/
diff --git a/tests/VulkanUnitTests/CMakeLists.txt b/tests/VulkanUnitTests/CMakeLists.txt
index 8640ccc..1f67ea5 100644
--- a/tests/VulkanUnitTests/CMakeLists.txt
+++ b/tests/VulkanUnitTests/CMakeLists.txt
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 set(ROOT_PROJECT_COMPILE_OPTIONS
-    ${SWIFTSHADER_COMPILE_OPTIONS}
     ${WARNINGS_AS_ERRORS}
 )
 
@@ -23,14 +22,16 @@
 )
 
 set(VULKAN_UNIT_TESTS_SRC_FILES
+    BasicTests.cpp
+    ComputeTests.cpp
+    Device.cpp
     Device.hpp
+    DrawTests.cpp
+    Driver.cpp
     Driver.hpp
+    main.cpp
     VkGlobalFuncs.hpp
     VkInstanceFuncs.hpp
-    Device.cpp
-    Driver.cpp
-    main.cpp
-    unittests.cpp
 )
 
 add_executable(vk-unittests
@@ -72,5 +73,6 @@
         gtest
         gmock
         SPIRV-Tools
+        VulkanWrapper
         ${ROOT_PROJECT_LINK_LIBRARIES}
 )
diff --git a/tests/VulkanUnitTests/ComputeTests.cpp b/tests/VulkanUnitTests/ComputeTests.cpp
new file mode 100644
index 0000000..9d94987
--- /dev/null
+++ b/tests/VulkanUnitTests/ComputeTests.cpp
@@ -0,0 +1,1481 @@
+// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "Device.hpp"
+#include "Driver.hpp"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "spirv-tools/libspirv.hpp"
+
+#include <cstring>
+#include <sstream>
+
+namespace {
+size_t alignUp(size_t val, size_t alignment)
+{
+	return alignment * ((val + alignment - 1) / alignment);
+}
+}  // anonymous namespace
+
+struct ComputeParams
+{
+	size_t numElements;
+	int localSizeX;
+	int localSizeY;
+	int localSizeZ;
+
+	friend std::ostream &operator<<(std::ostream &os, const ComputeParams &params)
+	{
+		return os << "ComputeParams{"
+		          << "numElements: " << params.numElements << ", "
+		          << "localSizeX: " << params.localSizeX << ", "
+		          << "localSizeY: " << params.localSizeY << ", "
+		          << "localSizeZ: " << params.localSizeZ << "}";
+	}
+};
+
+class ComputeTest : public testing::TestWithParam<ComputeParams>
+{
+protected:
+	static Driver driver;
+
+	static void SetUpTestSuite()
+	{
+		ASSERT_TRUE(driver.loadSwiftShader());
+	}
+
+	static void TearDownTestSuite()
+	{
+		driver.unload();
+	}
+};
+
+Driver ComputeTest::driver;
+
+std::vector<uint32_t> compileSpirv(const char *assembly)
+{
+	spvtools::SpirvTools core(SPV_ENV_VULKAN_1_0);
+
+	core.SetMessageConsumer([](spv_message_level_t, const char *, const spv_position_t &p, const char *m) {
+		FAIL() << p.line << ":" << p.column << ": " << m;
+	});
+
+	std::vector<uint32_t> spirv;
+	EXPECT_TRUE(core.Assemble(assembly, &spirv));
+	EXPECT_TRUE(core.Validate(spirv));
+
+	// Warn if the disassembly does not match the source assembly.
+	// We do this as debugging tests in the debugger is often made much harder
+	// if the SSA names (%X) in the debugger do not match the source.
+	std::string disassembled;
+	core.Disassemble(spirv, &disassembled, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+	if(disassembled != assembly)
+	{
+		printf("-- WARNING: Disassembly does not match assembly: ---\n\n");
+
+		auto splitLines = [](const std::string &str) -> std::vector<std::string> {
+			std::stringstream ss(str);
+			std::vector<std::string> out;
+			std::string line;
+			while(std::getline(ss, line, '\n')) { out.push_back(line); }
+			return out;
+		};
+
+		auto srcLines = splitLines(std::string(assembly));
+		auto disLines = splitLines(disassembled);
+
+		for(size_t line = 0; line < srcLines.size() && line < disLines.size(); line++)
+		{
+			auto srcLine = (line < srcLines.size()) ? srcLines[line] : "<missing>";
+			auto disLine = (line < disLines.size()) ? disLines[line] : "<missing>";
+			if(srcLine != disLine)
+			{
+				printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());
+			}
+		}
+		printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());
+	}
+
+	return spirv;
+}
+
+#define VK_ASSERT(x) ASSERT_EQ(x, VK_SUCCESS)
+
+// Base class for compute tests that read from an input buffer and write to an
+// output buffer of same length.
+class SwiftShaderVulkanBufferToBufferComputeTest : public ComputeTest
+{
+public:
+	void test(const std::string &shader,
+	          std::function<uint32_t(uint32_t idx)> input,
+	          std::function<uint32_t(uint32_t idx)> expected);
+};
+
+void SwiftShaderVulkanBufferToBufferComputeTest::test(
+    const std::string &shader,
+    std::function<uint32_t(uint32_t idx)> input,
+    std::function<uint32_t(uint32_t idx)> expected)
+{
+	auto code = compileSpirv(shader.c_str());
+
+	const VkInstanceCreateInfo createInfo = {
+		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType
+		nullptr,                                 // pNext
+		0,                                       // flags
+		nullptr,                                 // pApplicationInfo
+		0,                                       // enabledLayerCount
+		nullptr,                                 // ppEnabledLayerNames
+		0,                                       // enabledExtensionCount
+		nullptr,                                 // ppEnabledExtensionNames
+	};
+
+	VkInstance instance = VK_NULL_HANDLE;
+	VK_ASSERT(driver.vkCreateInstance(&createInfo, nullptr, &instance));
+
+	ASSERT_TRUE(driver.resolve(instance));
+
+	std::unique_ptr<Device> device;
+	VK_ASSERT(Device::CreateComputeDevice(&driver, instance, device));
+	ASSERT_TRUE(device->IsValid());
+
+	// struct Buffers
+	// {
+	//     uint32_t pad0[63];
+	//     uint32_t magic0;
+	//     uint32_t in[NUM_ELEMENTS]; // Aligned to 0x100
+	//     uint32_t magic1;
+	//     uint32_t pad1[N];
+	//     uint32_t magic2;
+	//     uint32_t out[NUM_ELEMENTS]; // Aligned to 0x100
+	//     uint32_t magic3;
+	// };
+	static constexpr uint32_t magic0 = 0x01234567;
+	static constexpr uint32_t magic1 = 0x89abcdef;
+	static constexpr uint32_t magic2 = 0xfedcba99;
+	static constexpr uint32_t magic3 = 0x87654321;
+	size_t numElements = GetParam().numElements;
+	size_t alignElements = 0x100 / sizeof(uint32_t);
+	size_t magic0Offset = alignElements - 1;
+	size_t inOffset = 1 + magic0Offset;
+	size_t magic1Offset = numElements + inOffset;
+	size_t magic2Offset = alignUp(magic1Offset + 1, alignElements) - 1;
+	size_t outOffset = 1 + magic2Offset;
+	size_t magic3Offset = numElements + outOffset;
+	size_t buffersTotalElements = alignUp(1 + magic3Offset, alignElements);
+	size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;
+
+	VkDeviceMemory memory;
+	VK_ASSERT(device->AllocateMemory(buffersSize,
+	                                 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
+	                                 &memory));
+
+	uint32_t *buffers;
+	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
+
+	buffers[magic0Offset] = magic0;
+	buffers[magic1Offset] = magic1;
+	buffers[magic2Offset] = magic2;
+	buffers[magic3Offset] = magic3;
+
+	for(size_t i = 0; i < numElements; i++)
+	{
+		buffers[inOffset + i] = input((uint32_t)i);
+	}
+
+	device->UnmapMemory(memory);
+	buffers = nullptr;
+
+	VkBuffer bufferIn;
+	VK_ASSERT(device->CreateStorageBuffer(memory,
+	                                      sizeof(uint32_t) * numElements,
+	                                      sizeof(uint32_t) * inOffset,
+	                                      &bufferIn));
+
+	VkBuffer bufferOut;
+	VK_ASSERT(device->CreateStorageBuffer(memory,
+	                                      sizeof(uint32_t) * numElements,
+	                                      sizeof(uint32_t) * outOffset,
+	                                      &bufferOut));
+
+	VkShaderModule shaderModule;
+	VK_ASSERT(device->CreateShaderModule(code, &shaderModule));
+
+	std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings = {
+		{
+		    0,                                  // binding
+		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
+		    1,                                  // descriptorCount
+		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
+		    0,                                  // pImmutableSamplers
+		},
+		{
+		    1,                                  // binding
+		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
+		    1,                                  // descriptorCount
+		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
+		    0,                                  // pImmutableSamplers
+		}
+	};
+
+	VkDescriptorSetLayout descriptorSetLayout;
+	VK_ASSERT(device->CreateDescriptorSetLayout(descriptorSetLayoutBindings, &descriptorSetLayout));
+
+	VkPipelineLayout pipelineLayout;
+	VK_ASSERT(device->CreatePipelineLayout(descriptorSetLayout, &pipelineLayout));
+
+	VkPipeline pipeline;
+	VK_ASSERT(device->CreateComputePipeline(shaderModule, pipelineLayout, &pipeline));
+
+	VkDescriptorPool descriptorPool;
+	VK_ASSERT(device->CreateStorageBufferDescriptorPool(2, &descriptorPool));
+
+	VkDescriptorSet descriptorSet;
+	VK_ASSERT(device->AllocateDescriptorSet(descriptorPool, descriptorSetLayout, &descriptorSet));
+
+	std::vector<VkDescriptorBufferInfo> descriptorBufferInfos = {
+		{
+		    bufferIn,       // buffer
+		    0,              // offset
+		    VK_WHOLE_SIZE,  // range
+		},
+		{
+		    bufferOut,      // buffer
+		    0,              // offset
+		    VK_WHOLE_SIZE,  // range
+		}
+	};
+	device->UpdateStorageBufferDescriptorSets(descriptorSet, descriptorBufferInfos);
+
+	VkCommandPool commandPool;
+	VK_ASSERT(device->CreateCommandPool(&commandPool));
+
+	VkCommandBuffer commandBuffer;
+	VK_ASSERT(device->AllocateCommandBuffer(commandPool, &commandBuffer));
+
+	VK_ASSERT(device->BeginCommandBuffer(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, commandBuffer));
+
+	driver.vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
+
+	driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,
+	                               0, nullptr);
+
+	driver.vkCmdDispatch(commandBuffer, (uint32_t)(numElements / GetParam().localSizeX), 1, 1);
+
+	VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));
+
+	VK_ASSERT(device->QueueSubmitAndWait(commandBuffer));
+
+	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
+
+	for(size_t i = 0; i < numElements; ++i)
+	{
+		auto got = buffers[i + outOffset];
+		EXPECT_EQ(expected((uint32_t)i), got) << "Unexpected output at " << i;
+	}
+
+	// Check for writes outside of bounds.
+	EXPECT_EQ(buffers[magic0Offset], magic0);
+	EXPECT_EQ(buffers[magic1Offset], magic1);
+	EXPECT_EQ(buffers[magic2Offset], magic2);
+	EXPECT_EQ(buffers[magic3Offset], magic3);
+
+	device->UnmapMemory(memory);
+	buffers = nullptr;
+
+	device->FreeCommandBuffer(commandPool, commandBuffer);
+	device->FreeMemory(memory);
+	device->DestroyPipeline(pipeline);
+	device->DestroyCommandPool(commandPool);
+	device->DestroyPipelineLayout(pipelineLayout);
+	device->DestroyDescriptorSetLayout(descriptorSetLayout);
+	device->DestroyDescriptorPool(descriptorPool);
+	device->DestroyBuffer(bufferIn);
+	device->DestroyBuffer(bufferOut);
+	device->DestroyShaderModule(shaderModule);
+	device.reset(nullptr);
+	driver.vkDestroyInstance(instance, nullptr);
+}
+
+INSTANTIATE_TEST_SUITE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(ComputeParams{ 512, 1, 1, 1 }, ComputeParams{ 512, 2, 1, 1 }, ComputeParams{ 512, 4, 1, 1 }, ComputeParams{ 512, 8, 1, 1 }, ComputeParams{ 512, 16, 1, 1 }, ComputeParams{ 512, 32, 1, 1 },
+
+                                                                                                    // Non-multiple of SIMD-lane.
+                                                                                                    ComputeParams{ 3, 1, 1, 1 }, ComputeParams{ 2, 1, 1, 1 }));
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)
+{
+	std::stringstream src;
+	// #version 450
+	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+	// layout(binding = 0, std430) buffer InBuffer
+	// {
+	//     int Data[];
+	// } In;
+	// layout(binding = 1, std430) buffer OutBuffer
+	// {
+	//     int Data[];
+	// } Out;
+	// void main()
+	// {
+	//     Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];
+	// }
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
+        "%12 = OpConstant %9 0\n"               // int32(0)
+        "%13 = OpConstant %10 0\n"              // uint32(0)
+        "%14 = OpTypeVector %10 3\n"            // vec3<int32>
+        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
+        "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
+        "%16 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
+        "%17 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%18 = OpLabel\n"
+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
+        "OpStore %23 %22\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, GlobalInvocationId)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
+        "%12 = OpConstant %9 0\n"               // int32(0)
+        "%13 = OpConstant %9 1\n"               // int32(1)
+        "%14 = OpConstant %10 0\n"              // uint32(0)
+        "%15 = OpConstant %10 1\n"              // uint32(1)
+        "%16 = OpConstant %10 2\n"              // uint32(2)
+        "%17 = OpTypeVector %10 3\n"            // vec3<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec3<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %14\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpAccessChain %19 %2 %15\n"      // &gl_GlobalInvocationId.y
+        "%24 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.z
+        "%25 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%26 = OpLoad %10 %23\n"                // gl_GlobalInvocationId.y
+        "%27 = OpLoad %10 %24\n"                // gl_GlobalInvocationId.z
+        "%28 = OpAccessChain %20 %6 %12 %25\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%29 = OpLoad %9 %28\n"                 // out.arr[gl_GlobalInvocationId.x]
+        "%30 = OpIAdd %9 %29 %26\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y
+        "%31 = OpIAdd %9 %30 %27\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
+        "%32 = OpAccessChain %20 %5 %12 %25\n"  // &out.arr[gl_GlobalInvocationId.x]
+        "OpStore %32 %31\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	// gl_GlobalInvocationId.y and gl_GlobalInvocationId.z should both be zero.
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
+        "%12 = OpConstant %9 0\n"               // int32(0)
+        "%13 = OpConstant %10 0\n"              // uint32(0)
+        "%14 = OpTypeVector %10 3\n"            // vec3<int32>
+        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
+        "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
+        "%16 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
+        "%17 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%18 = OpLabel\n"
+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %22 = in value
+        "OpBranch %24\n"
+        "%24 = OpLabel\n"
+        "OpBranch %25\n"
+        "%25 = OpLabel\n"
+        "OpBranch %26\n"
+        "%26 = OpLabel\n"
+        // %22 = out value
+        // End of branch logic
+        "OpStore %23 %22\n"
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
+        "%12 = OpConstant %9 0\n"               // int32(0)
+        "%13 = OpConstant %10 0\n"              // uint32(0)
+        "%14 = OpTypeVector %10 3\n"            // vec3<int32>
+        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
+        "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
+        "%16 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
+        "%17 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%18 = OpLabel\n"
+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %22 = in value
+        "OpBranch %24\n"
+        "%24 = OpLabel\n"
+        "%25 = OpIAdd %9 %22 %22\n"             // %25 = in*2
+        "OpBranch %26\n"
+        "%26 = OpLabel\n"
+        "OpBranch %27\n"
+        "%27 = OpLabel\n"
+        // %25 = out value
+        // End of branch logic
+        "OpStore %23 %25\n"               // use SSA value from previous block
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalSimple)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 2\n"               // int32(2)
+        "%15 = OpConstant %10 0\n"              // uint32(0)
+        "%16 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
+        "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
+        "%18 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%19 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%20 = OpLabel\n"
+        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
+        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
+        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %24 = in value
+        "%26 = OpSMod %9 %24 %14\n"             // in % 2
+        "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
+        "OpSelectionMerge %28 None\n"
+        "OpBranchConditional %27 %28 %28\n" // Both go to %28
+        "%28 = OpLabel\n"
+        // %26 = out value
+        // End of branch logic
+        "OpStore %25 %26\n"               // use SSA value from previous block
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalTwoEmptyBlocks)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 2\n"               // int32(2)
+        "%15 = OpConstant %10 0\n"              // uint32(0)
+        "%16 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
+        "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
+        "%18 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%19 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%20 = OpLabel\n"
+        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
+        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
+        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %24 = in value
+        "%26 = OpSMod %9 %24 %14\n"             // in % 2
+        "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
+        "OpSelectionMerge %28 None\n"
+        "OpBranchConditional %27 %29 %30\n"
+        "%29 = OpLabel\n"                       // (in % 2) == 0
+        "OpBranch %28\n"
+        "%30 = OpLabel\n"                       // (in % 2) != 0
+        "OpBranch %28\n"
+        "%28 = OpLabel\n"
+        // %26 = out value
+        // End of branch logic
+        "OpStore %25 %26\n"               // use SSA value from previous block
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
+}
+
+// TODO: Test for parallel assignment
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalStore)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
+        "OpSelectionMerge %29 None\n"
+        "OpBranchConditional %28 %30 %31\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 0
+        "OpStore %26 %14\n"               // write 1
+        "OpBranch %29\n"
+        "%31 = OpLabel\n"                       // (in % 2) != 0
+        "OpStore %26 %15\n"               // write 2
+        "OpBranch %29\n"
+        "%29 = OpLabel\n"
+        // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalReturnTrue)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
+        "OpSelectionMerge %29 None\n"
+        "OpBranchConditional %28 %30 %29\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 0
+        "OpReturn\n"
+        "%29 = OpLabel\n"                       // merge
+        "OpStore %26 %15\n"               // write 2
+                                          // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 0 : 2; });
+}
+
+// TODO: Test for parallel assignment
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalPhi)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
+        "OpSelectionMerge %29 None\n"
+        "OpBranchConditional %28 %30 %31\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 0
+        "OpBranch %29\n"
+        "%31 = OpLabel\n"                       // (in % 2) != 0
+        "OpBranch %29\n"
+        "%29 = OpLabel\n"
+        "%32 = OpPhi %9 %14 %30 %15 %31\n"      // (in % 2) == 0 ? 1 : 2
+                                                // End of branch logic
+        "OpStore %26 %32\n"
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 2\n"               // int32(2)
+        "%15 = OpConstant %10 0\n"              // uint32(0)
+        "%16 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
+        "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
+        "%18 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%19 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%20 = OpLabel\n"
+        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
+        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
+        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %24 = in value
+        "%26 = OpSMod %9 %24 %14\n"             // in % 2
+        "OpSelectionMerge %27 None\n"
+        "OpSwitch %26 %27 0 %28 1 %29\n"
+        "%28 = OpLabel\n"                       // (in % 2) == 0
+        "OpBranch %27\n"
+        "%29 = OpLabel\n"                       // (in % 2) == 1
+        "OpBranch %27\n"
+        "%27 = OpLabel\n"
+        // %26 = out value
+        // End of branch logic
+        "OpStore %25 %26\n"               // use SSA value from previous block
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %28 0 %29 1 %30\n"
+        "%29 = OpLabel\n"                       // (in % 2) == 0
+        "OpStore %26 %15\n"               // write 2
+        "OpBranch %28\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 1
+        "OpStore %26 %14\n"               // write 1
+        "OpBranch %28\n"
+        "%28 = OpLabel\n"
+        // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %28 0 %29 1 %30\n"
+        "%29 = OpLabel\n"                       // (in % 2) == 0
+        "OpBranch %28\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 1
+        "OpReturn\n"
+        "%28 = OpLabel\n"
+        "OpStore %26 %14\n"               // write 1
+                                          // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %29 1 %30\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 1
+        "OpBranch %28\n"
+        "%29 = OpLabel\n"                       // (in % 2) != 1
+        "OpReturn\n"
+        "%28 = OpLabel\n"                       // merge
+        "OpStore %26 %14\n"               // write 1
+                                          // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %29 0 %30 1 %31\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 0
+        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
+        "OpStore %26 %32\n"               // write a value (overwritten later)
+        "OpBranch %31\n"                  // fallthrough
+        "%31 = OpLabel\n"                       // (in % 2) == 1
+        "OpStore %26 %15\n"               // write 2
+        "OpBranch %28\n"
+        "%29 = OpLabel\n"                       // unreachable
+        "OpUnreachable\n"
+        "%28 = OpLabel\n"                       // merge
+                                                // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %29 0 %30 1 %31\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 0
+        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
+        "OpStore %26 %32\n"               // write a value (overwritten later)
+        "OpBranch %29\n"                  // fallthrough
+        "%29 = OpLabel\n"                       // default
+        "%33 = OpIAdd %9 %27 %14\n"             // generate an intermediate
+        "OpStore %26 %33\n"               // write a value (overwritten later)
+        "OpBranch %31\n"                  // fallthrough
+        "%31 = OpLabel\n"                       // (in % 2) == 1
+        "OpStore %26 %15\n"               // write 2
+        "OpBranch %28\n"
+        "%28 = OpLabel\n"                       // merge
+                                                // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)
+{
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %1 \"main\" %2\n"
+        "OpExecutionMode %1 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 ArrayStride 4\n"
+        "OpMemberDecorate %4 0 Offset 0\n"
+        "OpDecorate %4 BufferBlock\n"
+        "OpDecorate %5 DescriptorSet 0\n"
+        "OpDecorate %5 Binding 1\n"
+        "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "%7 = OpTypeVoid\n"
+        "%8 = OpTypeFunction %7\n"             // void()
+        "%9 = OpTypeInt 32 1\n"                // int32
+        "%10 = OpTypeInt 32 0\n"                // uint32
+        "%11 = OpTypeBool\n"
+        "%3 = OpTypeRuntimeArray %9\n"         // int32[]
+        "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
+        "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
+        "%13 = OpConstant %9 0\n"               // int32(0)
+        "%14 = OpConstant %9 1\n"               // int32(1)
+        "%15 = OpConstant %9 2\n"               // int32(2)
+        "%16 = OpConstant %10 0\n"              // uint32(0)
+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>
+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
+        "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
+        "%19 = OpTypePointer Input %10\n"       // uint32*
+        "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
+        "%20 = OpTypePointer Uniform %9\n"      // int32*
+        "%1 = OpFunction %7 None %8\n"         // -- Function begin --
+        "%21 = OpLabel\n"
+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
+                                                // Start of branch logic
+                                                // %25 = in value
+        "%27 = OpSMod %9 %25 %15\n"             // in % 2
+        "OpSelectionMerge %28 None\n"
+        "OpSwitch %27 %29 1 %30\n"
+        "%30 = OpLabel\n"                       // (in % 2) == 1
+        "OpBranch %28\n"
+        "%29 = OpLabel\n"                       // (in % 2) != 1
+        "OpBranch %28\n"
+        "%28 = OpLabel\n"                       // merge
+        "%31 = OpPhi %9 %14 %30 %15 %29\n"      // (in % 2) == 1 ? 1 : 2
+        "OpStore %26 %31\n"
+        // End of branch logic
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)
+{
+	// #version 450
+	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+	// layout(binding = 0, std430) buffer InBuffer
+	// {
+	//     int Data[];
+	// } In;
+	// layout(binding = 1, std430) buffer OutBuffer
+	// {
+	//     int Data[];
+	// } Out;
+	// void main()
+	// {
+	//     int phi = 0;
+	//     uint lane = gl_GlobalInvocationID.x % 4;
+	//     for (uint i = 0; i < 4; i++)
+	//     {
+	//         if (lane == i)
+	//         {
+	//             phi = In.Data[gl_GlobalInvocationID.x];
+	//             break;
+	//         }
+	//     }
+	//     Out.Data[gl_GlobalInvocationID.x] = phi;
+	// }
+	std::stringstream src;
+	// clang-format off
+    src <<
+        "OpCapability Shader\n"
+        "%1 = OpExtInstImport \"GLSL.std.450\"\n"
+        "OpMemoryModel Logical GLSL450\n"
+        "OpEntryPoint GLCompute %2 \"main\" %3\n"
+        "OpExecutionMode %2 LocalSize " <<
+        GetParam().localSizeX << " " <<
+        GetParam().localSizeY << " " <<
+        GetParam().localSizeZ << "\n" <<
+        "OpDecorate %3 BuiltIn GlobalInvocationId\n"
+        "OpDecorate %4 ArrayStride 4\n"
+        "OpMemberDecorate %5 0 Offset 0\n"
+        "OpDecorate %5 BufferBlock\n"
+        "OpDecorate %6 DescriptorSet 0\n"
+        "OpDecorate %6 Binding 0\n"
+        "OpDecorate %7 ArrayStride 4\n"
+        "OpMemberDecorate %8 0 Offset 0\n"
+        "OpDecorate %8 BufferBlock\n"
+        "OpDecorate %9 DescriptorSet 0\n"
+        "OpDecorate %9 Binding 1\n"
+        "%10 = OpTypeVoid\n"
+        "%11 = OpTypeFunction %10\n"
+        "%12 = OpTypeInt 32 1\n"
+        "%13 = OpConstant %12 0\n"
+        "%14 = OpTypeInt 32 0\n"
+        "%15 = OpTypeVector %14 3\n"
+        "%16 = OpTypePointer Input %15\n"
+        "%3 = OpVariable %16 Input\n"
+        "%17 = OpConstant %14 0\n"
+        "%18 = OpTypePointer Input %14\n"
+        "%19 = OpConstant %14 4\n"
+        "%20 = OpTypeBool\n"
+        "%4 = OpTypeRuntimeArray %12\n"
+        "%5 = OpTypeStruct %4\n"
+        "%21 = OpTypePointer Uniform %5\n"
+        "%6 = OpVariable %21 Uniform\n"
+        "%22 = OpTypePointer Uniform %12\n"
+        "%23 = OpConstant %12 1\n"
+        "%7 = OpTypeRuntimeArray %12\n"
+        "%8 = OpTypeStruct %7\n"
+        "%24 = OpTypePointer Uniform %8\n"
+        "%9 = OpVariable %24 Uniform\n"
+        "%2 = OpFunction %10 None %11\n"
+        "%25 = OpLabel\n"
+        "%26 = OpAccessChain %18 %3 %17\n"
+        "%27 = OpLoad %14 %26\n"
+        "%28 = OpUMod %14 %27 %19\n"
+        "OpBranch %29\n"
+        "%29 = OpLabel\n"
+        "%30 = OpPhi %14 %17 %25 %31 %32\n"
+        "%33 = OpULessThan %20 %30 %19\n"
+        "OpLoopMerge %34 %32 None\n"
+        "OpBranchConditional %33 %35 %34\n"
+        "%35 = OpLabel\n"
+        "%36 = OpIEqual %20 %28 %30\n"
+        "OpSelectionMerge %37 None\n"
+        "OpBranchConditional %36 %38 %37\n"
+        "%38 = OpLabel\n"
+        "%39 = OpAccessChain %22 %6 %13 %27\n"
+        "%40 = OpLoad %12 %39\n"
+        "OpBranch %34\n"
+        "%37 = OpLabel\n"
+        "OpBranch %32\n"
+        "%32 = OpLabel\n"
+        "%31 = OpIAdd %14 %30 %23\n"
+        "OpBranch %29\n"
+        "%34 = OpLabel\n"
+        "%41 = OpPhi %12 %13 %29 %40 %38\n" // %40: phi
+        "%42 = OpAccessChain %22 %9 %13 %27\n"
+        "OpStore %42 %41\n"
+        "OpReturn\n"
+        "OpFunctionEnd\n";
+	// clang-format on
+
+	test(
+	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
+}
diff --git a/tests/VulkanUnitTests/Device.cpp b/tests/VulkanUnitTests/Device.cpp
index 7e892e0..462492f 100644
--- a/tests/VulkanUnitTests/Device.cpp
+++ b/tests/VulkanUnitTests/Device.cpp
@@ -333,7 +333,7 @@
 		});
 	}
 
-	driver->vkUpdateDescriptorSets(device, writes.size(), writes.data(), 0, nullptr);
+	driver->vkUpdateDescriptorSets(device, (uint32_t)writes.size(), writes.data(), 0, nullptr);
 }
 
 VkResult Device::AllocateMemory(size_t size, VkMemoryPropertyFlags flags, VkDeviceMemory *out) const
diff --git a/tests/VulkanUnitTests/DrawTests.cpp b/tests/VulkanUnitTests/DrawTests.cpp
new file mode 100644
index 0000000..2167c80
--- /dev/null
+++ b/tests/VulkanUnitTests/DrawTests.cpp
@@ -0,0 +1,76 @@
+// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "DrawTester.hpp"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+class DrawTest : public testing::Test
+{
+};
+
+// Test that a vertex shader with no gl_Position works.
+// This was fixed in swiftshader-cl/51808
+TEST_F(DrawTest, VertexShaderNoPositionOutput)
+{
+	DrawTester tester;
+	tester.onCreateVertexBuffers([](DrawTester &tester) {
+		struct Vertex
+		{
+			float position[3];
+		};
+
+		Vertex vertexBufferData[] = {
+			{ { 1.0f, 1.0f, 0.5f } },
+			{ { -1.0f, 1.0f, 0.5f } },
+			{ { 0.0f, -1.0f, 0.5f } }
+		};
+
+		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
+		inputAttributes.push_back(vk::VertexInputAttributeDescription(0, 0, vk::Format::eR32G32B32Sfloat, offsetof(Vertex, position)));
+
+		tester.addVertexBuffer(vertexBufferData, sizeof(vertexBufferData), std::move(inputAttributes));
+	});
+
+	tester.onCreateVertexShader([](DrawTester &tester) {
+		const char *vertexShader = R"(#version 310 es
+			layout(location = 0) in vec3 inPos;
+
+			void main()
+			{
+				// Remove gl_Position on purpose for the test
+				//gl_Position = vec4(inPos.xyz, 1.0);
+			})";
+
+		return tester.createShaderModule(vertexShader, EShLanguage::EShLangVertex);
+	});
+
+	tester.onCreateFragmentShader([](DrawTester &tester) {
+		const char *fragmentShader = R"(#version 310 es
+			precision highp float;
+
+			layout(location = 0) out vec4 outColor;
+
+			void main()
+			{
+				outColor = vec4(1.0, 1.0, 1.0, 1.0);
+			})";
+
+		return tester.createShaderModule(fragmentShader, EShLanguage::EShLangFragment);
+	});
+
+	tester.initialize();
+	tester.renderFrame();
+}
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
deleted file mode 100644
index 8470d4d..0000000
--- a/tests/VulkanUnitTests/unittests.cpp
+++ /dev/null
@@ -1,1661 +0,0 @@
-// Copyright 2018 The SwiftShader Authors. All Rights Reserved.

-//

-// Licensed under the Apache License, Version 2.0 (the "License");

-// you may not use this file except in compliance with the License.

-// You may obtain a copy of the License at

-//

-//    http://www.apache.org/licenses/LICENSE-2.0

-//

-// Unless required by applicable law or agreed to in writing, software

-// distributed under the License is distributed on an "AS IS" BASIS,

-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

-// See the License for the specific language governing permissions and

-// limitations under the License.

-

-// Vulkan unit tests that provide coverage for functionality not tested by

-// the dEQP test suite. Also used as a smoke test.

-

-#include "Device.hpp"

-#include "Driver.hpp"

-

-#include "gmock/gmock.h"

-#include "gtest/gtest.h"

-

-#include "spirv-tools/libspirv.hpp"

-

-#include <cstring>

-#include <sstream>

-

-namespace {

-size_t alignUp(size_t val, size_t alignment)

-{

-	return alignment * ((val + alignment - 1) / alignment);

-}

-}  // anonymous namespace

-

-enum class LoadDriver

-{

-	PerSuite,

-	PerTest

-};

-

-template<typename TestBase, LoadDriver loadDriver>

-class SwiftShaderTest : public TestBase

-{

-protected:

-	static Driver driver;

-

-	static void SetUpTestSuite()

-	{

-		if(loadDriver == LoadDriver::PerSuite)

-		{

-			ASSERT_TRUE(driver.loadSwiftShader());

-		}

-	}

-

-	static void TearDownTestSuite()

-	{

-		if(loadDriver == LoadDriver::PerSuite)

-		{

-			driver.unload();

-		}

-	}

-

-	virtual void SetUp()

-	{

-		if(loadDriver == LoadDriver::PerTest)

-		{

-			ASSERT_TRUE(driver.loadSwiftShader());

-		}

-	}

-

-	virtual void TearDown()

-	{

-		if(loadDriver == LoadDriver::PerTest)

-		{

-			driver.unload();

-		}

-	}

-};

-

-template<typename TestBase, LoadDriver loadDriver>

-Driver SwiftShaderTest<TestBase, loadDriver>::driver;

-

-class SwiftShaderVulkanTest : public SwiftShaderTest<testing::Test, LoadDriver::PerTest>

-{

-};

-

-TEST_F(SwiftShaderVulkanTest, ICD_Check)

-{

-	auto createInstance = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkCreateInstance");

-	EXPECT_NE(createInstance, nullptr);

-

-	auto enumerateInstanceExtensionProperties =

-	    driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceExtensionProperties");

-	EXPECT_NE(enumerateInstanceExtensionProperties, nullptr);

-

-	auto enumerateInstanceLayerProperties =

-	    driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceLayerProperties");

-	EXPECT_NE(enumerateInstanceLayerProperties, nullptr);

-

-	auto enumerateInstanceVersion = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "vkEnumerateInstanceVersion");

-	EXPECT_NE(enumerateInstanceVersion, nullptr);

-

-	auto bad_function = driver.vk_icdGetInstanceProcAddr(VK_NULL_HANDLE, "bad_function");

-	EXPECT_EQ(bad_function, nullptr);

-}

-

-TEST_F(SwiftShaderVulkanTest, Version)

-{

-	uint32_t apiVersion = 0;

-	VkResult result = driver.vkEnumerateInstanceVersion(&apiVersion);

-	EXPECT_EQ(apiVersion, (uint32_t)VK_API_VERSION_1_1);

-

-	const VkInstanceCreateInfo createInfo = {

-		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType

-		nullptr,                                 // pNext

-		0,                                       // flags

-		nullptr,                                 // pApplicationInfo

-		0,                                       // enabledLayerCount

-		nullptr,                                 // ppEnabledLayerNames

-		0,                                       // enabledExtensionCount

-		nullptr,                                 // ppEnabledExtensionNames

-	};

-	VkInstance instance = VK_NULL_HANDLE;

-	result = driver.vkCreateInstance(&createInfo, nullptr, &instance);

-	EXPECT_EQ(result, VK_SUCCESS);

-

-	ASSERT_TRUE(driver.resolve(instance));

-

-	uint32_t pPhysicalDeviceCount = 0;

-	result = driver.vkEnumeratePhysicalDevices(instance, &pPhysicalDeviceCount, nullptr);

-	EXPECT_EQ(result, VK_SUCCESS);

-	EXPECT_EQ(pPhysicalDeviceCount, 1U);

-

-	VkPhysicalDevice pPhysicalDevice = VK_NULL_HANDLE;

-	result = driver.vkEnumeratePhysicalDevices(instance, &pPhysicalDeviceCount, &pPhysicalDevice);

-	EXPECT_EQ(result, VK_SUCCESS);

-	EXPECT_NE(pPhysicalDevice, (VkPhysicalDevice)VK_NULL_HANDLE);

-

-	VkPhysicalDeviceProperties physicalDeviceProperties;

-	driver.vkGetPhysicalDeviceProperties(pPhysicalDevice, &physicalDeviceProperties);

-	EXPECT_EQ(physicalDeviceProperties.apiVersion, (uint32_t)VK_API_VERSION_1_1);

-	EXPECT_EQ(physicalDeviceProperties.deviceID, 0xC0DEU);

-	EXPECT_EQ(physicalDeviceProperties.deviceType, VK_PHYSICAL_DEVICE_TYPE_CPU);

-

-	EXPECT_NE(strstr(physicalDeviceProperties.deviceName, "SwiftShader Device"), nullptr);

-

-	VkPhysicalDeviceProperties2 physicalDeviceProperties2;

-	VkPhysicalDeviceDriverPropertiesKHR physicalDeviceDriverProperties;

-	physicalDeviceProperties2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;

-	physicalDeviceProperties2.pNext = &physicalDeviceDriverProperties;

-	physicalDeviceDriverProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES_KHR;

-	physicalDeviceDriverProperties.pNext = nullptr;

-	physicalDeviceDriverProperties.driverID = (VkDriverIdKHR)0;

-	driver.vkGetPhysicalDeviceProperties2(pPhysicalDevice, &physicalDeviceProperties2);

-	EXPECT_EQ(physicalDeviceDriverProperties.driverID, VK_DRIVER_ID_GOOGLE_SWIFTSHADER_KHR);

-

-	driver.vkDestroyInstance(instance, nullptr);

-}

-/*

-TEST_F(SwiftShaderVulkanTest, UnsupportedDeviceExtension_DISABLED)

-{

-	uint32_t apiVersion = 0;

-	VkResult result = driver.vkEnumerateInstanceVersion(&apiVersion);

-	EXPECT_EQ(apiVersion, (uint32_t)VK_API_VERSION_1_1);

-

-	const VkInstanceCreateInfo createInfo = {

-		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType

-		nullptr,                                 // pNext

-		0,                                       // flags

-		nullptr,                                 // pApplicationInfo

-		0,                                       // enabledLayerCount

-		nullptr,                                 // ppEnabledLayerNames

-		0,                                       // enabledExtensionCount

-		nullptr,                                 // ppEnabledExtensionNames

-	};

-	VkInstance instance = VK_NULL_HANDLE;

-	result = driver.vkCreateInstance(&createInfo, nullptr, &instance);

-	EXPECT_EQ(result, VK_SUCCESS);

-

-	ASSERT_TRUE(driver.resolve(instance));

-

-	VkBaseInStructure unsupportedExt = { VK_STRUCTURE_TYPE_SHADER_MODULE_VALIDATION_CACHE_CREATE_INFO_EXT, nullptr };

-

-	// Gather all physical devices

-	std::vector<VkPhysicalDevice> physicalDevices;

-	result = Device::GetPhysicalDevices(&driver, instance, physicalDevices);

-	EXPECT_EQ(result, VK_SUCCESS);

-

-	// Inspect each physical device's queue families for compute support.

-	for(auto physicalDevice : physicalDevices)

-	{

-		int queueFamilyIndex = Device::GetComputeQueueFamilyIndex(&driver, physicalDevice);

-		if(queueFamilyIndex < 0)

-		{

-			continue;

-		}

-

-		const float queuePrioritory = 1.0f;

-		const VkDeviceQueueCreateInfo deviceQueueCreateInfo = {

-			VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO,  // sType

-			nullptr,                                     // pNext

-			0,                                           // flags

-			(uint32_t)queueFamilyIndex,                  // queueFamilyIndex

-			1,                                           // queueCount

-			&queuePrioritory,                            // pQueuePriorities

-		};

-

-		const VkDeviceCreateInfo deviceCreateInfo = {

-			VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,  // sType

-			&unsupportedExt,                       // pNext

-			0,                                     // flags

-			1,                                     // queueCreateInfoCount

-			&deviceQueueCreateInfo,                // pQueueCreateInfos

-			0,                                     // enabledLayerCount

-			nullptr,                               // ppEnabledLayerNames

-			0,                                     // enabledExtensionCount

-			nullptr,                               // ppEnabledExtensionNames

-			nullptr,                               // pEnabledFeatures

-		};

-

-		VkDevice device;

-		result = driver.vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device);

-		EXPECT_EQ(result, VK_SUCCESS);

-		driver.vkDestroyDevice(device, nullptr);

-	}

-

-	driver.vkDestroyInstance(instance, nullptr); 

-}

-*/

-std::vector<uint32_t> compileSpirv(const char *assembly)

-{

-	spvtools::SpirvTools core(SPV_ENV_VULKAN_1_0);

-

-	core.SetMessageConsumer([](spv_message_level_t, const char *, const spv_position_t &p, const char *m) {

-		FAIL() << p.line << ":" << p.column << ": " << m;

-	});

-

-	std::vector<uint32_t> spirv;

-	EXPECT_TRUE(core.Assemble(assembly, &spirv));

-	EXPECT_TRUE(core.Validate(spirv));

-

-	// Warn if the disassembly does not match the source assembly.

-	// We do this as debugging tests in the debugger is often made much harder

-	// if the SSA names (%X) in the debugger do not match the source.

-	std::string disassembled;

-	core.Disassemble(spirv, &disassembled, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);

-	if(disassembled != assembly)

-	{

-		printf("-- WARNING: Disassembly does not match assembly: ---\n\n");

-

-		auto splitLines = [](const std::string &str) -> std::vector<std::string> {

-			std::stringstream ss(str);

-			std::vector<std::string> out;

-			std::string line;

-			while(std::getline(ss, line, '\n')) { out.push_back(line); }

-			return out;

-		};

-

-		auto srcLines = splitLines(std::string(assembly));

-		auto disLines = splitLines(disassembled);

-

-		for(size_t line = 0; line < srcLines.size() && line < disLines.size(); line++)

-		{

-			auto srcLine = (line < srcLines.size()) ? srcLines[line] : "<missing>";

-			auto disLine = (line < disLines.size()) ? disLines[line] : "<missing>";

-			if(srcLine != disLine)

-			{

-				printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());

-			}

-		}

-		printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());

-	}

-

-	return spirv;

-}

-

-#define VK_ASSERT(x) ASSERT_EQ(x, VK_SUCCESS)

-

-struct ComputeParams

-{

-	size_t numElements;

-	int localSizeX;

-	int localSizeY;

-	int localSizeZ;

-

-	friend std::ostream &operator<<(std::ostream &os, const ComputeParams &params)

-	{

-		return os << "ComputeParams{"

-		          << "numElements: " << params.numElements << ", "

-		          << "localSizeX: " << params.localSizeX << ", "

-		          << "localSizeY: " << params.localSizeY << ", "

-		          << "localSizeZ: " << params.localSizeZ << "}";

-	}

-};

-

-// Base class for compute tests that read from an input buffer and write to an

-// output buffer of same length.

-class SwiftShaderVulkanBufferToBufferComputeTest : public SwiftShaderTest<testing::TestWithParam<ComputeParams>, LoadDriver::PerSuite>

-{

-public:

-	void test(const std::string &shader,

-	          std::function<uint32_t(uint32_t idx)> input,

-	          std::function<uint32_t(uint32_t idx)> expected);

-};

-

-void SwiftShaderVulkanBufferToBufferComputeTest::test(

-    const std::string &shader,

-    std::function<uint32_t(uint32_t idx)> input,

-    std::function<uint32_t(uint32_t idx)> expected)

-{

-	auto code = compileSpirv(shader.c_str());

-

-	const VkInstanceCreateInfo createInfo = {

-		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType

-		nullptr,                                 // pNext

-		0,                                       // flags

-		nullptr,                                 // pApplicationInfo

-		0,                                       // enabledLayerCount

-		nullptr,                                 // ppEnabledLayerNames

-		0,                                       // enabledExtensionCount

-		nullptr,                                 // ppEnabledExtensionNames

-	};

-

-	VkInstance instance = VK_NULL_HANDLE;

-	VK_ASSERT(driver.vkCreateInstance(&createInfo, nullptr, &instance));

-

-	ASSERT_TRUE(driver.resolve(instance));

-

-	std::unique_ptr<Device> device;

-	VK_ASSERT(Device::CreateComputeDevice(&driver, instance, device));

-	ASSERT_TRUE(device->IsValid());

-

-	// struct Buffers

-	// {

-	//     uint32_t pad0[63];

-	//     uint32_t magic0;

-	//     uint32_t in[NUM_ELEMENTS]; // Aligned to 0x100

-	//     uint32_t magic1;

-	//     uint32_t pad1[N];

-	//     uint32_t magic2;

-	//     uint32_t out[NUM_ELEMENTS]; // Aligned to 0x100

-	//     uint32_t magic3;

-	// };

-	static constexpr uint32_t magic0 = 0x01234567;

-	static constexpr uint32_t magic1 = 0x89abcdef;

-	static constexpr uint32_t magic2 = 0xfedcba99;

-	static constexpr uint32_t magic3 = 0x87654321;

-	size_t numElements = GetParam().numElements;

-	size_t alignElements = 0x100 / sizeof(uint32_t);

-	size_t magic0Offset = alignElements - 1;

-	size_t inOffset = 1 + magic0Offset;

-	size_t magic1Offset = numElements + inOffset;

-	size_t magic2Offset = alignUp(magic1Offset + 1, alignElements) - 1;

-	size_t outOffset = 1 + magic2Offset;

-	size_t magic3Offset = numElements + outOffset;

-	size_t buffersTotalElements = alignUp(1 + magic3Offset, alignElements);

-	size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;

-

-	VkDeviceMemory memory;

-	VK_ASSERT(device->AllocateMemory(buffersSize,

-	                                 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,

-	                                 &memory));

-

-	uint32_t *buffers;

-	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));

-

-	buffers[magic0Offset] = magic0;

-	buffers[magic1Offset] = magic1;

-	buffers[magic2Offset] = magic2;

-	buffers[magic3Offset] = magic3;

-

-	for(size_t i = 0; i < numElements; i++)

-	{

-		buffers[inOffset + i] = input(i);

-	}

-

-	device->UnmapMemory(memory);

-	buffers = nullptr;

-

-	VkBuffer bufferIn;

-	VK_ASSERT(device->CreateStorageBuffer(memory,

-	                                      sizeof(uint32_t) * numElements,

-	                                      sizeof(uint32_t) * inOffset,

-	                                      &bufferIn));

-

-	VkBuffer bufferOut;

-	VK_ASSERT(device->CreateStorageBuffer(memory,

-	                                      sizeof(uint32_t) * numElements,

-	                                      sizeof(uint32_t) * outOffset,

-	                                      &bufferOut));

-

-	VkShaderModule shaderModule;

-	VK_ASSERT(device->CreateShaderModule(code, &shaderModule));

-

-	std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings = {

-		{

-		    0,                                  // binding

-		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType

-		    1,                                  // descriptorCount

-		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags

-		    0,                                  // pImmutableSamplers

-		},

-		{

-		    1,                                  // binding

-		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType

-		    1,                                  // descriptorCount

-		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags

-		    0,                                  // pImmutableSamplers

-		}

-	};

-

-	VkDescriptorSetLayout descriptorSetLayout;

-	VK_ASSERT(device->CreateDescriptorSetLayout(descriptorSetLayoutBindings, &descriptorSetLayout));

-

-	VkPipelineLayout pipelineLayout;

-	VK_ASSERT(device->CreatePipelineLayout(descriptorSetLayout, &pipelineLayout));

-

-	VkPipeline pipeline;

-	VK_ASSERT(device->CreateComputePipeline(shaderModule, pipelineLayout, &pipeline));

-

-	VkDescriptorPool descriptorPool;

-	VK_ASSERT(device->CreateStorageBufferDescriptorPool(2, &descriptorPool));

-

-	VkDescriptorSet descriptorSet;

-	VK_ASSERT(device->AllocateDescriptorSet(descriptorPool, descriptorSetLayout, &descriptorSet));

-

-	std::vector<VkDescriptorBufferInfo> descriptorBufferInfos = {

-		{

-		    bufferIn,       // buffer

-		    0,              // offset

-		    VK_WHOLE_SIZE,  // range

-		},

-		{

-		    bufferOut,      // buffer

-		    0,              // offset

-		    VK_WHOLE_SIZE,  // range

-		}

-	};

-	device->UpdateStorageBufferDescriptorSets(descriptorSet, descriptorBufferInfos);

-

-	VkCommandPool commandPool;

-	VK_ASSERT(device->CreateCommandPool(&commandPool));

-

-	VkCommandBuffer commandBuffer;

-	VK_ASSERT(device->AllocateCommandBuffer(commandPool, &commandBuffer));

-

-	VK_ASSERT(device->BeginCommandBuffer(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, commandBuffer));

-

-	driver.vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);

-

-	driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,

-	                               0, nullptr);

-

-	driver.vkCmdDispatch(commandBuffer, numElements / GetParam().localSizeX, 1, 1);

-

-	VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));

-

-	VK_ASSERT(device->QueueSubmitAndWait(commandBuffer));

-

-	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));

-

-	for(size_t i = 0; i < numElements; ++i)

-	{

-		auto got = buffers[i + outOffset];

-		EXPECT_EQ(expected(i), got) << "Unexpected output at " << i;

-	}

-

-	// Check for writes outside of bounds.

-	EXPECT_EQ(buffers[magic0Offset], magic0);

-	EXPECT_EQ(buffers[magic1Offset], magic1);

-	EXPECT_EQ(buffers[magic2Offset], magic2);

-	EXPECT_EQ(buffers[magic3Offset], magic3);

-

-	device->UnmapMemory(memory);

-	buffers = nullptr;

-

-	device->FreeCommandBuffer(commandPool, commandBuffer);

-	device->FreeMemory(memory);

-	device->DestroyPipeline(pipeline);

-	device->DestroyCommandPool(commandPool);

-	device->DestroyPipelineLayout(pipelineLayout);

-	device->DestroyDescriptorSetLayout(descriptorSetLayout);

-	device->DestroyDescriptorPool(descriptorPool);

-	device->DestroyBuffer(bufferIn);

-	device->DestroyBuffer(bufferOut);

-	device->DestroyShaderModule(shaderModule);

-	device.reset(nullptr);

-	driver.vkDestroyInstance(instance, nullptr);

-}

-

-INSTANTIATE_TEST_SUITE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(ComputeParams{ 512, 1, 1, 1 }, ComputeParams{ 512, 2, 1, 1 }, ComputeParams{ 512, 4, 1, 1 }, ComputeParams{ 512, 8, 1, 1 }, ComputeParams{ 512, 16, 1, 1 }, ComputeParams{ 512, 32, 1, 1 },

-

-                                                                                                    // Non-multiple of SIMD-lane.

-                                                                                                    ComputeParams{ 3, 1, 1, 1 }, ComputeParams{ 2, 1, 1, 1 }));

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)

-{

-	std::stringstream src;

-	// #version 450

-	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

-	// layout(binding = 0, std430) buffer InBuffer

-	// {

-	//     int Data[];

-	// } In;

-	// layout(binding = 1, std430) buffer OutBuffer

-	// {

-	//     int Data[];

-	// } Out;

-	// void main()

-	// {

-	//     Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];

-	// }

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

-        "%12 = OpConstant %9 0\n"               // int32(0)

-        "%13 = OpConstant %10 0\n"              // uint32(0)

-        "%14 = OpTypeVector %10 3\n"            // vec3<int32>

-        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*

-         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

-        "%16 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

-        "%17 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%18 = OpLabel\n"

-        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

-        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

-        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

-              "OpStore %23 %22\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, GlobalInvocationId)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

-        "%12 = OpConstant %9 0\n"               // int32(0)

-        "%13 = OpConstant %9 1\n"               // int32(1)

-        "%14 = OpConstant %10 0\n"              // uint32(0)

-        "%15 = OpConstant %10 1\n"              // uint32(1)

-        "%16 = OpConstant %10 2\n"              // uint32(2)

-        "%17 = OpTypeVector %10 3\n"            // vec3<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec3<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %14\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpAccessChain %19 %2 %15\n"      // &gl_GlobalInvocationId.y

-        "%24 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.z

-        "%25 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%26 = OpLoad %10 %23\n"                // gl_GlobalInvocationId.y

-        "%27 = OpLoad %10 %24\n"                // gl_GlobalInvocationId.z

-        "%28 = OpAccessChain %20 %6 %12 %25\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%29 = OpLoad %9 %28\n"                 // out.arr[gl_GlobalInvocationId.x]

-        "%30 = OpIAdd %9 %29 %26\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y

-        "%31 = OpIAdd %9 %30 %27\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z

-        "%32 = OpAccessChain %20 %5 %12 %25\n"  // &out.arr[gl_GlobalInvocationId.x]

-              "OpStore %32 %31\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	// gl_GlobalInvocationId.y and gl_GlobalInvocationId.z should both be zero.

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

-        "%12 = OpConstant %9 0\n"               // int32(0)

-        "%13 = OpConstant %10 0\n"              // uint32(0)

-        "%14 = OpTypeVector %10 3\n"            // vec3<int32>

-        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*

-         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

-        "%16 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

-        "%17 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%18 = OpLabel\n"

-        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

-        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

-        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %22 = in value

-              "OpBranch %24\n"

-        "%24 = OpLabel\n"

-              "OpBranch %25\n"

-        "%25 = OpLabel\n"

-              "OpBranch %26\n"

-        "%26 = OpLabel\n"

-    // %22 = out value

-    // End of branch logic

-              "OpStore %23 %22\n"

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

-        "%12 = OpConstant %9 0\n"               // int32(0)

-        "%13 = OpConstant %10 0\n"              // uint32(0)

-        "%14 = OpTypeVector %10 3\n"            // vec3<int32>

-        "%15 = OpTypePointer Input %14\n"       // vec3<int32>*

-         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

-        "%16 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

-        "%17 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%18 = OpLabel\n"

-        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

-        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

-        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %22 = in value

-              "OpBranch %24\n"

-        "%24 = OpLabel\n"

-        "%25 = OpIAdd %9 %22 %22\n"             // %25 = in*2

-              "OpBranch %26\n"

-        "%26 = OpLabel\n"

-              "OpBranch %27\n"

-        "%27 = OpLabel\n"

-    // %25 = out value

-    // End of branch logic

-              "OpStore %23 %25\n"               // use SSA value from previous block

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalSimple)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 2\n"               // int32(2)

-        "%15 = OpConstant %10 0\n"              // uint32(0)

-        "%16 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*

-         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId

-        "%18 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%19 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%20 = OpLabel\n"

-        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x

-        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x

-        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %24 = in value

-        "%26 = OpSMod %9 %24 %14\n"             // in % 2

-        "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0

-              "OpSelectionMerge %28 None\n"

-              "OpBranchConditional %27 %28 %28\n" // Both go to %28

-        "%28 = OpLabel\n"

-    // %26 = out value

-    // End of branch logic

-              "OpStore %25 %26\n"               // use SSA value from previous block

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalTwoEmptyBlocks)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 2\n"               // int32(2)

-        "%15 = OpConstant %10 0\n"              // uint32(0)

-        "%16 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*

-         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId

-        "%18 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%19 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%20 = OpLabel\n"

-        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x

-        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x

-        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %24 = in value

-        "%26 = OpSMod %9 %24 %14\n"             // in % 2

-        "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0

-              "OpSelectionMerge %28 None\n"

-              "OpBranchConditional %27 %29 %30\n"

-        "%29 = OpLabel\n"                       // (in % 2) == 0

-              "OpBranch %28\n"

-        "%30 = OpLabel\n"                       // (in % 2) != 0

-              "OpBranch %28\n"

-        "%28 = OpLabel\n"

-    // %26 = out value

-    // End of branch logic

-              "OpStore %25 %26\n"               // use SSA value from previous block

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });

-}

-

-// TODO: Test for parallel assignment

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalStore)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0

-              "OpSelectionMerge %29 None\n"

-              "OpBranchConditional %28 %30 %31\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 0

-              "OpStore %26 %14\n"               // write 1

-              "OpBranch %29\n"

-        "%31 = OpLabel\n"                       // (in % 2) != 0

-              "OpStore %26 %15\n"               // write 2

-              "OpBranch %29\n"

-        "%29 = OpLabel\n"

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalReturnTrue)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0

-              "OpSelectionMerge %29 None\n"

-              "OpBranchConditional %28 %30 %29\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 0

-              "OpReturn\n"

-        "%29 = OpLabel\n"                       // merge

-              "OpStore %26 %15\n"               // write 2

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 0 : 2; });

-}

-

-// TODO: Test for parallel assignment

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalPhi)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-        "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0

-              "OpSelectionMerge %29 None\n"

-              "OpBranchConditional %28 %30 %31\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 0

-              "OpBranch %29\n"

-        "%31 = OpLabel\n"                       // (in % 2) != 0

-              "OpBranch %29\n"

-        "%29 = OpLabel\n"

-        "%32 = OpPhi %9 %14 %30 %15 %31\n"      // (in % 2) == 0 ? 1 : 2

-    // End of branch logic

-              "OpStore %26 %32\n"

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 2\n"               // int32(2)

-        "%15 = OpConstant %10 0\n"              // uint32(0)

-        "%16 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*

-         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId

-        "%18 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%19 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%20 = OpLabel\n"

-        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x

-        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x

-        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %24 = in value

-        "%26 = OpSMod %9 %24 %14\n"             // in % 2

-              "OpSelectionMerge %27 None\n"

-              "OpSwitch %26 %27 0 %28 1 %29\n"

-        "%28 = OpLabel\n"                       // (in % 2) == 0

-              "OpBranch %27\n"

-        "%29 = OpLabel\n"                       // (in % 2) == 1

-              "OpBranch %27\n"

-        "%27 = OpLabel\n"

-    // %26 = out value

-    // End of branch logic

-              "OpStore %25 %26\n"               // use SSA value from previous block

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %28 0 %29 1 %30\n"

-        "%29 = OpLabel\n"                       // (in % 2) == 0

-              "OpStore %26 %15\n"               // write 2

-              "OpBranch %28\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 1

-              "OpStore %26 %14\n"               // write 1

-              "OpBranch %28\n"

-        "%28 = OpLabel\n"

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %28 0 %29 1 %30\n"

-        "%29 = OpLabel\n"                       // (in % 2) == 0

-              "OpBranch %28\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 1

-              "OpReturn\n"

-        "%28 = OpLabel\n"

-              "OpStore %26 %14\n"               // write 1

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %29 1 %30\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 1

-              "OpBranch %28\n"

-        "%29 = OpLabel\n"                       // (in % 2) != 1

-              "OpReturn\n"

-        "%28 = OpLabel\n"                       // merge

-              "OpStore %26 %14\n"               // write 1

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %29 0 %30 1 %31\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 0

-        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate

-              "OpStore %26 %32\n"               // write a value (overwritten later)

-              "OpBranch %31\n"                  // fallthrough

-        "%31 = OpLabel\n"                       // (in % 2) == 1

-              "OpStore %26 %15\n"               // write 2

-              "OpBranch %28\n"

-        "%29 = OpLabel\n"                       // unreachable

-              "OpUnreachable\n"

-        "%28 = OpLabel\n"                       // merge

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %29 0 %30 1 %31\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 0

-        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate

-              "OpStore %26 %32\n"               // write a value (overwritten later)

-              "OpBranch %29\n"                  // fallthrough

-        "%29 = OpLabel\n"                       // default

-        "%33 = OpIAdd %9 %27 %14\n"             // generate an intermediate

-              "OpStore %26 %33\n"               // write a value (overwritten later)

-              "OpBranch %31\n"                  // fallthrough

-        "%31 = OpLabel\n"                       // (in % 2) == 1

-              "OpStore %26 %15\n"               // write 2

-              "OpBranch %28\n"

-        "%28 = OpLabel\n"                       // merge

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)

-{

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                GetParam().localSizeX << " " <<

-                GetParam().localSizeY << " " <<

-                GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-         "%7 = OpTypeVoid\n"

-         "%8 = OpTypeFunction %7\n"             // void()

-         "%9 = OpTypeInt 32 1\n"                // int32

-        "%10 = OpTypeInt 32 0\n"                // uint32

-        "%11 = OpTypeBool\n"

-         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

-         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

-        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

-         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

-        "%13 = OpConstant %9 0\n"               // int32(0)

-        "%14 = OpConstant %9 1\n"               // int32(1)

-        "%15 = OpConstant %9 2\n"               // int32(2)

-        "%16 = OpConstant %10 0\n"              // uint32(0)

-        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

-        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

-         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

-        "%19 = OpTypePointer Input %10\n"       // uint32*

-         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

-        "%20 = OpTypePointer Uniform %9\n"      // int32*

-         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

-        "%21 = OpLabel\n"

-        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

-        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

-        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

-        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

-        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

-    // Start of branch logic

-    // %25 = in value

-        "%27 = OpSMod %9 %25 %15\n"             // in % 2

-              "OpSelectionMerge %28 None\n"

-              "OpSwitch %27 %29 1 %30\n"

-        "%30 = OpLabel\n"                       // (in % 2) == 1

-              "OpBranch %28\n"

-        "%29 = OpLabel\n"                       // (in % 2) != 1

-              "OpBranch %28\n"

-        "%28 = OpLabel\n"                       // merge

-        "%31 = OpPhi %9 %14 %30 %15 %29\n"      // (in % 2) == 1 ? 1 : 2

-              "OpStore %26 %31\n"

-    // End of branch logic

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });

-}

-

-TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)

-{

-	// #version 450

-	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

-	// layout(binding = 0, std430) buffer InBuffer

-	// {

-	//     int Data[];

-	// } In;

-	// layout(binding = 1, std430) buffer OutBuffer

-	// {

-	//     int Data[];

-	// } Out;

-	// void main()

-	// {

-	//     int phi = 0;

-	//     uint lane = gl_GlobalInvocationID.x % 4;

-	//     for (uint i = 0; i < 4; i++)

-	//     {

-	//         if (lane == i)

-	//         {

-	//             phi = In.Data[gl_GlobalInvocationID.x];

-	//             break;

-	//         }

-	//     }

-	//     Out.Data[gl_GlobalInvocationID.x] = phi;

-	// }

-	std::stringstream src;

-	// clang-format off

-    src <<

-              "OpCapability Shader\n"

-         "%1 = OpExtInstImport \"GLSL.std.450\"\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %2 \"main\" %3\n"

-              "OpExecutionMode %2 LocalSize " <<

-                              GetParam().localSizeX << " " <<

-                              GetParam().localSizeY << " " <<

-                              GetParam().localSizeZ << "\n" <<

-              "OpDecorate %3 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %4 ArrayStride 4\n"

-              "OpMemberDecorate %5 0 Offset 0\n"

-              "OpDecorate %5 BufferBlock\n"

-              "OpDecorate %6 DescriptorSet 0\n"

-              "OpDecorate %6 Binding 0\n"

-              "OpDecorate %7 ArrayStride 4\n"

-              "OpMemberDecorate %8 0 Offset 0\n"

-              "OpDecorate %8 BufferBlock\n"

-              "OpDecorate %9 DescriptorSet 0\n"

-              "OpDecorate %9 Binding 1\n"

-        "%10 = OpTypeVoid\n"

-        "%11 = OpTypeFunction %10\n"

-        "%12 = OpTypeInt 32 1\n"

-        "%13 = OpConstant %12 0\n"

-        "%14 = OpTypeInt 32 0\n"

-        "%15 = OpTypeVector %14 3\n"

-        "%16 = OpTypePointer Input %15\n"

-         "%3 = OpVariable %16 Input\n"

-        "%17 = OpConstant %14 0\n"

-        "%18 = OpTypePointer Input %14\n"

-        "%19 = OpConstant %14 4\n"

-        "%20 = OpTypeBool\n"

-         "%4 = OpTypeRuntimeArray %12\n"

-         "%5 = OpTypeStruct %4\n"

-        "%21 = OpTypePointer Uniform %5\n"

-         "%6 = OpVariable %21 Uniform\n"

-        "%22 = OpTypePointer Uniform %12\n"

-        "%23 = OpConstant %12 1\n"

-         "%7 = OpTypeRuntimeArray %12\n"

-         "%8 = OpTypeStruct %7\n"

-        "%24 = OpTypePointer Uniform %8\n"

-         "%9 = OpVariable %24 Uniform\n"

-         "%2 = OpFunction %10 None %11\n"

-        "%25 = OpLabel\n"

-        "%26 = OpAccessChain %18 %3 %17\n"

-        "%27 = OpLoad %14 %26\n"

-        "%28 = OpUMod %14 %27 %19\n"

-              "OpBranch %29\n"

-        "%29 = OpLabel\n"

-        "%30 = OpPhi %14 %17 %25 %31 %32\n"

-        "%33 = OpULessThan %20 %30 %19\n"

-              "OpLoopMerge %34 %32 None\n"

-              "OpBranchConditional %33 %35 %34\n"

-        "%35 = OpLabel\n"

-        "%36 = OpIEqual %20 %28 %30\n"

-              "OpSelectionMerge %37 None\n"

-              "OpBranchConditional %36 %38 %37\n"

-        "%38 = OpLabel\n"

-        "%39 = OpAccessChain %22 %6 %13 %27\n"

-        "%40 = OpLoad %12 %39\n"

-              "OpBranch %34\n"

-        "%37 = OpLabel\n"

-              "OpBranch %32\n"

-        "%32 = OpLabel\n"

-        "%31 = OpIAdd %14 %30 %23\n"

-              "OpBranch %29\n"

-        "%34 = OpLabel\n"

-        "%41 = OpPhi %12 %13 %29 %40 %38\n" // %40: phi

-        "%42 = OpAccessChain %22 %9 %13 %27\n"

-              "OpStore %42 %41\n"

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-	// clang-format on

-

-	test(

-	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

-}

diff --git a/tests/VulkanWrapper/CMakeLists.txt b/tests/VulkanWrapper/CMakeLists.txt
index b7851c7..4cf4479 100644
--- a/tests/VulkanWrapper/CMakeLists.txt
+++ b/tests/VulkanWrapper/CMakeLists.txt
@@ -17,20 +17,24 @@
 )
 
 set(VULKAN_WRAPPER_SRC_FILES
-    Buffer.hpp
-    Framebuffer.hpp
-    Image.hpp
-    Swapchain.hpp
-    Util.hpp
-    VulkanHeaders.hpp
-    Window.hpp
     Buffer.cpp
+    Buffer.hpp
+    DrawTester.cpp
+    DrawTester.hpp
     Framebuffer.cpp
+    Framebuffer.hpp
     Image.cpp
+    Image.hpp
     Swapchain.cpp
+    Swapchain.hpp
     Util.cpp
+    Util.hpp
     VulkanHeaders.cpp
+    VulkanHeaders.hpp
+    VulkanTester.cpp
+    VulkanTester.hpp
     Window.cpp
+    Window.hpp
 )
 
 add_library(VulkanWrapper STATIC
@@ -50,7 +54,7 @@
 endif()
 
 set_target_properties(VulkanWrapper PROPERTIES
-    FOLDER "Benchmarks"
+    FOLDER "Tests"
     RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
 )
 
@@ -60,6 +64,11 @@
         "${SWIFTSHADER_DIR}/include"
 )
 
+target_compile_definitions(VulkanWrapper
+    PUBLIC
+        "STANDALONE"
+)
+
 target_compile_options(VulkanWrapper
     PRIVATE
         ${ROOT_PROJECT_COMPILE_OPTIONS}
diff --git a/tests/VulkanBenchmarks/DrawBenchmark.cpp b/tests/VulkanWrapper/DrawTester.cpp
similarity index 92%
rename from tests/VulkanBenchmarks/DrawBenchmark.cpp
rename to tests/VulkanWrapper/DrawTester.cpp
index 567e542..2e2e29e 100644
--- a/tests/VulkanBenchmarks/DrawBenchmark.cpp
+++ b/tests/VulkanWrapper/DrawTester.cpp
@@ -12,16 +12,16 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "DrawBenchmark.hpp"
+#include "DrawTester.hpp"
 
 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))
 
-DrawBenchmark::DrawBenchmark(Multisample multisample)
+DrawTester::DrawTester(Multisample multisample)
     : multisample(multisample == Multisample::True)
 {
 }
 
-DrawBenchmark::~DrawBenchmark()
+DrawTester::~DrawTester()
 {
 	device.freeCommandBuffers(commandPool, commandBuffers);
 
@@ -59,9 +59,9 @@
 	window.reset();
 }
 
-void DrawBenchmark::initialize()
+void DrawTester::initialize()
 {
-	VulkanBenchmark::initialize();
+	VulkanTester::initialize();
 
 	window.reset(new Window(instance, windowSize));
 	swapchain.reset(new Swapchain(physicalDevice, device, *window));
@@ -78,7 +78,7 @@
 	createCommandBuffers(renderPass);
 }
 
-void DrawBenchmark::renderFrame()
+void DrawTester::renderFrame()
 {
 	swapchain->acquireNextImage(presentCompleteSemaphore, currentFrameBuffer);
 
@@ -101,12 +101,12 @@
 	swapchain->queuePresent(queue, currentFrameBuffer, renderCompleteSemaphore);
 }
 
-void DrawBenchmark::show()
+void DrawTester::show()
 {
 	window->show();
 }
 
-vk::RenderPass DrawBenchmark::createRenderPass(vk::Format colorFormat)
+vk::RenderPass DrawTester::createRenderPass(vk::Format colorFormat)
 {
 	std::vector<vk::AttachmentDescription> attachments(multisample ? 2 : 1);
 
@@ -187,7 +187,7 @@
 	return device.createRenderPass(renderPassInfo);
 }
 
-void DrawBenchmark::createFramebuffers(vk::RenderPass renderPass)
+void DrawTester::createFramebuffers(vk::RenderPass renderPass)
 {
 	framebuffers.resize(swapchain->imageCount());
 
@@ -197,14 +197,14 @@
 	}
 }
 
-void DrawBenchmark::prepareVertices()
+void DrawTester::prepareVertices()
 {
-	doCreateVertexBuffers();
+	hooks.createVertexBuffers(*this);
 }
 
-vk::Pipeline DrawBenchmark::createGraphicsPipeline(vk::RenderPass renderPass)
+vk::Pipeline DrawTester::createGraphicsPipeline(vk::RenderPass renderPass)
 {
-	auto setLayoutBindings = doCreateDescriptorSetLayouts();
+	auto setLayoutBindings = hooks.createDescriptorSetLayout(*this);
 
 	std::vector<vk::DescriptorSetLayout> setLayouts;
 	if(!setLayoutBindings.empty())
@@ -271,8 +271,8 @@
 	multisampleState.rasterizationSamples = multisample ? vk::SampleCountFlagBits::e4 : vk::SampleCountFlagBits::e1;
 	multisampleState.pSampleMask = nullptr;
 
-	vk::ShaderModule vertexModule = doCreateVertexShader();
-	vk::ShaderModule fragmentModule = doCreateFragmentShader();
+	vk::ShaderModule vertexModule = hooks.createVertexShader(*this);
+	vk::ShaderModule fragmentModule = hooks.createFragmentShader(*this);
 
 	assert(vertexModule);    // TODO: if nullptr, use a default
 	assert(fragmentModule);  // TODO: if nullptr, use a default
@@ -307,7 +307,7 @@
 	return pipeline;
 }
 
-void DrawBenchmark::createSynchronizationPrimitives()
+void DrawTester::createSynchronizationPrimitives()
 {
 	vk::SemaphoreCreateInfo semaphoreCreateInfo;
 	presentCompleteSemaphore = device.createSemaphore(semaphoreCreateInfo);
@@ -322,7 +322,7 @@
 	}
 }
 
-void DrawBenchmark::createCommandBuffers(vk::RenderPass renderPass)
+void DrawTester::createCommandBuffers(vk::RenderPass renderPass)
 {
 	vk::CommandPoolCreateInfo commandPoolCreateInfo;
 	commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
@@ -351,7 +351,7 @@
 
 		descriptorSets = device.allocateDescriptorSets(allocInfo);
 
-		doUpdateDescriptorSet(commandPool, descriptorSets[0]);
+		hooks.updateDescriptorSet(*this, commandPool, descriptorSets[0]);
 	}
 
 	vk::CommandBufferAllocateInfo commandBufferAllocateInfo;
@@ -402,7 +402,7 @@
 	}
 }
 
-void DrawBenchmark::addVertexBuffer(void *vertexBufferData, size_t vertexBufferDataSize, size_t vertexSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes)
+void DrawTester::addVertexBuffer(void *vertexBufferData, size_t vertexBufferDataSize, size_t vertexSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes)
 {
 	assert(!vertices.buffer);  // For now, only support adding once
 
@@ -437,7 +437,7 @@
 	vertices.numVertices = static_cast<uint32_t>(vertexBufferDataSize / vertexSize);
 }
 
-vk::ShaderModule DrawBenchmark::createShaderModule(const char *glslSource, EShLanguage glslLanguage)
+vk::ShaderModule DrawTester::createShaderModule(const char *glslSource, EShLanguage glslLanguage)
 {
 	auto spirv = Util::compileGLSLtoSPIRV(glslSource, glslLanguage);
 
diff --git a/tests/VulkanWrapper/DrawTester.hpp b/tests/VulkanWrapper/DrawTester.hpp
new file mode 100644
index 0000000..431a6c5
--- /dev/null
+++ b/tests/VulkanWrapper/DrawTester.hpp
@@ -0,0 +1,197 @@
+// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DRAW_TESTER_HPP_
+#define DRAW_TESTER_HPP_
+
+#include "Framebuffer.hpp"
+#include "Image.hpp"
+#include "Swapchain.hpp"
+#include "Util.hpp"
+#include "VulkanTester.hpp"
+#include "Window.hpp"
+
+enum class Multisample
+{
+	False,
+	True
+};
+
+class DrawTester : public VulkanTester
+{
+public:
+	using ThisType = DrawTester;
+
+	DrawTester(Multisample multisample = Multisample::False);
+	~DrawTester();
+
+	void initialize();
+	void renderFrame();
+	void show();
+
+	/////////////////////////
+	// Hooks
+	/////////////////////////
+
+	// Called from prepareVertices.
+	// Callback may call tester.addVertexBuffer() from this function.
+	void onCreateVertexBuffers(std::function<void(ThisType &tester)> callback);
+
+	// Called from createGraphicsPipeline.
+	// Callback must return vector of DescriptorSetLayoutBindings for which a DescriptorSetLayout
+	// will be created and stored in this->descriptorSetLayout.
+	void onCreateDescriptorSetLayouts(std::function<std::vector<vk::DescriptorSetLayoutBinding>(ThisType &tester)> callback);
+
+	// Called from createGraphicsPipeline.
+	// Callback should call tester.createShaderModule() and return the result.
+	void onCreateVertexShader(std::function<vk::ShaderModule(ThisType &tester)> callback);
+
+	// Called from createGraphicsPipeline.
+	// Callback should call tester.createShaderModule() and return the result.
+	void onCreateFragmentShader(std::function<vk::ShaderModule(ThisType &tester)> callback);
+
+	// Called from createCommandBuffers.
+	// Callback may create resources (tester.addImage, tester.addSampler, etc.), and make sure to
+	// call tester.device().updateDescriptorSets.
+	void onUpdateDescriptorSet(std::function<void(ThisType &tester, vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet)> callback);
+
+	/////////////////////////
+	// Resource Management
+	/////////////////////////
+
+	// Call from doCreateFragmentShader()
+	vk::ShaderModule createShaderModule(const char *glslSource, EShLanguage glslLanguage);
+
+	// Call from doCreateVertexBuffers()
+	template<typename VertexType>
+	void addVertexBuffer(VertexType *vertexBufferData, size_t vertexBufferDataSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes)
+	{
+		addVertexBuffer(vertexBufferData, vertexBufferDataSize, sizeof(VertexType), std::move(inputAttributes));
+	}
+
+	template<typename T>
+	struct Resource
+	{
+		size_t id;
+		T &obj;
+	};
+
+	template<typename... Args>
+	Resource<Image> addImage(Args &&... args)
+	{
+		images.emplace_back(std::make_unique<Image>(std::forward<Args>(args)...));
+		return { images.size() - 1, *images.back() };
+	}
+
+	Image &getImageById(size_t id)
+	{
+		return *images[id].get();
+	}
+
+	Resource<vk::Sampler> addSampler(const vk::SamplerCreateInfo &samplerCreateInfo)
+	{
+		auto sampler = device.createSampler(samplerCreateInfo);
+		samplers.push_back(sampler);
+		return { samplers.size() - 1, samplers.back() };
+	}
+
+	vk::Sampler &getSamplerById(size_t id)
+	{
+		return samplers[id];
+	}
+
+private:
+	void createSynchronizationPrimitives();
+	void createCommandBuffers(vk::RenderPass renderPass);
+	void prepareVertices();
+	void createFramebuffers(vk::RenderPass renderPass);
+	vk::RenderPass createRenderPass(vk::Format colorFormat);
+	vk::Pipeline createGraphicsPipeline(vk::RenderPass renderPass);
+	void addVertexBuffer(void *vertexBufferData, size_t vertexBufferDataSize, size_t vertexSize, std::vector<vk::VertexInputAttributeDescription> inputAttributes);
+
+	struct Hook
+	{
+		std::function<void(ThisType &tester)> createVertexBuffers = [](auto &) {};
+		std::function<std::vector<vk::DescriptorSetLayoutBinding>(ThisType &tester)> createDescriptorSetLayout = [](auto &) { return std::vector<vk::DescriptorSetLayoutBinding>{}; };
+		std::function<vk::ShaderModule(ThisType &tester)> createVertexShader = [](auto &) { return vk::ShaderModule{}; };
+		std::function<vk::ShaderModule(ThisType &tester)> createFragmentShader = [](auto &) { return vk::ShaderModule{}; };
+		std::function<void(ThisType &tester, vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet)> updateDescriptorSet = [](auto &, auto &, auto &) {};
+	} hooks;
+
+	const vk::Extent2D windowSize = { 1280, 720 };
+	const bool multisample;
+
+	std::unique_ptr<Window> window;
+	std::unique_ptr<Swapchain> swapchain;
+
+	vk::RenderPass renderPass;  // Owning handle
+	std::vector<std::unique_ptr<Framebuffer>> framebuffers;
+	uint32_t currentFrameBuffer = 0;
+
+	struct VertexBuffer
+	{
+		vk::Buffer buffer;        // Owning handle
+		vk::DeviceMemory memory;  // Owning handle
+
+		vk::VertexInputBindingDescription inputBinding;
+		std::vector<vk::VertexInputAttributeDescription> inputAttributes;
+		vk::PipelineVertexInputStateCreateInfo inputState;
+
+		uint32_t numVertices;
+	} vertices;
+
+	vk::DescriptorSetLayout descriptorSetLayout;  // Owning handle
+	vk::PipelineLayout pipelineLayout;            // Owning handle
+	vk::Pipeline pipeline;                        // Owning handle
+
+	vk::Semaphore presentCompleteSemaphore;  // Owning handle
+	vk::Semaphore renderCompleteSemaphore;   // Owning handle
+	std::vector<vk::Fence> waitFences;       // Owning handles
+
+	vk::CommandPool commandPool;        // Owning handle
+	vk::DescriptorPool descriptorPool;  // Owning handle
+
+	// Resources
+	std::vector<std::unique_ptr<Image>> images;
+	std::vector<vk::Sampler> samplers;  // Owning handles
+
+	std::vector<vk::CommandBuffer> commandBuffers;  // Owning handles
+};
+
+inline void DrawTester::onCreateVertexBuffers(std::function<void(ThisType &tester)> callback)
+{
+	hooks.createVertexBuffers = std::move(callback);
+}
+
+inline void DrawTester::onCreateDescriptorSetLayouts(std::function<std::vector<vk::DescriptorSetLayoutBinding>(ThisType &tester)> callback)
+{
+	hooks.createDescriptorSetLayout = std::move(callback);
+}
+
+inline void DrawTester::onCreateVertexShader(std::function<vk::ShaderModule(ThisType &tester)> callback)
+{
+	hooks.createVertexShader = std::move(callback);
+}
+
+inline void DrawTester::onCreateFragmentShader(std::function<vk::ShaderModule(ThisType &tester)> callback)
+{
+	hooks.createFragmentShader = std::move(callback);
+}
+
+inline void DrawTester::onUpdateDescriptorSet(std::function<void(ThisType &tester, vk::CommandPool &commandPool, vk::DescriptorSet &descriptorSet)> callback)
+{
+	hooks.updateDescriptorSet = std::move(callback);
+}
+
+#endif  // DRAW_TESTER_HPP_
diff --git a/tests/VulkanWrapper/Swapchain.cpp b/tests/VulkanWrapper/Swapchain.cpp
index 843fc63..772f489 100644
--- a/tests/VulkanWrapper/Swapchain.cpp
+++ b/tests/VulkanWrapper/Swapchain.cpp
@@ -70,7 +70,7 @@
 	device.destroySwapchainKHR(swapchain, nullptr);
 }
 
-void Swapchain::acquireNextImage(VkSemaphore presentCompleteSemaphore, uint32_t &imageIndex)
+void Swapchain::acquireNextImage(vk::Semaphore presentCompleteSemaphore, uint32_t &imageIndex)
 {
 	auto result = device.acquireNextImageKHR(swapchain, UINT64_MAX, presentCompleteSemaphore, vk::Fence());
 	imageIndex = result.value;
diff --git a/tests/VulkanWrapper/Swapchain.hpp b/tests/VulkanWrapper/Swapchain.hpp
index d7ab6fc..750b056 100644
--- a/tests/VulkanWrapper/Swapchain.hpp
+++ b/tests/VulkanWrapper/Swapchain.hpp
@@ -26,7 +26,7 @@
 	Swapchain(vk::PhysicalDevice physicalDevice, vk::Device device, Window &window);
 	~Swapchain();
 
-	void acquireNextImage(VkSemaphore presentCompleteSemaphore, uint32_t &imageIndex);
+	void acquireNextImage(vk::Semaphore presentCompleteSemaphore, uint32_t &imageIndex);
 	void queuePresent(vk::Queue queue, uint32_t imageIndex, vk::Semaphore waitSemaphore);
 
 	size_t imageCount() const
diff --git a/tests/VulkanWrapper/Util.cpp b/tests/VulkanWrapper/Util.cpp
index af2980c..37b8a9b 100644
--- a/tests/VulkanWrapper/Util.cpp
+++ b/tests/VulkanWrapper/Util.cpp
@@ -157,7 +157,7 @@
 	{
 		std::string debugLog = glslangShader->getInfoDebugLog();
 		std::string infoLog = glslangShader->getInfoLog();
-		assert(false);
+		assert(false && "Failed to parse shader");
 	}
 
 	glslang::TIntermediate *intermediateRepresentation = glslangShader->getIntermediate();
diff --git a/tests/VulkanWrapper/VulkanTester.cpp b/tests/VulkanWrapper/VulkanTester.cpp
new file mode 100644
index 0000000..d3ad036
--- /dev/null
+++ b/tests/VulkanWrapper/VulkanTester.cpp
@@ -0,0 +1,138 @@
+// Copyright 2021 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "VulkanTester.hpp"
+#include <fstream>
+
+#if defined(_WIN32)
+#	define OS_WINDOWS 1
+#elif defined(__APPLE__)
+#	define OS_MAC 1
+#elif defined(__ANDROID__)
+#	define OS_ANDROID 1
+#elif defined(__linux__)
+#	define OS_LINUX 1
+#elif defined(__Fuchsia__)
+#	define OS_FUCHSIA 1
+#else
+#	error Unimplemented platform
+#endif
+
+namespace {
+std::vector<const char *> getDriverPaths()
+{
+#if OS_WINDOWS
+#	if !defined(STANDALONE)
+// The DLL is delay loaded (see BUILD.gn), so we can load
+// the correct ones from Chrome's swiftshader subdirectory.
+// HMODULE libvulkan = LoadLibraryA("swiftshader\\libvulkan.dll");
+// EXPECT_NE((HMODULE)NULL, libvulkan);
+// return true;
+#		error TODO: !STANDALONE
+#	elif defined(NDEBUG)
+#		if defined(_WIN64)
+	return { "./build/Release_x64/vk_swiftshader.dll",
+		     "./build/Release/vk_swiftshader.dll",
+		     "./vk_swiftshader.dll" };
+#		else
+	return { "./build/Release_Win32/vk_swiftshader.dll",
+		     "./build/Release/vk_swiftshader.dll",
+		     "./vk_swiftshader.dll" };
+#		endif
+#	else
+#		if defined(_WIN64)
+	return { "./build/Debug_x64/vk_swiftshader.dll",
+		     "./build/Debug/vk_swiftshader.dll",
+		     "./vk_swiftshader.dll" };
+#		else
+	return { "./build/Debug_Win32/vk_swiftshader.dll",
+		     "./build/Debug/vk_swiftshader.dll",
+		     "./vk_swiftshader.dll" };
+#		endif
+#	endif
+#elif OS_MAC
+	return { "./build/Darwin/libvk_swiftshader.dylib",
+		     "swiftshader/libvk_swiftshader.dylib",
+		     "libvk_swiftshader.dylib" };
+#elif OS_LINUX
+	return { "./build/Linux/libvk_swiftshader.so",
+		     "swiftshader/libvk_swiftshader.so",
+		     "./libvk_swiftshader.so",
+		     "libvk_swiftshader.so" };
+#elif OS_ANDROID || OS_FUCHSIA
+	return
+	{
+		"libvk_swiftshader.so"
+	}
+#else
+#	error Unimplemented platform
+	return {};
+#endif
+}
+
+bool fileExists(const char *path)
+{
+	std::ifstream f(path);
+	return f.good();
+}
+
+std::unique_ptr<vk::DynamicLoader> loadDriver()
+{
+	for(auto &p : getDriverPaths())
+	{
+		if(!fileExists(p))
+			continue;
+		return std::make_unique<vk::DynamicLoader>(p);
+	}
+	return {};
+}
+
+}  // namespace
+
+VulkanTester::~VulkanTester()
+{
+	device.waitIdle();
+	device.destroy(nullptr);
+	instance.destroy(nullptr);
+}
+
+void VulkanTester::initialize()
+{
+	dl = loadDriver();
+	assert(dl && dl->success());
+
+	PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr = dl->getProcAddress<PFN_vkGetInstanceProcAddr>("vkGetInstanceProcAddr");
+	VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr);
+
+	instance = vk::createInstance({}, nullptr);
+	VULKAN_HPP_DEFAULT_DISPATCHER.init(instance);
+
+	std::vector<vk::PhysicalDevice> physicalDevices = instance.enumeratePhysicalDevices();
+	assert(!physicalDevices.empty());
+	physicalDevice = physicalDevices[0];
+
+	const float defaultQueuePriority = 0.0f;
+	vk::DeviceQueueCreateInfo queueCreateInfo;
+	queueCreateInfo.queueFamilyIndex = queueFamilyIndex;
+	queueCreateInfo.queueCount = 1;
+	queueCreateInfo.pQueuePriorities = &defaultQueuePriority;
+
+	vk::DeviceCreateInfo deviceCreateInfo;
+	deviceCreateInfo.queueCreateInfoCount = 1;
+	deviceCreateInfo.pQueueCreateInfos = &queueCreateInfo;
+
+	device = physicalDevice.createDevice(deviceCreateInfo, nullptr);
+
+	queue = device.getQueue(queueFamilyIndex, 0);
+}
diff --git a/tests/VulkanBenchmarks/VulkanBenchmark.hpp b/tests/VulkanWrapper/VulkanTester.hpp
similarity index 71%
rename from tests/VulkanBenchmarks/VulkanBenchmark.hpp
rename to tests/VulkanWrapper/VulkanTester.hpp
index f724bcc..89b7d05 100644
--- a/tests/VulkanBenchmarks/VulkanBenchmark.hpp
+++ b/tests/VulkanWrapper/VulkanTester.hpp
@@ -12,20 +12,25 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef VULKAN_BENCHMARK_HPP_
-#define VULKAN_BENCHMARK_HPP_
+#ifndef VULKAN_TESTER_HPP_
+#define VULKAN_TESTER_HPP_
 
 #include "VulkanHeaders.hpp"
 
-class VulkanBenchmark
+class VulkanTester
 {
 public:
-	VulkanBenchmark() = default;
-	virtual ~VulkanBenchmark();
+	VulkanTester() = default;
+	virtual ~VulkanTester();
 
 	// Call once after construction so that virtual functions may be called during init
 	void initialize();
 
+	const vk::DynamicLoader &dynamicLoader() const { return *dl; }
+	vk::Device &getDevice() { return this->device; }
+	vk::Queue &getQueue() { return this->queue; }
+	uint32_t getQueueFamilyIndex() const { return queueFamilyIndex; }
+
 private:
 	std::unique_ptr<vk::DynamicLoader> dl;
 
@@ -38,4 +43,4 @@
 	vk::Queue queue;
 };
 
-#endif  // VULKAN_BENCHMARK_HPP_
+#endif  // VULKAN_TESTER_HPP_