|  | // Copyright (c) 2018 Google LLC | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "source/opt/combine_access_chains.h" | 
|  |  | 
|  | #include <utility> | 
|  |  | 
|  | #include "source/opt/constants.h" | 
|  | #include "source/opt/ir_builder.h" | 
|  | #include "source/opt/ir_context.h" | 
|  |  | 
|  | namespace spvtools { | 
|  | namespace opt { | 
|  |  | 
|  | Pass::Status CombineAccessChains::Process() { | 
|  | bool modified = false; | 
|  |  | 
|  | for (auto& function : *get_module()) { | 
|  | modified |= ProcessFunction(function); | 
|  | } | 
|  |  | 
|  | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::ProcessFunction(Function& function) { | 
|  | if (function.IsDeclaration()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | bool modified = false; | 
|  |  | 
|  | cfg()->ForEachBlockInReversePostOrder( | 
|  | function.entry().get(), [&modified, this](BasicBlock* block) { | 
|  | block->ForEachInst([&modified, this](Instruction* inst) { | 
|  | switch (inst->opcode()) { | 
|  | case SpvOpAccessChain: | 
|  | case SpvOpInBoundsAccessChain: | 
|  | case SpvOpPtrAccessChain: | 
|  | case SpvOpInBoundsPtrAccessChain: | 
|  | modified |= CombineAccessChain(inst); | 
|  | break; | 
|  | default: | 
|  | break; | 
|  | } | 
|  | }); | 
|  | }); | 
|  |  | 
|  | return modified; | 
|  | } | 
|  |  | 
|  | uint32_t CombineAccessChains::GetConstantValue( | 
|  | const analysis::Constant* constant_inst) { | 
|  | if (constant_inst->type()->AsInteger()->width() <= 32) { | 
|  | if (constant_inst->type()->AsInteger()->IsSigned()) { | 
|  | return static_cast<uint32_t>(constant_inst->GetS32()); | 
|  | } else { | 
|  | return constant_inst->GetU32(); | 
|  | } | 
|  | } else { | 
|  | assert(false); | 
|  | return 0u; | 
|  | } | 
|  | } | 
|  |  | 
|  | uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { | 
|  | uint32_t array_stride = 0; | 
|  | context()->get_decoration_mgr()->WhileEachDecoration( | 
|  | inst->type_id(), SpvDecorationArrayStride, | 
|  | [&array_stride](const Instruction& decoration) { | 
|  | assert(decoration.opcode() != SpvOpDecorateId); | 
|  | if (decoration.opcode() == SpvOpDecorate) { | 
|  | array_stride = decoration.GetSingleWordInOperand(1); | 
|  | } else { | 
|  | array_stride = decoration.GetSingleWordInOperand(2); | 
|  | } | 
|  | return false; | 
|  | }); | 
|  | return array_stride; | 
|  | } | 
|  |  | 
|  | const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { | 
|  | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); | 
|  | analysis::TypeManager* type_mgr = context()->get_type_mgr(); | 
|  |  | 
|  | Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); | 
|  | const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); | 
|  | assert(type->AsPointer()); | 
|  | type = type->AsPointer()->pointee_type(); | 
|  | std::vector<uint32_t> element_indices; | 
|  | uint32_t starting_index = 1; | 
|  | if (IsPtrAccessChain(inst->opcode())) { | 
|  | // Skip the first index of OpPtrAccessChain as it does not affect type | 
|  | // resolution. | 
|  | starting_index = 2; | 
|  | } | 
|  | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { | 
|  | Instruction* index_inst = | 
|  | def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); | 
|  | const analysis::Constant* index_constant = | 
|  | context()->get_constant_mgr()->GetConstantFromInst(index_inst); | 
|  | if (index_constant) { | 
|  | uint32_t index_value = GetConstantValue(index_constant); | 
|  | element_indices.push_back(index_value); | 
|  | } else { | 
|  | // This index must not matter to resolve the type in valid SPIR-V. | 
|  | element_indices.push_back(0); | 
|  | } | 
|  | } | 
|  | type = type_mgr->GetMemberType(type, element_indices); | 
|  | return type; | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::CombineIndices(Instruction* ptr_input, | 
|  | Instruction* inst, | 
|  | std::vector<Operand>* new_operands) { | 
|  | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); | 
|  | analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); | 
|  |  | 
|  | Instruction* last_index_inst = def_use_mgr->GetDef( | 
|  | ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); | 
|  | const analysis::Constant* last_index_constant = | 
|  | constant_mgr->GetConstantFromInst(last_index_inst); | 
|  |  | 
|  | Instruction* element_inst = | 
|  | def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); | 
|  | const analysis::Constant* element_constant = | 
|  | constant_mgr->GetConstantFromInst(element_inst); | 
|  |  | 
|  | // Combine the last index of the AccessChain (|ptr_inst|) with the element | 
|  | // operand of the PtrAccessChain (|inst|). | 
|  | const bool combining_element_operands = | 
|  | IsPtrAccessChain(inst->opcode()) && | 
|  | IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; | 
|  | uint32_t new_value_id = 0; | 
|  | const analysis::Type* type = GetIndexedType(ptr_input); | 
|  | if (last_index_constant && element_constant) { | 
|  | // Combine the constants. | 
|  | uint32_t new_value = GetConstantValue(last_index_constant) + | 
|  | GetConstantValue(element_constant); | 
|  | const analysis::Constant* new_value_constant = | 
|  | constant_mgr->GetConstant(last_index_constant->type(), {new_value}); | 
|  | Instruction* new_value_inst = | 
|  | constant_mgr->GetDefiningInstruction(new_value_constant); | 
|  | new_value_id = new_value_inst->result_id(); | 
|  | } else if (!type->AsStruct() || combining_element_operands) { | 
|  | // Generate an addition of the two indices. | 
|  | InstructionBuilder builder( | 
|  | context(), inst, | 
|  | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|  | Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), | 
|  | last_index_inst->result_id(), | 
|  | element_inst->result_id()); | 
|  | new_value_id = addition->result_id(); | 
|  | } else { | 
|  | // Indexing into structs must be constant, so bail out here. | 
|  | return false; | 
|  | } | 
|  | new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::CreateNewInputOperands( | 
|  | Instruction* ptr_input, Instruction* inst, | 
|  | std::vector<Operand>* new_operands) { | 
|  | // Start by copying all the input operands of the feeder access chain. | 
|  | for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { | 
|  | new_operands->push_back(ptr_input->GetInOperand(i)); | 
|  | } | 
|  |  | 
|  | // Deal with the last index of the feeder access chain. | 
|  | if (IsPtrAccessChain(inst->opcode())) { | 
|  | // The last index of the feeder should be combined with the element operand | 
|  | // of |inst|. | 
|  | if (!CombineIndices(ptr_input, inst, new_operands)) return false; | 
|  | } else { | 
|  | // The indices aren't being combined so now add the last index operand of | 
|  | // |ptr_input|. | 
|  | new_operands->push_back( | 
|  | ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); | 
|  | } | 
|  |  | 
|  | // Copy the remaining index operands. | 
|  | uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; | 
|  | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { | 
|  | new_operands->push_back(inst->GetInOperand(i)); | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::CombineAccessChain(Instruction* inst) { | 
|  | assert((inst->opcode() == SpvOpPtrAccessChain || | 
|  | inst->opcode() == SpvOpAccessChain || | 
|  | inst->opcode() == SpvOpInBoundsAccessChain || | 
|  | inst->opcode() == SpvOpInBoundsPtrAccessChain) && | 
|  | "Wrong opcode. Expected an access chain."); | 
|  |  | 
|  | Instruction* ptr_input = | 
|  | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); | 
|  | if (ptr_input->opcode() != SpvOpAccessChain && | 
|  | ptr_input->opcode() != SpvOpInBoundsAccessChain && | 
|  | ptr_input->opcode() != SpvOpPtrAccessChain && | 
|  | ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; | 
|  |  | 
|  | // Handles the following cases: | 
|  | // 1. |ptr_input| is an index-less access chain. Replace the pointer | 
|  | //    in |inst| with |ptr_input|'s pointer. | 
|  | // 2. |inst| is a index-less access chain. Change |inst| to an | 
|  | //    OpCopyObject. | 
|  | // 3. |inst| is not a pointer access chain. | 
|  | //    |inst|'s indices are appended to |ptr_input|'s indices. | 
|  | // 4. |ptr_input| is not pointer access chain. | 
|  | //    |inst| is a pointer access chain. | 
|  | //    |inst|'s element operand is combined with the last index in | 
|  | //    |ptr_input| to form a new operand. | 
|  | // 5. |ptr_input| is a pointer access chain. | 
|  | //    Like the above scenario, |inst|'s element operand is combined | 
|  | //    with |ptr_input|'s last index. This results is either a | 
|  | //    combined element operand or combined regular index. | 
|  |  | 
|  | // TODO(alan-baker): Support this properly. Requires analyzing the | 
|  | // size/alignment of the type and converting the stride into an element | 
|  | // index. | 
|  | uint32_t array_stride = GetArrayStride(ptr_input); | 
|  | if (array_stride != 0) return false; | 
|  |  | 
|  | if (ptr_input->NumInOperands() == 1) { | 
|  | // The input is effectively a no-op. | 
|  | inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); | 
|  | context()->AnalyzeUses(inst); | 
|  | } else if (inst->NumInOperands() == 1) { | 
|  | // |inst| is a no-op, change it to a copy. Instruction simplification will | 
|  | // clean it up. | 
|  | inst->SetOpcode(SpvOpCopyObject); | 
|  | } else { | 
|  | std::vector<Operand> new_operands; | 
|  | if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; | 
|  |  | 
|  | // Update the instruction. | 
|  | inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); | 
|  | inst->SetInOperands(std::move(new_operands)); | 
|  | context()->AnalyzeUses(inst); | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { | 
|  | auto IsInBounds = [](SpvOp opcode) { | 
|  | return opcode == SpvOpInBoundsPtrAccessChain || | 
|  | opcode == SpvOpInBoundsAccessChain; | 
|  | }; | 
|  |  | 
|  | if (input_opcode == SpvOpInBoundsPtrAccessChain) { | 
|  | if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; | 
|  | } else if (input_opcode == SpvOpInBoundsAccessChain) { | 
|  | if (!IsInBounds(base_opcode)) return SpvOpAccessChain; | 
|  | } | 
|  |  | 
|  | return input_opcode; | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { | 
|  | return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; | 
|  | } | 
|  |  | 
|  | bool CombineAccessChains::Has64BitIndices(Instruction* inst) { | 
|  | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { | 
|  | Instruction* index_inst = | 
|  | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); | 
|  | const analysis::Type* index_type = | 
|  | context()->get_type_mgr()->GetType(index_inst->type_id()); | 
|  | if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) | 
|  | return true; | 
|  | } | 
|  | return false; | 
|  | } | 
|  |  | 
|  | }  // namespace opt | 
|  | }  // namespace spvtools |