SpirvShader: Add Function inner class.
Function holds all the blocks used by the function.
Moved various function-specific methods from SpirvShader to this new Function inner class.
This change is currently a pure-refactor (no change in behavior), but is required for function inlining in SpirvShader.
Bug: b/133213304
Change-Id: I50c7ecca8ce518d8df054c2461410c0acd4e1f52
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/33352
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 74c0a22..e7845c3 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -475,7 +475,7 @@
// Simplifying assumptions (to be satisfied by earlier transformations)
// - The only input/output OpVariables present are those used by the entrypoint
- Object::ID entryPointFunctionId;
+ Function::ID currentFunction;
Block::ID currentBlock;
InsnIterator blockStart;
@@ -488,13 +488,13 @@
case spv::OpEntryPoint:
{
auto executionModel = spv::ExecutionModel(insn.word(1));
- auto id = Object::ID(insn.word(2));
+ auto id = Function::ID(insn.word(2));
auto name = insn.string(3);
auto stage = executionModelToStage(executionModel);
if (stage == pipelineStage && strcmp(name, entryPointName) == 0)
{
- ASSERT_MSG(entryPointFunctionId == 0, "Duplicate entry point with name '%s' and stage %d", name, int(stage));
- entryPointFunctionId = id;
+ ASSERT_MSG(entryPoint == 0, "Duplicate entry point with name '%s' and stage %d", name, int(stage));
+ entryPoint = id;
}
break;
}
@@ -607,8 +607,10 @@
case spv::OpUnreachable:
{
ASSERT(currentBlock.value() != 0);
+ ASSERT(currentFunction.value() != 0);
+
auto blockEnd = insn; blockEnd++;
- blocks[currentBlock] = Block(blockStart, blockEnd);
+ functions[currentFunction].blocks[currentBlock] = Block(blockStart, blockEnd);
currentBlock = Block::ID(0);
if (opcode == spv::OpKill)
@@ -791,39 +793,35 @@
case spv::OpFunction:
{
- auto functionId = Object::ID(insn.word(2));
- if (functionId == entryPointFunctionId)
+ auto functionId = Function::ID(insn.word(2));
+ ASSERT_MSG(currentFunction == 0, "Functions %d and %d overlap", currentFunction.value(), functionId.value());
+ currentFunction = functionId;
+ auto &function = functions[functionId];
+ function.result = Type::ID(insn.word(1));
+ function.type = Type::ID(insn.word(4));
+ // Scan forward to find the function's label.
+ for (auto it = insn; it != end() && function.entry == 0; it++)
{
- // Scan forward to find the function's label.
- for (auto it = insn; it != end() && entryPointBlockId == 0; it++)
+ switch (it.opcode())
{
- switch (it.opcode())
- {
- case spv::OpFunction:
- case spv::OpFunctionParameter:
- break;
- case spv::OpLabel:
- entryPointBlockId = Block::ID(it.word(1));
- break;
- default:
- WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
- }
+ case spv::OpFunction:
+ case spv::OpFunctionParameter:
+ break;
+ case spv::OpLabel:
+ function.entry = Block::ID(it.word(1));
+ break;
+ default:
+ WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
}
}
- else
- {
- // All non-entry point functions should be inlined into an
- // entry point function.
- // This isn't the target entry point, so must be another
- // entry point that we are not interested in. Just skip it.
- for (; insn != end() && insn.opcode() != spv::OpFunctionEnd; insn++) {}
- }
-
+ ASSERT_MSG(function.entry != 0, "Function<%d> has no label", currentFunction.value());
break;
}
+
case spv::OpFunctionEnd:
- // Due to preprocessing, the entrypoint and its function provide no value.
+ currentFunction = 0;
break;
+
case spv::OpExtInstImport:
{
// We will only support the GLSL 450 extended instruction set, so no point in tracking the ID we assign it.
@@ -1062,47 +1060,10 @@
}
}
- ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", entryPointName);
- AssignBlockFields();
- }
-
- void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
- {
- if (reachable.count(id) == 0)
+ ASSERT_MSG(entryPoint != 0, "Entry point '%s' not found", entryPointName);
+ for (auto &it : functions)
{
- reachable.emplace(id);
- for (auto out : getBlock(id).outs)
- {
- TraverseReachableBlocks(out, reachable);
- }
- }
- }
-
- void SpirvShader::AssignBlockFields()
- {
- Block::Set reachable;
- TraverseReachableBlocks(entryPointBlockId, reachable);
-
- for (auto &it : blocks)
- {
- auto &blockId = it.first;
- auto &block = it.second;
- if (reachable.count(blockId) > 0)
- {
- for (auto &outId : it.second.outs)
- {
- auto outIt = blocks.find(outId);
- ASSERT_MSG(outIt != blocks.end(), "Block %d has a non-existent out %d", blockId.value(), outId.value());
- auto &out = outIt->second;
- out.ins.emplace(blockId);
- }
- if (block.kind == Block::Loop)
- {
- auto mergeIt = blocks.find(block.mergeBlock);
- ASSERT_MSG(mergeIt != blocks.end(), "Loop block %d has a non-existent merge block %d", blockId.value(), block.mergeBlock.value());
- mergeIt->second.isLoopMerge = true;
- }
- }
+ it.second.AssignBlockFields();
}
}
@@ -2043,7 +2004,7 @@
void SpirvShader::emit(SpirvRoutine *routine, RValue<SIMD::Int> const &activeLaneMask, const vk::DescriptorSet::Bindings &descriptorSets) const
{
- EmitState state(routine, activeLaneMask, descriptorSets, robustBufferAccess);
+ EmitState state(routine, entryPoint, activeLaneMask, descriptorSets, robustBufferAccess);
// Emit everything up to the first label
// TODO: Separate out dispatch of block from non-block instructions?
@@ -2056,13 +2017,14 @@
EmitInstruction(insn, &state);
}
- // Emit all the blocks starting from entryPointBlockId.
- EmitBlocks(entryPointBlockId, &state);
+ // Emit all the blocks starting from entryPoint.
+ EmitBlocks(getFunction(entryPoint).entry, &state);
}
void SpirvShader::EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore /* = 0 */) const
{
auto oldPending = state->pending;
+ auto &function = getFunction(state->function);
std::deque<Block::ID> pending;
state->pending = &pending;
@@ -2071,7 +2033,7 @@
{
auto id = pending.front();
- auto const &block = getBlock(id);
+ auto const &block = function.getBlock(id);
if (id == ignore)
{
pending.pop_front();
@@ -2080,7 +2042,7 @@
// Ensure all dependency blocks have been generated.
auto depsDone = true;
- ForeachBlockDependency(id, [&](Block::ID dep)
+ function.ForeachBlockDependency(id, [&](Block::ID dep)
{
if (state->visited.count(dep) == 0)
{
@@ -2138,30 +2100,18 @@
}
}
- void SpirvShader::ForeachBlockDependency(Block::ID blockId, std::function<void(Block::ID)> f) const
- {
- auto block = getBlock(blockId);
- for (auto dep : block.ins)
- {
- if (block.kind != Block::Loop || // if not a loop...
- !existsPath(blockId, dep, block.mergeBlock)) // or a loop and not a loop back edge
- {
- f(dep);
- }
- }
- }
-
void SpirvShader::EmitNonLoop(EmitState *state) const
{
+ auto &function = getFunction(state->function);
auto blockId = state->currentBlock;
- auto block = getBlock(blockId);
+ auto block = function.getBlock(blockId);
if (!state->visited.emplace(blockId).second)
{
return; // Already generated this block.
}
- if (blockId != entryPointBlockId)
+ if (blockId != function.entry)
{
// Set the activeLaneMask.
SIMD::Int activeLaneMask(0);
@@ -2186,10 +2136,11 @@
void SpirvShader::EmitLoop(EmitState *state) const
{
+ auto &function = getFunction(state->function);
auto blockId = state->currentBlock;
- auto &block = getBlock(blockId);
+ auto &block = function.getBlock(blockId);
auto mergeBlockId = block.mergeBlock;
- auto &mergeBlock = getBlock(mergeBlockId);
+ auto &mergeBlock = function.getBlock(mergeBlockId);
if (!state->visited.emplace(blockId).second)
{
@@ -2200,7 +2151,7 @@
std::unordered_set<Block::ID> loopBlocks;
for (auto in : block.ins)
{
- if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back-edge
+ if (!function.ExistsPath(blockId, in, mergeBlockId)) // if not a loop back-edge
{
incomingBlocks.emplace(in);
}
@@ -2231,7 +2182,7 @@
// mergeActiveLaneMasks contains edge lane masks for the merge block.
// This is the union of all edge masks across all iterations of the loop.
std::unordered_map<Block::ID, SIMD::Int> mergeActiveLaneMasks;
- for (auto in : getBlock(mergeBlockId).ins)
+ for (auto in : function.getBlock(mergeBlockId).ins)
{
mergeActiveLaneMasks.emplace(in, SIMD::Int(0));
}
@@ -2274,14 +2225,14 @@
loopActiveLaneMask = SIMD::Int(0);
for (auto in : block.ins)
{
- if (existsPath(blockId, in, mergeBlockId))
+ if (function.ExistsPath(blockId, in, mergeBlockId))
{
loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
}
}
// Add active lanes to the merge lane mask.
- for (auto in : getBlock(mergeBlockId).ins)
+ for (auto in : function.getBlock(mergeBlockId).ins)
{
auto edge = Block::Edge{in, mergeBlockId};
auto it = state->edgeActiveLaneMasks.find(edge);
@@ -4638,7 +4589,8 @@
SpirvShader::EmitResult SpirvShader::EmitBranchConditional(InsnIterator insn, EmitState *state) const
{
- auto block = getBlock(state->currentBlock);
+ auto &function = getFunction(state->function);
+ auto block = function.getBlock(state->currentBlock);
ASSERT(block.branchInstruction == insn);
auto condId = Object::ID(block.branchInstruction.word(1));
@@ -4658,7 +4610,8 @@
SpirvShader::EmitResult SpirvShader::EmitSwitch(InsnIterator insn, EmitState *state) const
{
- auto block = getBlock(state->currentBlock);
+ auto &function = getFunction(state->function);
+ auto block = function.getBlock(state->currentBlock);
ASSERT(block.branchInstruction == insn);
auto selId = Object::ID(block.branchInstruction.word(1));
@@ -4712,7 +4665,8 @@
SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const
{
- auto currentBlock = getBlock(state->currentBlock);
+ auto &function = getFunction(state->function);
+ auto currentBlock = function.getBlock(state->currentBlock);
if (!currentBlock.isLoopMerge)
{
// If this is a loop merge block, then don't attempt to update the
@@ -6356,7 +6310,61 @@
}
}
- bool SpirvShader::existsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const
+
+ void SpirvShader::Function::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
+ {
+ if (reachable.count(id) == 0)
+ {
+ reachable.emplace(id);
+ for (auto out : getBlock(id).outs)
+ {
+ TraverseReachableBlocks(out, reachable);
+ }
+ }
+ }
+
+ void SpirvShader::Function::AssignBlockFields()
+ {
+ Block::Set reachable;
+ TraverseReachableBlocks(entry, reachable);
+
+ for (auto &it : blocks)
+ {
+ auto &blockId = it.first;
+ auto &block = it.second;
+ if (reachable.count(blockId) > 0)
+ {
+ for (auto &outId : it.second.outs)
+ {
+ auto outIt = blocks.find(outId);
+ ASSERT_MSG(outIt != blocks.end(), "Block %d has a non-existent out %d", blockId.value(), outId.value());
+ auto &out = outIt->second;
+ out.ins.emplace(blockId);
+ }
+ if (block.kind == Block::Loop)
+ {
+ auto mergeIt = blocks.find(block.mergeBlock);
+ ASSERT_MSG(mergeIt != blocks.end(), "Loop block %d has a non-existent merge block %d", blockId.value(), block.mergeBlock.value());
+ mergeIt->second.isLoopMerge = true;
+ }
+ }
+ }
+ }
+
+ void SpirvShader::Function::ForeachBlockDependency(Block::ID blockId, std::function<void(Block::ID)> f) const
+ {
+ auto block = getBlock(blockId);
+ for (auto dep : block.ins)
+ {
+ if (block.kind != Block::Loop || // if not a loop...
+ !ExistsPath(blockId, dep, block.mergeBlock)) // or a loop and not a loop back edge
+ {
+ f(dep);
+ }
+ }
+ }
+
+ bool SpirvShader::Function::ExistsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const
{
// TODO: Optimize: This can be cached on the block.
Block::Set seen;
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 995419e..d1c627b 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -541,6 +541,45 @@
InsnIterator end_;
};
+ class Function
+ {
+ public:
+ using ID = SpirvID<Function>;
+
+ // Walks all reachable the blocks starting from id adding them to
+ // reachable.
+ void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
+
+ // AssignBlockFields() performs the following for all reachable blocks:
+ // * Assigns Block::ins with the identifiers of all blocks that contain
+ // this block in their Block::outs.
+ // * Sets Block::isLoopMerge to true if the block is the merge of a
+ // another loop block.
+ void AssignBlockFields();
+
+ // ForeachBlockDependency calls f with each dependency of the given
+ // block. A dependency is an incoming block that is not a loop-back
+ // edge.
+ void ForeachBlockDependency(Block::ID blockId, std::function<void(Block::ID)> f) const;
+
+ // ExistsPath returns true if there's a direct or indirect flow from
+ // the 'from' block to the 'to' block that does not pass through
+ // notPassingThrough.
+ bool ExistsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const;
+
+ Block const &getBlock(Block::ID id) const
+ {
+ auto it = blocks.find(id);
+ ASSERT_MSG(it != blocks.end(), "Unknown block %d", id.value());
+ return it->second;
+ }
+
+ Block::ID entry; // function entry point block.
+ HandleMap<Block> blocks; // blocks belonging to this function.
+ Type::ID type; // type of the function.
+ Type::ID result; // return type.
+ };
+
struct TypeOrObject {}; // Dummy struct to represent a Type or Object.
// TypeOrObjectID is an identifier that represents a Type or an Object,
@@ -623,7 +662,7 @@
// shader entry point represented by this object.
uint64_t getSerialID() const
{
- return ((uint64_t)entryPointBlockId.value() << 32) | codeSerialID;
+ return ((uint64_t)entryPoint.value() << 32) | codeSerialID;
}
SpirvShader(uint32_t codeSerialID,
@@ -816,10 +855,10 @@
return it->second;
}
- Block const &getBlock(Block::ID id) const
+ Function const &getFunction(Function::ID id) const
{
- auto it = blocks.find(id);
- ASSERT_MSG(it != blocks.end(), "Unknown block %d", id.value());
+ auto it = functions.find(id);
+ ASSERT_MSG(it != functions.end(), "Unknown function %d", id.value());
return it->second;
}
@@ -828,22 +867,11 @@
Modes modes;
HandleMap<Type> types;
HandleMap<Object> defs;
- HandleMap<Block> blocks;
- Block::ID entryPointBlockId; // Block of the entry point function.
+ HandleMap<Function> functions;
+ Function::ID entryPoint;
const bool robustBufferAccess = true;
- // Walks all reachable the blocks starting from id adding them to
- // reachable.
- void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
-
- // AssignBlockFields() performs the following for all reachable blocks:
- // * Assigns Block::ins with the identifiers of all blocks that contain
- // this block in their Block::outs.
- // * Sets Block::isLoopMerge to true if the block is the merge of a
- // another loop block.
- void AssignBlockFields();
-
// DeclareType creates a Type for the given OpTypeX instruction, storing
// it into the types map. It is called from the analysis pass (constructor).
void DeclareType(InsnIterator insn);
@@ -923,8 +951,13 @@
class EmitState
{
public:
- EmitState(SpirvRoutine *routine, RValue<SIMD::Int> activeLaneMask, const vk::DescriptorSet::Bindings &descriptorSets, bool robustBufferAccess)
+ EmitState(SpirvRoutine *routine,
+ Function::ID function,
+ RValue<SIMD::Int> activeLaneMask,
+ const vk::DescriptorSet::Bindings &descriptorSets,
+ bool robustBufferAccess)
: routine(routine),
+ function(function),
activeLaneMaskValue(activeLaneMask.value),
descriptorSets(descriptorSets),
robust(robustBufferAccess)
@@ -954,6 +987,7 @@
void addActiveLaneMaskEdge(Block::ID from, Block::ID to, RValue<SIMD::Int> mask);
SpirvRoutine *routine = nullptr; // The current routine being built.
+ Function::ID function; // The current function being built.
rr::Value *activeLaneMaskValue = nullptr; // The current active lane mask.
Block::ID currentBlock; // The current block being built.
Block::Set visited; // Blocks already built.
@@ -1055,21 +1089,11 @@
// Returns the *component* offset in the literal for the given access chain.
uint32_t WalkLiteralAccessChain(Type::ID id, uint32_t numIndexes, uint32_t const *indexes) const;
- // existsPath returns true if there's a direct or indirect flow from
- // the 'from' block to the 'to' block that does not pass through
- // notPassingThrough.
- bool existsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const;
-
// Lookup the active lane mask for the edge from -> to.
// If from is unreachable, then a mask of all zeros is returned.
// Asserts if from is reachable and the edge does not exist.
RValue<SIMD::Int> GetActiveLaneMaskEdge(EmitState *state, Block::ID from, Block::ID to) const;
- // ForeachBlockDependency calls f with each dependency of the given
- // block. A dependency is an incoming block that is not a loop-back
- // edge.
- void ForeachBlockDependency(Block::ID blockId, std::function<void(Block::ID)> f) const;
-
// Emit all the unvisited blocks (except for ignore) in DFS order,
// starting with id.
void EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore = 0) const;
diff --git a/src/Pipeline/SpirvShaderSampling.cpp b/src/Pipeline/SpirvShaderSampling.cpp
index 87bba39..e02c32a 100644
--- a/src/Pipeline/SpirvShaderSampling.cpp
+++ b/src/Pipeline/SpirvShaderSampling.cpp
@@ -117,7 +117,7 @@
SpirvShader::ImageSampler *SpirvShader::emitSamplerFunction(ImageInstruction instruction, const Sampler &samplerState)
{
// TODO(b/129523279): Hold a separate mutex lock for the sampler being built.
- Function<Void(Pointer<Byte>, Pointer<Byte>, Pointer<SIMD::Float>, Pointer<SIMD::Float>, Pointer<Byte>)> function;
+ rr::Function<Void(Pointer<Byte>, Pointer<Byte>, Pointer<SIMD::Float>, Pointer<SIMD::Float>, Pointer<Byte>)> function;
{
Pointer<Byte> texture = function.Arg<0>();
Pointer<Byte> sampler = function.Arg<1>();