| // Copyright (c) 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/opt/interface_var_sroa.h" |
| |
| #include <iostream> |
| |
| #include "source/opt/decoration_manager.h" |
| #include "source/opt/def_use_manager.h" |
| #include "source/opt/function.h" |
| #include "source/opt/log.h" |
| #include "source/opt/type_manager.h" |
| #include "source/util/make_unique.h" |
| |
| const static uint32_t kOpDecorateDecorationInOperandIndex = 1; |
| const static uint32_t kOpDecorateLiteralInOperandIndex = 2; |
| const static uint32_t kOpEntryPointInOperandInterface = 3; |
| const static uint32_t kOpVariableStorageClassInOperandIndex = 0; |
| const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0; |
| const static uint32_t kOpTypeArrayLengthInOperandIndex = 1; |
| const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1; |
| const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0; |
| const static uint32_t kOpTypePtrTypeInOperandIndex = 1; |
| const static uint32_t kOpConstantValueInOperandIndex = 0; |
| |
| namespace spvtools { |
| namespace opt { |
| namespace { |
| |
| // Get the length of the OpTypeArray |array_type|. |
| uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr, |
| Instruction* array_type) { |
| assert(array_type->opcode() == SpvOpTypeArray); |
| uint32_t const_int_id = |
| array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex); |
| Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id); |
| assert(array_length_inst->opcode() == SpvOpConstant); |
| return array_length_inst->GetSingleWordInOperand( |
| kOpConstantValueInOperandIndex); |
| } |
| |
| // Get the element type instruction of the OpTypeArray |array_type|. |
| Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr, |
| Instruction* array_type) { |
| assert(array_type->opcode() == SpvOpTypeArray); |
| uint32_t elem_type_id = |
| array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); |
| return def_use_mgr->GetDef(elem_type_id); |
| } |
| |
| // Get the column type instruction of the OpTypeMatrix |matrix_type|. |
| Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr, |
| Instruction* matrix_type) { |
| assert(matrix_type->opcode() == SpvOpTypeMatrix); |
| uint32_t column_type_id = |
| matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); |
| return def_use_mgr->GetDef(column_type_id); |
| } |
| |
| // Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it |
| // |depth_to_component| times recursively and returns the component type. |
| // |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction. |
| uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr, |
| uint32_t type_id, |
| uint32_t depth_to_component) { |
| if (depth_to_component == 0) return type_id; |
| |
| Instruction* type_inst = def_use_mgr->GetDef(type_id); |
| if (type_inst->opcode() == SpvOpTypeArray) { |
| uint32_t elem_type_id = |
| type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); |
| return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id, |
| depth_to_component - 1); |
| } |
| |
| assert(type_inst->opcode() == SpvOpTypeMatrix); |
| uint32_t column_type_id = |
| type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); |
| return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id, |
| depth_to_component - 1); |
| } |
| |
| // Creates an OpDecorate instruction whose Target is |var_id| and Decoration is |
| // |decoration|. Adds |literal| as an extra operand of the instruction. |
| void CreateDecoration(analysis::DecorationManager* decoration_mgr, |
| uint32_t var_id, SpvDecoration decoration, |
| uint32_t literal) { |
| std::vector<Operand> operands({ |
| {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}}, |
| {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION, |
| {static_cast<uint32_t>(decoration)}}, |
| {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}}, |
| }); |
| decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands)); |
| } |
| |
| // Replaces load instructions with composite construct instructions in all the |
| // users of the loads. |loads_to_composites| is the mapping from each load to |
| // its corresponding OpCompositeConstruct. |
| void ReplaceLoadWithCompositeConstruct( |
| IRContext* context, |
| const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) { |
| for (const auto& load_and_composite : loads_to_composites) { |
| Instruction* load = load_and_composite.first; |
| Instruction* composite_construct = load_and_composite.second; |
| |
| std::vector<Instruction*> users; |
| context->get_def_use_mgr()->ForEachUse( |
| load, [&users, composite_construct](Instruction* user, uint32_t index) { |
| user->GetOperand(index).words[0] = composite_construct->result_id(); |
| users.push_back(user); |
| }); |
| |
| for (Instruction* user : users) |
| context->get_def_use_mgr()->AnalyzeInstUse(user); |
| } |
| } |
| |
| // Returns the storage class of the instruction |var|. |
| SpvStorageClass GetStorageClass(Instruction* var) { |
| return static_cast<SpvStorageClass>( |
| var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex)); |
| } |
| |
| } // namespace |
| |
| bool InterfaceVariableScalarReplacement::HasExtraArrayness( |
| Instruction& entry_point, Instruction* var) { |
| SpvExecutionModel execution_model = |
| static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0)); |
| if (execution_model != SpvExecutionModelTessellationEvaluation && |
| execution_model != SpvExecutionModelTessellationControl) { |
| return false; |
| } |
| if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(), |
| SpvDecorationPatch)) { |
| if (execution_model == SpvExecutionModelTessellationControl) return true; |
| return GetStorageClass(var) != SpvStorageClassOutput; |
| } |
| return false; |
| } |
| |
| bool InterfaceVariableScalarReplacement:: |
| CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var, |
| bool has_extra_arrayness) { |
| if (has_extra_arrayness) { |
| return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var); |
| } |
| return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var); |
| } |
| |
| bool InterfaceVariableScalarReplacement::GetVariableLocation( |
| Instruction* var, uint32_t* location) { |
| return !context()->get_decoration_mgr()->WhileEachDecoration( |
| var->result_id(), SpvDecorationLocation, |
| [location](const Instruction& inst) { |
| *location = |
| inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); |
| return false; |
| }); |
| } |
| |
| bool InterfaceVariableScalarReplacement::GetVariableComponent( |
| Instruction* var, uint32_t* component) { |
| return !context()->get_decoration_mgr()->WhileEachDecoration( |
| var->result_id(), SpvDecorationComponent, |
| [component](const Instruction& inst) { |
| *component = |
| inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); |
| return false; |
| }); |
| } |
| |
| std::vector<Instruction*> |
| InterfaceVariableScalarReplacement::CollectInterfaceVariables( |
| Instruction& entry_point) { |
| std::vector<Instruction*> interface_vars; |
| for (uint32_t i = kOpEntryPointInOperandInterface; |
| i < entry_point.NumInOperands(); ++i) { |
| Instruction* interface_var = context()->get_def_use_mgr()->GetDef( |
| entry_point.GetSingleWordInOperand(i)); |
| assert(interface_var->opcode() == SpvOpVariable); |
| |
| SpvStorageClass storage_class = GetStorageClass(interface_var); |
| if (storage_class != SpvStorageClassInput && |
| storage_class != SpvStorageClassOutput) { |
| continue; |
| } |
| |
| interface_vars.push_back(interface_var); |
| } |
| return interface_vars; |
| } |
| |
| void InterfaceVariableScalarReplacement::KillInstructionAndUsers( |
| Instruction* inst) { |
| if (inst->opcode() == SpvOpEntryPoint) { |
| return; |
| } |
| if (inst->opcode() != SpvOpAccessChain) { |
| context()->KillInst(inst); |
| return; |
| } |
| context()->get_def_use_mgr()->ForEachUser( |
| inst, [this](Instruction* user) { KillInstructionAndUsers(user); }); |
| context()->KillInst(inst); |
| } |
| |
| void InterfaceVariableScalarReplacement::KillInstructionsAndUsers( |
| const std::vector<Instruction*>& insts) { |
| for (Instruction* inst : insts) { |
| KillInstructionAndUsers(inst); |
| } |
| } |
| |
| void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations( |
| uint32_t var_id) { |
| context()->get_decoration_mgr()->RemoveDecorationsFrom( |
| var_id, [](const Instruction& inst) { |
| uint32_t decoration = |
| inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex); |
| return decoration == SpvDecorationLocation || |
| decoration == SpvDecorationComponent; |
| }); |
| } |
| |
| bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars( |
| Instruction* interface_var, Instruction* interface_var_type, |
| uint32_t location, uint32_t component, uint32_t extra_array_length) { |
| NestedCompositeComponents scalar_interface_vars = |
| CreateScalarInterfaceVarsForReplacement(interface_var_type, |
| GetStorageClass(interface_var), |
| extra_array_length); |
| |
| AddLocationAndComponentDecorations(scalar_interface_vars, &location, |
| component); |
| KillLocationAndComponentDecorations(interface_var->result_id()); |
| |
| if (!ReplaceInterfaceVarWith(interface_var, extra_array_length, |
| scalar_interface_vars)) { |
| return false; |
| } |
| |
| context()->KillInst(interface_var); |
| return true; |
| } |
| |
| bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith( |
| Instruction* interface_var, uint32_t extra_array_length, |
| const NestedCompositeComponents& scalar_interface_vars) { |
| std::vector<Instruction*> users; |
| context()->get_def_use_mgr()->ForEachUser( |
| interface_var, [&users](Instruction* user) { users.push_back(user); }); |
| |
| std::vector<uint32_t> interface_var_component_indices; |
| std::unordered_map<Instruction*, Instruction*> loads_to_composites; |
| std::unordered_map<Instruction*, Instruction*> |
| loads_for_access_chain_to_composites; |
| if (extra_array_length != 0) { |
| // Note that the extra arrayness is the first dimension of the array |
| // interface variable. |
| for (uint32_t index = 0; index < extra_array_length; ++index) { |
| std::unordered_map<Instruction*, Instruction*> loads_to_component_values; |
| if (!ReplaceComponentsOfInterfaceVarWith( |
| interface_var, users, scalar_interface_vars, |
| interface_var_component_indices, &index, |
| &loads_to_component_values, |
| &loads_for_access_chain_to_composites)) { |
| return false; |
| } |
| AddComponentsToCompositesForLoads(loads_to_component_values, |
| &loads_to_composites, 0); |
| } |
| } else if (!ReplaceComponentsOfInterfaceVarWith( |
| interface_var, users, scalar_interface_vars, |
| interface_var_component_indices, nullptr, &loads_to_composites, |
| &loads_for_access_chain_to_composites)) { |
| return false; |
| } |
| |
| ReplaceLoadWithCompositeConstruct(context(), loads_to_composites); |
| ReplaceLoadWithCompositeConstruct(context(), |
| loads_for_access_chain_to_composites); |
| |
| KillInstructionsAndUsers(users); |
| return true; |
| } |
| |
| void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations( |
| const NestedCompositeComponents& vars, uint32_t* location, |
| uint32_t component) { |
| if (!vars.HasMultipleComponents()) { |
| uint32_t var_id = vars.GetComponentVariable()->result_id(); |
| CreateDecoration(context()->get_decoration_mgr(), var_id, |
| SpvDecorationLocation, *location); |
| CreateDecoration(context()->get_decoration_mgr(), var_id, |
| SpvDecorationComponent, component); |
| ++(*location); |
| return; |
| } |
| for (const auto& var : vars.GetComponents()) { |
| AddLocationAndComponentDecorations(var, location, component); |
| } |
| } |
| |
| bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith( |
| Instruction* interface_var, |
| const std::vector<Instruction*>& interface_var_users, |
| const NestedCompositeComponents& scalar_interface_vars, |
| std::vector<uint32_t>& interface_var_component_indices, |
| const uint32_t* extra_array_index, |
| std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
| std::unordered_map<Instruction*, Instruction*>* |
| loads_for_access_chain_to_composites) { |
| if (!scalar_interface_vars.HasMultipleComponents()) { |
| for (Instruction* interface_var_user : interface_var_users) { |
| if (!ReplaceComponentOfInterfaceVarWith( |
| interface_var, interface_var_user, |
| scalar_interface_vars.GetComponentVariable(), |
| interface_var_component_indices, extra_array_index, |
| loads_to_composites, loads_for_access_chain_to_composites)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| return ReplaceMultipleComponentsOfInterfaceVarWith( |
| interface_var, interface_var_users, scalar_interface_vars.GetComponents(), |
| interface_var_component_indices, extra_array_index, loads_to_composites, |
| loads_for_access_chain_to_composites); |
| } |
| |
| bool InterfaceVariableScalarReplacement:: |
| ReplaceMultipleComponentsOfInterfaceVarWith( |
| Instruction* interface_var, |
| const std::vector<Instruction*>& interface_var_users, |
| const std::vector<NestedCompositeComponents>& components, |
| std::vector<uint32_t>& interface_var_component_indices, |
| const uint32_t* extra_array_index, |
| std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
| std::unordered_map<Instruction*, Instruction*>* |
| loads_for_access_chain_to_composites) { |
| for (uint32_t i = 0; i < components.size(); ++i) { |
| interface_var_component_indices.push_back(i); |
| std::unordered_map<Instruction*, Instruction*> loads_to_component_values; |
| std::unordered_map<Instruction*, Instruction*> |
| loads_for_access_chain_to_component_values; |
| if (!ReplaceComponentsOfInterfaceVarWith( |
| interface_var, interface_var_users, components[i], |
| interface_var_component_indices, extra_array_index, |
| &loads_to_component_values, |
| &loads_for_access_chain_to_component_values)) { |
| return false; |
| } |
| interface_var_component_indices.pop_back(); |
| |
| uint32_t depth_to_component = |
| static_cast<uint32_t>(interface_var_component_indices.size()); |
| AddComponentsToCompositesForLoads( |
| loads_for_access_chain_to_component_values, |
| loads_for_access_chain_to_composites, depth_to_component); |
| if (extra_array_index) ++depth_to_component; |
| AddComponentsToCompositesForLoads(loads_to_component_values, |
| loads_to_composites, depth_to_component); |
| } |
| return true; |
| } |
| |
| bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith( |
| Instruction* interface_var, Instruction* interface_var_user, |
| Instruction* scalar_var, |
| const std::vector<uint32_t>& interface_var_component_indices, |
| const uint32_t* extra_array_index, |
| std::unordered_map<Instruction*, Instruction*>* loads_to_component_values, |
| std::unordered_map<Instruction*, Instruction*>* |
| loads_for_access_chain_to_component_values) { |
| SpvOp opcode = interface_var_user->opcode(); |
| if (opcode == SpvOpStore) { |
| uint32_t value_id = interface_var_user->GetSingleWordInOperand(1); |
| StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices, |
| scalar_var, extra_array_index, |
| interface_var_user); |
| return true; |
| } |
| if (opcode == SpvOpLoad) { |
| Instruction* scalar_load = |
| LoadScalarVar(scalar_var, extra_array_index, interface_var_user); |
| loads_to_component_values->insert({interface_var_user, scalar_load}); |
| return true; |
| } |
| |
| // Copy OpName and annotation instructions only once. Therefore, we create |
| // them only for the first element of the extra array. |
| if (extra_array_index && *extra_array_index != 0) return true; |
| |
| if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString || |
| opcode == SpvOpDecorate) { |
| CloneAnnotationForVariable(interface_var_user, scalar_var->result_id()); |
| return true; |
| } |
| |
| if (opcode == SpvOpName) { |
| std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context())); |
| new_inst->SetInOperand(0, {scalar_var->result_id()}); |
| context()->AddDebug2Inst(std::move(new_inst)); |
| return true; |
| } |
| |
| if (opcode == SpvOpEntryPoint) { |
| return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user, |
| scalar_var->result_id()); |
| } |
| |
| if (opcode == SpvOpAccessChain) { |
| ReplaceAccessChainWith(interface_var_user, interface_var_component_indices, |
| scalar_var, |
| loads_for_access_chain_to_component_values); |
| return true; |
| } |
| |
| std::string message("Unhandled instruction"); |
| message += "\n " + interface_var_user->PrettyPrint( |
| SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| message += |
| "\nfor interface variable scalar replacement\n " + |
| interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
| return false; |
| } |
| |
| void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain( |
| Instruction* access_chain, Instruction* base_access_chain) { |
| assert(base_access_chain->opcode() == SpvOpAccessChain && |
| access_chain->opcode() == SpvOpAccessChain && |
| access_chain->GetSingleWordInOperand(0) == |
| base_access_chain->result_id()); |
| Instruction::OperandList new_operands; |
| for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) { |
| new_operands.emplace_back(base_access_chain->GetInOperand(i)); |
| } |
| for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { |
| new_operands.emplace_back(access_chain->GetInOperand(i)); |
| } |
| access_chain->SetInOperands(std::move(new_operands)); |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar( |
| uint32_t var_type_id, Instruction* var, |
| const std::vector<uint32_t>& index_ids, Instruction* insert_before, |
| uint32_t* component_type_id) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| *component_type_id = GetComponentTypeOfArrayMatrix( |
| def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size())); |
| |
| uint32_t ptr_type_id = |
| GetPointerType(*component_type_id, GetStorageClass(var)); |
| |
| std::unique_ptr<Instruction> new_access_chain( |
| new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(), |
| std::initializer_list<Operand>{ |
| {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
| for (uint32_t index_id : index_ids) { |
| new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}}); |
| } |
| |
| Instruction* inst = new_access_chain.get(); |
| def_use_mgr->AnalyzeInstDefUse(inst); |
| insert_before->InsertBefore(std::move(new_access_chain)); |
| return inst; |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex( |
| uint32_t component_type_id, Instruction* var, uint32_t index, |
| Instruction* insert_before) { |
| uint32_t ptr_type_id = |
| GetPointerType(component_type_id, GetStorageClass(var)); |
| uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index); |
| std::unique_ptr<Instruction> new_access_chain( |
| new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(), |
| std::initializer_list<Operand>{ |
| {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
| {SPV_OPERAND_TYPE_ID, {index_id}}, |
| })); |
| Instruction* inst = new_access_chain.get(); |
| context()->get_def_use_mgr()->AnalyzeInstDefUse(inst); |
| insert_before->InsertBefore(std::move(new_access_chain)); |
| return inst; |
| } |
| |
| void InterfaceVariableScalarReplacement::ReplaceAccessChainWith( |
| Instruction* access_chain, |
| const std::vector<uint32_t>& interface_var_component_indices, |
| Instruction* scalar_var, |
| std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) { |
| std::vector<uint32_t> indexes; |
| for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { |
| indexes.push_back(access_chain->GetSingleWordInOperand(i)); |
| } |
| |
| // Note that we have a strong assumption that |access_chain| has only a single |
| // index that is for the extra arrayness. |
| context()->get_def_use_mgr()->ForEachUser( |
| access_chain, |
| [this, access_chain, &indexes, &interface_var_component_indices, |
| scalar_var, loads_to_component_values](Instruction* user) { |
| switch (user->opcode()) { |
| case SpvOpAccessChain: { |
| UseBaseAccessChainForAccessChain(user, access_chain); |
| ReplaceAccessChainWith(user, interface_var_component_indices, |
| scalar_var, loads_to_component_values); |
| return; |
| } |
| case SpvOpStore: { |
| uint32_t value_id = user->GetSingleWordInOperand(1); |
| StoreComponentOfValueToAccessChainToScalarVar( |
| value_id, interface_var_component_indices, scalar_var, indexes, |
| user); |
| return; |
| } |
| case SpvOpLoad: { |
| Instruction* value = |
| LoadAccessChainToVar(scalar_var, indexes, user); |
| loads_to_component_values->insert({user, value}); |
| return; |
| } |
| default: |
| break; |
| } |
| }); |
| } |
| |
| void InterfaceVariableScalarReplacement::CloneAnnotationForVariable( |
| Instruction* annotation_inst, uint32_t var_id) { |
| assert(annotation_inst->opcode() == SpvOpDecorate || |
| annotation_inst->opcode() == SpvOpDecorateId || |
| annotation_inst->opcode() == SpvOpDecorateString); |
| std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context())); |
| new_inst->SetInOperand(0, {var_id}); |
| context()->AddAnnotationInst(std::move(new_inst)); |
| } |
| |
| bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint( |
| Instruction* interface_var, Instruction* entry_point, |
| uint32_t scalar_var_id) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| uint32_t interface_var_id = interface_var->result_id(); |
| if (interface_vars_removed_from_entry_point_operands_.find( |
| interface_var_id) != |
| interface_vars_removed_from_entry_point_operands_.end()) { |
| entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}}); |
| def_use_mgr->AnalyzeInstUse(entry_point); |
| return true; |
| } |
| |
| bool success = !entry_point->WhileEachInId( |
| [&interface_var_id, &scalar_var_id](uint32_t* id) { |
| if (*id == interface_var_id) { |
| *id = scalar_var_id; |
| return false; |
| } |
| return true; |
| }); |
| if (!success) { |
| std::string message( |
| "interface variable is not an operand of the entry point"); |
| message += "\n " + interface_var->PrettyPrint( |
| SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| message += "\n " + entry_point->PrettyPrint( |
| SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
| return false; |
| } |
| |
| def_use_mgr->AnalyzeInstUse(entry_point); |
| interface_vars_removed_from_entry_point_operands_.insert(interface_var_id); |
| return true; |
| } |
| |
| uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar( |
| Instruction* var) { |
| assert(var->opcode() == SpvOpVariable); |
| |
| uint32_t ptr_type_id = var->type_id(); |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id); |
| |
| assert(ptr_type_inst->opcode() == SpvOpTypePointer && |
| "Variable must have a pointer type."); |
| return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex); |
| } |
| |
| void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar( |
| uint32_t value_id, const std::vector<uint32_t>& component_indices, |
| Instruction* scalar_var, const uint32_t* extra_array_index, |
| Instruction* insert_before) { |
| uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
| Instruction* ptr = scalar_var; |
| if (extra_array_index) { |
| auto* ty_mgr = context()->get_type_mgr(); |
| analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); |
| assert(array_type != nullptr); |
| component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); |
| ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, |
| *extra_array_index, insert_before); |
| } |
| |
| StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, |
| extra_array_index, insert_before); |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::LoadScalarVar( |
| Instruction* scalar_var, const uint32_t* extra_array_index, |
| Instruction* insert_before) { |
| uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
| Instruction* ptr = scalar_var; |
| if (extra_array_index) { |
| auto* ty_mgr = context()->get_type_mgr(); |
| analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); |
| assert(array_type != nullptr); |
| component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); |
| ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, |
| *extra_array_index, insert_before); |
| } |
| |
| return CreateLoad(component_type_id, ptr, insert_before); |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::CreateLoad( |
| uint32_t type_id, Instruction* ptr, Instruction* insert_before) { |
| std::unique_ptr<Instruction> load( |
| new Instruction(context(), SpvOpLoad, type_id, TakeNextId(), |
| std::initializer_list<Operand>{ |
| {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}})); |
| Instruction* load_inst = load.get(); |
| context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst); |
| insert_before->InsertBefore(std::move(load)); |
| return load_inst; |
| } |
| |
| void InterfaceVariableScalarReplacement::StoreComponentOfValueTo( |
| uint32_t component_type_id, uint32_t value_id, |
| const std::vector<uint32_t>& component_indices, Instruction* ptr, |
| const uint32_t* extra_array_index, Instruction* insert_before) { |
| std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract( |
| component_type_id, value_id, component_indices, extra_array_index)); |
| |
| std::unique_ptr<Instruction> new_store( |
| new Instruction(context(), SpvOpStore)); |
| new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}}); |
| new_store->AddOperand( |
| {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}}); |
| |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| def_use_mgr->AnalyzeInstDefUse(composite_extract.get()); |
| def_use_mgr->AnalyzeInstDefUse(new_store.get()); |
| |
| insert_before->InsertBefore(std::move(composite_extract)); |
| insert_before->InsertBefore(std::move(new_store)); |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract( |
| uint32_t type_id, uint32_t composite_id, |
| const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) { |
| uint32_t component_id = TakeNextId(); |
| Instruction* composite_extract = new Instruction( |
| context(), SpvOpCompositeExtract, type_id, component_id, |
| std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}}); |
| if (extra_first_index) { |
| composite_extract->AddOperand( |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}}); |
| } |
| for (uint32_t index : indexes) { |
| composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}); |
| } |
| return composite_extract; |
| } |
| |
| void InterfaceVariableScalarReplacement:: |
| StoreComponentOfValueToAccessChainToScalarVar( |
| uint32_t value_id, const std::vector<uint32_t>& component_indices, |
| Instruction* scalar_var, |
| const std::vector<uint32_t>& access_chain_indices, |
| Instruction* insert_before) { |
| uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
| Instruction* ptr = scalar_var; |
| if (!access_chain_indices.empty()) { |
| ptr = CreateAccessChainToVar(component_type_id, scalar_var, |
| access_chain_indices, insert_before, |
| &component_type_id); |
| } |
| |
| StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, |
| nullptr, insert_before); |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar( |
| Instruction* var, const std::vector<uint32_t>& indexes, |
| Instruction* insert_before) { |
| uint32_t component_type_id = GetPointeeTypeIdOfVar(var); |
| Instruction* ptr = var; |
| if (!indexes.empty()) { |
| ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before, |
| &component_type_id); |
| } |
| |
| return CreateLoad(component_type_id, ptr, insert_before); |
| } |
| |
| Instruction* |
| InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad( |
| Instruction* load, uint32_t depth_to_component) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| uint32_t type_id = load->type_id(); |
| if (depth_to_component != 0) { |
| type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(), |
| depth_to_component); |
| } |
| uint32_t new_id = context()->TakeNextId(); |
| std::unique_ptr<Instruction> new_composite_construct( |
| new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {})); |
| Instruction* composite_construct = new_composite_construct.get(); |
| def_use_mgr->AnalyzeInstDefUse(composite_construct); |
| |
| // Insert |new_composite_construct| after |load|. When there are multiple |
| // recursive composite construct instructions for a load, we have to place the |
| // composite construct with a lower depth later because it constructs the |
| // composite that contains other composites with lower depths. |
| auto* insert_before = load->NextNode(); |
| while (true) { |
| auto itr = |
| composite_ids_to_component_depths.find(insert_before->result_id()); |
| if (itr == composite_ids_to_component_depths.end()) break; |
| if (itr->second <= depth_to_component) break; |
| insert_before = insert_before->NextNode(); |
| } |
| insert_before->InsertBefore(std::move(new_composite_construct)); |
| composite_ids_to_component_depths.insert({new_id, depth_to_component}); |
| return composite_construct; |
| } |
| |
| void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads( |
| const std::unordered_map<Instruction*, Instruction*>& |
| loads_to_component_values, |
| std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
| uint32_t depth_to_component) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| for (auto& load_and_component_vale : loads_to_component_values) { |
| Instruction* load = load_and_component_vale.first; |
| Instruction* component_value = load_and_component_vale.second; |
| Instruction* composite_construct = nullptr; |
| auto itr = loads_to_composites->find(load); |
| if (itr == loads_to_composites->end()) { |
| composite_construct = |
| CreateCompositeConstructForComponentOfLoad(load, depth_to_component); |
| loads_to_composites->insert({load, composite_construct}); |
| } else { |
| composite_construct = itr->second; |
| } |
| composite_construct->AddOperand( |
| {SPV_OPERAND_TYPE_ID, {component_value->result_id()}}); |
| def_use_mgr->AnalyzeInstDefUse(composite_construct); |
| } |
| } |
| |
| uint32_t InterfaceVariableScalarReplacement::GetArrayType( |
| uint32_t elem_type_id, uint32_t array_length) { |
| analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id); |
| uint32_t array_length_id = |
| context()->get_constant_mgr()->GetUIntConst(array_length); |
| analysis::Array array_type( |
| elem_type, |
| analysis::Array::LengthInfo{array_length_id, {0, array_length}}); |
| return context()->get_type_mgr()->GetTypeInstruction(&array_type); |
| } |
| |
| uint32_t InterfaceVariableScalarReplacement::GetPointerType( |
| uint32_t type_id, SpvStorageClass storage_class) { |
| analysis::Type* type = context()->get_type_mgr()->GetType(type_id); |
| analysis::Pointer ptr_type(type, storage_class); |
| return context()->get_type_mgr()->GetTypeInstruction(&ptr_type); |
| } |
| |
| InterfaceVariableScalarReplacement::NestedCompositeComponents |
| InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray( |
| Instruction* interface_var_type, SpvStorageClass storage_class, |
| uint32_t extra_array_length) { |
| assert(interface_var_type->opcode() == SpvOpTypeArray); |
| |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type); |
| Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type); |
| |
| NestedCompositeComponents scalar_vars; |
| while (array_length > 0) { |
| NestedCompositeComponents scalar_vars_for_element = |
| CreateScalarInterfaceVarsForReplacement(elem_type, storage_class, |
| extra_array_length); |
| scalar_vars.AddComponent(scalar_vars_for_element); |
| --array_length; |
| } |
| return scalar_vars; |
| } |
| |
| InterfaceVariableScalarReplacement::NestedCompositeComponents |
| InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix( |
| Instruction* interface_var_type, SpvStorageClass storage_class, |
| uint32_t extra_array_length) { |
| assert(interface_var_type->opcode() == SpvOpTypeMatrix); |
| |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| uint32_t column_count = interface_var_type->GetSingleWordInOperand( |
| kOpTypeMatrixColCountInOperandIndex); |
| Instruction* column_type = |
| GetMatrixColumnType(def_use_mgr, interface_var_type); |
| |
| NestedCompositeComponents scalar_vars; |
| while (column_count > 0) { |
| NestedCompositeComponents scalar_vars_for_column = |
| CreateScalarInterfaceVarsForReplacement(column_type, storage_class, |
| extra_array_length); |
| scalar_vars.AddComponent(scalar_vars_for_column); |
| --column_count; |
| } |
| return scalar_vars; |
| } |
| |
| InterfaceVariableScalarReplacement::NestedCompositeComponents |
| InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement( |
| Instruction* interface_var_type, SpvStorageClass storage_class, |
| uint32_t extra_array_length) { |
| // Handle array case. |
| if (interface_var_type->opcode() == SpvOpTypeArray) { |
| return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class, |
| extra_array_length); |
| } |
| |
| // Handle matrix case. |
| if (interface_var_type->opcode() == SpvOpTypeMatrix) { |
| return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class, |
| extra_array_length); |
| } |
| |
| // Handle scalar or vector case. |
| NestedCompositeComponents scalar_var; |
| uint32_t type_id = interface_var_type->result_id(); |
| if (extra_array_length != 0) { |
| type_id = GetArrayType(type_id, extra_array_length); |
| } |
| uint32_t ptr_type_id = |
| context()->get_type_mgr()->FindPointerToType(type_id, storage_class); |
| uint32_t id = TakeNextId(); |
| std::unique_ptr<Instruction> variable( |
| new Instruction(context(), SpvOpVariable, ptr_type_id, id, |
| std::initializer_list<Operand>{ |
| {SPV_OPERAND_TYPE_STORAGE_CLASS, |
| {static_cast<uint32_t>(storage_class)}}})); |
| scalar_var.SetSingleComponentVariable(variable.get()); |
| context()->AddGlobalValue(std::move(variable)); |
| return scalar_var; |
| } |
| |
| Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable( |
| Instruction* var) { |
| uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var); |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| return def_use_mgr->GetDef(pointee_type_id); |
| } |
| |
| Pass::Status InterfaceVariableScalarReplacement::Process() { |
| Pass::Status status = Status::SuccessWithoutChange; |
| for (Instruction& entry_point : get_module()->entry_points()) { |
| status = |
| CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point)); |
| } |
| return status; |
| } |
| |
| bool InterfaceVariableScalarReplacement:: |
| ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) { |
| if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end()) |
| return false; |
| |
| std::string message( |
| "A variable is arrayed for an entry point but it is not " |
| "arrayed for another entry point"); |
| message += |
| "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
| return true; |
| } |
| |
| bool InterfaceVariableScalarReplacement:: |
| ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) { |
| if (vars_without_extra_arrayness.find(var) == |
| vars_without_extra_arrayness.end()) |
| return false; |
| |
| std::string message( |
| "A variable is not arrayed for an entry point but it is " |
| "arrayed for another entry point"); |
| message += |
| "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
| context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
| return true; |
| } |
| |
| Pass::Status |
| InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars( |
| Instruction& entry_point) { |
| std::vector<Instruction*> interface_vars = |
| CollectInterfaceVariables(entry_point); |
| |
| Pass::Status status = Status::SuccessWithoutChange; |
| for (Instruction* interface_var : interface_vars) { |
| uint32_t location, component; |
| if (!GetVariableLocation(interface_var, &location)) continue; |
| if (!GetVariableComponent(interface_var, &component)) component = 0; |
| |
| Instruction* interface_var_type = GetTypeOfVariable(interface_var); |
| uint32_t extra_array_length = 0; |
| if (HasExtraArrayness(entry_point, interface_var)) { |
| extra_array_length = |
| GetArrayLength(context()->get_def_use_mgr(), interface_var_type); |
| interface_var_type = |
| GetArrayElementType(context()->get_def_use_mgr(), interface_var_type); |
| vars_with_extra_arrayness.insert(interface_var); |
| } else { |
| vars_without_extra_arrayness.insert(interface_var); |
| } |
| |
| if (!CheckExtraArraynessConflictBetweenEntries(interface_var, |
| extra_array_length != 0)) { |
| return Pass::Status::Failure; |
| } |
| |
| if (interface_var_type->opcode() != SpvOpTypeArray && |
| interface_var_type->opcode() != SpvOpTypeMatrix) { |
| continue; |
| } |
| |
| if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type, |
| location, component, |
| extra_array_length)) { |
| return Pass::Status::Failure; |
| } |
| status = Pass::Status::SuccessWithChange; |
| } |
| |
| return status; |
| } |
| |
| } // namespace opt |
| } // namespace spvtools |