| // Copyright (c) 2018 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "source/comp/markv_encoder.h" |
| |
| #include "source/binary.h" |
| #include "source/opcode.h" |
| #include "spirv-tools/libspirv.hpp" |
| |
| namespace spvtools { |
| namespace comp { |
| namespace { |
| |
| const size_t kCommentNumWhitespaces = 2; |
| |
| } // namespace |
| |
| spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { |
| auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); |
| |
| if (codec) { |
| uint64_t bits = 0; |
| size_t num_bits = 0; |
| if (codec->Encode(word, &bits, &num_bits)) { |
| // Encoding successful. |
| writer_.WriteBits(bits, num_bits); |
| return SPV_SUCCESS; |
| } else { |
| // Encoding failed, write kMarkvNoneOfTheAbove flag. |
| if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, |
| &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Non-id word Huffman table for " |
| << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " |
| << operand_index_ << " is missing kMarkvNoneOfTheAbove"; |
| writer_.WriteBits(bits, num_bits); |
| } |
| } |
| |
| // Fallback encoding. |
| const size_t chunk_length = |
| model_->GetOperandVariableWidthChunkLength(operand_.type); |
| if (chunk_length) { |
| writer_.WriteVariableWidthU32(word, chunk_length); |
| } else { |
| writer_.WriteUnencoded(word); |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode, |
| uint32_t num_operands) { |
| uint64_t bits = 0; |
| size_t num_bits = 0; |
| |
| const uint32_t word = opcode | (num_operands << 16); |
| |
| // First try to use the Markov chain codec. |
| auto* codec = |
| model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); |
| if (codec) { |
| if (codec->Encode(word, &bits, &num_bits)) { |
| // The word was successfully encoded into bits/num_bits. |
| writer_.WriteBits(bits, num_bits); |
| return SPV_SUCCESS; |
| } else { |
| // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove |
| // and use fallback encoding. |
| if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, |
| &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "opcode_and_num_operands Huffman table for " |
| << spvOpcodeString(GetPrevOpcode()) |
| << "is missing kMarkvNoneOfTheAbove"; |
| writer_.WriteBits(bits, num_bits); |
| } |
| } |
| |
| // Fallback to base-rate codec. |
| codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); |
| assert(codec); |
| if (codec->Encode(word, &bits, &num_bits)) { |
| // The word was successfully encoded into bits/num_bits. |
| writer_.WriteBits(bits, num_bits); |
| return SPV_SUCCESS; |
| } else { |
| // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove |
| // and return false. |
| if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Global opcode_and_num_operands Huffman table is missing " |
| << "kMarkvNoneOfTheAbove"; |
| writer_.WriteBits(bits, num_bits); |
| return SPV_UNSUPPORTED; |
| } |
| } |
| |
| spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, |
| uint64_t fallback_method) { |
| const auto* codec = GetMtfHuffmanCodec(mtf); |
| if (!codec) { |
| assert(fallback_method != kMtfNone); |
| codec = GetMtfHuffmanCodec(fallback_method); |
| } |
| |
| if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; |
| |
| uint64_t bits = 0; |
| size_t num_bits = 0; |
| if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) { |
| // Encode using Huffman coding. |
| if (!codec->Encode(rank, &bits, &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Failed to encode MTF rank with Huffman"; |
| |
| writer_.WriteBits(bits, num_bits); |
| } else { |
| // Encode by value. |
| if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits, |
| &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Failed to encode kMtfRankEncodedByValueSignal"; |
| |
| writer_.WriteBits(bits, num_bits); |
| writer_.WriteVariableWidthU32( |
| rank - MarkvCodec::kMtfSmallestRankEncodedByValue, |
| model_->mtf_rank_chunk_length()); |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { |
| // Get the descriptor for id. |
| const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id); |
| auto* codec = |
| model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); |
| uint64_t bits = 0; |
| size_t num_bits = 0; |
| uint64_t mtf = kMtfNone; |
| if (long_descriptor && codec && |
| codec->Encode(long_descriptor, &bits, &num_bits)) { |
| // If the descriptor exists and is in the table, write the descriptor and |
| // proceed to encoding the rank. |
| writer_.WriteBits(bits, num_bits); |
| mtf = GetMtfLongIdDescriptor(long_descriptor); |
| } else { |
| if (codec) { |
| // The descriptor doesn't exist or we have no coding for it. Write |
| // kMarkvNoneOfTheAbove and go to fallback method. |
| if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, |
| &num_bits)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Descriptor Huffman table for " |
| << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " |
| << operand_index_ << " is missing kMarkvNoneOfTheAbove"; |
| |
| writer_.WriteBits(bits, num_bits); |
| } |
| |
| if (model_->id_fallback_strategy() != |
| MarkvModel::IdFallbackStrategy::kShortDescriptor) { |
| return SPV_UNSUPPORTED; |
| } |
| |
| const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id); |
| writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits); |
| |
| if (short_descriptor == 0) { |
| // Forward declared id. |
| return SPV_UNSUPPORTED; |
| } |
| |
| mtf = GetMtfShortIdDescriptor(short_descriptor); |
| } |
| |
| // Descriptor has been encoded. Now encode the rank of the id in the |
| // associated mtf sequence. |
| return EncodeExistingId(mtf, id); |
| } |
| |
| spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { |
| assert(multi_mtf_.GetSize(mtf) > 0); |
| if (multi_mtf_.GetSize(mtf) == 1) { |
| // If the sequence has only one element no need to write rank, the decoder |
| // would make the same decision. |
| return SPV_SUCCESS; |
| } |
| |
| uint32_t rank = 0; |
| if (!multi_mtf_.RankFromValue(mtf, id, &rank)) |
| return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; |
| |
| return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); |
| } |
| |
| spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { |
| { |
| // Try to encode using id descriptor mtfs. |
| const spv_result_t result = EncodeIdWithDescriptor(id); |
| if (result != SPV_UNSUPPORTED) return result; |
| // If can't be done continue with other methods. |
| } |
| |
| const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( |
| SpvOp(inst_.opcode))(operand_index_); |
| uint32_t rank = 0; |
| |
| if (model_->id_fallback_strategy() == |
| MarkvModel::IdFallbackStrategy::kRuleBased) { |
| // Encode using rule-based mtf. |
| uint64_t mtf = GetRuleBasedMtf(); |
| |
| if (mtf != kMtfNone && !can_forward_declare) { |
| assert(multi_mtf_.HasValue(kMtfAll, id)); |
| return EncodeExistingId(mtf, id); |
| } |
| |
| if (mtf == kMtfNone) mtf = kMtfAll; |
| |
| if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { |
| // This is the first occurrence of a forward declared id. |
| multi_mtf_.Insert(kMtfAll, id); |
| multi_mtf_.Insert(kMtfForwardDeclared, id); |
| if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); |
| rank = 0; |
| } |
| |
| return EncodeMtfRankHuffman(rank, mtf, kMtfAll); |
| } else { |
| assert(can_forward_declare); |
| |
| if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) { |
| // This is the first occurrence of a forward declared id. |
| multi_mtf_.Insert(kMtfForwardDeclared, id); |
| rank = 0; |
| } |
| |
| writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); |
| return SPV_SUCCESS; |
| } |
| } |
| |
| spv_result_t MarkvEncoder::EncodeTypeId() { |
| if (inst_.opcode == SpvOpFunctionParameter) { |
| assert(!remaining_function_parameter_types_.empty()); |
| assert(inst_.type_id == remaining_function_parameter_types_.front()); |
| remaining_function_parameter_types_.pop_front(); |
| return SPV_SUCCESS; |
| } |
| |
| { |
| // Try to encode using id descriptor mtfs. |
| const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); |
| if (result != SPV_UNSUPPORTED) return result; |
| // If can't be done continue with other methods. |
| } |
| |
| assert(model_->id_fallback_strategy() == |
| MarkvModel::IdFallbackStrategy::kRuleBased); |
| |
| uint64_t mtf = GetRuleBasedMtf(); |
| assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( |
| operand_index_)); |
| |
| if (mtf == kMtfNone) { |
| mtf = kMtfTypeNonFunction; |
| // Function types should have been handled by GetRuleBasedMtf. |
| assert(inst_.opcode != SpvOpFunction); |
| } |
| |
| return EncodeExistingId(mtf, inst_.type_id); |
| } |
| |
| spv_result_t MarkvEncoder::EncodeResultId() { |
| uint32_t rank = 0; |
| |
| const uint64_t num_still_forward_declared = |
| multi_mtf_.GetSize(kMtfForwardDeclared); |
| |
| if (num_still_forward_declared) { |
| // We write the rank only if kMtfForwardDeclared is not empty. If it is |
| // empty the decoder knows that there are no forward declared ids to expect. |
| if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { |
| // This is a definition of a forward declared id. We can remove the id |
| // from kMtfForwardDeclared. |
| if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) |
| return Diag(SPV_ERROR_INTERNAL) |
| << "Failed to remove id from kMtfForwardDeclared"; |
| writer_.WriteBits(1, 1); |
| writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); |
| } else { |
| rank = 0; |
| writer_.WriteBits(0, 1); |
| } |
| } |
| |
| if (model_->id_fallback_strategy() == |
| MarkvModel::IdFallbackStrategy::kRuleBased) { |
| if (!rank) { |
| multi_mtf_.Insert(kMtfAll, inst_.result_id); |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvEncoder::EncodeLiteralNumber( |
| const spv_parsed_operand_t& operand) { |
| if (operand.number_bit_width <= 32) { |
| const uint32_t word = inst_.words[operand.offset]; |
| return EncodeNonIdWord(word); |
| } else { |
| assert(operand.number_bit_width <= 64); |
| const uint64_t word = uint64_t(inst_.words[operand.offset]) | |
| (uint64_t(inst_.words[operand.offset + 1]) << 32); |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int64_t val = 0; |
| std::memcpy(&val, &word, 8); |
| writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), |
| model_->s64_block_exponent()); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| writer_.WriteUnencoded(word); |
| } else { |
| return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; |
| } |
| } |
| return SPV_SUCCESS; |
| } |
| |
| void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { |
| const size_t num_bits_to_next_byte = |
| GetNumBitsToNextByte(writer_.GetNumBits()); |
| if (num_bits_to_next_byte == 0 || |
| num_bits_to_next_byte > byte_break_if_less_than) |
| return; |
| |
| if (logger_) { |
| logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| logger_->AppendText("<byte break>"); |
| } |
| |
| writer_.WriteBits(0, num_bits_to_next_byte); |
| } |
| |
| spv_result_t MarkvEncoder::EncodeInstruction( |
| const spv_parsed_instruction_t& inst) { |
| SpvOp opcode = SpvOp(inst.opcode); |
| inst_ = inst; |
| |
| LogDisassemblyInstruction(); |
| |
| const spv_result_t opcode_encodig_result = |
| EncodeOpcodeAndNumOperands(opcode, inst.num_operands); |
| if (opcode_encodig_result < 0) return opcode_encodig_result; |
| |
| if (opcode_encodig_result != SPV_SUCCESS) { |
| // Fallback encoding for opcode and num_operands. |
| writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); |
| |
| if (!OpcodeHasFixedNumberOfOperands(opcode)) { |
| // If the opcode has a variable number of operands, encode the number of |
| // operands with the instruction. |
| |
| if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| |
| writer_.WriteVariableWidthU16(inst.num_operands, |
| model_->num_operands_chunk_length()); |
| } |
| } |
| |
| // Write operands. |
| const uint32_t num_operands = inst_.num_operands; |
| for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) { |
| operand_ = inst_.operands[operand_index_]; |
| |
| if (logger_) { |
| logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| logger_->AppendText("<"); |
| logger_->AppendText(spvOperandTypeStr(operand_.type)); |
| logger_->AppendText(">"); |
| } |
| |
| switch (operand_.type) { |
| case SPV_OPERAND_TYPE_RESULT_ID: |
| case SPV_OPERAND_TYPE_TYPE_ID: |
| case SPV_OPERAND_TYPE_ID: |
| case SPV_OPERAND_TYPE_OPTIONAL_ID: |
| case SPV_OPERAND_TYPE_SCOPE_ID: |
| case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { |
| const uint32_t id = inst_.words[operand_.offset]; |
| if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { |
| const spv_result_t result = EncodeTypeId(); |
| if (result != SPV_SUCCESS) return result; |
| } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { |
| const spv_result_t result = EncodeResultId(); |
| if (result != SPV_SUCCESS) return result; |
| } else { |
| const spv_result_t result = EncodeRefId(id); |
| if (result != SPV_SUCCESS) return result; |
| } |
| |
| PromoteIfNeeded(id); |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_LITERAL_INTEGER: { |
| const spv_result_t result = |
| EncodeNonIdWord(inst_.words[operand_.offset]); |
| if (result != SPV_SUCCESS) return result; |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { |
| const spv_result_t result = EncodeLiteralNumber(operand_); |
| if (result != SPV_SUCCESS) return result; |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_LITERAL_STRING: { |
| const char* src = |
| reinterpret_cast<const char*>(&inst_.words[operand_.offset]); |
| |
| auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); |
| if (codec) { |
| uint64_t bits = 0; |
| size_t num_bits = 0; |
| const std::string str = src; |
| if (codec->Encode(str, &bits, &num_bits)) { |
| writer_.WriteBits(bits, num_bits); |
| break; |
| } else { |
| bool result = |
| codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); |
| (void)result; |
| assert(result); |
| writer_.WriteBits(bits, num_bits); |
| } |
| } |
| |
| const size_t length = spv_strnlen_s(src, operand_.num_words * 4); |
| if (length == operand_.num_words * 4) |
| return Diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to find terminal character of literal string"; |
| for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); |
| break; |
| } |
| |
| default: { |
| for (int i = 0; i < operand_.num_words; ++i) { |
| const uint32_t word = inst_.words[operand_.offset + i]; |
| const spv_result_t result = EncodeNonIdWord(word); |
| if (result != SPV_SUCCESS) return result; |
| } |
| break; |
| } |
| } |
| } |
| |
| AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte); |
| |
| if (logger_) { |
| logger_->NewLine(); |
| logger_->NewLine(); |
| if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; |
| } |
| |
| ProcessCurInstruction(); |
| |
| return SPV_SUCCESS; |
| } |
| |
| } // namespace comp |
| } // namespace spvtools |