SpirvShader: Fixes for complex loops.

Emit loops in forward direction. While traversing backwards from the loop back-edge means you don't have to worry about flowing down the merge block, it can lead to blocks being generated in orders that can break the visit-once logic.

Don't consider flows passing through the return block as a back edge.

Strip unreachable blocks from ins - nothing should ever consider them.

Tests: dEQP-VK.glsl.loops.*
Bug: b/128527271
Change-Id: I497a06f5ce65d54b39294e4016b2df6d2c70487c
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/28188
Tested-by: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index a680d99..a7d044c 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -21,8 +21,6 @@
 #include "Vulkan/VkPipelineLayout.hpp"
 #include "Device/Config.hpp"
 
-#include <queue>
-
 #ifdef Bool
 #undef Bool // b/127920555
 #endif
@@ -469,37 +467,38 @@
 			}
 		}
 
-		MarkReachableBlocks(mainBlockId);
 		AssignBlockIns();
 	}
 
-	void SpirvShader::MarkReachableBlocks(Block::ID id)
+	void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
 	{
-		auto it = blocks.find(id);
-		ASSERT_MSG(it != blocks.end(), "Unknown block %d", id.value());
-		auto &block = it->second;
-		if (!block.reachable)
+		if (reachable.count(id) == 0)
 		{
-			block.reachable = true;
-			for (auto out : block.outs)
+			reachable.emplace(id);
+			for (auto out : getBlock(id).outs)
 			{
-				MarkReachableBlocks(out);
+				TraverseReachableBlocks(out, reachable);
 			}
 		}
 	}
 
 	void SpirvShader::AssignBlockIns()
 	{
+		Block::Set reachable;
+		TraverseReachableBlocks(mainBlockId, reachable);
+
 		for (auto &it : blocks)
 		{
 			auto &blockId = it.first;
-			auto &block = it.second;
-			for (auto &outId : block.outs)
+			if (reachable.count(blockId) > 0)
 			{
-				auto outIt = blocks.find(outId);
-				ASSERT_MSG(outIt != blocks.end(), "Block %d has a non-existent out %d", blockId.value(), outId.value());
-				auto &out = outIt->second;
-				out.ins.emplace(blockId);
+				for (auto &outId : it.second.outs)
+				{
+					auto outIt = blocks.find(outId);
+					ASSERT_MSG(outIt != blocks.end(), "Block %d has a non-existent out %d", blockId.value(), outId.value());
+					auto &out = outIt->second;
+					out.ins.emplace(blockId);
+				}
 			}
 		}
 	}
@@ -1183,72 +1182,50 @@
 			EmitInstruction(insn, &state);
 		}
 
-		// Emit all the blocks in BFS order, starting with the main block.
+		// Emit all the blocks starting from mainBlockId.
+		EmitBlocks(mainBlockId, &state);
+	}
+
+	void SpirvShader::EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore /* = 0 */) const
+	{
+		auto oldPending = state->pending;
+
 		std::queue<Block::ID> pending;
-		pending.push(mainBlockId);
+		state->pending = &pending;
+		pending.push(id);
 		while (pending.size() > 0)
 		{
 			auto id = pending.front();
 			pending.pop();
-			if (state.visited.count(id) == 0)
+
+			auto const &block = getBlock(id);
+			if (id == ignore)
 			{
-				EmitBlock(id, &state);
-				for (auto it : getBlock(id).outs)
-				{
-					pending.push(it);
-				}
+				continue;
+			}
+
+			state->currentBlock = id;
+
+			switch (block.kind)
+			{
+				case Block::Simple:
+				case Block::StructuredBranchConditional:
+				case Block::UnstructuredBranchConditional:
+				case Block::StructuredSwitch:
+				case Block::UnstructuredSwitch:
+					EmitNonLoop(state);
+					break;
+
+				case Block::Loop:
+					EmitLoop(state);
+					break;
+
+				default:
+					UNREACHABLE("Unexpected Block Kind: %d", int(block.kind));
 			}
 		}
-	}
 
-	void SpirvShader::EmitBlock(Block::ID id, EmitState *state) const
-	{
-		auto &block = getBlock(id);
-
-		if (!block.reachable)
-		{
-			return;
-		}
-
-		if (state->visited.count(id) > 0)
-		{
-			return; // Already processed this block.
-		}
-
-		state->visited.emplace(id);
-
-		switch (block.kind)
-		{
-			case Block::Simple:
-			case Block::StructuredBranchConditional:
-			case Block::UnstructuredBranchConditional:
-			case Block::StructuredSwitch:
-			case Block::UnstructuredSwitch:
-				if (id != mainBlockId)
-				{
-					// Emit all preceding blocks and set the activeLaneMask.
-					Intermediate activeLaneMask(1);
-					activeLaneMask.move(0, SIMD::Int(0));
-					for (auto in : block.ins)
-					{
-						EmitBlock(in, state);
-						auto inMask = GetActiveLaneMaskEdge(state, in, id);
-						activeLaneMask.replace(0, activeLaneMask.Int(0) | inMask);
-					}
-					state->setActiveLaneMask(activeLaneMask.Int(0));
-				}
-				state->currentBlock = id;
-				EmitInstructions(block.begin(), block.end(), state);
-				break;
-
-			case Block::Loop:
-				state->currentBlock = id;
-				EmitLoop(state);
-				break;
-
-			default:
-				UNREACHABLE("Unexpected Block Kind: %d", int(block.kind));
-		}
+		state->pending = oldPending;
 	}
 
 	void SpirvShader::EmitInstructions(InsnIterator begin, InsnIterator end, EmitState *state) const
@@ -1269,19 +1246,93 @@
 		}
 	}
 
+	void SpirvShader::EmitNonLoop(EmitState *state) const
+	{
+		auto blockId = state->currentBlock;
+		auto block = getBlock(blockId);
+
+		// Ensure all incoming blocks have been generated.
+		auto depsDone = true;
+		for (auto in : block.ins)
+		{
+			if (state->visited.count(in) == 0)
+			{
+				state->pending->emplace(in);
+				depsDone = false;
+			}
+		}
+
+		if (!depsDone)
+		{
+			// come back to this once the dependencies have been generated
+			state->pending->emplace(blockId);
+			return;
+		}
+
+		if (!state->visited.emplace(blockId).second)
+		{
+			return; // Already generated this block.
+		}
+
+		if (blockId != mainBlockId)
+		{
+			// Set the activeLaneMask.
+			Intermediate activeLaneMask(1);
+			activeLaneMask.move(0, SIMD::Int(0));
+			for (auto in : block.ins)
+			{
+				auto inMask = GetActiveLaneMaskEdge(state, in, blockId);
+				activeLaneMask.replace(0, activeLaneMask.Int(0) | inMask);
+			}
+			state->setActiveLaneMask(activeLaneMask.Int(0));
+		}
+
+		EmitInstructions(block.begin(), block.end(), state);
+
+		for (auto out : block.outs)
+		{
+			state->pending->emplace(out);
+		}
+	}
+
 	void SpirvShader::EmitLoop(EmitState *state) const
 	{
 		auto blockId = state->currentBlock;
 		auto block = getBlock(blockId);
 
+		// Ensure all incoming non-back edge blocks have been generated.
+		auto depsDone = true;
+		for (auto in : block.ins)
+		{
+			if (state->visited.count(in) == 0)
+			{
+				if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
+				{
+					state->pending->emplace(in);
+					depsDone = false;
+				}
+			}
+		}
+
+		if (!depsDone)
+		{
+			// come back to this once the dependencies have been generated
+			state->pending->emplace(blockId);
+			return;
+		}
+
+		if (!state->visited.emplace(blockId).second)
+		{
+			return; // Already emitted this loop.
+		}
+
 		// loopActiveLaneMask is the mask of lanes that are continuing to loop.
 		// This is initialized with the incoming active lane masks.
 		SIMD::Int loopActiveLaneMask = SIMD::Int(0);
 		for (auto in : block.ins)
 		{
-			if (!existsPath(blockId, in)) // if not a loop back edge
+			if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
 			{
-				EmitBlock(in, state);
 				loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
 			}
 		}
@@ -1323,7 +1374,7 @@
 				{
 					auto varId = Object::ID(insn.word(w + 0));
 					auto blockId = Block::ID(insn.word(w + 1));
-					if (existsPath(state->currentBlock, blockId))
+					if (existsPath(state->currentBlock, blockId, block.mergeBlock))
 					{
 						// This source is from a loop back-edge.
 						ASSERT(phi.continueValue == 0 || phi.continueValue == varId);
@@ -1377,14 +1428,21 @@
 			}
 		}
 
-		// Emit all the back-edge blocks and use their active lane masks to
-		// rebuild the loopActiveLaneMask.
+		// Emit all loop blocks, but don't emit the merge block yet.
+		for (auto out : block.outs)
+		{
+			if (existsPath(out, blockId, block.mergeBlock))
+			{
+				EmitBlocks(out, state, block.mergeBlock);
+			}
+		}
+
+		// Rebuild the loopActiveLaneMask from the loop back edges.
 		loopActiveLaneMask = SIMD::Int(0);
 		for (auto in : block.ins)
 		{
-			if (existsPath(blockId, in))
+			if (existsPath(blockId, in, block.mergeBlock))
 			{
-				EmitBlock(in, state);
 				loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
 			}
 		}
@@ -1408,9 +1466,9 @@
 		// otherwise jump to the merge block.
 		Nucleus::createCondBr(AnyTrue(loopActiveLaneMask).value, headerBasicBlock, mergeBasicBlock);
 
-		// Emit the merge block, and we're done.
+		// Continue emitting from the merge block.
 		Nucleus::setInsertBlock(mergeBasicBlock);
-		EmitBlock(block.mergeBlock, state);
+		state->pending->emplace(block.mergeBlock);
 	}
 
 	SpirvShader::EmitResult SpirvShader::EmitInstruction(InsnIterator insn, EmitState *state) const
@@ -3110,10 +3168,11 @@
 		}
 	}
 
-	bool SpirvShader::existsPath(Block::ID from, Block::ID to) const
+	bool SpirvShader::existsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const
 	{
 		// TODO: Optimize: This can be cached on the block.
 		Block::Set seen;
+		seen.emplace(notPassingThrough);
 
 		std::queue<Block::ID> pending;
 		pending.emplace(from);
@@ -3157,11 +3216,6 @@
 
 	RValue<SIMD::Int> SpirvShader::GetActiveLaneMaskEdge(EmitState *state, Block::ID from, Block::ID to) const
 	{
-		if (!getBlock(from).reachable)
-		{
-			return SIMD::Int(0);
-		}
-
 		auto edge = Block::Edge{from, to};
 		auto it = state->edgeActiveLaneMasks.find(edge);
 		ASSERT_MSG(it != state->edgeActiveLaneMasks.end(), "Could not find edge %d -> %d", from.value(), to.value());
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 968ca33..0cbbf59 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -34,6 +34,7 @@
 #include <cstdint>
 #include <type_traits>
 #include <memory>
+#include <queue>
 
 namespace vk
 {
@@ -310,7 +311,6 @@
 			ID continueTarget; // Loop continue block.
 			Set ins; // Blocks that branch into this block.
 			Set outs; // Blocks that this block branches to.
-			bool reachable = false;
 
 		private:
 			InsnIterator begin_;
@@ -481,9 +481,9 @@
 		HandleMap<Block> blocks;
 		Block::ID mainBlockId; // Block of the entry point function.
 
-		// Walks all reachable the blocks starting from id, and sets
-		// Block::reachable to true.
-		void MarkReachableBlocks(Block::ID id);
+		// Walks all reachable the blocks starting from id adding them to
+		// reachable.
+		void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
 
 		// Assigns Block::ins from Block::outs for every block.
 		void AssignBlockIns();
@@ -572,6 +572,7 @@
 			Block::ID currentBlock; // The current block being built.
 			Block::Set visited; // Blocks already built.
 			std::unordered_map<Block::Edge, RValue<SIMD::Int>, Block::Edge::Hash> edgeActiveLaneMasks;
+			std::queue<Block::ID> *pending;
 		};
 
 		// EmitResult is an enumerator of result values from the Emit functions.
@@ -582,17 +583,22 @@
 		};
 
 		// existsPath returns true if there's a direct or indirect flow from
-		// the 'from' block to the 'to' block.
-		bool existsPath(Block::ID from, Block::ID to) const;
+		// the 'from' block to the 'to' block that does not pass through
+		// notPassingThrough.
+		bool existsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const;
 
 		// Lookup the active lane mask for the edge from -> to.
 		// If from is unreachable, then a mask of all zeros is returned.
 		// Asserts if from is reachable and the edge does not exist.
 		RValue<SIMD::Int> GetActiveLaneMaskEdge(EmitState *state, Block::ID from, Block::ID to) const;
 
-		void EmitBlock(Block::ID id, EmitState *state) const;
-		void EmitInstructions(InsnIterator begin, InsnIterator end, EmitState *state) const;
+		// Emit all the unvisited blocks (except for ignore) in BFS order,
+		// starting with id.
+		void EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore = 0) const;
+		void EmitNonLoop(EmitState *state) const;
 		void EmitLoop(EmitState *state) const;
+
+		void EmitInstructions(InsnIterator begin, InsnIterator end, EmitState *state) const;
 		EmitResult EmitInstruction(InsnIterator insn, EmitState *state) const;
 
 		// Emit pass instructions: