|  | // Copyright (c) 2017 Google Inc. | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "source/opt/scalar_replacement_pass.h" | 
|  |  | 
|  | #include <algorithm> | 
|  | #include <queue> | 
|  | #include <tuple> | 
|  | #include <utility> | 
|  |  | 
|  | #include "source/enum_string_mapping.h" | 
|  | #include "source/extensions.h" | 
|  | #include "source/opt/reflect.h" | 
|  | #include "source/opt/types.h" | 
|  | #include "source/util/make_unique.h" | 
|  |  | 
|  | namespace spvtools { | 
|  | namespace opt { | 
|  |  | 
|  | Pass::Status ScalarReplacementPass::Process() { | 
|  | Status status = Status::SuccessWithoutChange; | 
|  | for (auto& f : *get_module()) { | 
|  | Status functionStatus = ProcessFunction(&f); | 
|  | if (functionStatus == Status::Failure) | 
|  | return functionStatus; | 
|  | else if (functionStatus == Status::SuccessWithChange) | 
|  | status = functionStatus; | 
|  | } | 
|  |  | 
|  | return status; | 
|  | } | 
|  |  | 
|  | Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { | 
|  | std::queue<Instruction*> worklist; | 
|  | BasicBlock& entry = *function->begin(); | 
|  | for (auto iter = entry.begin(); iter != entry.end(); ++iter) { | 
|  | // Function storage class OpVariables must appear as the first instructions | 
|  | // of the entry block. | 
|  | if (iter->opcode() != SpvOpVariable) break; | 
|  |  | 
|  | Instruction* varInst = &*iter; | 
|  | if (CanReplaceVariable(varInst)) { | 
|  | worklist.push(varInst); | 
|  | } | 
|  | } | 
|  |  | 
|  | Status status = Status::SuccessWithoutChange; | 
|  | while (!worklist.empty()) { | 
|  | Instruction* varInst = worklist.front(); | 
|  | worklist.pop(); | 
|  |  | 
|  | Status var_status = ReplaceVariable(varInst, &worklist); | 
|  | if (var_status == Status::Failure) | 
|  | return var_status; | 
|  | else if (var_status == Status::SuccessWithChange) | 
|  | status = var_status; | 
|  | } | 
|  |  | 
|  | return status; | 
|  | } | 
|  |  | 
|  | Pass::Status ScalarReplacementPass::ReplaceVariable( | 
|  | Instruction* inst, std::queue<Instruction*>* worklist) { | 
|  | std::vector<Instruction*> replacements; | 
|  | if (!CreateReplacementVariables(inst, &replacements)) { | 
|  | return Status::Failure; | 
|  | } | 
|  |  | 
|  | std::vector<Instruction*> dead; | 
|  | bool replaced_all_uses = get_def_use_mgr()->WhileEachUser( | 
|  | inst, [this, &replacements, &dead](Instruction* user) { | 
|  | if (!IsAnnotationInst(user->opcode())) { | 
|  | switch (user->opcode()) { | 
|  | case SpvOpLoad: | 
|  | if (ReplaceWholeLoad(user, replacements)) { | 
|  | dead.push_back(user); | 
|  | } else { | 
|  | return false; | 
|  | } | 
|  | break; | 
|  | case SpvOpStore: | 
|  | if (ReplaceWholeStore(user, replacements)) { | 
|  | dead.push_back(user); | 
|  | } else { | 
|  | return false; | 
|  | } | 
|  | break; | 
|  | case SpvOpAccessChain: | 
|  | case SpvOpInBoundsAccessChain: | 
|  | if (ReplaceAccessChain(user, replacements)) | 
|  | dead.push_back(user); | 
|  | else | 
|  | return false; | 
|  | break; | 
|  | case SpvOpName: | 
|  | case SpvOpMemberName: | 
|  | break; | 
|  | default: | 
|  | assert(false && "Unexpected opcode"); | 
|  | break; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | }); | 
|  |  | 
|  | if (replaced_all_uses) { | 
|  | dead.push_back(inst); | 
|  | } else { | 
|  | return Status::Failure; | 
|  | } | 
|  |  | 
|  | // If there are no dead instructions to clean up, return with no changes. | 
|  | if (dead.empty()) return Status::SuccessWithoutChange; | 
|  |  | 
|  | // Clean up some dead code. | 
|  | while (!dead.empty()) { | 
|  | Instruction* toKill = dead.back(); | 
|  | dead.pop_back(); | 
|  | context()->KillInst(toKill); | 
|  | } | 
|  |  | 
|  | // Attempt to further scalarize. | 
|  | for (auto var : replacements) { | 
|  | if (var->opcode() == SpvOpVariable) { | 
|  | if (get_def_use_mgr()->NumUsers(var) == 0) { | 
|  | context()->KillInst(var); | 
|  | } else if (CanReplaceVariable(var)) { | 
|  | worklist->push(var); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | return Status::SuccessWithChange; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::ReplaceWholeLoad( | 
|  | Instruction* load, const std::vector<Instruction*>& replacements) { | 
|  | // Replaces the load of the entire composite with a load from each replacement | 
|  | // variable followed by a composite construction. | 
|  | BasicBlock* block = context()->get_instr_block(load); | 
|  | std::vector<Instruction*> loads; | 
|  | loads.reserve(replacements.size()); | 
|  | BasicBlock::iterator where(load); | 
|  | for (auto var : replacements) { | 
|  | // Create a load of each replacement variable. | 
|  | if (var->opcode() != SpvOpVariable) { | 
|  | loads.push_back(var); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | Instruction* type = GetStorageType(var); | 
|  | uint32_t loadId = TakeNextId(); | 
|  | if (loadId == 0) { | 
|  | return false; | 
|  | } | 
|  | std::unique_ptr<Instruction> newLoad( | 
|  | new Instruction(context(), SpvOpLoad, type->result_id(), loadId, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); | 
|  | // Copy memory access attributes which start at index 1. Index 0 is the | 
|  | // pointer to load. | 
|  | for (uint32_t i = 1; i < load->NumInOperands(); ++i) { | 
|  | Operand copy(load->GetInOperand(i)); | 
|  | newLoad->AddOperand(std::move(copy)); | 
|  | } | 
|  | where = where.InsertBefore(std::move(newLoad)); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(&*where); | 
|  | context()->set_instr_block(&*where, block); | 
|  | loads.push_back(&*where); | 
|  | } | 
|  |  | 
|  | // Construct a new composite. | 
|  | uint32_t compositeId = TakeNextId(); | 
|  | if (compositeId == 0) { | 
|  | return false; | 
|  | } | 
|  | where = load; | 
|  | std::unique_ptr<Instruction> compositeConstruct(new Instruction( | 
|  | context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); | 
|  | for (auto l : loads) { | 
|  | Operand op(SPV_OPERAND_TYPE_ID, | 
|  | std::initializer_list<uint32_t>{l->result_id()}); | 
|  | compositeConstruct->AddOperand(std::move(op)); | 
|  | } | 
|  | where = where.InsertBefore(std::move(compositeConstruct)); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(&*where); | 
|  | context()->set_instr_block(&*where, block); | 
|  | context()->ReplaceAllUsesWith(load->result_id(), compositeId); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::ReplaceWholeStore( | 
|  | Instruction* store, const std::vector<Instruction*>& replacements) { | 
|  | // Replaces a store to the whole composite with a series of extract and stores | 
|  | // to each element. | 
|  | uint32_t storeInput = store->GetSingleWordInOperand(1u); | 
|  | BasicBlock* block = context()->get_instr_block(store); | 
|  | BasicBlock::iterator where(store); | 
|  | uint32_t elementIndex = 0; | 
|  | for (auto var : replacements) { | 
|  | // Create the extract. | 
|  | if (var->opcode() != SpvOpVariable) { | 
|  | elementIndex++; | 
|  | continue; | 
|  | } | 
|  |  | 
|  | Instruction* type = GetStorageType(var); | 
|  | uint32_t extractId = TakeNextId(); | 
|  | if (extractId == 0) { | 
|  | return false; | 
|  | } | 
|  | std::unique_ptr<Instruction> extract(new Instruction( | 
|  | context(), SpvOpCompositeExtract, type->result_id(), extractId, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_ID, {storeInput}}, | 
|  | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); | 
|  | auto iter = where.InsertBefore(std::move(extract)); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); | 
|  | context()->set_instr_block(&*iter, block); | 
|  |  | 
|  | // Create the store. | 
|  | std::unique_ptr<Instruction> newStore( | 
|  | new Instruction(context(), SpvOpStore, 0, 0, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, | 
|  | {SPV_OPERAND_TYPE_ID, {extractId}}})); | 
|  | // Copy memory access attributes which start at index 2. Index 0 is the | 
|  | // pointer and index 1 is the data. | 
|  | for (uint32_t i = 2; i < store->NumInOperands(); ++i) { | 
|  | Operand copy(store->GetInOperand(i)); | 
|  | newStore->AddOperand(std::move(copy)); | 
|  | } | 
|  | iter = where.InsertBefore(std::move(newStore)); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); | 
|  | context()->set_instr_block(&*iter, block); | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::ReplaceAccessChain( | 
|  | Instruction* chain, const std::vector<Instruction*>& replacements) { | 
|  | // Replaces the access chain with either another access chain (with one fewer | 
|  | // indexes) or a direct use of the replacement variable. | 
|  | uint32_t indexId = chain->GetSingleWordInOperand(1u); | 
|  | const Instruction* index = get_def_use_mgr()->GetDef(indexId); | 
|  | int64_t indexValue = context() | 
|  | ->get_constant_mgr() | 
|  | ->GetConstantFromInst(index) | 
|  | ->GetSignExtendedValue(); | 
|  | if (indexValue < 0 || | 
|  | indexValue >= static_cast<int64_t>(replacements.size())) { | 
|  | // Out of bounds access, this is illegal IR.  Notice that OpAccessChain | 
|  | // indexing is 0-based, so we should also reject index == size-of-array. | 
|  | return false; | 
|  | } else { | 
|  | const Instruction* var = replacements[static_cast<size_t>(indexValue)]; | 
|  | if (chain->NumInOperands() > 2) { | 
|  | // Replace input access chain with another access chain. | 
|  | BasicBlock::iterator chainIter(chain); | 
|  | uint32_t replacementId = TakeNextId(); | 
|  | if (replacementId == 0) { | 
|  | return false; | 
|  | } | 
|  | std::unique_ptr<Instruction> replacementChain(new Instruction( | 
|  | context(), chain->opcode(), chain->type_id(), replacementId, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); | 
|  | // Add the remaining indexes. | 
|  | for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { | 
|  | Operand copy(chain->GetInOperand(i)); | 
|  | replacementChain->AddOperand(std::move(copy)); | 
|  | } | 
|  | auto iter = chainIter.InsertBefore(std::move(replacementChain)); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); | 
|  | context()->set_instr_block(&*iter, context()->get_instr_block(chain)); | 
|  | context()->ReplaceAllUsesWith(chain->result_id(), replacementId); | 
|  | } else { | 
|  | // Replace with a use of the variable. | 
|  | context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); | 
|  | } | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CreateReplacementVariables( | 
|  | Instruction* inst, std::vector<Instruction*>* replacements) { | 
|  | Instruction* type = GetStorageType(inst); | 
|  |  | 
|  | std::unique_ptr<std::unordered_set<int64_t>> components_used = | 
|  | GetUsedComponents(inst); | 
|  |  | 
|  | uint32_t elem = 0; | 
|  | switch (type->opcode()) { | 
|  | case SpvOpTypeStruct: | 
|  | type->ForEachInOperand( | 
|  | [this, inst, &elem, replacements, &components_used](uint32_t* id) { | 
|  | if (!components_used || components_used->count(elem)) { | 
|  | CreateVariable(*id, inst, elem, replacements); | 
|  | } else { | 
|  | replacements->push_back(CreateNullConstant(*id)); | 
|  | } | 
|  | elem++; | 
|  | }); | 
|  | break; | 
|  | case SpvOpTypeArray: | 
|  | for (uint32_t i = 0; i != GetArrayLength(type); ++i) { | 
|  | if (!components_used || components_used->count(i)) { | 
|  | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, | 
|  | replacements); | 
|  | } else { | 
|  | replacements->push_back( | 
|  | CreateNullConstant(type->GetSingleWordInOperand(0u))); | 
|  | } | 
|  | } | 
|  | break; | 
|  |  | 
|  | case SpvOpTypeMatrix: | 
|  | case SpvOpTypeVector: | 
|  | for (uint32_t i = 0; i != GetNumElements(type); ++i) { | 
|  | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); | 
|  | } | 
|  | break; | 
|  |  | 
|  | default: | 
|  | assert(false && "Unexpected type."); | 
|  | break; | 
|  | } | 
|  |  | 
|  | TransferAnnotations(inst, replacements); | 
|  | return std::find(replacements->begin(), replacements->end(), nullptr) == | 
|  | replacements->end(); | 
|  | } | 
|  |  | 
|  | void ScalarReplacementPass::TransferAnnotations( | 
|  | const Instruction* source, std::vector<Instruction*>* replacements) { | 
|  | // Only transfer invariant and restrict decorations on the variable. There are | 
|  | // no type or member decorations that are necessary to transfer. | 
|  | for (auto inst : | 
|  | get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { | 
|  | assert(inst->opcode() == SpvOpDecorate); | 
|  | uint32_t decoration = inst->GetSingleWordInOperand(1u); | 
|  | if (decoration == SpvDecorationInvariant || | 
|  | decoration == SpvDecorationRestrict) { | 
|  | for (auto var : *replacements) { | 
|  | if (var == nullptr) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | std::unique_ptr<Instruction> annotation( | 
|  | new Instruction(context(), SpvOpDecorate, 0, 0, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, | 
|  | {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); | 
|  | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { | 
|  | Operand copy(inst->GetInOperand(i)); | 
|  | annotation->AddOperand(std::move(copy)); | 
|  | } | 
|  | context()->AddAnnotationInst(std::move(annotation)); | 
|  | get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void ScalarReplacementPass::CreateVariable( | 
|  | uint32_t typeId, Instruction* varInst, uint32_t index, | 
|  | std::vector<Instruction*>* replacements) { | 
|  | uint32_t ptrId = GetOrCreatePointerType(typeId); | 
|  | uint32_t id = TakeNextId(); | 
|  |  | 
|  | if (id == 0) { | 
|  | replacements->push_back(nullptr); | 
|  | } | 
|  |  | 
|  | std::unique_ptr<Instruction> variable(new Instruction( | 
|  | context(), SpvOpVariable, ptrId, id, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); | 
|  |  | 
|  | BasicBlock* block = context()->get_instr_block(varInst); | 
|  | block->begin().InsertBefore(std::move(variable)); | 
|  | Instruction* inst = &*block->begin(); | 
|  |  | 
|  | // If varInst was initialized, make sure to initialize its replacement. | 
|  | GetOrCreateInitialValue(varInst, index, inst); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(inst); | 
|  | context()->set_instr_block(inst, block); | 
|  |  | 
|  | // Copy decorations from the member to the new variable. | 
|  | Instruction* typeInst = GetStorageType(varInst); | 
|  | for (auto dec_inst : | 
|  | get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { | 
|  | uint32_t decoration; | 
|  | if (dec_inst->opcode() != SpvOpMemberDecorate) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | if (dec_inst->GetSingleWordInOperand(1) != index) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | decoration = dec_inst->GetSingleWordInOperand(2u); | 
|  | switch (decoration) { | 
|  | case SpvDecorationRelaxedPrecision: { | 
|  | std::unique_ptr<Instruction> new_dec_inst( | 
|  | new Instruction(context(), SpvOpDecorate, 0, 0, {})); | 
|  | new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id})); | 
|  | for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) { | 
|  | new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i))); | 
|  | } | 
|  | context()->AddAnnotationInst(std::move(new_dec_inst)); | 
|  | } break; | 
|  | default: | 
|  | break; | 
|  | } | 
|  | } | 
|  |  | 
|  | replacements->push_back(inst); | 
|  | } | 
|  |  | 
|  | uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { | 
|  | auto iter = pointee_to_pointer_.find(id); | 
|  | if (iter != pointee_to_pointer_.end()) return iter->second; | 
|  |  | 
|  | analysis::Type* pointeeTy; | 
|  | std::unique_ptr<analysis::Pointer> pointerTy; | 
|  | std::tie(pointeeTy, pointerTy) = | 
|  | context()->get_type_mgr()->GetTypeAndPointerType(id, | 
|  | SpvStorageClassFunction); | 
|  | uint32_t ptrId = 0; | 
|  | if (pointeeTy->IsUniqueType()) { | 
|  | // Non-ambiguous type, just ask the type manager for an id. | 
|  | ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get()); | 
|  | pointee_to_pointer_[id] = ptrId; | 
|  | return ptrId; | 
|  | } | 
|  |  | 
|  | // Ambiguous type. We must perform a linear search to try and find the right | 
|  | // type. | 
|  | for (auto global : context()->types_values()) { | 
|  | if (global.opcode() == SpvOpTypePointer && | 
|  | global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && | 
|  | global.GetSingleWordInOperand(1u) == id) { | 
|  | if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { | 
|  | // Only reuse a decoration-less pointer of the correct type. | 
|  | ptrId = global.result_id(); | 
|  | break; | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | if (ptrId != 0) { | 
|  | pointee_to_pointer_[id] = ptrId; | 
|  | return ptrId; | 
|  | } | 
|  |  | 
|  | ptrId = TakeNextId(); | 
|  | context()->AddType(MakeUnique<Instruction>( | 
|  | context(), SpvOpTypePointer, 0, ptrId, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, | 
|  | {SPV_OPERAND_TYPE_ID, {id}}})); | 
|  | Instruction* ptr = &*--context()->types_values_end(); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(ptr); | 
|  | pointee_to_pointer_[id] = ptrId; | 
|  | // Register with the type manager if necessary. | 
|  | context()->get_type_mgr()->RegisterType(ptrId, *pointerTy); | 
|  |  | 
|  | return ptrId; | 
|  | } | 
|  |  | 
|  | void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, | 
|  | uint32_t index, | 
|  | Instruction* newVar) { | 
|  | assert(source->opcode() == SpvOpVariable); | 
|  | if (source->NumInOperands() < 2) return; | 
|  |  | 
|  | uint32_t initId = source->GetSingleWordInOperand(1u); | 
|  | uint32_t storageId = GetStorageType(newVar)->result_id(); | 
|  | Instruction* init = get_def_use_mgr()->GetDef(initId); | 
|  | uint32_t newInitId = 0; | 
|  | // TODO(dnovillo): Refactor this with constant propagation. | 
|  | if (init->opcode() == SpvOpConstantNull) { | 
|  | // Initialize to appropriate NULL. | 
|  | auto iter = type_to_null_.find(storageId); | 
|  | if (iter == type_to_null_.end()) { | 
|  | newInitId = TakeNextId(); | 
|  | type_to_null_[storageId] = newInitId; | 
|  | context()->AddGlobalValue( | 
|  | MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId, | 
|  | newInitId, std::initializer_list<Operand>{})); | 
|  | Instruction* newNull = &*--context()->types_values_end(); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(newNull); | 
|  | } else { | 
|  | newInitId = iter->second; | 
|  | } | 
|  | } else if (IsSpecConstantInst(init->opcode())) { | 
|  | // Create a new constant extract. | 
|  | newInitId = TakeNextId(); | 
|  | context()->AddGlobalValue(MakeUnique<Instruction>( | 
|  | context(), SpvOpSpecConstantOp, storageId, newInitId, | 
|  | std::initializer_list<Operand>{ | 
|  | {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, | 
|  | {SPV_OPERAND_TYPE_ID, {init->result_id()}}, | 
|  | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); | 
|  | Instruction* newSpecConst = &*--context()->types_values_end(); | 
|  | get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); | 
|  | } else if (init->opcode() == SpvOpConstantComposite) { | 
|  | // Get the appropriate index constant. | 
|  | newInitId = init->GetSingleWordInOperand(index); | 
|  | Instruction* element = get_def_use_mgr()->GetDef(newInitId); | 
|  | if (element->opcode() == SpvOpUndef) { | 
|  | // Undef is not a valid initializer for a variable. | 
|  | newInitId = 0; | 
|  | } | 
|  | } else { | 
|  | assert(false); | 
|  | } | 
|  |  | 
|  | if (newInitId != 0) { | 
|  | newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); | 
|  | } | 
|  | } | 
|  |  | 
|  | uint64_t ScalarReplacementPass::GetArrayLength( | 
|  | const Instruction* arrayType) const { | 
|  | assert(arrayType->opcode() == SpvOpTypeArray); | 
|  | const Instruction* length = | 
|  | get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); | 
|  | return context() | 
|  | ->get_constant_mgr() | 
|  | ->GetConstantFromInst(length) | 
|  | ->GetZeroExtendedValue(); | 
|  | } | 
|  |  | 
|  | uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { | 
|  | assert(type->opcode() == SpvOpTypeVector || | 
|  | type->opcode() == SpvOpTypeMatrix); | 
|  | const Operand& op = type->GetInOperand(1u); | 
|  | assert(op.words.size() <= 2); | 
|  | uint64_t len = 0; | 
|  | for (size_t i = 0; i != op.words.size(); ++i) { | 
|  | len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i)); | 
|  | } | 
|  | return len; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { | 
|  | const Instruction* inst = get_def_use_mgr()->GetDef(id); | 
|  | assert(inst); | 
|  | return spvOpcodeIsSpecConstant(inst->opcode()); | 
|  | } | 
|  |  | 
|  | Instruction* ScalarReplacementPass::GetStorageType( | 
|  | const Instruction* inst) const { | 
|  | assert(inst->opcode() == SpvOpVariable); | 
|  |  | 
|  | uint32_t ptrTypeId = inst->type_id(); | 
|  | uint32_t typeId = | 
|  | get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); | 
|  | return get_def_use_mgr()->GetDef(typeId); | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CanReplaceVariable( | 
|  | const Instruction* varInst) const { | 
|  | assert(varInst->opcode() == SpvOpVariable); | 
|  |  | 
|  | // Can only replace function scope variables. | 
|  | if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | const Instruction* typeInst = GetStorageType(varInst); | 
|  | if (!CheckType(typeInst)) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!CheckAnnotations(varInst)) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!CheckUses(varInst)) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { | 
|  | if (!CheckTypeAnnotations(typeInst)) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | switch (typeInst->opcode()) { | 
|  | case SpvOpTypeStruct: | 
|  | // Don't bother with empty structs or very large structs. | 
|  | if (typeInst->NumInOperands() == 0 || | 
|  | IsLargerThanSizeLimit(typeInst->NumInOperands())) { | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | case SpvOpTypeArray: | 
|  | if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { | 
|  | return false; | 
|  | } | 
|  | if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | // TODO(alanbaker): Develop some heuristics for when this should be | 
|  | // re-enabled. | 
|  | //// Specifically including matrix and vector in an attempt to reduce the | 
|  | //// number of vector registers required. | 
|  | // case SpvOpTypeMatrix: | 
|  | // case SpvOpTypeVector: | 
|  | //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; | 
|  | //  return true; | 
|  |  | 
|  | case SpvOpTypeRuntimeArray: | 
|  | default: | 
|  | return false; | 
|  | } | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckTypeAnnotations( | 
|  | const Instruction* typeInst) const { | 
|  | for (auto inst : | 
|  | get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { | 
|  | uint32_t decoration; | 
|  | if (inst->opcode() == SpvOpDecorate) { | 
|  | decoration = inst->GetSingleWordInOperand(1u); | 
|  | } else { | 
|  | assert(inst->opcode() == SpvOpMemberDecorate); | 
|  | decoration = inst->GetSingleWordInOperand(2u); | 
|  | } | 
|  |  | 
|  | switch (decoration) { | 
|  | case SpvDecorationRowMajor: | 
|  | case SpvDecorationColMajor: | 
|  | case SpvDecorationArrayStride: | 
|  | case SpvDecorationMatrixStride: | 
|  | case SpvDecorationCPacked: | 
|  | case SpvDecorationInvariant: | 
|  | case SpvDecorationRestrict: | 
|  | case SpvDecorationOffset: | 
|  | case SpvDecorationAlignment: | 
|  | case SpvDecorationAlignmentId: | 
|  | case SpvDecorationMaxByteOffset: | 
|  | case SpvDecorationRelaxedPrecision: | 
|  | break; | 
|  | default: | 
|  | return false; | 
|  | } | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { | 
|  | for (auto inst : | 
|  | get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { | 
|  | assert(inst->opcode() == SpvOpDecorate); | 
|  | uint32_t decoration = inst->GetSingleWordInOperand(1u); | 
|  | switch (decoration) { | 
|  | case SpvDecorationInvariant: | 
|  | case SpvDecorationRestrict: | 
|  | case SpvDecorationAlignment: | 
|  | case SpvDecorationAlignmentId: | 
|  | case SpvDecorationMaxByteOffset: | 
|  | break; | 
|  | default: | 
|  | return false; | 
|  | } | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { | 
|  | VariableStats stats = {0, 0}; | 
|  | bool ok = CheckUses(inst, &stats); | 
|  |  | 
|  | // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when | 
|  | // SRoA is costly, such as when the structure has many (unaccessed?) | 
|  | // members. | 
|  |  | 
|  | return ok; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckUses(const Instruction* inst, | 
|  | VariableStats* stats) const { | 
|  | uint64_t max_legal_index = GetMaxLegalIndex(inst); | 
|  |  | 
|  | bool ok = true; | 
|  | get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok]( | 
|  | const Instruction* user, | 
|  | uint32_t index) { | 
|  | // Annotations are check as a group separately. | 
|  | if (!IsAnnotationInst(user->opcode())) { | 
|  | switch (user->opcode()) { | 
|  | case SpvOpAccessChain: | 
|  | case SpvOpInBoundsAccessChain: | 
|  | if (index == 2u && user->NumInOperands() > 1) { | 
|  | uint32_t id = user->GetSingleWordInOperand(1u); | 
|  | const Instruction* opInst = get_def_use_mgr()->GetDef(id); | 
|  | const auto* constant = | 
|  | context()->get_constant_mgr()->GetConstantFromInst(opInst); | 
|  | if (!constant) { | 
|  | ok = false; | 
|  | } else if (constant->GetZeroExtendedValue() >= max_legal_index) { | 
|  | ok = false; | 
|  | } else { | 
|  | if (!CheckUsesRelaxed(user)) ok = false; | 
|  | } | 
|  | stats->num_partial_accesses++; | 
|  | } else { | 
|  | ok = false; | 
|  | } | 
|  | break; | 
|  | case SpvOpLoad: | 
|  | if (!CheckLoad(user, index)) ok = false; | 
|  | stats->num_full_accesses++; | 
|  | break; | 
|  | case SpvOpStore: | 
|  | if (!CheckStore(user, index)) ok = false; | 
|  | stats->num_full_accesses++; | 
|  | break; | 
|  | case SpvOpName: | 
|  | case SpvOpMemberName: | 
|  | break; | 
|  | default: | 
|  | ok = false; | 
|  | break; | 
|  | } | 
|  | } | 
|  | }); | 
|  |  | 
|  | return ok; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { | 
|  | bool ok = true; | 
|  | get_def_use_mgr()->ForEachUse( | 
|  | inst, [this, &ok](const Instruction* user, uint32_t index) { | 
|  | switch (user->opcode()) { | 
|  | case SpvOpAccessChain: | 
|  | case SpvOpInBoundsAccessChain: | 
|  | if (index != 2u) { | 
|  | ok = false; | 
|  | } else { | 
|  | if (!CheckUsesRelaxed(user)) ok = false; | 
|  | } | 
|  | break; | 
|  | case SpvOpLoad: | 
|  | if (!CheckLoad(user, index)) ok = false; | 
|  | break; | 
|  | case SpvOpStore: | 
|  | if (!CheckStore(user, index)) ok = false; | 
|  | break; | 
|  | default: | 
|  | ok = false; | 
|  | break; | 
|  | } | 
|  | }); | 
|  |  | 
|  | return ok; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckLoad(const Instruction* inst, | 
|  | uint32_t index) const { | 
|  | if (index != 2u) return false; | 
|  | if (inst->NumInOperands() >= 2 && | 
|  | inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask) | 
|  | return false; | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool ScalarReplacementPass::CheckStore(const Instruction* inst, | 
|  | uint32_t index) const { | 
|  | if (index != 0u) return false; | 
|  | if (inst->NumInOperands() >= 3 && | 
|  | inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask) | 
|  | return false; | 
|  | return true; | 
|  | } | 
|  | bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const { | 
|  | if (max_num_elements_ == 0) { | 
|  | return false; | 
|  | } | 
|  | return length > max_num_elements_; | 
|  | } | 
|  |  | 
|  | std::unique_ptr<std::unordered_set<int64_t>> | 
|  | ScalarReplacementPass::GetUsedComponents(Instruction* inst) { | 
|  | std::unique_ptr<std::unordered_set<int64_t>> result( | 
|  | new std::unordered_set<int64_t>()); | 
|  |  | 
|  | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); | 
|  |  | 
|  | def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, | 
|  | this](Instruction* use) { | 
|  | switch (use->opcode()) { | 
|  | case SpvOpLoad: { | 
|  | // Look for extract from the load. | 
|  | std::vector<uint32_t> t; | 
|  | if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { | 
|  | if (use2->opcode() != SpvOpCompositeExtract || | 
|  | use2->NumInOperands() <= 1) { | 
|  | return false; | 
|  | } | 
|  | t.push_back(use2->GetSingleWordInOperand(1)); | 
|  | return true; | 
|  | })) { | 
|  | result->insert(t.begin(), t.end()); | 
|  | return true; | 
|  | } else { | 
|  | result.reset(nullptr); | 
|  | return false; | 
|  | } | 
|  | } | 
|  | case SpvOpName: | 
|  | case SpvOpMemberName: | 
|  | case SpvOpStore: | 
|  | // No components are used. | 
|  | return true; | 
|  | case SpvOpAccessChain: | 
|  | case SpvOpInBoundsAccessChain: { | 
|  | // Add the first index it if is a constant. | 
|  | // TODO: Could be improved by checking if the address is used in a load. | 
|  | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); | 
|  | uint32_t index_id = use->GetSingleWordInOperand(1); | 
|  | const analysis::Constant* index_const = | 
|  | const_mgr->FindDeclaredConstant(index_id); | 
|  | if (index_const) { | 
|  | result->insert(index_const->GetSignExtendedValue()); | 
|  | return true; | 
|  | } else { | 
|  | // Could be any element.  Assuming all are used. | 
|  | result.reset(nullptr); | 
|  | return false; | 
|  | } | 
|  | } | 
|  | default: | 
|  | // We do not know what is happening.  Have to assume the worst. | 
|  | result.reset(nullptr); | 
|  | return false; | 
|  | } | 
|  | }); | 
|  |  | 
|  | return result; | 
|  | } | 
|  |  | 
|  | Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { | 
|  | analysis::TypeManager* type_mgr = context()->get_type_mgr(); | 
|  | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); | 
|  |  | 
|  | const analysis::Type* type = type_mgr->GetType(type_id); | 
|  | const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); | 
|  | Instruction* null_inst = | 
|  | const_mgr->GetDefiningInstruction(null_const, type_id); | 
|  | if (null_inst != nullptr) { | 
|  | context()->UpdateDefUse(null_inst); | 
|  | } | 
|  | return null_inst; | 
|  | } | 
|  |  | 
|  | uint64_t ScalarReplacementPass::GetMaxLegalIndex( | 
|  | const Instruction* var_inst) const { | 
|  | assert(var_inst->opcode() == SpvOpVariable && | 
|  | "|var_inst| must be a variable instruction."); | 
|  | Instruction* type = GetStorageType(var_inst); | 
|  | switch (type->opcode()) { | 
|  | case SpvOpTypeStruct: | 
|  | return type->NumInOperands(); | 
|  | case SpvOpTypeArray: | 
|  | return GetArrayLength(type); | 
|  | case SpvOpTypeMatrix: | 
|  | case SpvOpTypeVector: | 
|  | return GetNumElements(type); | 
|  | default: | 
|  | return 0; | 
|  | } | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | }  // namespace opt | 
|  | }  // namespace spvtools |