| // Copyright (c) 2017 Google Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/opt/constants.h" |
| |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "source/opt/ir_context.h" |
| |
| namespace spvtools { |
| namespace opt { |
| namespace analysis { |
| |
| float Constant::GetFloat() const { |
| assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32); |
| |
| if (const FloatConstant* fc = AsFloatConstant()) { |
| return fc->GetFloatValue(); |
| } else { |
| assert(AsNullConstant() && "Must be a floating point constant."); |
| return 0.0f; |
| } |
| } |
| |
| double Constant::GetDouble() const { |
| assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64); |
| |
| if (const FloatConstant* fc = AsFloatConstant()) { |
| return fc->GetDoubleValue(); |
| } else { |
| assert(AsNullConstant() && "Must be a floating point constant."); |
| return 0.0; |
| } |
| } |
| |
| double Constant::GetValueAsDouble() const { |
| assert(type()->AsFloat() != nullptr); |
| if (type()->AsFloat()->width() == 32) { |
| return GetFloat(); |
| } else { |
| assert(type()->AsFloat()->width() == 64); |
| return GetDouble(); |
| } |
| } |
| |
| uint32_t Constant::GetU32() const { |
| assert(type()->AsInteger() != nullptr); |
| assert(type()->AsInteger()->width() == 32); |
| |
| if (const IntConstant* ic = AsIntConstant()) { |
| return ic->GetU32BitValue(); |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| return 0u; |
| } |
| } |
| |
| uint64_t Constant::GetU64() const { |
| assert(type()->AsInteger() != nullptr); |
| assert(type()->AsInteger()->width() == 64); |
| |
| if (const IntConstant* ic = AsIntConstant()) { |
| return ic->GetU64BitValue(); |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| return 0u; |
| } |
| } |
| |
| int32_t Constant::GetS32() const { |
| assert(type()->AsInteger() != nullptr); |
| assert(type()->AsInteger()->width() == 32); |
| |
| if (const IntConstant* ic = AsIntConstant()) { |
| return ic->GetS32BitValue(); |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| return 0; |
| } |
| } |
| |
| int64_t Constant::GetS64() const { |
| assert(type()->AsInteger() != nullptr); |
| assert(type()->AsInteger()->width() == 64); |
| |
| if (const IntConstant* ic = AsIntConstant()) { |
| return ic->GetS64BitValue(); |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| return 0; |
| } |
| } |
| |
| uint64_t Constant::GetZeroExtendedValue() const { |
| const auto* int_type = type()->AsInteger(); |
| assert(int_type != nullptr); |
| const auto width = int_type->width(); |
| assert(width <= 64); |
| |
| uint64_t value = 0; |
| if (const IntConstant* ic = AsIntConstant()) { |
| if (width <= 32) { |
| value = ic->GetU32BitValue(); |
| } else { |
| value = ic->GetU64BitValue(); |
| } |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| } |
| return value; |
| } |
| |
| int64_t Constant::GetSignExtendedValue() const { |
| const auto* int_type = type()->AsInteger(); |
| assert(int_type != nullptr); |
| const auto width = int_type->width(); |
| assert(width <= 64); |
| |
| int64_t value = 0; |
| if (const IntConstant* ic = AsIntConstant()) { |
| if (width <= 32) { |
| // Let the C++ compiler do the sign extension. |
| value = int64_t(ic->GetS32BitValue()); |
| } else { |
| value = ic->GetS64BitValue(); |
| } |
| } else { |
| assert(AsNullConstant() && "Must be an integer constant."); |
| } |
| return value; |
| } |
| |
| ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) { |
| // Populate the constant table with values from constant declarations in the |
| // module. The values of each OpConstant declaration is the identity |
| // assignment (i.e., each constant is its own value). |
| for (const auto& inst : ctx_->module()->GetConstants()) { |
| MapInst(inst); |
| } |
| } |
| |
| Type* ConstantManager::GetType(const Instruction* inst) const { |
| return context()->get_type_mgr()->GetType(inst->type_id()); |
| } |
| |
| std::vector<const Constant*> ConstantManager::GetOperandConstants( |
| Instruction* inst) const { |
| std::vector<const Constant*> constants; |
| for (uint32_t i = 0; i < inst->NumInOperands(); i++) { |
| const Operand* operand = &inst->GetInOperand(i); |
| if (operand->type != SPV_OPERAND_TYPE_ID) { |
| constants.push_back(nullptr); |
| } else { |
| uint32_t id = operand->words[0]; |
| const analysis::Constant* constant = FindDeclaredConstant(id); |
| constants.push_back(constant); |
| } |
| } |
| return constants; |
| } |
| |
| uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, |
| uint32_t type_id) const { |
| c = FindConstant(c); |
| if (c == nullptr) { |
| return 0; |
| } |
| |
| for (auto range = const_val_to_id_.equal_range(c); |
| range.first != range.second; ++range.first) { |
| Instruction* const_def = |
| context()->get_def_use_mgr()->GetDef(range.first->second); |
| if (type_id == 0 || const_def->type_id() == type_id) { |
| return range.first->second; |
| } |
| } |
| return 0; |
| } |
| |
| std::vector<const Constant*> ConstantManager::GetConstantsFromIds( |
| const std::vector<uint32_t>& ids) const { |
| std::vector<const Constant*> constants; |
| for (uint32_t id : ids) { |
| if (const Constant* c = FindDeclaredConstant(id)) { |
| constants.push_back(c); |
| } else { |
| return {}; |
| } |
| } |
| return constants; |
| } |
| |
| Instruction* ConstantManager::BuildInstructionAndAddToModule( |
| const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) { |
| // TODO(1841): Handle id overflow. |
| uint32_t new_id = context()->TakeNextId(); |
| if (new_id == 0) { |
| return nullptr; |
| } |
| |
| auto new_inst = CreateInstruction(new_id, new_const, type_id); |
| if (!new_inst) { |
| return nullptr; |
| } |
| auto* new_inst_ptr = new_inst.get(); |
| *pos = pos->InsertBefore(std::move(new_inst)); |
| ++(*pos); |
| context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr); |
| MapConstantToInst(new_const, new_inst_ptr); |
| return new_inst_ptr; |
| } |
| |
| Instruction* ConstantManager::GetDefiningInstruction( |
| const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { |
| uint32_t decl_id = FindDeclaredConstant(c, type_id); |
| if (decl_id == 0) { |
| auto iter = context()->types_values_end(); |
| if (pos == nullptr) pos = &iter; |
| return BuildInstructionAndAddToModule(c, pos, type_id); |
| } else { |
| auto def = context()->get_def_use_mgr()->GetDef(decl_id); |
| assert(def != nullptr); |
| assert((type_id == 0 || def->type_id() == type_id) && |
| "This constant already has an instruction with a different type."); |
| return def; |
| } |
| } |
| |
| std::unique_ptr<Constant> ConstantManager::CreateConstant( |
| const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const { |
| if (literal_words_or_ids.size() == 0) { |
| // Constant declared with OpConstantNull |
| return MakeUnique<NullConstant>(type); |
| } else if (auto* bt = type->AsBool()) { |
| assert(literal_words_or_ids.size() == 1 && |
| "Bool constant should be declared with one operand"); |
| return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front()); |
| } else if (auto* it = type->AsInteger()) { |
| return MakeUnique<IntConstant>(it, literal_words_or_ids); |
| } else if (auto* ft = type->AsFloat()) { |
| return MakeUnique<FloatConstant>(ft, literal_words_or_ids); |
| } else if (auto* vt = type->AsVector()) { |
| auto components = GetConstantsFromIds(literal_words_or_ids); |
| if (components.empty()) return nullptr; |
| // All components of VectorConstant must be of type Bool, Integer or Float. |
| if (!std::all_of(components.begin(), components.end(), |
| [](const Constant* c) { |
| if (c->type()->AsBool() || c->type()->AsInteger() || |
| c->type()->AsFloat()) { |
| return true; |
| } else { |
| return false; |
| } |
| })) |
| return nullptr; |
| // All components of VectorConstant must be in the same type. |
| const auto* component_type = components.front()->type(); |
| if (!std::all_of(components.begin(), components.end(), |
| [&component_type](const Constant* c) { |
| if (c->type() == component_type) return true; |
| return false; |
| })) |
| return nullptr; |
| return MakeUnique<VectorConstant>(vt, components); |
| } else if (auto* mt = type->AsMatrix()) { |
| auto components = GetConstantsFromIds(literal_words_or_ids); |
| if (components.empty()) return nullptr; |
| return MakeUnique<MatrixConstant>(mt, components); |
| } else if (auto* st = type->AsStruct()) { |
| auto components = GetConstantsFromIds(literal_words_or_ids); |
| if (components.empty()) return nullptr; |
| return MakeUnique<StructConstant>(st, components); |
| } else if (auto* at = type->AsArray()) { |
| auto components = GetConstantsFromIds(literal_words_or_ids); |
| if (components.empty()) return nullptr; |
| return MakeUnique<ArrayConstant>(at, components); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) { |
| std::vector<uint32_t> literal_words_or_ids; |
| |
| // Collect the constant defining literals or component ids. |
| for (uint32_t i = 0; i < inst->NumInOperands(); i++) { |
| literal_words_or_ids.insert(literal_words_or_ids.end(), |
| inst->GetInOperand(i).words.begin(), |
| inst->GetInOperand(i).words.end()); |
| } |
| |
| switch (inst->opcode()) { |
| // OpConstant{True|False} have the value embedded in the opcode. So they |
| // are not handled by the for-loop above. Here we add the value explicitly. |
| case SpvOp::SpvOpConstantTrue: |
| literal_words_or_ids.push_back(true); |
| break; |
| case SpvOp::SpvOpConstantFalse: |
| literal_words_or_ids.push_back(false); |
| break; |
| case SpvOp::SpvOpConstantNull: |
| case SpvOp::SpvOpConstant: |
| case SpvOp::SpvOpConstantComposite: |
| case SpvOp::SpvOpSpecConstantComposite: |
| break; |
| default: |
| return nullptr; |
| } |
| |
| return GetConstant(GetType(inst), literal_words_or_ids); |
| } |
| |
| std::unique_ptr<Instruction> ConstantManager::CreateInstruction( |
| uint32_t id, const Constant* c, uint32_t type_id) const { |
| uint32_t type = |
| (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id; |
| if (c->AsNullConstant()) { |
| return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type, |
| id, std::initializer_list<Operand>{}); |
| } else if (const BoolConstant* bc = c->AsBoolConstant()) { |
| return MakeUnique<Instruction>( |
| context(), |
| bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, |
| type, id, std::initializer_list<Operand>{}); |
| } else if (const IntConstant* ic = c->AsIntConstant()) { |
| return MakeUnique<Instruction>( |
| context(), SpvOp::SpvOpConstant, type, id, |
| std::initializer_list<Operand>{ |
| Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, |
| ic->words())}); |
| } else if (const FloatConstant* fc = c->AsFloatConstant()) { |
| return MakeUnique<Instruction>( |
| context(), SpvOp::SpvOpConstant, type, id, |
| std::initializer_list<Operand>{ |
| Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, |
| fc->words())}); |
| } else if (const CompositeConstant* cc = c->AsCompositeConstant()) { |
| return CreateCompositeInstruction(id, cc, type_id); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction( |
| uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { |
| std::vector<Operand> operands; |
| Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); |
| uint32_t component_index = 0; |
| for (const Constant* component_const : cc->GetComponents()) { |
| uint32_t component_type_id = 0; |
| if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { |
| component_type_id = type_inst->GetSingleWordInOperand(component_index); |
| } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { |
| component_type_id = type_inst->GetSingleWordInOperand(0); |
| } |
| uint32_t id = FindDeclaredConstant(component_const, component_type_id); |
| |
| if (id == 0) { |
| // Cannot get the id of the component constant, while all components |
| // should have been added to the module prior to the composite constant. |
| // Cannot create OpConstantComposite instruction in this case. |
| return nullptr; |
| } |
| operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, |
| std::initializer_list<uint32_t>{id}); |
| component_index++; |
| } |
| uint32_t type = |
| (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; |
| return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type, |
| result_id, std::move(operands)); |
| } |
| |
| const Constant* ConstantManager::GetConstant( |
| const Type* type, const std::vector<uint32_t>& literal_words_or_ids) { |
| auto cst = CreateConstant(type, literal_words_or_ids); |
| return cst ? RegisterConstant(std::move(cst)) : nullptr; |
| } |
| |
| std::vector<const analysis::Constant*> Constant::GetVectorComponents( |
| analysis::ConstantManager* const_mgr) const { |
| std::vector<const analysis::Constant*> components; |
| const analysis::VectorConstant* a = this->AsVectorConstant(); |
| const analysis::Vector* vector_type = this->type()->AsVector(); |
| assert(vector_type != nullptr); |
| if (a != nullptr) { |
| for (uint32_t i = 0; i < vector_type->element_count(); ++i) { |
| components.push_back(a->GetComponents()[i]); |
| } |
| } else { |
| const analysis::Type* element_type = vector_type->element_type(); |
| const analysis::Constant* element_null_const = |
| const_mgr->GetConstant(element_type, {}); |
| for (uint32_t i = 0; i < vector_type->element_count(); ++i) { |
| components.push_back(element_null_const); |
| } |
| } |
| return components; |
| } |
| |
| } // namespace analysis |
| } // namespace opt |
| } // namespace spvtools |