| // Copyright (c) 2018 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/opt/reduce_load_size.h" |
| |
| #include <set> |
| #include <vector> |
| |
| #include "source/opt/instruction.h" |
| #include "source/opt/ir_builder.h" |
| #include "source/opt/ir_context.h" |
| #include "source/util/bit_vector.h" |
| |
| namespace { |
| |
| const uint32_t kExtractCompositeIdInIdx = 0; |
| const uint32_t kVariableStorageClassInIdx = 0; |
| const uint32_t kLoadPointerInIdx = 0; |
| |
| } // namespace |
| |
| namespace spvtools { |
| namespace opt { |
| |
| Pass::Status ReduceLoadSize::Process() { |
| bool modified = false; |
| |
| for (auto& func : *get_module()) { |
| func.ForEachInst([&modified, this](Instruction* inst) { |
| if (inst->opcode() == SpvOpCompositeExtract) { |
| if (ShouldReplaceExtract(inst)) { |
| modified |= ReplaceExtract(inst); |
| } |
| } |
| }); |
| } |
| |
| return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
| } |
| |
| bool ReduceLoadSize::ReplaceExtract(Instruction* inst) { |
| assert(inst->opcode() == SpvOpCompositeExtract && |
| "Wrong opcode. Should be OpCompositeExtract."); |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); |
| |
| uint32_t composite_id = |
| inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
| Instruction* composite_inst = def_use_mgr->GetDef(composite_id); |
| |
| if (composite_inst->opcode() != SpvOpLoad) { |
| return false; |
| } |
| |
| analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id()); |
| if (composite_type->kind() == analysis::Type::kVector || |
| composite_type->kind() == analysis::Type::kMatrix) { |
| return false; |
| } |
| |
| Instruction* var = composite_inst->GetBaseAddress(); |
| if (var == nullptr || var->opcode() != SpvOpVariable) { |
| return false; |
| } |
| |
| SpvStorageClass storage_class = static_cast<SpvStorageClass>( |
| var->GetSingleWordInOperand(kVariableStorageClassInIdx)); |
| switch (storage_class) { |
| case SpvStorageClassUniform: |
| case SpvStorageClassUniformConstant: |
| case SpvStorageClassInput: |
| break; |
| default: |
| return false; |
| } |
| |
| // Create a new access chain and load just after the old load. |
| // We cannot create the new access chain load in the position of the extract |
| // because the storage may have been written to in between. |
| InstructionBuilder ir_builder( |
| inst->context(), composite_inst, |
| IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); |
| |
| uint32_t pointer_to_result_type_id = |
| type_mgr->FindPointerToType(inst->type_id(), storage_class); |
| assert(pointer_to_result_type_id != 0 && |
| "We did not find the pointer type that we need."); |
| |
| analysis::Integer int_type(32, false); |
| const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type); |
| std::vector<uint32_t> ids; |
| for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
| uint32_t index = inst->GetSingleWordInOperand(i); |
| const analysis::Constant* index_const = |
| const_mgr->GetConstant(uint32_type, {index}); |
| ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id()); |
| } |
| |
| Instruction* new_access_chain = ir_builder.AddAccessChain( |
| pointer_to_result_type_id, |
| composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids); |
| Instruction* new_load = |
| ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id()); |
| |
| context()->ReplaceAllUsesWith(inst->result_id(), new_load->result_id()); |
| context()->KillInst(inst); |
| return true; |
| } |
| |
| bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) { |
| analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| Instruction* op_inst = def_use_mgr->GetDef( |
| inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)); |
| |
| if (op_inst->opcode() != SpvOpLoad) { |
| return false; |
| } |
| |
| auto cached_result = should_replace_cache_.find(op_inst->result_id()); |
| if (cached_result != should_replace_cache_.end()) { |
| return cached_result->second; |
| } |
| |
| bool all_elements_used = false; |
| std::set<uint32_t> elements_used; |
| |
| all_elements_used = |
| !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) { |
| if (use->IsCommonDebugInstr()) return true; |
| if (use->opcode() != SpvOpCompositeExtract || |
| use->NumInOperands() == 1) { |
| return false; |
| } |
| elements_used.insert(use->GetSingleWordInOperand(1)); |
| return true; |
| }); |
| |
| bool should_replace = false; |
| if (all_elements_used) { |
| should_replace = false; |
| } else if (1.0 <= replacement_threshold_) { |
| should_replace = true; |
| } else { |
| analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); |
| analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| analysis::Type* load_type = type_mgr->GetType(op_inst->type_id()); |
| uint32_t total_size = 1; |
| switch (load_type->kind()) { |
| case analysis::Type::kArray: { |
| const analysis::Constant* size_const = |
| const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId()); |
| assert(size_const->AsIntConstant()); |
| total_size = size_const->GetU32(); |
| } break; |
| case analysis::Type::kStruct: |
| total_size = static_cast<uint32_t>( |
| load_type->AsStruct()->element_types().size()); |
| break; |
| default: |
| break; |
| } |
| double percent_used = static_cast<double>(elements_used.size()) / |
| static_cast<double>(total_size); |
| should_replace = (percent_used < replacement_threshold_); |
| } |
| |
| should_replace_cache_[op_inst->result_id()] = should_replace; |
| return should_replace; |
| } |
| |
| } // namespace opt |
| } // namespace spvtools |