SpirvShader: Correctly handle phi values in the loop merge

Yet another horrible phi/loop edge case (pun intended).

Added test.

Bug: b/133440380
Bug: b/133481698
Change-Id: I327842fa2d4314bce938454da81f67f890cf9e12
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31845
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 2dcd7d4..3c48413 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -984,7 +984,7 @@
 		}
 
 		ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
-		AssignBlockIns();
+		AssignBlockFields();
 	}
 
 	void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
@@ -999,7 +999,7 @@
 		}
 	}
 
-	void SpirvShader::AssignBlockIns()
+	void SpirvShader::AssignBlockFields()
 	{
 		Block::Set reachable;
 		TraverseReachableBlocks(entryPointBlockId, reachable);
@@ -1007,6 +1007,7 @@
 		for (auto &it : blocks)
 		{
 			auto &blockId = it.first;
+			auto &block = it.second;
 			if (reachable.count(blockId) > 0)
 			{
 				for (auto &outId : it.second.outs)
@@ -1016,6 +1017,12 @@
 					auto &out = outIt->second;
 					out.ins.emplace(blockId);
 				}
+				if (block.kind == Block::Loop)
+				{
+					auto mergeIt = blocks.find(block.mergeBlock);
+					ASSERT_MSG(mergeIt != blocks.end(), "Loop block %d has a non-existent merge block %d", blockId.value(), block.mergeBlock.value());
+					mergeIt->second.isLoopMerge = true;
+				}
 			}
 		}
 	}
@@ -2083,7 +2090,9 @@
 	void SpirvShader::EmitLoop(EmitState *state) const
 	{
 		auto blockId = state->currentBlock;
-		auto block = getBlock(blockId);
+		auto &block = getBlock(blockId);
+		auto mergeBlockId = block.mergeBlock;
+		auto &mergeBlock = getBlock(mergeBlockId);
 
 		// Ensure all incoming non-back edge blocks have been generated.
 		auto depsDone = true;
@@ -2091,7 +2100,7 @@
 		{
 			if (state->visited.count(in) == 0)
 			{
-				if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
+				if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back edge
 				{
 					state->pending->emplace(in);
 					depsDone = false;
@@ -2115,7 +2124,7 @@
 		std::unordered_set<Block::ID> loopBlocks;
 		for (auto in : block.ins)
 		{
-			if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back-edge
+			if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back-edge
 			{
 				incomingBlocks.emplace(in);
 			}
@@ -2131,7 +2140,7 @@
 		{
 			if (insn.opcode() == spv::OpPhi)
 			{
-				StorePhi(insn, state, incomingBlocks);
+				StorePhi(blockId, insn, state, incomingBlocks);
 			}
 		}
 
@@ -2146,7 +2155,7 @@
 		// mergeActiveLaneMasks contains edge lane masks for the merge block.
 		// This is the union of all edge masks across all iterations of the loop.
 		std::unordered_map<Block::ID, SIMD::Int> mergeActiveLaneMasks;
-		for (auto in : getBlock(block.mergeBlock).ins)
+		for (auto in : getBlock(mergeBlockId).ins)
 		{
 			mergeActiveLaneMasks.emplace(in, SIMD::Int(0));
 		}
@@ -2179,7 +2188,7 @@
 		// don't emit the merge block yet.
 		for (auto out : block.outs)
 		{
-			EmitBlocks(out, state, block.mergeBlock);
+			EmitBlocks(out, state, mergeBlockId);
 		}
 
 		// Restore current block id after emitting loop blocks.
@@ -2189,16 +2198,16 @@
 		loopActiveLaneMask = SIMD::Int(0);
 		for (auto in : block.ins)
 		{
-			if (existsPath(blockId, in, block.mergeBlock))
+			if (existsPath(blockId, in, mergeBlockId))
 			{
 				loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
 			}
 		}
 
 		// Add active lanes to the merge lane mask.
-		for (auto in : getBlock(block.mergeBlock).ins)
+		for (auto in : getBlock(mergeBlockId).ins)
 		{
-			auto edge = Block::Edge{in, block.mergeBlock};
+			auto edge = Block::Edge{in, mergeBlockId};
 			auto it = state->edgeActiveLaneMasks.find(edge);
 			if (it != state->edgeActiveLaneMasks.end())
 			{
@@ -2211,7 +2220,39 @@
 		{
 			if (insn.opcode() == spv::OpPhi)
 			{
-				StorePhi(insn, state, loopBlocks);
+				StorePhi(blockId, insn, state, loopBlocks);
+			}
+		}
+
+		// Use the [loop -> merge] active lane masks to update the phi values in
+		// the merge block. We need to do this to handle divergent control flow
+		// in the loop.
+		//
+		// Consider the following:
+		//
+		//     int phi_source = 0;
+		//     for (uint i = 0; i < 4; i++)
+		//     {
+		//         phi_source = 0;
+		//         if (gl_GlobalInvocationID.x % 4 == i) // divergent control flow
+		//         {
+		//             phi_source = 42; // single lane assignment.
+		//             break; // activeLaneMask for [loop->merge] is active for a single lane.
+		//         }
+		//         // -- we are here --
+		//     }
+		//     // merge block
+		//     int phi = phi_source; // OpPhi
+		//
+		// In this example, with each iteration of the loop, phi_source will
+		// only have a single lane assigned. However by 'phi' value in the merge
+		// block needs to be assigned the union of all the per-lane assignments
+		// of phi_source when that lane exited the loop.
+		for (auto insn = mergeBlock.begin(); insn != mergeBlock.end(); insn++)
+		{
+			if (insn.opcode() == spv::OpPhi)
+			{
+				StorePhi(mergeBlockId, insn, state, mergeBlock.ins);
 			}
 		}
 
@@ -2222,10 +2263,10 @@
 
 		// Continue emitting from the merge block.
 		Nucleus::setInsertBlock(mergeBasicBlock);
-		state->pending->emplace(block.mergeBlock);
+		state->pending->emplace(mergeBlockId);
 		for (auto it : mergeActiveLaneMasks)
 		{
-			state->addActiveLaneMaskEdge(it.first, block.mergeBlock, it.second);
+			state->addActiveLaneMaskEdge(it.first, mergeBlockId, it.second);
 		}
 	}
 
@@ -4618,7 +4659,13 @@
 	SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const
 	{
 		auto currentBlock = getBlock(state->currentBlock);
-		StorePhi(insn, state, currentBlock.ins);
+		if (!currentBlock.isLoopMerge)
+		{
+			// If this is a loop merge block, then don't attempt to update the
+			// phi values from the ins. EmitLoop() has had to take special care
+			// of this phi in order to correctly deal with divergent lanes.
+			StorePhi(state->currentBlock, insn, state, currentBlock.ins);
+		}
 		LoadPhi(insn, state);
 		return EmitResult::Continue;
 	}
@@ -4641,13 +4688,12 @@
 		}
 	}
 
-	void SpirvShader::StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const
+	void SpirvShader::StorePhi(Block::ID currentBlock, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const
 	{
 		auto routine = state->routine;
 		auto typeId = Type::ID(insn.word(1));
 		auto type = getType(typeId);
 		auto objectId = Object::ID(insn.word(2));
-		auto currentBlock = getBlock(state->currentBlock);
 
 		auto storageIt = state->routine->phis.find(objectId);
 		ASSERT(storageIt != state->routine->phis.end());
@@ -4663,7 +4709,7 @@
 				continue;
 			}
 
-			auto mask = GetActiveLaneMaskEdge(state, blockId, state->currentBlock);
+			auto mask = GetActiveLaneMaskEdge(state, blockId, currentBlock);
 			auto in = GenericValue(this, routine, varId);
 
 			for (uint32_t i = 0; i < type.sizeInComponents; i++)
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index f7eaa1d..09cf15b 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -449,14 +449,14 @@
 				Loop, // OpLoopMerge + [OpBranchConditional | OpBranch]
 			};
 
-			Kind kind;
+			Kind kind = Simple;
 			InsnIterator mergeInstruction; // Structured control flow merge instruction.
 			InsnIterator branchInstruction; // Branch instruction.
 			ID mergeBlock; // Structured flow merge block.
 			ID continueTarget; // Loop continue block.
 			Set ins; // Blocks that branch into this block.
 			Set outs; // Blocks that this block branches to.
-
+			bool isLoopMerge = false;
 		private:
 			InsnIterator begin_;
 			InsnIterator end_;
@@ -743,8 +743,12 @@
 		// reachable.
 		void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
 
-		// Assigns Block::ins from Block::outs for every block.
-		void AssignBlockIns();
+		// AssignBlockFields() performs the following for all reachable blocks:
+		// * Assigns Block::ins with the identifiers of all blocks that contain
+		//   this block in their Block::outs.
+		// * Sets Block::isLoopMerge to true if the block is the merge of a
+		//   another loop block.
+		void AssignBlockFields();
 
 		// DeclareType creates a Type for the given OpTypeX instruction, storing
 		// it into the types map. It is called from the analysis pass (constructor).
@@ -974,7 +978,7 @@
 		// StorePhi updates the phi's alloca storage value using the incoming
 		// values from blocks that are both in the OpPhi instruction and in
 		// filter.
-		void StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const;
+		void StorePhi(Block::ID blockID, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const;
 
 		// Emits a rr::Fence for the given MemorySemanticsMask.
 		void Fence(spv::MemorySemanticsMask semantics) const;
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
index 8a986b1..50ab3b4 100644
--- a/tests/VulkanUnitTests/unittests.cpp
+++ b/tests/VulkanUnitTests/unittests.cpp
@@ -149,7 +149,7 @@
                 printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());

             }

         }

-        printf("\n\n---\n");

+        printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());

     }

 

     return spirv;

@@ -390,6 +390,20 @@
 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)

 {

     std::stringstream src;

+    // #version 450

+    // layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

+    // layout(binding = 0, std430) buffer InBuffer

+    // {

+    //     int Data[];

+    // } In;

+    // layout(binding = 1, std430) buffer OutBuffer

+    // {

+    //     int Data[];

+    // } Out;

+    // void main()

+    // {

+    //     Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];

+    // }

     src <<

               "OpCapability Shader\n"

               "OpMemoryModel Logical GLSL450\n"

@@ -1384,3 +1398,103 @@
 

     test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });

 }

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)

+{

+    // #version 450

+    // layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

+    // layout(binding = 0, std430) buffer InBuffer

+    // {

+    //     int Data[];

+    // } In;

+    // layout(binding = 1, std430) buffer OutBuffer

+    // {

+    //     int Data[];

+    // } Out;

+    // void main()

+    // {

+    //     int phi = 0;

+    //     uint lane = gl_GlobalInvocationID.x % 4;

+    //     for (uint i = 0; i < 4; i++)

+    //     {

+    //         if (lane == i)

+    //         {

+    //             phi = In.Data[gl_GlobalInvocationID.x];

+    //             break;

+    //         }

+    //     }

+    //     Out.Data[gl_GlobalInvocationID.x] = phi;

+    // }

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+         "%1 = OpExtInstImport \"GLSL.std.450\"\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %2 \"main\" %3\n"

+              "OpExecutionMode %2 LocalSize " <<

+                              GetParam().localSizeX << " " <<

+                              GetParam().localSizeY << " " <<

+                              GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %4 ArrayStride 4\n"

+              "OpMemberDecorate %5 0 Offset 0\n"

+              "OpDecorate %5 BufferBlock\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+              "OpDecorate %7 ArrayStride 4\n"

+              "OpMemberDecorate %8 0 Offset 0\n"

+              "OpDecorate %8 BufferBlock\n"

+              "OpDecorate %9 DescriptorSet 0\n"

+              "OpDecorate %9 Binding 1\n"

+        "%10 = OpTypeVoid\n"

+        "%11 = OpTypeFunction %10\n"

+        "%12 = OpTypeInt 32 1\n"

+        "%13 = OpConstant %12 0\n"

+        "%14 = OpTypeInt 32 0\n"

+        "%15 = OpTypeVector %14 3\n"

+        "%16 = OpTypePointer Input %15\n"

+         "%3 = OpVariable %16 Input\n"

+        "%17 = OpConstant %14 0\n"

+        "%18 = OpTypePointer Input %14\n"

+        "%19 = OpConstant %14 4\n"

+        "%20 = OpTypeBool\n"

+         "%4 = OpTypeRuntimeArray %12\n"

+         "%5 = OpTypeStruct %4\n"

+        "%21 = OpTypePointer Uniform %5\n"

+         "%6 = OpVariable %21 Uniform\n"

+        "%22 = OpTypePointer Uniform %12\n"

+        "%23 = OpConstant %12 1\n"

+         "%7 = OpTypeRuntimeArray %12\n"

+         "%8 = OpTypeStruct %7\n"

+        "%24 = OpTypePointer Uniform %8\n"

+         "%9 = OpVariable %24 Uniform\n"

+         "%2 = OpFunction %10 None %11\n"

+        "%25 = OpLabel\n"

+        "%26 = OpAccessChain %18 %3 %17\n"

+        "%27 = OpLoad %14 %26\n"

+        "%28 = OpUMod %14 %27 %19\n"

+              "OpBranch %29\n"

+        "%29 = OpLabel\n"

+        "%30 = OpPhi %14 %17 %25 %31 %32\n"

+        "%33 = OpULessThan %20 %30 %19\n"

+              "OpLoopMerge %34 %32 None\n"

+              "OpBranchConditional %33 %35 %34\n"

+        "%35 = OpLabel\n"

+        "%36 = OpIEqual %20 %28 %30\n"

+              "OpSelectionMerge %32 None\n"

+              "OpBranchConditional %36 %37 %32\n"

+        "%37 = OpLabel\n"

+        "%38 = OpAccessChain %22 %6 %13 %27\n"

+        "%39 = OpLoad %12 %38\n"

+              "OpBranch %34\n"

+        "%32 = OpLabel\n"

+        "%31 = OpIAdd %14 %30 %23\n"

+              "OpBranch %29\n"

+        "%34 = OpLabel\n"

+        "%40 = OpPhi %12 %13 %29 %39 %37\n" // %39: phi

+        "%41 = OpAccessChain %22 %9 %13 %27\n"

+              "OpStore %41 %40\n"

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });

+}