SpirvShader: Correctly handle phi values in the loop merge
Yet another horrible phi/loop edge case (pun intended).
Added test.
Bug: b/133440380
Bug: b/133481698
Change-Id: I327842fa2d4314bce938454da81f67f890cf9e12
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31845
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 2dcd7d4..3c48413 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -984,7 +984,7 @@
}
ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
- AssignBlockIns();
+ AssignBlockFields();
}
void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
@@ -999,7 +999,7 @@
}
}
- void SpirvShader::AssignBlockIns()
+ void SpirvShader::AssignBlockFields()
{
Block::Set reachable;
TraverseReachableBlocks(entryPointBlockId, reachable);
@@ -1007,6 +1007,7 @@
for (auto &it : blocks)
{
auto &blockId = it.first;
+ auto &block = it.second;
if (reachable.count(blockId) > 0)
{
for (auto &outId : it.second.outs)
@@ -1016,6 +1017,12 @@
auto &out = outIt->second;
out.ins.emplace(blockId);
}
+ if (block.kind == Block::Loop)
+ {
+ auto mergeIt = blocks.find(block.mergeBlock);
+ ASSERT_MSG(mergeIt != blocks.end(), "Loop block %d has a non-existent merge block %d", blockId.value(), block.mergeBlock.value());
+ mergeIt->second.isLoopMerge = true;
+ }
}
}
}
@@ -2083,7 +2090,9 @@
void SpirvShader::EmitLoop(EmitState *state) const
{
auto blockId = state->currentBlock;
- auto block = getBlock(blockId);
+ auto &block = getBlock(blockId);
+ auto mergeBlockId = block.mergeBlock;
+ auto &mergeBlock = getBlock(mergeBlockId);
// Ensure all incoming non-back edge blocks have been generated.
auto depsDone = true;
@@ -2091,7 +2100,7 @@
{
if (state->visited.count(in) == 0)
{
- if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
+ if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back edge
{
state->pending->emplace(in);
depsDone = false;
@@ -2115,7 +2124,7 @@
std::unordered_set<Block::ID> loopBlocks;
for (auto in : block.ins)
{
- if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back-edge
+ if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back-edge
{
incomingBlocks.emplace(in);
}
@@ -2131,7 +2140,7 @@
{
if (insn.opcode() == spv::OpPhi)
{
- StorePhi(insn, state, incomingBlocks);
+ StorePhi(blockId, insn, state, incomingBlocks);
}
}
@@ -2146,7 +2155,7 @@
// mergeActiveLaneMasks contains edge lane masks for the merge block.
// This is the union of all edge masks across all iterations of the loop.
std::unordered_map<Block::ID, SIMD::Int> mergeActiveLaneMasks;
- for (auto in : getBlock(block.mergeBlock).ins)
+ for (auto in : getBlock(mergeBlockId).ins)
{
mergeActiveLaneMasks.emplace(in, SIMD::Int(0));
}
@@ -2179,7 +2188,7 @@
// don't emit the merge block yet.
for (auto out : block.outs)
{
- EmitBlocks(out, state, block.mergeBlock);
+ EmitBlocks(out, state, mergeBlockId);
}
// Restore current block id after emitting loop blocks.
@@ -2189,16 +2198,16 @@
loopActiveLaneMask = SIMD::Int(0);
for (auto in : block.ins)
{
- if (existsPath(blockId, in, block.mergeBlock))
+ if (existsPath(blockId, in, mergeBlockId))
{
loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
}
}
// Add active lanes to the merge lane mask.
- for (auto in : getBlock(block.mergeBlock).ins)
+ for (auto in : getBlock(mergeBlockId).ins)
{
- auto edge = Block::Edge{in, block.mergeBlock};
+ auto edge = Block::Edge{in, mergeBlockId};
auto it = state->edgeActiveLaneMasks.find(edge);
if (it != state->edgeActiveLaneMasks.end())
{
@@ -2211,7 +2220,39 @@
{
if (insn.opcode() == spv::OpPhi)
{
- StorePhi(insn, state, loopBlocks);
+ StorePhi(blockId, insn, state, loopBlocks);
+ }
+ }
+
+ // Use the [loop -> merge] active lane masks to update the phi values in
+ // the merge block. We need to do this to handle divergent control flow
+ // in the loop.
+ //
+ // Consider the following:
+ //
+ // int phi_source = 0;
+ // for (uint i = 0; i < 4; i++)
+ // {
+ // phi_source = 0;
+ // if (gl_GlobalInvocationID.x % 4 == i) // divergent control flow
+ // {
+ // phi_source = 42; // single lane assignment.
+ // break; // activeLaneMask for [loop->merge] is active for a single lane.
+ // }
+ // // -- we are here --
+ // }
+ // // merge block
+ // int phi = phi_source; // OpPhi
+ //
+ // In this example, with each iteration of the loop, phi_source will
+ // only have a single lane assigned. However by 'phi' value in the merge
+ // block needs to be assigned the union of all the per-lane assignments
+ // of phi_source when that lane exited the loop.
+ for (auto insn = mergeBlock.begin(); insn != mergeBlock.end(); insn++)
+ {
+ if (insn.opcode() == spv::OpPhi)
+ {
+ StorePhi(mergeBlockId, insn, state, mergeBlock.ins);
}
}
@@ -2222,10 +2263,10 @@
// Continue emitting from the merge block.
Nucleus::setInsertBlock(mergeBasicBlock);
- state->pending->emplace(block.mergeBlock);
+ state->pending->emplace(mergeBlockId);
for (auto it : mergeActiveLaneMasks)
{
- state->addActiveLaneMaskEdge(it.first, block.mergeBlock, it.second);
+ state->addActiveLaneMaskEdge(it.first, mergeBlockId, it.second);
}
}
@@ -4618,7 +4659,13 @@
SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const
{
auto currentBlock = getBlock(state->currentBlock);
- StorePhi(insn, state, currentBlock.ins);
+ if (!currentBlock.isLoopMerge)
+ {
+ // If this is a loop merge block, then don't attempt to update the
+ // phi values from the ins. EmitLoop() has had to take special care
+ // of this phi in order to correctly deal with divergent lanes.
+ StorePhi(state->currentBlock, insn, state, currentBlock.ins);
+ }
LoadPhi(insn, state);
return EmitResult::Continue;
}
@@ -4641,13 +4688,12 @@
}
}
- void SpirvShader::StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const
+ void SpirvShader::StorePhi(Block::ID currentBlock, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const
{
auto routine = state->routine;
auto typeId = Type::ID(insn.word(1));
auto type = getType(typeId);
auto objectId = Object::ID(insn.word(2));
- auto currentBlock = getBlock(state->currentBlock);
auto storageIt = state->routine->phis.find(objectId);
ASSERT(storageIt != state->routine->phis.end());
@@ -4663,7 +4709,7 @@
continue;
}
- auto mask = GetActiveLaneMaskEdge(state, blockId, state->currentBlock);
+ auto mask = GetActiveLaneMaskEdge(state, blockId, currentBlock);
auto in = GenericValue(this, routine, varId);
for (uint32_t i = 0; i < type.sizeInComponents; i++)
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index f7eaa1d..09cf15b 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -449,14 +449,14 @@
Loop, // OpLoopMerge + [OpBranchConditional | OpBranch]
};
- Kind kind;
+ Kind kind = Simple;
InsnIterator mergeInstruction; // Structured control flow merge instruction.
InsnIterator branchInstruction; // Branch instruction.
ID mergeBlock; // Structured flow merge block.
ID continueTarget; // Loop continue block.
Set ins; // Blocks that branch into this block.
Set outs; // Blocks that this block branches to.
-
+ bool isLoopMerge = false;
private:
InsnIterator begin_;
InsnIterator end_;
@@ -743,8 +743,12 @@
// reachable.
void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
- // Assigns Block::ins from Block::outs for every block.
- void AssignBlockIns();
+ // AssignBlockFields() performs the following for all reachable blocks:
+ // * Assigns Block::ins with the identifiers of all blocks that contain
+ // this block in their Block::outs.
+ // * Sets Block::isLoopMerge to true if the block is the merge of a
+ // another loop block.
+ void AssignBlockFields();
// DeclareType creates a Type for the given OpTypeX instruction, storing
// it into the types map. It is called from the analysis pass (constructor).
@@ -974,7 +978,7 @@
// StorePhi updates the phi's alloca storage value using the incoming
// values from blocks that are both in the OpPhi instruction and in
// filter.
- void StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const;
+ void StorePhi(Block::ID blockID, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const;
// Emits a rr::Fence for the given MemorySemanticsMask.
void Fence(spv::MemorySemanticsMask semantics) const;
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
index 8a986b1..50ab3b4 100644
--- a/tests/VulkanUnitTests/unittests.cpp
+++ b/tests/VulkanUnitTests/unittests.cpp
@@ -149,7 +149,7 @@
printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());
}
}
- printf("\n\n---\n");
+ printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());
}
return spirv;
@@ -390,6 +390,20 @@
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)
{
std::stringstream src;
+ // #version 450
+ // layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+ // layout(binding = 0, std430) buffer InBuffer
+ // {
+ // int Data[];
+ // } In;
+ // layout(binding = 1, std430) buffer OutBuffer
+ // {
+ // int Data[];
+ // } Out;
+ // void main()
+ // {
+ // Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];
+ // }
src <<
"OpCapability Shader\n"
"OpMemoryModel Logical GLSL450\n"
@@ -1384,3 +1398,103 @@
test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)
+{
+ // #version 450
+ // layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+ // layout(binding = 0, std430) buffer InBuffer
+ // {
+ // int Data[];
+ // } In;
+ // layout(binding = 1, std430) buffer OutBuffer
+ // {
+ // int Data[];
+ // } Out;
+ // void main()
+ // {
+ // int phi = 0;
+ // uint lane = gl_GlobalInvocationID.x % 4;
+ // for (uint i = 0; i < 4; i++)
+ // {
+ // if (lane == i)
+ // {
+ // phi = In.Data[gl_GlobalInvocationID.x];
+ // break;
+ // }
+ // }
+ // Out.Data[gl_GlobalInvocationID.x] = phi;
+ // }
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "%1 = OpExtInstImport \"GLSL.std.450\"\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %2 \"main\" %3\n"
+ "OpExecutionMode %2 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %4 ArrayStride 4\n"
+ "OpMemberDecorate %5 0 Offset 0\n"
+ "OpDecorate %5 BufferBlock\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "OpDecorate %7 ArrayStride 4\n"
+ "OpMemberDecorate %8 0 Offset 0\n"
+ "OpDecorate %8 BufferBlock\n"
+ "OpDecorate %9 DescriptorSet 0\n"
+ "OpDecorate %9 Binding 1\n"
+ "%10 = OpTypeVoid\n"
+ "%11 = OpTypeFunction %10\n"
+ "%12 = OpTypeInt 32 1\n"
+ "%13 = OpConstant %12 0\n"
+ "%14 = OpTypeInt 32 0\n"
+ "%15 = OpTypeVector %14 3\n"
+ "%16 = OpTypePointer Input %15\n"
+ "%3 = OpVariable %16 Input\n"
+ "%17 = OpConstant %14 0\n"
+ "%18 = OpTypePointer Input %14\n"
+ "%19 = OpConstant %14 4\n"
+ "%20 = OpTypeBool\n"
+ "%4 = OpTypeRuntimeArray %12\n"
+ "%5 = OpTypeStruct %4\n"
+ "%21 = OpTypePointer Uniform %5\n"
+ "%6 = OpVariable %21 Uniform\n"
+ "%22 = OpTypePointer Uniform %12\n"
+ "%23 = OpConstant %12 1\n"
+ "%7 = OpTypeRuntimeArray %12\n"
+ "%8 = OpTypeStruct %7\n"
+ "%24 = OpTypePointer Uniform %8\n"
+ "%9 = OpVariable %24 Uniform\n"
+ "%2 = OpFunction %10 None %11\n"
+ "%25 = OpLabel\n"
+ "%26 = OpAccessChain %18 %3 %17\n"
+ "%27 = OpLoad %14 %26\n"
+ "%28 = OpUMod %14 %27 %19\n"
+ "OpBranch %29\n"
+ "%29 = OpLabel\n"
+ "%30 = OpPhi %14 %17 %25 %31 %32\n"
+ "%33 = OpULessThan %20 %30 %19\n"
+ "OpLoopMerge %34 %32 None\n"
+ "OpBranchConditional %33 %35 %34\n"
+ "%35 = OpLabel\n"
+ "%36 = OpIEqual %20 %28 %30\n"
+ "OpSelectionMerge %32 None\n"
+ "OpBranchConditional %36 %37 %32\n"
+ "%37 = OpLabel\n"
+ "%38 = OpAccessChain %22 %6 %13 %27\n"
+ "%39 = OpLoad %12 %38\n"
+ "OpBranch %34\n"
+ "%32 = OpLabel\n"
+ "%31 = OpIAdd %14 %30 %23\n"
+ "OpBranch %29\n"
+ "%34 = OpLabel\n"
+ "%40 = OpPhi %12 %13 %29 %39 %37\n" // %39: phi
+ "%41 = OpAccessChain %22 %9 %13 %27\n"
+ "OpStore %41 %40\n"
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
+}