| // Copyright (c) 2018 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/opt/combine_access_chains.h" |
| |
| #include <utility> |
| |
| #include "source/opt/constants.h" |
| #include "source/opt/ir_builder.h" |
| #include "source/opt/ir_context.h" |
| |
| namespace spvtools { |
| namespace opt { |
| |
| Pass::Status CombineAccessChains::Process() { |
| bool modified = false; |
| |
| for (auto& function : *get_module()) { |
| modified |= ProcessFunction(function); |
| } |
| |
| return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
| } |
| |
| bool CombineAccessChains::ProcessFunction(Function& function) { |
| bool modified = false; |
| |
| cfg()->ForEachBlockInReversePostOrder( |
| function.entry().get(), [&modified, this](BasicBlock* block) { |
| block->ForEachInst([&modified, this](Instruction* inst) { |
| switch (inst->opcode()) { |
| case SpvOpAccessChain: |
| case SpvOpInBoundsAccessChain: |
| case SpvOpPtrAccessChain: |
| case SpvOpInBoundsPtrAccessChain: |
| modified |= CombineAccessChain(inst); |
| break; |
| default: |
| break; |
| } |
| }); |
| }); |
| |
| return modified; |
| } |
| |
| uint32_t CombineAccessChains::GetConstantValue( |
| const analysis::Constant* constant_inst) { |
| if (constant_inst->type()->AsInteger()->width() <= 32) { |
| if (constant_inst->type()->AsInteger()->IsSigned()) { |
| return static_cast<uint32_t>(constant_inst->GetS32()); |
| } else { |
| return constant_inst->GetU32(); |
| } |
| } else { |
| assert(false); |
| return 0u; |
| } |
| } |
| |
| uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { |
| uint32_t array_stride = 0; |
| context()->get_decoration_mgr()->WhileEachDecoration( |
| inst->type_id(), SpvDecorationArrayStride, |
| [&array_stride](const Instruction& decoration) { |
| assert(decoration.opcode() != SpvOpDecorateId); |
| if (decoration.opcode() == SpvOpDecorate) { |
| array_stride = decoration.GetSingleWordInOperand(1); |
| } else { |
| array_stride = decoration.GetSingleWordInOperand(2); |
| } |
| return false; |
| }); |
| return array_stride; |
| } |
| |
| const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| |
| Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
| const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); |
| assert(type->AsPointer()); |
| type = type->AsPointer()->pointee_type(); |
| std::vector<uint32_t> element_indices; |
| uint32_t starting_index = 1; |
| if (IsPtrAccessChain(inst->opcode())) { |
| // Skip the first index of OpPtrAccessChain as it does not affect type |
| // resolution. |
| starting_index = 2; |
| } |
| for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
| Instruction* index_inst = |
| def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); |
| const analysis::Constant* index_constant = |
| context()->get_constant_mgr()->GetConstantFromInst(index_inst); |
| if (index_constant) { |
| uint32_t index_value = GetConstantValue(index_constant); |
| element_indices.push_back(index_value); |
| } else { |
| // This index must not matter to resolve the type in valid SPIR-V. |
| element_indices.push_back(0); |
| } |
| } |
| type = type_mgr->GetMemberType(type, element_indices); |
| return type; |
| } |
| |
| bool CombineAccessChains::CombineIndices(Instruction* ptr_input, |
| Instruction* inst, |
| std::vector<Operand>* new_operands) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); |
| |
| Instruction* last_index_inst = def_use_mgr->GetDef( |
| ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); |
| const analysis::Constant* last_index_constant = |
| constant_mgr->GetConstantFromInst(last_index_inst); |
| |
| Instruction* element_inst = |
| def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
| const analysis::Constant* element_constant = |
| constant_mgr->GetConstantFromInst(element_inst); |
| |
| // Combine the last index of the AccessChain (|ptr_inst|) with the element |
| // operand of the PtrAccessChain (|inst|). |
| const bool combining_element_operands = |
| IsPtrAccessChain(inst->opcode()) && |
| IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; |
| uint32_t new_value_id = 0; |
| const analysis::Type* type = GetIndexedType(ptr_input); |
| if (last_index_constant && element_constant) { |
| // Combine the constants. |
| uint32_t new_value = GetConstantValue(last_index_constant) + |
| GetConstantValue(element_constant); |
| const analysis::Constant* new_value_constant = |
| constant_mgr->GetConstant(last_index_constant->type(), {new_value}); |
| Instruction* new_value_inst = |
| constant_mgr->GetDefiningInstruction(new_value_constant); |
| new_value_id = new_value_inst->result_id(); |
| } else if (!type->AsStruct() || combining_element_operands) { |
| // Generate an addition of the two indices. |
| InstructionBuilder builder( |
| context(), inst, |
| IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
| Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), |
| last_index_inst->result_id(), |
| element_inst->result_id()); |
| new_value_id = addition->result_id(); |
| } else { |
| // Indexing into structs must be constant, so bail out here. |
| return false; |
| } |
| new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); |
| return true; |
| } |
| |
| bool CombineAccessChains::CreateNewInputOperands( |
| Instruction* ptr_input, Instruction* inst, |
| std::vector<Operand>* new_operands) { |
| // Start by copying all the input operands of the feeder access chain. |
| for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { |
| new_operands->push_back(ptr_input->GetInOperand(i)); |
| } |
| |
| // Deal with the last index of the feeder access chain. |
| if (IsPtrAccessChain(inst->opcode())) { |
| // The last index of the feeder should be combined with the element operand |
| // of |inst|. |
| if (!CombineIndices(ptr_input, inst, new_operands)) return false; |
| } else { |
| // The indices aren't being combined so now add the last index operand of |
| // |ptr_input|. |
| new_operands->push_back( |
| ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); |
| } |
| |
| // Copy the remaining index operands. |
| uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; |
| for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
| new_operands->push_back(inst->GetInOperand(i)); |
| } |
| |
| return true; |
| } |
| |
| bool CombineAccessChains::CombineAccessChain(Instruction* inst) { |
| assert((inst->opcode() == SpvOpPtrAccessChain || |
| inst->opcode() == SpvOpAccessChain || |
| inst->opcode() == SpvOpInBoundsAccessChain || |
| inst->opcode() == SpvOpInBoundsPtrAccessChain) && |
| "Wrong opcode. Expected an access chain."); |
| |
| Instruction* ptr_input = |
| context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); |
| if (ptr_input->opcode() != SpvOpAccessChain && |
| ptr_input->opcode() != SpvOpInBoundsAccessChain && |
| ptr_input->opcode() != SpvOpPtrAccessChain && |
| ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { |
| return false; |
| } |
| |
| if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; |
| |
| // Handles the following cases: |
| // 1. |ptr_input| is an index-less access chain. Replace the pointer |
| // in |inst| with |ptr_input|'s pointer. |
| // 2. |inst| is a index-less access chain. Change |inst| to an |
| // OpCopyObject. |
| // 3. |inst| is not a pointer access chain. |
| // |inst|'s indices are appended to |ptr_input|'s indices. |
| // 4. |ptr_input| is not pointer access chain. |
| // |inst| is a pointer access chain. |
| // |inst|'s element operand is combined with the last index in |
| // |ptr_input| to form a new operand. |
| // 5. |ptr_input| is a pointer access chain. |
| // Like the above scenario, |inst|'s element operand is combined |
| // with |ptr_input|'s last index. This results is either a |
| // combined element operand or combined regular index. |
| |
| // TODO(alan-baker): Support this properly. Requires analyzing the |
| // size/alignment of the type and converting the stride into an element |
| // index. |
| uint32_t array_stride = GetArrayStride(ptr_input); |
| if (array_stride != 0) return false; |
| |
| if (ptr_input->NumInOperands() == 1) { |
| // The input is effectively a no-op. |
| inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); |
| context()->AnalyzeUses(inst); |
| } else if (inst->NumInOperands() == 1) { |
| // |inst| is a no-op, change it to a copy. Instruction simplification will |
| // clean it up. |
| inst->SetOpcode(SpvOpCopyObject); |
| } else { |
| std::vector<Operand> new_operands; |
| if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; |
| |
| // Update the instruction. |
| inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); |
| inst->SetInOperands(std::move(new_operands)); |
| context()->AnalyzeUses(inst); |
| } |
| return true; |
| } |
| |
| SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { |
| auto IsInBounds = [](SpvOp opcode) { |
| return opcode == SpvOpInBoundsPtrAccessChain || |
| opcode == SpvOpInBoundsAccessChain; |
| }; |
| |
| if (input_opcode == SpvOpInBoundsPtrAccessChain) { |
| if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; |
| } else if (input_opcode == SpvOpInBoundsAccessChain) { |
| if (!IsInBounds(base_opcode)) return SpvOpAccessChain; |
| } |
| |
| return input_opcode; |
| } |
| |
| bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { |
| return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; |
| } |
| |
| bool CombineAccessChains::Has64BitIndices(Instruction* inst) { |
| for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
| Instruction* index_inst = |
| context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); |
| const analysis::Type* index_type = |
| context()->get_type_mgr()->GetType(index_inst->type_id()); |
| if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) |
| return true; |
| } |
| return false; |
| } |
| |
| } // namespace opt |
| } // namespace spvtools |