SpirvShader: Implement OpBranch

Bug: b/128527271
Change-Id: I367ed0d578e36a56baf4b8c4c2256ee4de1297cc
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/27097
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 9f2b2d5..794f671 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -1185,6 +1185,10 @@
 			EmitAll(insn, routine);
 			break;
 
+		case spv::OpBranch:
+			EmitBranch(insn, routine);
+			break;
+
 		default:
 			UNIMPLEMENTED(OpcodeName(insn.opcode()).c_str());
 			break;
@@ -2148,6 +2152,12 @@
 		dst.emplace(0, result);
 	}
 
+	void SpirvShader::EmitBranch(InsnIterator insn, SpirvRoutine *routine) const
+	{
+		auto blockId = Block::ID(insn.word(1));
+		EmitBlock(routine, getBlock(blockId));
+	}
+
 	void SpirvShader::emitEpilog(SpirvRoutine *routine) const
 	{
 		for (auto insn : *this)
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 34c7951..ecfa91c 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -490,6 +490,7 @@
 		void EmitExtendedInstruction(InsnIterator insn, SpirvRoutine *routine) const;
 		void EmitAny(InsnIterator insn, SpirvRoutine *routine) const;
 		void EmitAll(InsnIterator insn, SpirvRoutine *routine) const;
+		void EmitBranch(InsnIterator insn, SpirvRoutine *routine) const;
 
 		// OpcodeName returns the name of the opcode op.
 		// If NDEBUG is defined, then OpcodeName will only return the numerical code.
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
index 78eb489..bc57e41 100644
--- a/tests/VulkanUnitTests/unittests.cpp
+++ b/tests/VulkanUnitTests/unittests.cpp
@@ -149,82 +149,41 @@
 

 struct ComputeParams

 {

+    size_t numElements;

     int localSizeX;

     int localSizeY;

     int localSizeZ;

+

+    friend std::ostream& operator<<(std::ostream& os, const ComputeParams& params) {

+        return os << "ComputeParams{" <<

+            "numElements: " << params.numElements << ", " <<

+            "localSizeX: " << params.localSizeX << ", " <<

+            "localSizeY: " << params.localSizeY << ", " <<

+            "localSizeZ: " << params.localSizeZ <<

+            "}";

+    }

 };

 

-class SwiftShaderVulkanComputeTest : public testing::TestWithParam<ComputeParams> {};

-

-INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanComputeTest, testing::Values(

-    ComputeParams{1, 1, 1},

-    ComputeParams{2, 1, 1},

-    ComputeParams{4, 1, 1},

-    ComputeParams{8, 1, 1},

-    ComputeParams{16, 1, 1},

-    ComputeParams{32, 1, 1}

-));

-

-TEST_P(SwiftShaderVulkanComputeTest, Memcpy)

+// Base class for compute tests that read from an input buffer and write to an

+// output buffer of same length.

+class SwiftShaderVulkanBufferToBufferComputeTest : public testing::TestWithParam<ComputeParams>

 {

+public:

+    void test(const std::string& shader,

+        std::function<uint32_t(uint32_t idx)> input,

+        std::function<uint32_t(uint32_t idx)> expected);

+};

+

+void SwiftShaderVulkanBufferToBufferComputeTest::test(

+        const std::string& shader,

+        std::function<uint32_t(uint32_t idx)> input,

+        std::function<uint32_t(uint32_t idx)> expected)

+{

+    auto code = compileSpirv(shader.c_str());

+

     Driver driver;

     ASSERT_TRUE(driver.loadSwiftShader());

 

-    auto params = GetParam();

-

-    std::stringstream src;

-    src <<

-              "OpCapability Shader\n"

-              "OpMemoryModel Logical GLSL450\n"

-              "OpEntryPoint GLCompute %1 \"main\" %2\n"

-              "OpExecutionMode %1 LocalSize " <<

-                params.localSizeX << " " <<

-                params.localSizeY << " " <<

-                params.localSizeZ << "\n" <<

-              "OpDecorate %3 ArrayStride 4\n"

-              "OpMemberDecorate %4 0 Offset 0\n"

-              "OpDecorate %4 BufferBlock\n"

-              "OpDecorate %5 DescriptorSet 0\n"

-              "OpDecorate %5 Binding 1\n"

-              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

-              "OpDecorate %6 ArrayStride 4\n"

-              "OpMemberDecorate %7 0 Offset 0\n"

-              "OpDecorate %7 BufferBlock\n"

-              "OpDecorate %8 DescriptorSet 0\n"

-              "OpDecorate %8 Binding 0\n"

-         "%9 = OpTypeVoid\n"

-        "%10 = OpTypeFunction %9\n"

-        "%11 = OpTypeInt 32 1\n"

-         "%3 = OpTypeRuntimeArray %11\n"

-         "%4 = OpTypeStruct %3\n"

-        "%12 = OpTypePointer Uniform %4\n"

-         "%5 = OpVariable %12 Uniform\n"

-        "%13 = OpConstant %11 0\n"

-        "%14 = OpTypeInt 32 0\n"

-        "%15 = OpTypeVector %14 3\n"

-        "%16 = OpTypePointer Input %15\n"

-         "%2 = OpVariable %16 Input\n"

-        "%17 = OpConstant %14 0\n"

-        "%18 = OpTypePointer Input %14\n"

-         "%6 = OpTypeRuntimeArray %11\n"

-         "%7 = OpTypeStruct %6\n"

-        "%19 = OpTypePointer Uniform %7\n"

-         "%8 = OpVariable %19 Uniform\n"

-        "%20 = OpTypePointer Uniform %11\n"

-        "%21 = OpConstant %11 1\n"

-         "%1 = OpFunction %9 None %10\n"

-        "%22 = OpLabel\n"

-        "%23 = OpAccessChain %18 %2 %17\n"

-        "%24 = OpLoad %14 %23\n"

-        "%25 = OpAccessChain %20 %8 %13 %24\n"

-        "%26 = OpLoad %11 %25\n"

-        "%27 = OpAccessChain %20 %5 %13 %24\n"

-              "OpStore %27 %26\n"

-              "OpReturn\n"

-              "OpFunctionEnd\n";

-

-    auto code = compileSpirv(src.str().c_str());

-

     const VkInstanceCreateInfo createInfo = {

         VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType

         nullptr,                                 // pNext

@@ -245,48 +204,57 @@
     VK_ASSERT(Device::CreateComputeDevice(&driver, instance, &device));

     ASSERT_TRUE(device.IsValid());

 

-    constexpr int NUM_ELEMENTS = 256;

-

-    struct Buffers

-    {

-        uint32_t magic0;

-        uint32_t in[NUM_ELEMENTS];

-        uint32_t magic1;

-        uint32_t out[NUM_ELEMENTS];

-        uint32_t magic2;

-    };

-

-    constexpr uint32_t magic0 = 0x01234567;

-    constexpr uint32_t magic1 = 0x89abcdef;

-    constexpr uint32_t magic2 = 0xfedcba99;

+    // struct Buffers

+    // {

+    //     uint32_t magic0;

+    //     uint32_t in[NUM_ELEMENTS];

+    //     uint32_t magic1;

+    //     uint32_t out[NUM_ELEMENTS];

+    //     uint32_t magic2;

+    // };

+    static constexpr uint32_t magic0 = 0x01234567;

+    static constexpr uint32_t magic1 = 0x89abcdef;

+    static constexpr uint32_t magic2 = 0xfedcba99;

+    size_t numElements = GetParam().numElements;

+    size_t magic0Offset = 0;

+    size_t inOffset = 1 + magic0Offset;

+    size_t magic1Offset = numElements + inOffset;

+    size_t outOffset = 1 + magic1Offset;

+    size_t magic2Offset = numElements + outOffset;

+    size_t buffersTotalElements = 1 + magic2Offset;

+    size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;

 

     VkDeviceMemory memory;

-    VK_ASSERT(device.AllocateMemory(sizeof(Buffers),

+    VK_ASSERT(device.AllocateMemory(buffersSize,

             VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,

             &memory));

 

-    Buffers* buffers;

-    VK_ASSERT(device.MapMemory(memory, 0, sizeof(Buffers), 0, (void**)&buffers));

+    uint32_t* buffers;

+    VK_ASSERT(device.MapMemory(memory, 0, buffersSize, 0, (void**)&buffers));

 

-    memset(buffers, 0, sizeof(Buffers));

+    buffers[magic0Offset] = magic0;

+    buffers[magic1Offset] = magic1;

+    buffers[magic2Offset] = magic2;

 

-    buffers->magic0 = magic0;

-    buffers->magic1 = magic1;

-    buffers->magic2 = magic2;

-

-    for(int i = 0; i < NUM_ELEMENTS; i++)

+    for(size_t i = 0; i < numElements; i++)

     {

-        buffers->in[i] = (uint32_t)i;

+        buffers[inOffset + i] = input(i);

     }

 

     device.UnmapMemory(memory);

     buffers = nullptr;

 

     VkBuffer bufferIn;

-    VK_ASSERT(device.CreateStorageBuffer(memory, sizeof(Buffers::in), offsetof(Buffers, in), &bufferIn));

+    VK_ASSERT(device.CreateStorageBuffer(memory,

+            sizeof(uint32_t) * numElements,

+            sizeof(uint32_t) * inOffset,

+            &bufferIn));

 

     VkBuffer bufferOut;

-    VK_ASSERT(device.CreateStorageBuffer(memory, sizeof(Buffers::out), offsetof(Buffers, out), &bufferOut));

+    VK_ASSERT(device.CreateStorageBuffer(memory,

+            sizeof(uint32_t) * numElements,

+            sizeof(uint32_t) * outOffset,

+            &bufferOut));

 

     VkShaderModule shaderModule;

     VK_ASSERT(device.CreateShaderModule(code, &shaderModule));

@@ -352,24 +320,206 @@
     driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,

                                    0, nullptr);

 

-    driver.vkCmdDispatch(commandBuffer, NUM_ELEMENTS / params.localSizeX, 1, 1);

+    driver.vkCmdDispatch(commandBuffer, numElements / GetParam().localSizeX, 1, 1);

 

     VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));

 

     VK_ASSERT(device.QueueSubmitAndWait(commandBuffer));

 

-    VK_ASSERT(device.MapMemory(memory, 0, sizeof(Buffers), 0, (void**)&buffers));

+    VK_ASSERT(device.MapMemory(memory, 0, buffersSize, 0, (void**)&buffers));

 

-    for (int i = 0; i < NUM_ELEMENTS; ++i)

+    for (size_t i = 0; i < numElements; ++i)

     {

-        EXPECT_EQ(buffers->in[i], buffers->out[i]) << "Unexpected output at " << i;

+        auto got = buffers[i + outOffset];

+        EXPECT_EQ(expected(i), got) << "Unexpected output at " << i;

     }

 

     // Check for writes outside of bounds.

-    EXPECT_EQ(buffers->magic0, magic0);

-    EXPECT_EQ(buffers->magic1, magic1);

-    EXPECT_EQ(buffers->magic2, magic2);

+    EXPECT_EQ(buffers[magic0Offset], magic0);

+    EXPECT_EQ(buffers[magic1Offset], magic1);

+    EXPECT_EQ(buffers[magic2Offset], magic2);

 

     device.UnmapMemory(memory);

     buffers = nullptr;

 }

+

+INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(

+    ComputeParams{512, 1, 1, 1},

+    ComputeParams{512, 2, 1, 1},

+    ComputeParams{512, 4, 1, 1},

+    ComputeParams{512, 8, 1, 1},

+    ComputeParams{512, 16, 1, 1},

+    ComputeParams{512, 32, 1, 1},

+

+    // Non-multiple of SIMD-lane.

+    ComputeParams{3, 1, 1, 1},

+    ComputeParams{2, 1, 1, 1}

+));

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

+        "%12 = OpConstant %9 0\n"               // int32(0)

+        "%13 = OpConstant %10 0\n"              // uint32(0)

+        "%14 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%15 = OpTypePointer Input %14\n"       // vec4<int32>*

+         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

+        "%16 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

+        "%17 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%18 = OpLabel\n"

+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%22 = OpLoad %9 %21\n"                 // out.arr[gl_GlobalInvocationId.x]

+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

+              "OpStore %23 %22\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

+        "%12 = OpConstant %9 0\n"               // int32(0)

+        "%13 = OpConstant %10 0\n"              // uint32(0)

+        "%14 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%15 = OpTypePointer Input %14\n"       // vec4<int32>*

+         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

+        "%16 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

+        "%17 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%18 = OpLabel\n"

+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %22 = in value

+              "OpBranch %24\n"

+        "%24 = OpLabel\n"

+              "OpBranch %25\n"

+        "%25 = OpLabel\n"

+              "OpBranch %26\n"

+        "%26 = OpLabel\n"

+    // %22 = out value

+    // End of branch logic

+              "OpStore %23 %22\n"

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in

+        "%12 = OpConstant %9 0\n"               // int32(0)

+        "%13 = OpConstant %10 0\n"              // uint32(0)

+        "%14 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%15 = OpTypePointer Input %14\n"       // vec4<int32>*

+         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId

+        "%16 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out

+        "%17 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%18 = OpLabel\n"

+        "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x

+        "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x

+        "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %22 = in value

+              "OpBranch %24\n"

+        "%24 = OpLabel\n"

+        "%25 = OpIAdd %9 %22 %22\n"             // %25 = in*2

+              "OpBranch %26\n"

+        "%26 = OpLabel\n"

+              "OpBranch %27\n"

+        "%27 = OpLabel\n"

+    // %25 = out value

+    // End of branch logic

+              "OpStore %23 %25\n"               // use SSA value from previous block

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });

+}
\ No newline at end of file