|  | // Copyright (c) 2018 Google LLC | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "source/opt/reduce_load_size.h" | 
|  |  | 
|  | #include <set> | 
|  | #include <vector> | 
|  |  | 
|  | #include "source/opt/instruction.h" | 
|  | #include "source/opt/ir_builder.h" | 
|  | #include "source/opt/ir_context.h" | 
|  | #include "source/util/bit_vector.h" | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | const uint32_t kExtractCompositeIdInIdx = 0; | 
|  | const uint32_t kVariableStorageClassInIdx = 0; | 
|  | const uint32_t kLoadPointerInIdx = 0; | 
|  | const double kThreshold = 0.9; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | namespace spvtools { | 
|  | namespace opt { | 
|  |  | 
|  | Pass::Status ReduceLoadSize::Process() { | 
|  | bool modified = false; | 
|  |  | 
|  | for (auto& func : *get_module()) { | 
|  | func.ForEachInst([&modified, this](Instruction* inst) { | 
|  | if (inst->opcode() == SpvOpCompositeExtract) { | 
|  | if (ShouldReplaceExtract(inst)) { | 
|  | modified |= ReplaceExtract(inst); | 
|  | } | 
|  | } | 
|  | }); | 
|  | } | 
|  |  | 
|  | return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; | 
|  | } | 
|  |  | 
|  | bool ReduceLoadSize::ReplaceExtract(Instruction* inst) { | 
|  | assert(inst->opcode() == SpvOpCompositeExtract && | 
|  | "Wrong opcode.  Should be OpCompositeExtract."); | 
|  | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); | 
|  | analysis::TypeManager* type_mgr = context()->get_type_mgr(); | 
|  | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); | 
|  |  | 
|  | uint32_t composite_id = | 
|  | inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); | 
|  | Instruction* composite_inst = def_use_mgr->GetDef(composite_id); | 
|  |  | 
|  | if (composite_inst->opcode() != SpvOpLoad) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id()); | 
|  | if (composite_type->kind() == analysis::Type::kVector || | 
|  | composite_type->kind() == analysis::Type::kMatrix) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | Instruction* var = composite_inst->GetBaseAddress(); | 
|  | if (var == nullptr || var->opcode() != SpvOpVariable) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | SpvStorageClass storage_class = static_cast<SpvStorageClass>( | 
|  | var->GetSingleWordInOperand(kVariableStorageClassInIdx)); | 
|  | switch (storage_class) { | 
|  | case SpvStorageClassUniform: | 
|  | case SpvStorageClassUniformConstant: | 
|  | case SpvStorageClassInput: | 
|  | break; | 
|  | default: | 
|  | return false; | 
|  | } | 
|  |  | 
|  | // Create a new access chain and load just after the old load. | 
|  | // We cannot create the new access chain load in the position of the extract | 
|  | // because the storage may have been written to in between. | 
|  | InstructionBuilder ir_builder( | 
|  | inst->context(), composite_inst, | 
|  | IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); | 
|  |  | 
|  | uint32_t pointer_to_result_type_id = | 
|  | type_mgr->FindPointerToType(inst->type_id(), storage_class); | 
|  | assert(pointer_to_result_type_id != 0 && | 
|  | "We did not find the pointer type that we need."); | 
|  |  | 
|  | analysis::Integer int_type(32, false); | 
|  | const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type); | 
|  | std::vector<uint32_t> ids; | 
|  | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { | 
|  | uint32_t index = inst->GetSingleWordInOperand(i); | 
|  | const analysis::Constant* index_const = | 
|  | const_mgr->GetConstant(uint32_type, {index}); | 
|  | ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id()); | 
|  | } | 
|  |  | 
|  | Instruction* new_access_chain = ir_builder.AddAccessChain( | 
|  | pointer_to_result_type_id, | 
|  | composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids); | 
|  | Instruction* new_laod = | 
|  | ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id()); | 
|  |  | 
|  | context()->ReplaceAllUsesWith(inst->result_id(), new_laod->result_id()); | 
|  | context()->KillInst(inst); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) { | 
|  | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); | 
|  | Instruction* op_inst = def_use_mgr->GetDef( | 
|  | inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)); | 
|  |  | 
|  | if (op_inst->opcode() != SpvOpLoad) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | auto cached_result = should_replace_cache_.find(op_inst->result_id()); | 
|  | if (cached_result != should_replace_cache_.end()) { | 
|  | return cached_result->second; | 
|  | } | 
|  |  | 
|  | bool all_elements_used = false; | 
|  | std::set<uint32_t> elements_used; | 
|  |  | 
|  | all_elements_used = | 
|  | !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) { | 
|  | if (use->opcode() != SpvOpCompositeExtract || | 
|  | use->NumInOperands() == 1) { | 
|  | return false; | 
|  | } | 
|  | elements_used.insert(use->GetSingleWordInOperand(1)); | 
|  | return true; | 
|  | }); | 
|  |  | 
|  | bool should_replace = false; | 
|  | if (all_elements_used) { | 
|  | should_replace = false; | 
|  | } else { | 
|  | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); | 
|  | analysis::TypeManager* type_mgr = context()->get_type_mgr(); | 
|  | analysis::Type* load_type = type_mgr->GetType(op_inst->type_id()); | 
|  | uint32_t total_size = 1; | 
|  | switch (load_type->kind()) { | 
|  | case analysis::Type::kArray: { | 
|  | const analysis::Constant* size_const = | 
|  | const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId()); | 
|  | assert(size_const->AsIntConstant()); | 
|  | total_size = size_const->GetU32(); | 
|  | } break; | 
|  | case analysis::Type::kStruct: | 
|  | total_size = static_cast<uint32_t>( | 
|  | load_type->AsStruct()->element_types().size()); | 
|  | break; | 
|  | default: | 
|  | break; | 
|  | } | 
|  | double percent_used = static_cast<double>(elements_used.size()) / | 
|  | static_cast<double>(total_size); | 
|  | should_replace = (percent_used < kThreshold); | 
|  | } | 
|  |  | 
|  | should_replace_cache_[op_inst->result_id()] = should_replace; | 
|  | return should_replace; | 
|  | } | 
|  |  | 
|  | }  // namespace opt | 
|  | }  // namespace spvtools |