| // Copyright (c) 2025 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/opt/resolve_binding_conflicts_pass.h" |
| |
| #include <algorithm> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "source/opt/decoration_manager.h" |
| #include "source/opt/def_use_manager.h" |
| #include "source/opt/instruction.h" |
| #include "source/opt/ir_builder.h" |
| #include "source/opt/ir_context.h" |
| #include "spirv/unified1/spirv.h" |
| |
| namespace spvtools { |
| namespace opt { |
| |
| // A VarBindingInfo contains the binding information for a single resource |
| // variable. |
| // |
| // Exactly one such object is created per resource variable in the |
| // module. In particular, when a resource variable is statically used by |
| // more than one entry point, those entry points share the same VarBindingInfo |
| // object for that variable. |
| struct VarBindingInfo { |
| const Instruction* const var; |
| const uint32_t descriptor_set; |
| Instruction* const binding_decoration; |
| |
| // Returns the binding number. |
| uint32_t binding() const { |
| return binding_decoration->GetSingleWordInOperand(2); |
| } |
| // Sets the binding number to 'b'. |
| void updateBinding(uint32_t b) { binding_decoration->SetOperand(2, {b}); } |
| }; |
| |
| // The bindings in the same descriptor set that are used by an entry point. |
| using BindingList = std::vector<VarBindingInfo*>; |
| // A map from descriptor set number to the list of bindings in that descriptor |
| // set, as used by a particular entry point. |
| using DescriptorSets = std::unordered_map<uint32_t, BindingList>; |
| |
| IRContext::Analysis ResolveBindingConflictsPass::GetPreservedAnalyses() { |
| // All analyses are kept up to date. |
| // At most this modifies the Binding numbers on variables. |
| return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | |
| IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | |
| IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | |
| IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | |
| IRContext::kAnalysisScalarEvolution | |
| IRContext::kAnalysisRegisterPressure | |
| IRContext::kAnalysisValueNumberTable | |
| IRContext::kAnalysisStructuredCFG | IRContext::kAnalysisBuiltinVarId | |
| IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants | |
| IRContext::kAnalysisTypes | IRContext::kAnalysisDebugInfo | |
| IRContext::kAnalysisLiveness; |
| } |
| |
| // Orders variable binding info objects. |
| // * The binding number is most signficant; |
| // * Then a sampler-like object compares greater than non-sampler like object. |
| // * Otherwise compare based on variable ID. |
| // This provides a total order among bindings in a descriptor set for a valid |
| // Vulkan module. |
| bool Less(const VarBindingInfo* const lhs, const VarBindingInfo* const rhs) { |
| if (lhs->binding() < rhs->binding()) return true; |
| if (lhs->binding() > rhs->binding()) return false; |
| |
| // Examine types. |
| // In valid Vulkan the only conflict can occur between |
| // images and samplers. We only care about a specific |
| // comparison when one is a image-like thing and the other |
| // is a sampler-like thing of the same shape. So unwrap |
| // types until we hit one of those two. |
| |
| auto* def_use_mgr = lhs->var->context()->get_def_use_mgr(); |
| |
| // Returns the type found by iteratively following pointer pointee type, |
| // or array element type. |
| auto unwrap = [&def_use_mgr](Instruction* ty) { |
| bool keep_going = true; |
| do { |
| switch (ty->opcode()) { |
| case spv::Op::OpTypePointer: |
| ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(1)); |
| break; |
| case spv::Op::OpTypeArray: |
| case spv::Op::OpTypeRuntimeArray: |
| ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(0)); |
| break; |
| default: |
| keep_going = false; |
| break; |
| } |
| } while (keep_going); |
| return ty; |
| }; |
| |
| auto* lhs_ty = unwrap(def_use_mgr->GetDef(lhs->var->type_id())); |
| auto* rhs_ty = unwrap(def_use_mgr->GetDef(rhs->var->type_id())); |
| if (lhs_ty->opcode() == rhs_ty->opcode()) { |
| // Pick based on variable ID. |
| return lhs->var->result_id() < rhs->var->result_id(); |
| } |
| // A sampler is always greater than an image. |
| if (lhs_ty->opcode() == spv::Op::OpTypeSampler) { |
| return false; |
| } |
| if (rhs_ty->opcode() == spv::Op::OpTypeSampler) { |
| return true; |
| } |
| // Pick based on variable ID. |
| return lhs->var->result_id() < rhs->var->result_id(); |
| } |
| |
| // Summarizes the caller-callee relationships between functions in a module. |
| class CallGraph { |
| public: |
| // Returns the list of all functions statically reachable from entry points, |
| // where callees precede callers. |
| const std::vector<uint32_t>& CalleesBeforeCallers() const { |
| return visit_order_; |
| } |
| // Returns the list functions called from a given function. |
| const std::unordered_set<uint32_t>& Callees(uint32_t caller) { |
| return calls_[caller]; |
| } |
| |
| CallGraph(IRContext& context) { |
| // Populate calls_. |
| std::queue<uint32_t> callee_queue; |
| for (const auto& fn : *context.module()) { |
| auto& callees = calls_[fn.result_id()]; |
| context.AddCalls(&fn, &callee_queue); |
| while (!callee_queue.empty()) { |
| callees.insert(callee_queue.front()); |
| callee_queue.pop(); |
| } |
| } |
| |
| // Perform depth-first search, starting from each entry point. |
| // Populates visit_order_. |
| for (const auto& ep : context.module()->entry_points()) { |
| Visit(ep.GetSingleWordInOperand(1)); |
| } |
| } |
| |
| private: |
| // Visits a function, recursively visiting its callees. Adds this ID |
| // to the visit_order after all callees have been visited. |
| void Visit(uint32_t func_id) { |
| if (visited_.count(func_id)) { |
| return; |
| } |
| visited_.insert(func_id); |
| for (auto callee_id : calls_[func_id]) { |
| Visit(callee_id); |
| } |
| visit_order_.push_back(func_id); |
| } |
| |
| // Maps the ID of a function to the IDs of functions it calls. |
| std::unordered_map<uint32_t, std::unordered_set<uint32_t>> calls_; |
| |
| // IDs of visited functions; |
| std::unordered_set<uint32_t> visited_; |
| // IDs of functions, where callees precede callers. |
| std::vector<uint32_t> visit_order_; |
| }; |
| |
| // Returns vector binding info for all resource variables in the module. |
| auto GetVarBindings(IRContext& context) { |
| std::vector<VarBindingInfo> vars; |
| auto* deco_mgr = context.get_decoration_mgr(); |
| for (auto& inst : context.module()->types_values()) { |
| if (inst.opcode() == spv::Op::OpVariable) { |
| Instruction* descriptor_set_deco = nullptr; |
| Instruction* binding_deco = nullptr; |
| for (auto* deco : deco_mgr->GetDecorationsFor(inst.result_id(), false)) { |
| switch (static_cast<spv::Decoration>(deco->GetSingleWordInOperand(1))) { |
| case spv::Decoration::DescriptorSet: |
| assert(!descriptor_set_deco); |
| descriptor_set_deco = deco; |
| break; |
| case spv::Decoration::Binding: |
| assert(!binding_deco); |
| binding_deco = deco; |
| break; |
| default: |
| break; |
| } |
| } |
| if (descriptor_set_deco && binding_deco) { |
| vars.push_back({&inst, descriptor_set_deco->GetSingleWordInOperand(2), |
| binding_deco}); |
| } |
| } |
| } |
| return vars; |
| } |
| |
| // Merges the bindings from source into sink. Maintains order and uniqueness |
| // within a list of bindings. |
| void Merge(DescriptorSets& sink, const DescriptorSets& source) { |
| for (auto index_and_bindings : source) { |
| const uint32_t index = index_and_bindings.first; |
| const BindingList& src1 = index_and_bindings.second; |
| const BindingList& src2 = sink[index]; |
| BindingList merged; |
| merged.resize(src1.size() + src2.size()); |
| auto merged_end = std::merge(src1.begin(), src1.end(), src2.begin(), |
| src2.end(), merged.begin(), Less); |
| auto unique_end = std::unique(merged.begin(), merged_end); |
| merged.resize(unique_end - merged.begin()); |
| sink[index] = std::move(merged); |
| } |
| } |
| |
| // Resolves conflicts within this binding list, so the binding number on an |
| // item is at least one more than the binding number on the previous item. |
| // When this does not yet hold, increase the binding number on the second |
| // item in the pair. Returns true if any changes were applied. |
| bool ResolveConflicts(BindingList& bl) { |
| bool changed = false; |
| for (size_t i = 1; i < bl.size(); i++) { |
| const auto prev_num = bl[i - 1]->binding(); |
| if (prev_num >= bl[i]->binding()) { |
| bl[i]->updateBinding(prev_num + 1); |
| changed = true; |
| } |
| } |
| return changed; |
| } |
| |
| Pass::Status ResolveBindingConflictsPass::Process() { |
| // Assumes the descriptor set and binding decorations are not provided |
| // via decoration groups. Decoration groups were deprecated in SPIR-V 1.3 |
| // Revision 6. I have not seen any compiler generate them. --dneto |
| |
| auto vars = GetVarBindings(*context()); |
| |
| // Maps a function ID to the variables used directly or indirectly by the |
| // function, organized into descriptor sets. Each descriptor set |
| // consists of a BindingList of distinct variables. |
| std::unordered_map<uint32_t, DescriptorSets> used_vars; |
| |
| // Determine variables directly used by functions. |
| auto* def_use_mgr = context()->get_def_use_mgr(); |
| for (auto& var : vars) { |
| std::unordered_set<uint32_t> visited_functions_for_var; |
| def_use_mgr->ForEachUser(var.var, [&](Instruction* user) { |
| if (auto* block = context()->get_instr_block(user)) { |
| auto* fn = block->GetParent(); |
| assert(fn); |
| const auto fn_id = fn->result_id(); |
| if (visited_functions_for_var.insert(fn_id).second) { |
| used_vars[fn_id][var.descriptor_set].push_back(&var); |
| } |
| } |
| }); |
| } |
| |
| // Sort within a descriptor set by binding number. |
| for (auto& sets_for_fn : used_vars) { |
| for (auto& ds : sets_for_fn.second) { |
| BindingList& bindings = ds.second; |
| std::stable_sort(bindings.begin(), bindings.end(), Less); |
| } |
| } |
| |
| // Propagate from callees to callers. |
| CallGraph call_graph(*context()); |
| for (const uint32_t caller : call_graph.CalleesBeforeCallers()) { |
| DescriptorSets& caller_ds = used_vars[caller]; |
| for (const uint32_t callee : call_graph.Callees(caller)) { |
| Merge(caller_ds, used_vars[callee]); |
| } |
| } |
| |
| // At this point, the descriptor sets associated with each entry point |
| // capture exactly the set of resource variables statically used |
| // by the static call tree of that entry point. |
| |
| // Resolve conflicts. |
| // VarBindingInfo objects may be shared between the bindings lists. |
| // Updating a binding in one list can require updating another list later. |
| // So repeat updates until settling. |
| |
| // The union of BindingLists across all entry points. |
| std::vector<BindingList*> ep_bindings; |
| |
| for (auto& ep : context()->module()->entry_points()) { |
| for (auto& ds : used_vars[ep.GetSingleWordInOperand(1)]) { |
| BindingList& bindings = ds.second; |
| ep_bindings.push_back(&bindings); |
| } |
| } |
| bool modified = false; |
| bool found_conflict; |
| do { |
| found_conflict = false; |
| for (BindingList* bl : ep_bindings) { |
| found_conflict |= ResolveConflicts(*bl); |
| } |
| modified |= found_conflict; |
| } while (found_conflict); |
| |
| return modified ? Pass::Status::SuccessWithChange |
| : Pass::Status::SuccessWithoutChange; |
| } |
| |
| } // namespace opt |
| } // namespace spvtools |