blob: ae3ce79f2b670a6c9a12e5901d4f74f3a864f808 [file] [log] [blame]
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result IDs which can be regenerated) and
// uses various bit reduction techniques to reduce the size of the binary.
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/opcode.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace comp {
namespace {
// Custom hash function used to produce short descriptors.
uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
// The hash function is a sum of hashes of each word seeded by word index.
// Knuth's multiplicative hash is used to hash the words.
const uint32_t kKnuthMulHash = 2654435761;
uint32_t val = 0;
for (uint32_t i = 0; i < words.size(); ++i) {
val += (words[i] + i + 123) * kKnuthMulHash;
}
return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1);
}
// Returns a set of mtf rank codecs based on a plausible hand-coded
// distribution.
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
GetMtfHuffmanCodecs() {
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{0, 5},
{1, 40},
{2, 10},
{3, 5},
{4, 5},
{5, 5},
{6, 3},
{7, 3},
{8, 3},
{9, 3},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfAll, std::move(codec));
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{1, 50},
{2, 20},
{3, 5},
{4, 5},
{5, 2},
{6, 1},
{7, 1},
{8, 1},
{9, 1},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
return codecs;
}
} // namespace
const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303;
const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10;
const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal =
std::numeric_limits<uint32_t>::max();
const uint32_t MarkvCodec::kShortDescriptorNumBits = 8;
const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8;
MarkvCodec::MarkvCodec(spv_const_context context,
spv_validator_options validator_options,
const MarkvModel* model)
: validator_options_(validator_options),
grammar_(context),
model_(model),
short_id_descriptors_(ShortHashU32Array),
mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
context_(context) {}
MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); }
MarkvCodec::MarkvHeader::MarkvHeader()
: magic_number(MarkvCodec::kMarkvMagicNumber),
markv_version(MarkvCodec::GetMarkvVersion()) {}
// Defines and returns current MARK-V version.
// static
uint32_t MarkvCodec::GetMarkvVersion() {
const uint32_t kVersionMajor = 1;
const uint32_t kVersionMinor = 4;
return kVersionMinor | (kVersionMajor << 16);
}
size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const {
return (8 - (bit_pos % 8)) % 8;
}
// Returns true if the opcode has a fixed number of operands. May return a
// false negative.
bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const {
switch (opcode) {
// TODO(atgoo@github.com) This is not a complete list.
case SpvOpNop:
case SpvOpName:
case SpvOpUndef:
case SpvOpSizeOf:
case SpvOpLine:
case SpvOpNoLine:
case SpvOpDecorationGroup:
case SpvOpExtension:
case SpvOpExtInstImport:
case SpvOpMemoryModel:
case SpvOpCapability:
case SpvOpTypeVoid:
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypeVector:
case SpvOpTypeMatrix:
case SpvOpTypeSampler:
case SpvOpTypeSampledImage:
case SpvOpTypeArray:
case SpvOpTypePointer:
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpLabel:
case SpvOpBranch:
case SpvOpFunction:
case SpvOpFunctionParameter:
case SpvOpFunctionEnd:
case SpvOpBitcast:
case SpvOpCopyObject:
case SpvOpTranspose:
case SpvOpSNegate:
case SpvOpFNegate:
case SpvOpIAdd:
case SpvOpFAdd:
case SpvOpISub:
case SpvOpFSub:
case SpvOpIMul:
case SpvOpFMul:
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSMod:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
return true;
default:
break;
}
return false;
}
void MarkvCodec::ProcessCurInstruction() {
instructions_.emplace_back(new val::Instruction(&inst_));
const SpvOp opcode = SpvOp(inst_.opcode);
if (inst_.result_id) {
id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
// Collect ids local to the current function.
if (cur_function_id_) {
ids_local_to_cur_function_.push_back(inst_.result_id);
}
// Starting new function.
if (opcode == SpvOpFunction) {
cur_function_id_ = inst_.result_id;
cur_function_return_type_ = inst_.type_id;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
inst_.result_id);
}
// Store function parameter types in a queue, so that we know which types
// to expect in the following OpFunctionParameter instructions.
const val::Instruction* def_inst = FindDef(inst_.words[4]);
assert(def_inst);
assert(def_inst->opcode() == SpvOpTypeFunction);
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
remaining_function_parameter_types_.push_back(def_inst->word(i));
}
}
}
// Remove local ids from MTFs if function end.
if (opcode == SpvOpFunctionEnd) {
cur_function_id_ = 0;
for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
ids_local_to_cur_function_.clear();
assert(remaining_function_parameter_types_.empty());
}
if (!inst_.result_id) return;
{
// Save the result ID to type ID mapping.
// In the grammar, type ID always appears before result ID.
// A regular value maps to its type. Some instructions (e.g. OpLabel)
// have no type Id, and will map to 0. The result Id for a
// type-generating instruction (e.g. OpTypeInt) maps to itself.
auto insertion_result = id_to_type_id_.emplace(
inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
? inst_.result_id
: inst_.type_id);
(void)insertion_result;
assert(insertion_result.second);
}
// Add result_id to MTFs.
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
switch (opcode) {
case SpvOpTypeFloat:
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypeVector:
case SpvOpTypePointer:
case SpvOpExtInstImport:
case SpvOpTypeSampledImage:
case SpvOpTypeImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
break;
default:
break;
}
if (spvOpcodeIsComposite(opcode)) {
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
}
if (opcode == SpvOpLabel) {
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
}
if (opcode == SpvOpTypeInt) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeFloat) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeBool) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeVector) {
const uint32_t component_type_id = inst_.words[2];
const uint32_t size = inst_.words[3];
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
}
if (inst_.opcode == SpvOpTypeFunction) {
const uint32_t return_type = inst_.words[2];
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
inst_.result_id);
}
if (inst_.type_id) {
const val::Instruction* type_inst = FindDef(inst_.type_id);
assert(type_inst);
multi_mtf_.Insert(kMtfObject, inst_.result_id);
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
}
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
switch (type_inst->opcode()) {
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypePointer:
case SpvOpTypeVector:
case SpvOpTypeImage:
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(
GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
inst_.result_id);
break;
default:
break;
}
if (type_inst->opcode() == SpvOpTypeVector) {
const uint32_t component_type = type_inst->word(2);
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
inst_.result_id);
}
if (type_inst->opcode() == SpvOpTypePointer) {
assert(type_inst->operands().size() > 2);
assert(type_inst->words().size() > type_inst->operands()[2].offset);
const uint32_t data_type =
type_inst->word(type_inst->operands()[2].offset);
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
}
}
if (spvOpcodeGeneratesType(opcode)) {
if (opcode != SpvOpTypeFunction) {
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
}
}
}
if (model_->AnyDescriptorHasCodingScheme()) {
const uint32_t long_descriptor =
long_id_descriptors_.ProcessInstruction(inst_);
if (model_->DescriptorHasCodingScheme(long_descriptor))
multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
inst_.result_id);
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
const uint32_t short_descriptor =
short_id_descriptors_.ProcessInstruction(inst_);
multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
inst_.result_id);
}
}
uint64_t MarkvCodec::GetRuleBasedMtf() {
// This function is only called for id operands (but not result ids).
assert(spvIsIdType(operand_.type) ||
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
// All operand slots which expect label id.
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranchConditional &&
(operand_index_ == 1 || operand_index_ == 2)) ||
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
operand_index_ % 2 == 1) ||
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
return kMtfLabel;
}
switch (opcode) {
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpFNegate: {
if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
return GetMtfIdOfType(inst_.type_id);
}
case SpvOpISub:
case SpvOpIAdd:
case SpvOpIMul:
case SpvOpSDiv:
case SpvOpUDiv:
case SpvOpSMod:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSNegate: {
if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
return kMtfIntScalarOrVector;
}
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
case SpvOpFOrdEqual:
case SpvOpFUnordEqual:
case SpvOpFOrdNotEqual:
case SpvOpFUnordNotEqual:
case SpvOpFOrdLessThan:
case SpvOpFUnordLessThan:
case SpvOpFOrdGreaterThan:
case SpvOpFUnordGreaterThan:
case SpvOpFOrdLessThanEqual:
case SpvOpFUnordLessThanEqual:
case SpvOpFOrdGreaterThanEqual:
case SpvOpFUnordGreaterThanEqual: {
if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
if (operand_index_ == 2) return kMtfFloatScalarOrVector;
if (operand_index_ == 3) {
const uint32_t first_operand_id = GetInstWords()[3];
const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
return GetMtfIdOfType(first_operand_type);
}
break;
}
case SpvOpVectorShuffle: {
if (operand_index_ == 0) {
assert(inst_.num_operands > 4);
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
}
assert(inst_.type_id);
if (operand_index_ == 2 || operand_index_ == 3)
return GetMtfVectorOfComponentType(
GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpVectorTimesScalar: {
if (operand_index_ == 0) {
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
assert(inst_.type_id);
if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
if (operand_index_ == 3)
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpDot: {
if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
assert(inst_.type_id);
if (operand_index_ == 2)
return GetMtfVectorOfComponentType(inst_.type_id);
if (operand_index_ == 3) {
const uint32_t vector_id = GetInstWords()[3];
const uint32_t vector_type = id_to_type_id_.at(vector_id);
return GetMtfIdOfType(vector_type);
}
break;
}
case SpvOpTypeVector: {
if (operand_index_ == 1) {
return kMtfTypeScalar;
}
break;
}
case SpvOpTypeMatrix: {
if (operand_index_ == 1) {
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
break;
}
case SpvOpTypePointer: {
if (operand_index_ == 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeStruct: {
if (operand_index_ >= 1) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeFunction: {
if (operand_index_ == 1) {
return kMtfTypeNonFunction;
}
if (operand_index_ >= 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpLoad: {
if (operand_index_ == 0) return kMtfTypeNonFunction;
if (operand_index_ == 2) {
assert(inst_.type_id);
return GetMtfPointerToType(inst_.type_id);
}
break;
}
case SpvOpStore: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 1) {
const uint32_t pointer_id = GetInstWords()[1];
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
const val::Instruction* pointer_inst = FindDef(pointer_type);
assert(pointer_inst);
assert(pointer_inst->opcode() == SpvOpTypePointer);
const uint32_t data_type =
pointer_inst->word(pointer_inst->operands()[2].offset);
return GetMtfIdOfType(data_type);
}
break;
}
case SpvOpVariable: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
break;
}
case SpvOpAccessChain: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 2) return kMtfTypePointerToComposite;
if (operand_index_ >= 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
break;
}
case SpvOpCompositeConstruct: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const uint32_t composite_type = GetInstWords()[1];
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
return kMtfFloatScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
return kMtfIntScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
return kMtfBoolScalarOrVector;
}
break;
}
case SpvOpCompositeExtract: {
if (operand_index_ == 2) return kMtfComposite;
break;
}
case SpvOpConstantComposite: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const val::Instruction* composite_type_inst = FindDef(inst_.type_id);
assert(composite_type_inst);
if (composite_type_inst->opcode() == SpvOpTypeVector) {
return GetMtfIdOfType(composite_type_inst->word(2));
}
}
break;
}
case SpvOpExtInst: {
if (operand_index_ == 2)
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
if (operand_index_ >= 4) {
const uint32_t return_type = GetInstWords()[1];
const uint32_t ext_inst_type = inst_.ext_inst_type;
const uint32_t ext_inst_index = GetInstWords()[4];
// TODO(atgoo@github.com) The list of extended instructions is
// incomplete. Only common instructions and low-hanging fruits listed.
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
switch (ext_inst_index) {
case GLSLstd450FAbs:
case GLSLstd450FClamp:
case GLSLstd450FMax:
case GLSLstd450FMin:
case GLSLstd450FMix:
case GLSLstd450Step:
case GLSLstd450SmoothStep:
case GLSLstd450Fma:
case GLSLstd450Pow:
case GLSLstd450Exp:
case GLSLstd450Exp2:
case GLSLstd450Log:
case GLSLstd450Log2:
case GLSLstd450Sqrt:
case GLSLstd450InverseSqrt:
case GLSLstd450Fract:
case GLSLstd450Floor:
case GLSLstd450Ceil:
case GLSLstd450Radians:
case GLSLstd450Degrees:
case GLSLstd450Sin:
case GLSLstd450Cos:
case GLSLstd450Tan:
case GLSLstd450Sinh:
case GLSLstd450Cosh:
case GLSLstd450Tanh:
case GLSLstd450Asin:
case GLSLstd450Acos:
case GLSLstd450Atan:
case GLSLstd450Atan2:
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450MatrixInverse:
case GLSLstd450Cross:
case GLSLstd450Normalize:
case GLSLstd450Reflect:
case GLSLstd450FaceForward:
return GetMtfIdOfType(return_type);
case GLSLstd450Length:
case GLSLstd450Distance:
case GLSLstd450Refract:
return kMtfFloatScalarOrVector;
default:
break;
}
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
switch (ext_inst_index) {
case OpenCLLIB::Fabs:
case OpenCLLIB::FClamp:
case OpenCLLIB::Fmax:
case OpenCLLIB::Fmin:
case OpenCLLIB::Step:
case OpenCLLIB::Smoothstep:
case OpenCLLIB::Fma:
case OpenCLLIB::Pow:
case OpenCLLIB::Exp:
case OpenCLLIB::Exp2:
case OpenCLLIB::Log:
case OpenCLLIB::Log2:
case OpenCLLIB::Sqrt:
case OpenCLLIB::Rsqrt:
case OpenCLLIB::Fract:
case OpenCLLIB::Floor:
case OpenCLLIB::Ceil:
case OpenCLLIB::Radians:
case OpenCLLIB::Degrees:
case OpenCLLIB::Sin:
case OpenCLLIB::Cos:
case OpenCLLIB::Tan:
case OpenCLLIB::Sinh:
case OpenCLLIB::Cosh:
case OpenCLLIB::Tanh:
case OpenCLLIB::Asin:
case OpenCLLIB::Acos:
case OpenCLLIB::Atan:
case OpenCLLIB::Atan2:
case OpenCLLIB::Asinh:
case OpenCLLIB::Acosh:
case OpenCLLIB::Atanh:
case OpenCLLIB::Cross:
case OpenCLLIB::Normalize:
return GetMtfIdOfType(return_type);
case OpenCLLIB::Length:
case OpenCLLIB::Distance:
return kMtfFloatScalarOrVector;
default:
break;
}
}
}
break;
}
case SpvOpFunction: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 3) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionTypeWithReturnType(return_type);
}
break;
}
case SpvOpFunctionCall: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 2) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionWithReturnType(return_type);
}
if (operand_index_ >= 3) {
const uint32_t function_id = GetInstWords()[3];
const val::Instruction* function_inst = FindDef(function_id);
if (!function_inst) return kMtfObject;
assert(function_inst->opcode() == SpvOpFunction);
const uint32_t function_type_id = function_inst->word(4);
const val::Instruction* function_type_inst = FindDef(function_type_id);
assert(function_type_inst);
assert(function_type_inst->opcode() == SpvOpTypeFunction);
const uint32_t argument_type = function_type_inst->word(operand_index_);
return GetMtfIdOfType(argument_type);
}
break;
}
case SpvOpReturnValue: {
if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
break;
}
case SpvOpBranchConditional: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
break;
}
case SpvOpSampledImage: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
break;
}
case SpvOpImageSampleImplicitLod: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
break;
}
default:
break;
}
return kMtfNone;
}
} // namespace comp
} // namespace spvtools