|  | // Copyright (c) 2019 Google LLC. | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "source/opt/generate_webgpu_initializers_pass.h" | 
|  | #include "source/opt/ir_context.h" | 
|  |  | 
|  | namespace spvtools { | 
|  | namespace opt { | 
|  |  | 
|  | using inst_iterator = InstructionList::iterator; | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | bool NeedsWebGPUInitializer(Instruction* inst) { | 
|  | if (inst->opcode() != SpvOpVariable) return false; | 
|  |  | 
|  | auto storage_class = inst->GetSingleWordOperand(2); | 
|  | if (storage_class != SpvStorageClassOutput && | 
|  | storage_class != SpvStorageClassPrivate && | 
|  | storage_class != SpvStorageClassFunction) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (inst->NumOperands() > 3) return false; | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | Pass::Status GenerateWebGPUInitializersPass::Process() { | 
|  | auto* module = context()->module(); | 
|  | bool changed = false; | 
|  |  | 
|  | // Handle global/module scoped variables | 
|  | for (auto iter = module->types_values_begin(); | 
|  | iter != module->types_values_end(); ++iter) { | 
|  | Instruction* inst = &(*iter); | 
|  |  | 
|  | if (inst->opcode() == SpvOpConstantNull) { | 
|  | null_constant_type_map_[inst->type_id()] = inst; | 
|  | seen_null_constants_.insert(inst); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | if (!NeedsWebGPUInitializer(inst)) continue; | 
|  |  | 
|  | changed = true; | 
|  |  | 
|  | auto* constant_inst = GetNullConstantForVariable(inst); | 
|  | if (!constant_inst) return Status::Failure; | 
|  |  | 
|  | if (seen_null_constants_.find(constant_inst) == | 
|  | seen_null_constants_.end()) { | 
|  | constant_inst->InsertBefore(inst); | 
|  | null_constant_type_map_[inst->type_id()] = inst; | 
|  | seen_null_constants_.insert(inst); | 
|  | } | 
|  | AddNullInitializerToVariable(constant_inst, inst); | 
|  | } | 
|  |  | 
|  | // Handle local/function scoped variables | 
|  | for (auto func = module->begin(); func != module->end(); ++func) { | 
|  | auto block = func->entry().get(); | 
|  | for (auto iter = block->begin(); | 
|  | iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) { | 
|  | Instruction* inst = &(*iter); | 
|  | if (!NeedsWebGPUInitializer(inst)) continue; | 
|  |  | 
|  | changed = true; | 
|  | auto* constant_inst = GetNullConstantForVariable(inst); | 
|  | if (!constant_inst) return Status::Failure; | 
|  |  | 
|  | AddNullInitializerToVariable(constant_inst, inst); | 
|  | } | 
|  | } | 
|  |  | 
|  | return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; | 
|  | } | 
|  |  | 
|  | Instruction* GenerateWebGPUInitializersPass::GetNullConstantForVariable( | 
|  | Instruction* variable_inst) { | 
|  | auto constant_mgr = context()->get_constant_mgr(); | 
|  | auto* def_use_mgr = get_def_use_mgr(); | 
|  |  | 
|  | auto* ptr_inst = def_use_mgr->GetDef(variable_inst->type_id()); | 
|  | auto type_id = ptr_inst->GetInOperand(1).words[0]; | 
|  | if (null_constant_type_map_.find(type_id) == null_constant_type_map_.end()) { | 
|  | auto* constant_type = context()->get_type_mgr()->GetType(type_id); | 
|  | auto* constant = constant_mgr->GetConstant(constant_type, {}); | 
|  | return constant_mgr->GetDefiningInstruction(constant, type_id); | 
|  | } else { | 
|  | return null_constant_type_map_[type_id]; | 
|  | } | 
|  | } | 
|  |  | 
|  | void GenerateWebGPUInitializersPass::AddNullInitializerToVariable( | 
|  | Instruction* constant_inst, Instruction* variable_inst) { | 
|  | auto constant_id = constant_inst->result_id(); | 
|  | variable_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {constant_id})); | 
|  | get_def_use_mgr()->AnalyzeInstUse(variable_inst); | 
|  | } | 
|  |  | 
|  | }  // namespace opt | 
|  | }  // namespace spvtools |