Squashed 'third_party/SPIRV-Tools/' changes from 7014be600c..82d91083cb
82d91083cb spirv-val: Add PerVertexKHR (#4807)
088cb1a5c8 Add more folding for composite instructions (#4802)
c267127846 Add SPV_KHR_fragment_shader_barycentric support (#4805)
98340ec500 Add warning about spurious 'git cl upload' messages (#4800)
f74b85853c Handle 64-bit integers in local access chain convert (#4798)
f7a6e3b9d5 Handle chains of OpAccessChain in replacing variable index access for flattened resources. (#4797)
ad3514b732 spirv-opt: add pass for interface variable scalar replacement (#4779)
ffc8f2d455 Remove deprecated flags from spirv-opt help message (#4788)
c11ea09652 spirv-opt : Add FixFuncCallArgumentsPass (#4775)
9e377b0f97 spirv-val: Add CullMaskKHR support (#4792)
git-subtree-dir: third_party/SPIRV-Tools
git-subtree-split: 82d91083cb56c89d2cb8e9d56d4d69f07ac34fed
Change-Id: Ib474fd28c17584cd9fba94b8f8c863c3faa2ed0a
diff --git a/Android.mk b/Android.mk
index b9fbcc8..6dd1834 100644
--- a/Android.mk
+++ b/Android.mk
@@ -109,6 +109,7 @@
source/opt/eliminate_dead_input_components_pass.cpp \
source/opt/eliminate_dead_members_pass.cpp \
source/opt/feature_manager.cpp \
+ source/opt/fix_func_call_arguments.cpp \
source/opt/fix_storage_class.cpp \
source/opt/flatten_decoration_pass.cpp \
source/opt/fold.cpp \
@@ -127,6 +128,7 @@
source/opt/instruction.cpp \
source/opt/instruction_list.cpp \
source/opt/instrument_pass.cpp \
+ source/opt/interface_var_sroa.cpp \
source/opt/interp_fixup_pass.cpp \
source/opt/ir_context.cpp \
source/opt/ir_loader.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index ba05497..9f96c24 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -629,6 +629,8 @@
"source/opt/empty_pass.h",
"source/opt/feature_manager.cpp",
"source/opt/feature_manager.h",
+ "source/opt/fix_func_call_arguments.cpp",
+ "source/opt/fix_func_call_arguments.h",
"source/opt/fix_storage_class.cpp",
"source/opt/fix_storage_class.h",
"source/opt/flatten_decoration_pass.cpp",
@@ -665,6 +667,8 @@
"source/opt/instruction_list.h",
"source/opt/instrument_pass.cpp",
"source/opt/instrument_pass.h",
+ "source/opt/interface_var_sroa.cpp",
+ "source/opt/interface_var_sroa.h",
"source/opt/interp_fixup_pass.cpp",
"source/opt/interp_fixup_pass.h",
"source/opt/ir_builder.h",
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index fbbd9bc..9497356 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -903,10 +903,20 @@
const std::vector<opt::DescriptorSetAndBinding>&
descriptor_set_binding_pairs);
+// Create an interface-variable-scalar-replacement pass that replaces array or
+// matrix interface variables with a series of scalar or vector interface
+// variables. For example, it replaces `float3 foo[2]` with `float3 foo0, foo1`.
+Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass();
+
// Creates a remove-dont-inline pass to remove the |DontInline| function control
// from every function in the module. This is useful if you want the inliner to
// inline these functions some reason.
Optimizer::PassToken CreateRemoveDontInlinePass();
+// Create a fix-func-call-param pass to fix non memory argument for the function
+// call, as spirv-validation requires function parameters to be an memory
+// object, currently the pass would remove accesschain pointer argument passed
+// to the function
+Optimizer::PassToken CreateFixFuncCallArgumentsPass();
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 61e7a98..75fe4c0 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set(SPIRV_TOOLS_OPT_SOURCES
+ fix_func_call_arguments.h
aggressive_dead_code_elim_pass.h
amd_ext_to_khr.h
basic_block.h
@@ -67,6 +68,7 @@
instruction.h
instruction_list.h
instrument_pass.h
+ interface_var_sroa.h
interp_fixup_pass.h
ir_builder.h
ir_context.h
@@ -126,6 +128,7 @@
workaround1209.h
wrap_opkill.h
+ fix_func_call_arguments.cpp
aggressive_dead_code_elim_pass.cpp
amd_ext_to_khr.cpp
basic_block.cpp
@@ -180,6 +183,7 @@
instruction.cpp
instruction_list.cpp
instrument_pass.cpp
+ interface_var_sroa.cpp
interp_fixup_pass.cpp
ir_context.cpp
ir_loader.cpp
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 0473752..2486242 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -968,6 +968,7 @@
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
+ "SPV_KHR_fragment_shader_barycentric",
});
}
diff --git a/source/opt/fix_func_call_arguments.cpp b/source/opt/fix_func_call_arguments.cpp
new file mode 100644
index 0000000..d140fb4
--- /dev/null
+++ b/source/opt/fix_func_call_arguments.cpp
@@ -0,0 +1,90 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fix_func_call_arguments.h"
+
+#include "ir_builder.h"
+
+using namespace spvtools;
+using namespace opt;
+
+bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() {
+ auto funcsNum = get_module()->end() - get_module()->begin();
+ return funcsNum == 1;
+}
+
+Pass::Status FixFuncCallArgumentsPass::Process() {
+ bool modified = false;
+ if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange;
+ for (auto& func : *get_module()) {
+ func.ForEachInst([this, &modified](Instruction* inst) {
+ if (inst->opcode() == SpvOpFunctionCall) {
+ modified |= FixFuncCallArguments(inst);
+ }
+ });
+ }
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+bool FixFuncCallArgumentsPass::FixFuncCallArguments(
+ Instruction* func_call_inst) {
+ bool modified = false;
+ for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) {
+ Operand& op = func_call_inst->GetInOperand(i);
+ if (op.type != SPV_OPERAND_TYPE_ID) continue;
+ Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId());
+ if (operand_inst->opcode() == SpvOpAccessChain) {
+ uint32_t var_id =
+ ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst);
+ func_call_inst->SetInOperand(i, {var_id});
+ modified = true;
+ }
+ }
+ if (modified) {
+ context()->UpdateDefUse(func_call_inst);
+ }
+ return modified;
+}
+
+uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments(
+ Instruction* func_call_inst, Instruction* operand_inst) {
+ InstructionBuilder builder(
+ context(), func_call_inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ Instruction* next_insert_point = func_call_inst->NextNode();
+ // Get Variable insertion point
+ Function* func = context()->get_instr_block(func_call_inst)->GetParent();
+ Instruction* variable_insertion_point = &*(func->begin()->begin());
+ Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id());
+ Instruction* op_type =
+ get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1));
+ uint32_t varType = context()->get_type_mgr()->FindPointerToType(
+ op_type->result_id(), SpvStorageClassFunction);
+ // Create new variable
+ builder.SetInsertPoint(variable_insertion_point);
+ Instruction* var = builder.AddVariable(varType, SpvStorageClassFunction);
+ // Load access chain to the new variable before function call
+ builder.SetInsertPoint(func_call_inst);
+
+ uint32_t operand_id = operand_inst->result_id();
+ Instruction* load = builder.AddLoad(op_type->result_id(), operand_id);
+ builder.AddStore(var->result_id(), load->result_id());
+ // Load return value to the acesschain after function call
+ builder.SetInsertPoint(next_insert_point);
+ load = builder.AddLoad(op_type->result_id(), var->result_id());
+ builder.AddStore(operand_id, load->result_id());
+
+ return var->result_id();
+}
diff --git a/source/opt/fix_func_call_arguments.h b/source/opt/fix_func_call_arguments.h
new file mode 100644
index 0000000..15781b8
--- /dev/null
+++ b/source/opt/fix_func_call_arguments.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _VAR_FUNC_CALL_PASS_H
+#define _VAR_FUNC_CALL_PASS_H
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+class FixFuncCallArgumentsPass : public Pass {
+ public:
+ FixFuncCallArgumentsPass() {}
+ const char* name() const override { return "fix-for-funcall-param"; }
+ Status Process() override;
+ // Returns true if the module has one one function.
+ bool ModuleHasASingleFunction();
+ // Copies from the memory pointed to by |operand_inst| to a new function scope
+ // variable created before |func_call_inst|, and
+ // copies the value of the new variable back to the memory pointed to by
+ // |operand_inst| after |funct_call_inst| Returns the id of
+ // the new variable.
+ uint32_t ReplaceAccessChainFuncCallArguments(Instruction* func_call_inst,
+ Instruction* operand_inst);
+
+ // Fix function call |func_call_inst| non memory object arguments
+ bool FixFuncCallArguments(Instruction* func_call_inst);
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisTypes;
+ }
+};
+} // namespace opt
+} // namespace spvtools
+
+#endif // _VAR_FUNC_CALL_PASS_H
\ No newline at end of file
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index d15ad04..ab7a20e 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1631,6 +1631,57 @@
return true;
}
+// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
+// OpCompositeExtract instruction, and returns the type of the final element
+// being accessed.
+const analysis::Type* GetElementType(uint32_t type_id,
+ Instruction::iterator start,
+ Instruction::iterator end,
+ const analysis::TypeManager* type_mgr) {
+ const analysis::Type* type = type_mgr->GetType(type_id);
+ for (auto index : make_range(std::move(start), std::move(end))) {
+ assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
+ index.words.size() == 1);
+ if (auto* array_type = type->AsArray()) {
+ type = array_type->element_type();
+ } else if (auto* matrix_type = type->AsMatrix()) {
+ type = matrix_type->element_type();
+ } else if (auto* struct_type = type->AsStruct()) {
+ type = struct_type->element_types()[index.words[0]];
+ } else {
+ type = nullptr;
+ }
+ }
+ return type;
+}
+
+// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
+// to index into a composite object, excluding the last index. The two
+// instructions must have the same opcode, and be either OpCompositeExtract or
+// OpCompositeInsert instructions.
+bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
+ assert(inst_1->opcode() == inst_2->opcode() &&
+ "Expecting the opcodes to be the same.");
+ assert((inst_1->opcode() == SpvOpCompositeInsert ||
+ inst_1->opcode() == SpvOpCompositeExtract) &&
+ "Instructions must be OpCompositeInsert or OpCompositeExtract.");
+
+ if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
+ return false;
+ }
+
+ uint32_t first_index_position =
+ (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
+ for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
+ i++) {
+ if (inst_1->GetSingleWordInOperand(i) !=
+ inst_2->GetSingleWordInOperand(i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
// If the OpCompositeConstruct is simply putting back together elements that
// where extracted from the same source, we can simply reuse the source.
//
@@ -1653,19 +1704,24 @@
// - extractions
// - extracting the same position they are inserting
// - all extract from the same id.
+ Instruction* first_element_inst = nullptr;
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
const uint32_t element_id = inst->GetSingleWordInOperand(i);
Instruction* element_inst = def_use_mgr->GetDef(element_id);
+ if (first_element_inst == nullptr) {
+ first_element_inst = element_inst;
+ }
if (element_inst->opcode() != SpvOpCompositeExtract) {
return false;
}
- if (element_inst->NumInOperands() != 2) {
+ if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
return false;
}
- if (element_inst->GetSingleWordInOperand(1) != i) {
+ if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
+ 1) != i) {
return false;
}
@@ -1681,13 +1737,31 @@
// The last check it to see that the object being extracted from is the
// correct type.
Instruction* original_inst = def_use_mgr->GetDef(original_id);
- if (original_inst->type_id() != inst->type_id()) {
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* original_type =
+ GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
+ first_element_inst->end() - 1, type_mgr);
+
+ if (original_type == nullptr) {
return false;
}
- // Simplify by using the original object.
- inst->SetOpcode(SpvOpCopyObject);
- inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ if (inst->type_id() != type_mgr->GetId(original_type)) {
+ return false;
+ }
+
+ if (first_element_inst->NumInOperands() == 2) {
+ // Simplify by using the original object.
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ return true;
+ }
+
+ // Copies the original id and all indexes except for the last to the new
+ // extract instruction.
+ inst->SetOpcode(SpvOpCompositeExtract);
+ inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
+ first_element_inst->end() - 1));
return true;
}
@@ -1891,6 +1965,139 @@
};
}
+// Returns the number of elements in the composite type |type|. Returns 0 if
+// |type| is a scalar value.
+uint32_t GetNumberOfElements(const analysis::Type* type) {
+ if (auto* vector_type = type->AsVector()) {
+ return vector_type->element_count();
+ }
+ if (auto* matrix_type = type->AsMatrix()) {
+ return matrix_type->element_count();
+ }
+ if (auto* struct_type = type->AsStruct()) {
+ return static_cast<uint32_t>(struct_type->element_types().size());
+ }
+ if (auto* array_type = type->AsArray()) {
+ return array_type->length_info().words[0];
+ }
+ return 0;
+}
+
+// Returns a map with the set of values that were inserted into an object by
+// the chain of OpCompositeInsertInstruction starting with |inst|.
+// The map will map the index to the value inserted at that index.
+std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
+ analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+ std::map<uint32_t, uint32_t> values_inserted;
+ Instruction* current_inst = inst;
+ while (current_inst->opcode() == SpvOpCompositeInsert) {
+ if (current_inst->NumInOperands() > inst->NumInOperands()) {
+ // This is the catch the case
+ // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
+ // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
+ // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
+ // In this case we cannot do a single construct to get the matrix.
+ uint32_t partially_inserted_element_index =
+ current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+ if (values_inserted.count(partially_inserted_element_index) == 0)
+ return {};
+ }
+ if (HaveSameIndexesExceptForLast(inst, current_inst)) {
+ values_inserted.insert(
+ {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
+ 1),
+ current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
+ }
+ current_inst = def_use_mgr->GetDef(
+ current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
+ }
+ return values_inserted;
+}
+
+// Returns true of there is an entry in |values_inserted| for every element of
+// |Type|.
+bool DoInsertedValuesCoverEntireObject(
+ const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
+ uint32_t container_size = GetNumberOfElements(type);
+ if (container_size != values_inserted.size()) {
+ return false;
+ }
+
+ if (values_inserted.rbegin()->first >= container_size) {
+ return false;
+ }
+ return true;
+}
+
+// Returns the type of the element that immediately contains the element being
+// inserted by the OpCompositeInsert instruction |inst|.
+const analysis::Type* GetContainerType(Instruction* inst) {
+ assert(inst->opcode() == SpvOpCompositeInsert);
+ analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+ return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
+ type_mgr);
+}
+
+// Returns an OpCompositeConstruct instruction that build an object with
+// |type_id| out of the values in |values_inserted|. Each value will be
+// placed at the index corresponding to the value. The new instruction will
+// be placed before |insert_before|.
+Instruction* BuildCompositeConstruct(
+ uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
+ Instruction* insert_before) {
+ InstructionBuilder ir_builder(
+ insert_before->context(), insert_before,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ std::vector<uint32_t> ids_in_order;
+ for (auto it : values_inserted) {
+ ids_in_order.push_back(it.second);
+ }
+ Instruction* construct =
+ ir_builder.AddCompositeConstruct(type_id, ids_in_order);
+ return construct;
+}
+
+// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
+// object as |inst| with final index removed. If the resulting
+// OpCompositeInsert instruction would have no remaining indexes, the
+// instruction is replaced with an OpCopyObject instead.
+void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
+ if (inst->NumInOperands() == 3) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
+ } else {
+ inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
+ inst->RemoveOperand(inst->NumOperands() - 1);
+ }
+}
+
+// Replaces a series of |OpCompositeInsert| instruction that cover the entire
+// object with an |OpCompositeConstruct|.
+bool CompositeInsertToCompositeConstruct(
+ IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpCompositeInsert &&
+ "Wrong opcode. Should be OpCompositeInsert.");
+ if (inst->NumInOperands() < 3) return false;
+
+ std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
+ const analysis::Type* container_type = GetContainerType(inst);
+ if (container_type == nullptr) {
+ return false;
+ }
+
+ if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
+ return false;
+ }
+
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ Instruction* construct = BuildCompositeConstruct(
+ type_mgr->GetId(container_type), values_inserted, inst);
+ InsertConstructedObject(inst, construct);
+ return true;
+}
+
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
@@ -2591,6 +2798,8 @@
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
+ rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
+
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
diff --git a/source/opt/interface_var_sroa.cpp b/source/opt/interface_var_sroa.cpp
new file mode 100644
index 0000000..58ed897
--- /dev/null
+++ b/source/opt/interface_var_sroa.cpp
@@ -0,0 +1,964 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/interface_var_sroa.h"
+
+#include <iostream>
+
+#include "source/opt/decoration_manager.h"
+#include "source/opt/def_use_manager.h"
+#include "source/opt/function.h"
+#include "source/opt/log.h"
+#include "source/opt/type_manager.h"
+#include "source/util/make_unique.h"
+
+const static uint32_t kOpDecorateDecorationInOperandIndex = 1;
+const static uint32_t kOpDecorateLiteralInOperandIndex = 2;
+const static uint32_t kOpEntryPointInOperandInterface = 3;
+const static uint32_t kOpVariableStorageClassInOperandIndex = 0;
+const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
+const static uint32_t kOpTypeArrayLengthInOperandIndex = 1;
+const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
+const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
+const static uint32_t kOpTypePtrTypeInOperandIndex = 1;
+const static uint32_t kOpConstantValueInOperandIndex = 0;
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+// Get the length of the OpTypeArray |array_type|.
+uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
+ Instruction* array_type) {
+ assert(array_type->opcode() == SpvOpTypeArray);
+ uint32_t const_int_id =
+ array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
+ Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
+ assert(array_length_inst->opcode() == SpvOpConstant);
+ return array_length_inst->GetSingleWordInOperand(
+ kOpConstantValueInOperandIndex);
+}
+
+// Get the element type instruction of the OpTypeArray |array_type|.
+Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
+ Instruction* array_type) {
+ assert(array_type->opcode() == SpvOpTypeArray);
+ uint32_t elem_type_id =
+ array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
+ return def_use_mgr->GetDef(elem_type_id);
+}
+
+// Get the column type instruction of the OpTypeMatrix |matrix_type|.
+Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
+ Instruction* matrix_type) {
+ assert(matrix_type->opcode() == SpvOpTypeMatrix);
+ uint32_t column_type_id =
+ matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
+ return def_use_mgr->GetDef(column_type_id);
+}
+
+// Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
+// |depth_to_component| times recursively and returns the component type.
+// |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
+uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
+ uint32_t type_id,
+ uint32_t depth_to_component) {
+ if (depth_to_component == 0) return type_id;
+
+ Instruction* type_inst = def_use_mgr->GetDef(type_id);
+ if (type_inst->opcode() == SpvOpTypeArray) {
+ uint32_t elem_type_id =
+ type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
+ return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
+ depth_to_component - 1);
+ }
+
+ assert(type_inst->opcode() == SpvOpTypeMatrix);
+ uint32_t column_type_id =
+ type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
+ return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
+ depth_to_component - 1);
+}
+
+// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
+// |decoration|. Adds |literal| as an extra operand of the instruction.
+void CreateDecoration(analysis::DecorationManager* decoration_mgr,
+ uint32_t var_id, SpvDecoration decoration,
+ uint32_t literal) {
+ std::vector<Operand> operands({
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
+ {static_cast<uint32_t>(decoration)}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
+ });
+ decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands));
+}
+
+// Replaces load instructions with composite construct instructions in all the
+// users of the loads. |loads_to_composites| is the mapping from each load to
+// its corresponding OpCompositeConstruct.
+void ReplaceLoadWithCompositeConstruct(
+ IRContext* context,
+ const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
+ for (const auto& load_and_composite : loads_to_composites) {
+ Instruction* load = load_and_composite.first;
+ Instruction* composite_construct = load_and_composite.second;
+
+ std::vector<Instruction*> users;
+ context->get_def_use_mgr()->ForEachUse(
+ load, [&users, composite_construct](Instruction* user, uint32_t index) {
+ user->GetOperand(index).words[0] = composite_construct->result_id();
+ users.push_back(user);
+ });
+
+ for (Instruction* user : users)
+ context->get_def_use_mgr()->AnalyzeInstUse(user);
+ }
+}
+
+// Returns the storage class of the instruction |var|.
+SpvStorageClass GetStorageClass(Instruction* var) {
+ return static_cast<SpvStorageClass>(
+ var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
+}
+
+} // namespace
+
+bool InterfaceVariableScalarReplacement::HasExtraArrayness(
+ Instruction& entry_point, Instruction* var) {
+ SpvExecutionModel execution_model =
+ static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
+ if (execution_model != SpvExecutionModelTessellationEvaluation &&
+ execution_model != SpvExecutionModelTessellationControl) {
+ return false;
+ }
+ if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(),
+ SpvDecorationPatch)) {
+ if (execution_model == SpvExecutionModelTessellationControl) return true;
+ return GetStorageClass(var) != SpvStorageClassOutput;
+ }
+ return false;
+}
+
+bool InterfaceVariableScalarReplacement::
+ CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
+ bool has_extra_arrayness) {
+ if (has_extra_arrayness) {
+ return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
+ }
+ return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
+}
+
+bool InterfaceVariableScalarReplacement::GetVariableLocation(
+ Instruction* var, uint32_t* location) {
+ return !context()->get_decoration_mgr()->WhileEachDecoration(
+ var->result_id(), SpvDecorationLocation,
+ [location](const Instruction& inst) {
+ *location =
+ inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
+ return false;
+ });
+}
+
+bool InterfaceVariableScalarReplacement::GetVariableComponent(
+ Instruction* var, uint32_t* component) {
+ return !context()->get_decoration_mgr()->WhileEachDecoration(
+ var->result_id(), SpvDecorationComponent,
+ [component](const Instruction& inst) {
+ *component =
+ inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
+ return false;
+ });
+}
+
+std::vector<Instruction*>
+InterfaceVariableScalarReplacement::CollectInterfaceVariables(
+ Instruction& entry_point) {
+ std::vector<Instruction*> interface_vars;
+ for (uint32_t i = kOpEntryPointInOperandInterface;
+ i < entry_point.NumInOperands(); ++i) {
+ Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
+ entry_point.GetSingleWordInOperand(i));
+ assert(interface_var->opcode() == SpvOpVariable);
+
+ SpvStorageClass storage_class = GetStorageClass(interface_var);
+ if (storage_class != SpvStorageClassInput &&
+ storage_class != SpvStorageClassOutput) {
+ continue;
+ }
+
+ interface_vars.push_back(interface_var);
+ }
+ return interface_vars;
+}
+
+void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
+ Instruction* inst) {
+ if (inst->opcode() == SpvOpEntryPoint) {
+ return;
+ }
+ if (inst->opcode() != SpvOpAccessChain) {
+ context()->KillInst(inst);
+ return;
+ }
+ context()->get_def_use_mgr()->ForEachUser(
+ inst, [this](Instruction* user) { KillInstructionAndUsers(user); });
+ context()->KillInst(inst);
+}
+
+void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
+ const std::vector<Instruction*>& insts) {
+ for (Instruction* inst : insts) {
+ KillInstructionAndUsers(inst);
+ }
+}
+
+void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
+ uint32_t var_id) {
+ context()->get_decoration_mgr()->RemoveDecorationsFrom(
+ var_id, [](const Instruction& inst) {
+ uint32_t decoration =
+ inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex);
+ return decoration == SpvDecorationLocation ||
+ decoration == SpvDecorationComponent;
+ });
+}
+
+bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
+ Instruction* interface_var, Instruction* interface_var_type,
+ uint32_t location, uint32_t component, uint32_t extra_array_length) {
+ NestedCompositeComponents scalar_interface_vars =
+ CreateScalarInterfaceVarsForReplacement(interface_var_type,
+ GetStorageClass(interface_var),
+ extra_array_length);
+
+ AddLocationAndComponentDecorations(scalar_interface_vars, &location,
+ component);
+ KillLocationAndComponentDecorations(interface_var->result_id());
+
+ if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
+ scalar_interface_vars)) {
+ return false;
+ }
+
+ context()->KillInst(interface_var);
+ return true;
+}
+
+bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
+ Instruction* interface_var, uint32_t extra_array_length,
+ const NestedCompositeComponents& scalar_interface_vars) {
+ std::vector<Instruction*> users;
+ context()->get_def_use_mgr()->ForEachUser(
+ interface_var, [&users](Instruction* user) { users.push_back(user); });
+
+ std::vector<uint32_t> interface_var_component_indices;
+ std::unordered_map<Instruction*, Instruction*> loads_to_composites;
+ std::unordered_map<Instruction*, Instruction*>
+ loads_for_access_chain_to_composites;
+ if (extra_array_length != 0) {
+ // Note that the extra arrayness is the first dimension of the array
+ // interface variable.
+ for (uint32_t index = 0; index < extra_array_length; ++index) {
+ std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
+ if (!ReplaceComponentsOfInterfaceVarWith(
+ interface_var, users, scalar_interface_vars,
+ interface_var_component_indices, &index,
+ &loads_to_component_values,
+ &loads_for_access_chain_to_composites)) {
+ return false;
+ }
+ AddComponentsToCompositesForLoads(loads_to_component_values,
+ &loads_to_composites, 0);
+ }
+ } else if (!ReplaceComponentsOfInterfaceVarWith(
+ interface_var, users, scalar_interface_vars,
+ interface_var_component_indices, nullptr, &loads_to_composites,
+ &loads_for_access_chain_to_composites)) {
+ return false;
+ }
+
+ ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
+ ReplaceLoadWithCompositeConstruct(context(),
+ loads_for_access_chain_to_composites);
+
+ KillInstructionsAndUsers(users);
+ return true;
+}
+
+void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
+ const NestedCompositeComponents& vars, uint32_t* location,
+ uint32_t component) {
+ if (!vars.HasMultipleComponents()) {
+ uint32_t var_id = vars.GetComponentVariable()->result_id();
+ CreateDecoration(context()->get_decoration_mgr(), var_id,
+ SpvDecorationLocation, *location);
+ CreateDecoration(context()->get_decoration_mgr(), var_id,
+ SpvDecorationComponent, component);
+ ++(*location);
+ return;
+ }
+ for (const auto& var : vars.GetComponents()) {
+ AddLocationAndComponentDecorations(var, location, component);
+ }
+}
+
+bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
+ Instruction* interface_var,
+ const std::vector<Instruction*>& interface_var_users,
+ const NestedCompositeComponents& scalar_interface_vars,
+ std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_composites) {
+ if (!scalar_interface_vars.HasMultipleComponents()) {
+ for (Instruction* interface_var_user : interface_var_users) {
+ if (!ReplaceComponentOfInterfaceVarWith(
+ interface_var, interface_var_user,
+ scalar_interface_vars.GetComponentVariable(),
+ interface_var_component_indices, extra_array_index,
+ loads_to_composites, loads_for_access_chain_to_composites)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return ReplaceMultipleComponentsOfInterfaceVarWith(
+ interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
+ interface_var_component_indices, extra_array_index, loads_to_composites,
+ loads_for_access_chain_to_composites);
+}
+
+bool InterfaceVariableScalarReplacement::
+ ReplaceMultipleComponentsOfInterfaceVarWith(
+ Instruction* interface_var,
+ const std::vector<Instruction*>& interface_var_users,
+ const std::vector<NestedCompositeComponents>& components,
+ std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_composites) {
+ for (uint32_t i = 0; i < components.size(); ++i) {
+ interface_var_component_indices.push_back(i);
+ std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
+ std::unordered_map<Instruction*, Instruction*>
+ loads_for_access_chain_to_component_values;
+ if (!ReplaceComponentsOfInterfaceVarWith(
+ interface_var, interface_var_users, components[i],
+ interface_var_component_indices, extra_array_index,
+ &loads_to_component_values,
+ &loads_for_access_chain_to_component_values)) {
+ return false;
+ }
+ interface_var_component_indices.pop_back();
+
+ uint32_t depth_to_component =
+ static_cast<uint32_t>(interface_var_component_indices.size());
+ AddComponentsToCompositesForLoads(
+ loads_for_access_chain_to_component_values,
+ loads_for_access_chain_to_composites, depth_to_component);
+ if (extra_array_index) ++depth_to_component;
+ AddComponentsToCompositesForLoads(loads_to_component_values,
+ loads_to_composites, depth_to_component);
+ }
+ return true;
+}
+
+bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
+ Instruction* interface_var, Instruction* interface_var_user,
+ Instruction* scalar_var,
+ const std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_component_values) {
+ SpvOp opcode = interface_var_user->opcode();
+ if (opcode == SpvOpStore) {
+ uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
+ StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
+ scalar_var, extra_array_index,
+ interface_var_user);
+ return true;
+ }
+ if (opcode == SpvOpLoad) {
+ Instruction* scalar_load =
+ LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
+ loads_to_component_values->insert({interface_var_user, scalar_load});
+ return true;
+ }
+
+ // Copy OpName and annotation instructions only once. Therefore, we create
+ // them only for the first element of the extra array.
+ if (extra_array_index && *extra_array_index != 0) return true;
+
+ if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString ||
+ opcode == SpvOpDecorate) {
+ CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
+ return true;
+ }
+
+ if (opcode == SpvOpName) {
+ std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
+ new_inst->SetInOperand(0, {scalar_var->result_id()});
+ context()->AddDebug2Inst(std::move(new_inst));
+ return true;
+ }
+
+ if (opcode == SpvOpEntryPoint) {
+ return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
+ scalar_var->result_id());
+ }
+
+ if (opcode == SpvOpAccessChain) {
+ ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
+ scalar_var,
+ loads_for_access_chain_to_component_values);
+ return true;
+ }
+
+ std::string message("Unhandled instruction");
+ message += "\n " + interface_var_user->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ message +=
+ "\nfor interface variable scalar replacement\n " +
+ interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
+ return false;
+}
+
+void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
+ Instruction* access_chain, Instruction* base_access_chain) {
+ assert(base_access_chain->opcode() == SpvOpAccessChain &&
+ access_chain->opcode() == SpvOpAccessChain &&
+ access_chain->GetSingleWordInOperand(0) ==
+ base_access_chain->result_id());
+ Instruction::OperandList new_operands;
+ for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
+ new_operands.emplace_back(base_access_chain->GetInOperand(i));
+ }
+ for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
+ new_operands.emplace_back(access_chain->GetInOperand(i));
+ }
+ access_chain->SetInOperands(std::move(new_operands));
+}
+
+Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
+ uint32_t var_type_id, Instruction* var,
+ const std::vector<uint32_t>& index_ids, Instruction* insert_before,
+ uint32_t* component_type_id) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ *component_type_id = GetComponentTypeOfArrayMatrix(
+ def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
+
+ uint32_t ptr_type_id =
+ GetPointerType(*component_type_id, GetStorageClass(var));
+
+ std::unique_ptr<Instruction> new_access_chain(
+ new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
+ for (uint32_t index_id : index_ids) {
+ new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
+ }
+
+ Instruction* inst = new_access_chain.get();
+ def_use_mgr->AnalyzeInstDefUse(inst);
+ insert_before->InsertBefore(std::move(new_access_chain));
+ return inst;
+}
+
+Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
+ uint32_t component_type_id, Instruction* var, uint32_t index,
+ Instruction* insert_before) {
+ uint32_t ptr_type_id =
+ GetPointerType(component_type_id, GetStorageClass(var));
+ uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
+ std::unique_ptr<Instruction> new_access_chain(
+ new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {var->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {index_id}},
+ }));
+ Instruction* inst = new_access_chain.get();
+ context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
+ insert_before->InsertBefore(std::move(new_access_chain));
+ return inst;
+}
+
+void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
+ Instruction* access_chain,
+ const std::vector<uint32_t>& interface_var_component_indices,
+ Instruction* scalar_var,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
+ std::vector<uint32_t> indexes;
+ for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
+ indexes.push_back(access_chain->GetSingleWordInOperand(i));
+ }
+
+ // Note that we have a strong assumption that |access_chain| has only a single
+ // index that is for the extra arrayness.
+ context()->get_def_use_mgr()->ForEachUser(
+ access_chain,
+ [this, access_chain, &indexes, &interface_var_component_indices,
+ scalar_var, loads_to_component_values](Instruction* user) {
+ switch (user->opcode()) {
+ case SpvOpAccessChain: {
+ UseBaseAccessChainForAccessChain(user, access_chain);
+ ReplaceAccessChainWith(user, interface_var_component_indices,
+ scalar_var, loads_to_component_values);
+ return;
+ }
+ case SpvOpStore: {
+ uint32_t value_id = user->GetSingleWordInOperand(1);
+ StoreComponentOfValueToAccessChainToScalarVar(
+ value_id, interface_var_component_indices, scalar_var, indexes,
+ user);
+ return;
+ }
+ case SpvOpLoad: {
+ Instruction* value =
+ LoadAccessChainToVar(scalar_var, indexes, user);
+ loads_to_component_values->insert({user, value});
+ return;
+ }
+ default:
+ break;
+ }
+ });
+}
+
+void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
+ Instruction* annotation_inst, uint32_t var_id) {
+ assert(annotation_inst->opcode() == SpvOpDecorate ||
+ annotation_inst->opcode() == SpvOpDecorateId ||
+ annotation_inst->opcode() == SpvOpDecorateString);
+ std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
+ new_inst->SetInOperand(0, {var_id});
+ context()->AddAnnotationInst(std::move(new_inst));
+}
+
+bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
+ Instruction* interface_var, Instruction* entry_point,
+ uint32_t scalar_var_id) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ uint32_t interface_var_id = interface_var->result_id();
+ if (interface_vars_removed_from_entry_point_operands_.find(
+ interface_var_id) !=
+ interface_vars_removed_from_entry_point_operands_.end()) {
+ entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
+ def_use_mgr->AnalyzeInstUse(entry_point);
+ return true;
+ }
+
+ bool success = !entry_point->WhileEachInId(
+ [&interface_var_id, &scalar_var_id](uint32_t* id) {
+ if (*id == interface_var_id) {
+ *id = scalar_var_id;
+ return false;
+ }
+ return true;
+ });
+ if (!success) {
+ std::string message(
+ "interface variable is not an operand of the entry point");
+ message += "\n " + interface_var->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ message += "\n " + entry_point->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
+ return false;
+ }
+
+ def_use_mgr->AnalyzeInstUse(entry_point);
+ interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
+ return true;
+}
+
+uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
+ Instruction* var) {
+ assert(var->opcode() == SpvOpVariable);
+
+ uint32_t ptr_type_id = var->type_id();
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
+
+ assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
+ "Variable must have a pointer type.");
+ return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
+}
+
+void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
+ uint32_t value_id, const std::vector<uint32_t>& component_indices,
+ Instruction* scalar_var, const uint32_t* extra_array_index,
+ Instruction* insert_before) {
+ uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
+ Instruction* ptr = scalar_var;
+ if (extra_array_index) {
+ auto* ty_mgr = context()->get_type_mgr();
+ analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
+ assert(array_type != nullptr);
+ component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
+ ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
+ *extra_array_index, insert_before);
+ }
+
+ StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
+ extra_array_index, insert_before);
+}
+
+Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
+ Instruction* scalar_var, const uint32_t* extra_array_index,
+ Instruction* insert_before) {
+ uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
+ Instruction* ptr = scalar_var;
+ if (extra_array_index) {
+ auto* ty_mgr = context()->get_type_mgr();
+ analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
+ assert(array_type != nullptr);
+ component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
+ ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
+ *extra_array_index, insert_before);
+ }
+
+ return CreateLoad(component_type_id, ptr, insert_before);
+}
+
+Instruction* InterfaceVariableScalarReplacement::CreateLoad(
+ uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
+ std::unique_ptr<Instruction> load(
+ new Instruction(context(), SpvOpLoad, type_id, TakeNextId(),
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
+ Instruction* load_inst = load.get();
+ context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
+ insert_before->InsertBefore(std::move(load));
+ return load_inst;
+}
+
+void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
+ uint32_t component_type_id, uint32_t value_id,
+ const std::vector<uint32_t>& component_indices, Instruction* ptr,
+ const uint32_t* extra_array_index, Instruction* insert_before) {
+ std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
+ component_type_id, value_id, component_indices, extra_array_index));
+
+ std::unique_ptr<Instruction> new_store(
+ new Instruction(context(), SpvOpStore));
+ new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
+ new_store->AddOperand(
+ {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
+
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
+ def_use_mgr->AnalyzeInstDefUse(new_store.get());
+
+ insert_before->InsertBefore(std::move(composite_extract));
+ insert_before->InsertBefore(std::move(new_store));
+}
+
+Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
+ uint32_t type_id, uint32_t composite_id,
+ const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
+ uint32_t component_id = TakeNextId();
+ Instruction* composite_extract = new Instruction(
+ context(), SpvOpCompositeExtract, type_id, component_id,
+ std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
+ if (extra_first_index) {
+ composite_extract->AddOperand(
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
+ }
+ for (uint32_t index : indexes) {
+ composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
+ }
+ return composite_extract;
+}
+
+void InterfaceVariableScalarReplacement::
+ StoreComponentOfValueToAccessChainToScalarVar(
+ uint32_t value_id, const std::vector<uint32_t>& component_indices,
+ Instruction* scalar_var,
+ const std::vector<uint32_t>& access_chain_indices,
+ Instruction* insert_before) {
+ uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
+ Instruction* ptr = scalar_var;
+ if (!access_chain_indices.empty()) {
+ ptr = CreateAccessChainToVar(component_type_id, scalar_var,
+ access_chain_indices, insert_before,
+ &component_type_id);
+ }
+
+ StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
+ nullptr, insert_before);
+}
+
+Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
+ Instruction* var, const std::vector<uint32_t>& indexes,
+ Instruction* insert_before) {
+ uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
+ Instruction* ptr = var;
+ if (!indexes.empty()) {
+ ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
+ &component_type_id);
+ }
+
+ return CreateLoad(component_type_id, ptr, insert_before);
+}
+
+Instruction*
+InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
+ Instruction* load, uint32_t depth_to_component) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ uint32_t type_id = load->type_id();
+ if (depth_to_component != 0) {
+ type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
+ depth_to_component);
+ }
+ uint32_t new_id = context()->TakeNextId();
+ std::unique_ptr<Instruction> new_composite_construct(
+ new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {}));
+ Instruction* composite_construct = new_composite_construct.get();
+ def_use_mgr->AnalyzeInstDefUse(composite_construct);
+
+ // Insert |new_composite_construct| after |load|. When there are multiple
+ // recursive composite construct instructions for a load, we have to place the
+ // composite construct with a lower depth later because it constructs the
+ // composite that contains other composites with lower depths.
+ auto* insert_before = load->NextNode();
+ while (true) {
+ auto itr =
+ composite_ids_to_component_depths.find(insert_before->result_id());
+ if (itr == composite_ids_to_component_depths.end()) break;
+ if (itr->second <= depth_to_component) break;
+ insert_before = insert_before->NextNode();
+ }
+ insert_before->InsertBefore(std::move(new_composite_construct));
+ composite_ids_to_component_depths.insert({new_id, depth_to_component});
+ return composite_construct;
+}
+
+void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
+ const std::unordered_map<Instruction*, Instruction*>&
+ loads_to_component_values,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ uint32_t depth_to_component) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ for (auto& load_and_component_vale : loads_to_component_values) {
+ Instruction* load = load_and_component_vale.first;
+ Instruction* component_value = load_and_component_vale.second;
+ Instruction* composite_construct = nullptr;
+ auto itr = loads_to_composites->find(load);
+ if (itr == loads_to_composites->end()) {
+ composite_construct =
+ CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
+ loads_to_composites->insert({load, composite_construct});
+ } else {
+ composite_construct = itr->second;
+ }
+ composite_construct->AddOperand(
+ {SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
+ def_use_mgr->AnalyzeInstDefUse(composite_construct);
+ }
+}
+
+uint32_t InterfaceVariableScalarReplacement::GetArrayType(
+ uint32_t elem_type_id, uint32_t array_length) {
+ analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
+ uint32_t array_length_id =
+ context()->get_constant_mgr()->GetUIntConst(array_length);
+ analysis::Array array_type(
+ elem_type,
+ analysis::Array::LengthInfo{array_length_id, {0, array_length}});
+ return context()->get_type_mgr()->GetTypeInstruction(&array_type);
+}
+
+uint32_t InterfaceVariableScalarReplacement::GetPointerType(
+ uint32_t type_id, SpvStorageClass storage_class) {
+ analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
+ analysis::Pointer ptr_type(type, storage_class);
+ return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
+}
+
+InterfaceVariableScalarReplacement::NestedCompositeComponents
+InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length) {
+ assert(interface_var_type->opcode() == SpvOpTypeArray);
+
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
+ Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
+
+ NestedCompositeComponents scalar_vars;
+ while (array_length > 0) {
+ NestedCompositeComponents scalar_vars_for_element =
+ CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
+ extra_array_length);
+ scalar_vars.AddComponent(scalar_vars_for_element);
+ --array_length;
+ }
+ return scalar_vars;
+}
+
+InterfaceVariableScalarReplacement::NestedCompositeComponents
+InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length) {
+ assert(interface_var_type->opcode() == SpvOpTypeMatrix);
+
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ uint32_t column_count = interface_var_type->GetSingleWordInOperand(
+ kOpTypeMatrixColCountInOperandIndex);
+ Instruction* column_type =
+ GetMatrixColumnType(def_use_mgr, interface_var_type);
+
+ NestedCompositeComponents scalar_vars;
+ while (column_count > 0) {
+ NestedCompositeComponents scalar_vars_for_column =
+ CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
+ extra_array_length);
+ scalar_vars.AddComponent(scalar_vars_for_column);
+ --column_count;
+ }
+ return scalar_vars;
+}
+
+InterfaceVariableScalarReplacement::NestedCompositeComponents
+InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length) {
+ // Handle array case.
+ if (interface_var_type->opcode() == SpvOpTypeArray) {
+ return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
+ extra_array_length);
+ }
+
+ // Handle matrix case.
+ if (interface_var_type->opcode() == SpvOpTypeMatrix) {
+ return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
+ extra_array_length);
+ }
+
+ // Handle scalar or vector case.
+ NestedCompositeComponents scalar_var;
+ uint32_t type_id = interface_var_type->result_id();
+ if (extra_array_length != 0) {
+ type_id = GetArrayType(type_id, extra_array_length);
+ }
+ uint32_t ptr_type_id =
+ context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
+ uint32_t id = TakeNextId();
+ std::unique_ptr<Instruction> variable(
+ new Instruction(context(), SpvOpVariable, ptr_type_id, id,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_STORAGE_CLASS,
+ {static_cast<uint32_t>(storage_class)}}}));
+ scalar_var.SetSingleComponentVariable(variable.get());
+ context()->AddGlobalValue(std::move(variable));
+ return scalar_var;
+}
+
+Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
+ Instruction* var) {
+ uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ return def_use_mgr->GetDef(pointee_type_id);
+}
+
+Pass::Status InterfaceVariableScalarReplacement::Process() {
+ Pass::Status status = Status::SuccessWithoutChange;
+ for (Instruction& entry_point : get_module()->entry_points()) {
+ status =
+ CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
+ }
+ return status;
+}
+
+bool InterfaceVariableScalarReplacement::
+ ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
+ if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
+ return false;
+
+ std::string message(
+ "A variable is arrayed for an entry point but it is not "
+ "arrayed for another entry point");
+ message +=
+ "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
+ return true;
+}
+
+bool InterfaceVariableScalarReplacement::
+ ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
+ if (vars_without_extra_arrayness.find(var) ==
+ vars_without_extra_arrayness.end())
+ return false;
+
+ std::string message(
+ "A variable is not arrayed for an entry point but it is "
+ "arrayed for another entry point");
+ message +=
+ "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
+ return true;
+}
+
+Pass::Status
+InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
+ Instruction& entry_point) {
+ std::vector<Instruction*> interface_vars =
+ CollectInterfaceVariables(entry_point);
+
+ Pass::Status status = Status::SuccessWithoutChange;
+ for (Instruction* interface_var : interface_vars) {
+ uint32_t location, component;
+ if (!GetVariableLocation(interface_var, &location)) continue;
+ if (!GetVariableComponent(interface_var, &component)) component = 0;
+
+ Instruction* interface_var_type = GetTypeOfVariable(interface_var);
+ uint32_t extra_array_length = 0;
+ if (HasExtraArrayness(entry_point, interface_var)) {
+ extra_array_length =
+ GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
+ interface_var_type =
+ GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
+ vars_with_extra_arrayness.insert(interface_var);
+ } else {
+ vars_without_extra_arrayness.insert(interface_var);
+ }
+
+ if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
+ extra_array_length != 0)) {
+ return Pass::Status::Failure;
+ }
+
+ if (interface_var_type->opcode() != SpvOpTypeArray &&
+ interface_var_type->opcode() != SpvOpTypeMatrix) {
+ continue;
+ }
+
+ if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
+ location, component,
+ extra_array_length)) {
+ return Pass::Status::Failure;
+ }
+ status = Pass::Status::SuccessWithChange;
+ }
+
+ return status;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/interface_var_sroa.h b/source/opt/interface_var_sroa.h
new file mode 100644
index 0000000..23baad0
--- /dev/null
+++ b/source/opt/interface_var_sroa.h
@@ -0,0 +1,401 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
+#define SOURCE_OPT_INTERFACE_VAR_SROA_H_
+
+#include <unordered_set>
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+//
+// Note that the current implementation of this pass covers only store, load,
+// access chain instructions for the interface variables. Supporting other types
+// of instructions is a future work.
+class InterfaceVariableScalarReplacement : public Pass {
+ public:
+ InterfaceVariableScalarReplacement() {}
+
+ const char* name() const override {
+ return "interface-variable-scalar-replacement";
+ }
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+ }
+
+ private:
+ // A struct containing components of a composite variable. If the composite
+ // consists of multiple or recursive components, |component_variable| is
+ // nullptr and |nested_composite_components| keeps the components. If it has a
+ // single component, |nested_composite_components| is empty and
+ // |component_variable| is the component. Note that each element of
+ // |nested_composite_components| has the NestedCompositeComponents struct as
+ // its type that can recursively keep the components.
+ struct NestedCompositeComponents {
+ NestedCompositeComponents() : component_variable(nullptr) {}
+
+ bool HasMultipleComponents() const {
+ return !nested_composite_components.empty();
+ }
+
+ const std::vector<NestedCompositeComponents>& GetComponents() const {
+ return nested_composite_components;
+ }
+
+ void AddComponent(const NestedCompositeComponents& component) {
+ nested_composite_components.push_back(component);
+ }
+
+ Instruction* GetComponentVariable() const { return component_variable; }
+
+ void SetSingleComponentVariable(Instruction* var) {
+ component_variable = var;
+ }
+
+ private:
+ std::vector<NestedCompositeComponents> nested_composite_components;
+ Instruction* component_variable;
+ };
+
+ // Collects all interface variables used by the |entry_point|.
+ std::vector<Instruction*> CollectInterfaceVariables(Instruction& entry_point);
+
+ // Returns whether |var| has the extra arrayness for the entry point
+ // |entry_point| or not.
+ bool HasExtraArrayness(Instruction& entry_point, Instruction* var);
+
+ // Finds a Location BuiltIn decoration of |var| and returns it via
+ // |location|. Returns true whether the location exists or not.
+ bool GetVariableLocation(Instruction* var, uint32_t* location);
+
+ // Finds a Component BuiltIn decoration of |var| and returns it via
+ // |component|. Returns true whether the component exists or not.
+ bool GetVariableComponent(Instruction* var, uint32_t* component);
+
+ // Returns the interface variable instruction whose result id is
+ // |interface_var_id|.
+ Instruction* GetInterfaceVariable(uint32_t interface_var_id);
+
+ // Returns the type of |var| as an instruction.
+ Instruction* GetTypeOfVariable(Instruction* var);
+
+ // Replaces an interface variable |interface_var| whose type is
+ // |interface_var_type| with scalars and returns whether it succeeds or not.
+ // |location| is the value of Location Decoration for |interface_var|.
+ // |component| is the value of Component Decoration for |interface_var|.
+ // If |extra_array_length| is 0, it means |interface_var| has a Patch
+ // decoration. Otherwise, |extra_array_length| denotes the length of the extra
+ // array of |interface_var|.
+ bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
+ Instruction* interface_var_type,
+ uint32_t location,
+ uint32_t component,
+ uint32_t extra_array_length);
+
+ // Creates scalar variables with the storage classe |storage_class| to replace
+ // an interface variable whose type is |interface_var_type|. If
+ // |extra_array_length| is not zero, adds the extra arrayness to the created
+ // scalar variables.
+ NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length);
+
+ // Creates scalar variables with the storage classe |storage_class| to replace
+ // the interface variable whose type is OpTypeArray |interface_var_type| with.
+ // If |extra_array_length| is not zero, adds the extra arrayness to all the
+ // scalar variables.
+ NestedCompositeComponents CreateScalarInterfaceVarsForArray(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length);
+
+ // Creates scalar variables with the storage classe |storage_class| to replace
+ // the interface variable whose type is OpTypeMatrix |interface_var_type|
+ // with. If |extra_array_length| is not zero, adds the extra arrayness to all
+ // the scalar variables.
+ NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
+ Instruction* interface_var_type, SpvStorageClass storage_class,
+ uint32_t extra_array_length);
+
+ // Recursively adds Location and Component decorations to variables in
+ // |vars| with |location| and |component|. Increases |location| by one after
+ // it actually adds Location and Component decorations for a variable.
+ void AddLocationAndComponentDecorations(const NestedCompositeComponents& vars,
+ uint32_t* location,
+ uint32_t component);
+
+ // Replaces the interface variable |interface_var| with
+ // |scalar_interface_vars| and returns whether it succeeds or not.
+ // |extra_arrayness| is the extra arrayness of the interface variable.
+ // |scalar_interface_vars| contains the nested variables to replace the
+ // interface variable with.
+ bool ReplaceInterfaceVarWith(
+ Instruction* interface_var, uint32_t extra_arrayness,
+ const NestedCompositeComponents& scalar_interface_vars);
+
+ // Replaces |interface_var| in the operands of instructions
+ // |interface_var_users| with |scalar_interface_vars|. This is a recursive
+ // method and |interface_var_component_indices| is used to specify which
+ // recursive component of |interface_var| is replaced. Returns composite
+ // construct instructions to be replaced with load instructions of
+ // |interface_var_users| via |loads_to_composites|. Returns composite
+ // construct instructions to be replaced with load instructions of access
+ // chain instructions in |interface_var_users| via
+ // |loads_for_access_chain_to_composites|.
+ bool ReplaceComponentsOfInterfaceVarWith(
+ Instruction* interface_var,
+ const std::vector<Instruction*>& interface_var_users,
+ const NestedCompositeComponents& scalar_interface_vars,
+ std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_composites);
+
+ // Replaces |interface_var| in the operands of instructions
+ // |interface_var_users| with |components| that is a vector of components for
+ // the interface variable |interface_var|. This is a recursive method and
+ // |interface_var_component_indices| is used to specify which recursive
+ // component of |interface_var| is replaced. Returns composite construct
+ // instructions to be replaced with load instructions of |interface_var_users|
+ // via |loads_to_composites|. Returns composite construct instructions to be
+ // replaced with load instructions of access chain instructions in
+ // |interface_var_users| via |loads_for_access_chain_to_composites|.
+ bool ReplaceMultipleComponentsOfInterfaceVarWith(
+ Instruction* interface_var,
+ const std::vector<Instruction*>& interface_var_users,
+ const std::vector<NestedCompositeComponents>& components,
+ std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_composites);
+
+ // Replaces a component of |interface_var| that is used as an operand of
+ // instruction |interface_var_user| with |scalar_var|.
+ // |interface_var_component_indices| is a vector of recursive indices for
+ // which recursive component of |interface_var| is replaced. If
+ // |interface_var_user| is a load, returns the component value via
+ // |loads_to_component_values|. If |interface_var_user| is an access chain,
+ // returns the component value for loads of |interface_var_user| via
+ // |loads_for_access_chain_to_component_values|.
+ bool ReplaceComponentOfInterfaceVarWith(
+ Instruction* interface_var, Instruction* interface_var_user,
+ Instruction* scalar_var,
+ const std::vector<uint32_t>& interface_var_component_indices,
+ const uint32_t* extra_array_index,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_for_access_chain_to_component_values);
+
+ // Creates instructions to load |scalar_var| and inserts them before
+ // |insert_before|. If |extra_array_index| is not null, they load
+ // |extra_array_index| th component of |scalar_var| instead of |scalar_var|
+ // itself.
+ Instruction* LoadScalarVar(Instruction* scalar_var,
+ const uint32_t* extra_array_index,
+ Instruction* insert_before);
+
+ // Creates instructions to load an access chain to |var| and inserts them
+ // before |insert_before|. |Indexes| will be Indexes operand of the access
+ // chain.
+ Instruction* LoadAccessChainToVar(Instruction* var,
+ const std::vector<uint32_t>& indexes,
+ Instruction* insert_before);
+
+ // Creates instructions to store a component of an aggregate whose id is
+ // |value_id| to an access chain to |scalar_var| and inserts the created
+ // instructions before |insert_before|. To get the component, recursively
+ // traverses the aggregate with |component_indices| as indexes.
+ // Numbers in |access_chain_indices| are the Indexes operand of the access
+ // chain to |scalar_var|
+ void StoreComponentOfValueToAccessChainToScalarVar(
+ uint32_t value_id, const std::vector<uint32_t>& component_indices,
+ Instruction* scalar_var,
+ const std::vector<uint32_t>& access_chain_indices,
+ Instruction* insert_before);
+
+ // Creates instructions to store a component of an aggregate whose id is
+ // |value_id| to |scalar_var| and inserts the created instructions before
+ // |insert_before|. To get the component, recursively traverses the aggregate
+ // using |extra_array_index| and |component_indices| as indexes.
+ void StoreComponentOfValueToScalarVar(
+ uint32_t value_id, const std::vector<uint32_t>& component_indices,
+ Instruction* scalar_var, const uint32_t* extra_array_index,
+ Instruction* insert_before);
+
+ // Creates instructions to store a component of an aggregate whose id is
+ // |value_id| to |ptr| and inserts the created instructions before
+ // |insert_before|. To get the component, recursively traverses the aggregate
+ // using |extra_array_index| and |component_indices| as indexes.
+ // |component_type_id| is the id of the type instruction of the component.
+ void StoreComponentOfValueTo(uint32_t component_type_id, uint32_t value_id,
+ const std::vector<uint32_t>& component_indices,
+ Instruction* ptr,
+ const uint32_t* extra_array_index,
+ Instruction* insert_before);
+
+ // Creates new OpCompositeExtract with |type_id| for Result Type,
+ // |composite_id| for Composite operand, and |indexes| for Indexes operands.
+ // If |extra_first_index| is not nullptr, uses it as the first Indexes
+ // operand.
+ Instruction* CreateCompositeExtract(uint32_t type_id, uint32_t composite_id,
+ const std::vector<uint32_t>& indexes,
+ const uint32_t* extra_first_index);
+
+ // Creates a new OpLoad whose Result Type is |type_id| and Pointer operand is
+ // |ptr|. Inserts the new instruction before |insert_before|.
+ Instruction* CreateLoad(uint32_t type_id, Instruction* ptr,
+ Instruction* insert_before);
+
+ // Clones an annotation instruction |annotation_inst| and sets the target
+ // operand of the new annotation instruction as |var_id|.
+ void CloneAnnotationForVariable(Instruction* annotation_inst,
+ uint32_t var_id);
+
+ // Replaces the interface variable |interface_var| in the operands of the
+ // entry point |entry_point| with |scalar_var_id|. If it cannot find
+ // |interface_var| from the operands of the entry point |entry_point|, adds
+ // |scalar_var_id| as an operand of the entry point |entry_point|.
+ bool ReplaceInterfaceVarInEntryPoint(Instruction* interface_var,
+ Instruction* entry_point,
+ uint32_t scalar_var_id);
+
+ // Creates an access chain instruction whose Base operand is |var| and Indexes
+ // operand is |index|. |component_type_id| is the id of the type instruction
+ // that is the type of component. Inserts the new access chain before
+ // |insert_before|.
+ Instruction* CreateAccessChainWithIndex(uint32_t component_type_id,
+ Instruction* var, uint32_t index,
+ Instruction* insert_before);
+
+ // Returns the pointee type of the type of variable |var|.
+ uint32_t GetPointeeTypeIdOfVar(Instruction* var);
+
+ // Replaces the access chain |access_chain| and its users with a new access
+ // chain that points |scalar_var| as the Base operand having
+ // |interface_var_component_indices| as Indexes operands and users of the new
+ // access chain. When some of the users are load instructions, returns the
+ // original load instruction to the new instruction that loads a component of
+ // the original load value via |loads_to_component_values|.
+ void ReplaceAccessChainWith(
+ Instruction* access_chain,
+ const std::vector<uint32_t>& interface_var_component_indices,
+ Instruction* scalar_var,
+ std::unordered_map<Instruction*, Instruction*>*
+ loads_to_component_values);
+
+ // Assuming that |access_chain| is an access chain instruction whose Base
+ // operand is |base_access_chain|, replaces the operands of |access_chain|
+ // with operands of |base_access_chain| and Indexes operands of
+ // |access_chain|.
+ void UseBaseAccessChainForAccessChain(Instruction* access_chain,
+ Instruction* base_access_chain);
+
+ // Creates composite construct instructions for load instructions that are the
+ // keys of |loads_to_component_values| if no such composite construct
+ // instructions exist. Adds a component of the composite as an operand of the
+ // created composite construct instruction. Each value of
+ // |loads_to_component_values| is the component. Returns the created composite
+ // construct instructions using |loads_to_composites|. |depth_to_component| is
+ // the number of recursive access steps to get the component from the
+ // composite.
+ void AddComponentsToCompositesForLoads(
+ const std::unordered_map<Instruction*, Instruction*>&
+ loads_to_component_values,
+ std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+ uint32_t depth_to_component);
+
+ // Creates a composite construct instruction for a component of the value of
+ // instruction |load| in |depth_to_component| th recursive depth and inserts
+ // it after |load|.
+ Instruction* CreateCompositeConstructForComponentOfLoad(
+ Instruction* load, uint32_t depth_to_component);
+
+ // Creates a new access chain instruction that points to variable |var| whose
+ // type is the instruction with |var_type_id| and inserts it before
+ // |insert_before|. The new access chain will have |index_ids| for Indexes
+ // operands. Returns the type id of the component that is pointed by the new
+ // access chain via |component_type_id|.
+ Instruction* CreateAccessChainToVar(uint32_t var_type_id, Instruction* var,
+ const std::vector<uint32_t>& index_ids,
+ Instruction* insert_before,
+ uint32_t* component_type_id);
+
+ // Returns the result id of OpTypeArray instrunction whose Element Type
+ // operand is |elem_type_id| and Length operand is |array_length|.
+ uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length);
+
+ // Returns the result id of OpTypePointer instrunction whose Type
+ // operand is |type_id| and Storage Class operand is |storage_class|.
+ uint32_t GetPointerType(uint32_t type_id, SpvStorageClass storage_class);
+
+ // Kills an instrunction |inst| and its users.
+ void KillInstructionAndUsers(Instruction* inst);
+
+ // Kills a vector of instrunctions |insts| and their users.
+ void KillInstructionsAndUsers(const std::vector<Instruction*>& insts);
+
+ // Kills all OpDecorate instructions for Location and Component of the
+ // variable whose id is |var_id|.
+ void KillLocationAndComponentDecorations(uint32_t var_id);
+
+ // If |var| has the extra arrayness for an entry point, reports an error and
+ // returns true. Otherwise, returns false.
+ bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var);
+
+ // If |var| does not have the extra arrayness for an entry point, reports an
+ // error and returns true. Otherwise, returns false.
+ bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var);
+
+ // If |interface_var| has the extra arrayness for an entry point but it does
+ // not have one for another entry point, reports an error and returns false.
+ // Otherwise, returns true. |has_extra_arrayness| denotes whether it has an
+ // extra arrayness for an entry point or not.
+ bool CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
+ bool has_extra_arrayness);
+
+ // Conducts the scalar replacement for the interface variables used by the
+ // |entry_point|.
+ Pass::Status ReplaceInterfaceVarsWithScalars(Instruction& entry_point);
+
+ // A set of interface variable ids that were already removed from operands of
+ // the entry point.
+ std::unordered_set<uint32_t>
+ interface_vars_removed_from_entry_point_operands_;
+
+ // A mapping from ids of new composite construct instructions that load
+ // instructions are replaced with to the recursive depth of the component of
+ // load that the new component construct instruction is used for.
+ std::unordered_map<uint32_t, uint32_t> composite_ids_to_component_depths;
+
+ // A set of interface variables with the extra arrayness for any of the entry
+ // points.
+ std::unordered_set<Instruction*> vars_with_extra_arrayness;
+
+ // A set of interface variables without the extra arrayness for any of the
+ // entry points.
+ std::unordered_set<Instruction*> vars_without_extra_arrayness;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_INTERFACE_VAR_SROA_H_
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index 4433cf0..9d4fa8f 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -487,6 +487,15 @@
return AddInstruction(std::move(new_inst));
}
+ Instruction* AddVariable(uint32_t type_id, uint32_t storage_class) {
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {storage_class}});
+ std::unique_ptr<Instruction> new_inst(
+ new Instruction(GetContext(), SpvOpVariable, type_id,
+ GetContext()->TakeNextId(), operands));
+ return AddInstruction(std::move(new_inst));
+ }
+
Instruction* AddStore(uint32_t ptr_id, uint32_t obj_id) {
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}});
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index f9f5153..2f27942 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -1094,6 +1094,9 @@
id_to_name_->insert({d->GetSingleWordInOperand(0), d.get()});
}
}
+ if (AreAnalysesValid(kAnalysisDefUse)) {
+ get_def_use_mgr()->AnalyzeInstDefUse(d.get());
+ }
module()->AddDebug2Inst(std::move(d));
}
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index 0c6d0c2..da4cac3 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -28,8 +28,6 @@
const uint32_t kStoreValIdInIdx = 1;
const uint32_t kAccessChainPtrIdInIdx = 0;
-const uint32_t kConstantValueInIdx = 0;
-const uint32_t kTypeIntWidthInIdx = 0;
} // anonymous namespace
@@ -67,7 +65,19 @@
ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
if (iidIdx > 0) {
const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
- uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
+ const auto* constant_value =
+ context()->get_constant_mgr()->GetConstantFromInst(cInst);
+ assert(constant_value != nullptr &&
+ "Expecting the index to be a constant.");
+
+ // We take the sign extended value because OpAccessChain interprets the
+ // index as signed.
+ int64_t long_value = constant_value->GetSignExtendedValue();
+ assert(long_value <= UINT32_MAX && long_value >= 0 &&
+ "The index value is too large for a composite insert or extract "
+ "instruction.");
+
+ uint32_t val = static_cast<uint32_t>(long_value);
in_opnds->push_back(
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
}
@@ -169,13 +179,16 @@
return true;
}
-bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
+bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
const Instruction* acp) const {
uint32_t inIdx = 0;
return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
if (inIdx > 0) {
Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
if (opInst->opcode() != SpvOpConstant) return false;
+ const auto* index =
+ context()->get_constant_mgr()->GetConstantFromInst(opInst);
+ if (index->GetSignExtendedValue() > UINT32_MAX) return false;
}
++inIdx;
return true;
@@ -231,7 +244,7 @@
break;
}
// Rule out variables accessed with non-constant indices
- if (!IsConstantIndexAccessChain(ptrInst)) {
+ if (!Is32BitConstantIndexAccessChain(ptrInst)) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
@@ -349,12 +362,6 @@
}
Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
- // If non-32-bit integer type in module, terminate processing
- // TODO(): Handle non-32-bit integer constants in access chains
- for (const Instruction& inst : get_module()->types_values())
- if (inst.opcode() == SpvOpTypeInt &&
- inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
- return Status::SuccessWithoutChange;
// Do not process if module contains OpGroupDecorate. Additional
// support required in KillNamesAndDecorates().
// TODO(greg-lunarg): Add support for OpGroupDecorate
@@ -435,6 +442,7 @@
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
+ "SPV_KHR_fragment_shader_barycentric",
});
}
diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h
index a51660f..8548e16 100644
--- a/source/opt/local_access_chain_convert_pass.h
+++ b/source/opt/local_access_chain_convert_pass.h
@@ -95,7 +95,8 @@
Instruction* original_load);
// Return true if all indices of access chain |acp| are OpConstant integers
- bool IsConstantIndexAccessChain(const Instruction* acp) const;
+ // whose values can fit into an unsigned 32-bit value.
+ bool Is32BitConstantIndexAccessChain(const Instruction* acp) const;
// Identify all function scope variables of target type which are
// accessed only with loads, stores and access chains with constant
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index 33c8bdf..a58e8e4 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -287,6 +287,7 @@
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
+ "SPV_KHR_fragment_shader_barycentric",
});
}
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index f22b191..8cdd0ab 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -140,6 +140,7 @@
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
+ "SPV_KHR_fragment_shader_barycentric",
});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index f28b1ba..2976151 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -525,6 +525,8 @@
RegisterPass(CreateRemoveDontInlinePass());
} else if (pass_name == "eliminate-dead-input-components") {
RegisterPass(CreateEliminateDeadInputComponentsPass());
+ } else if (pass_name == "fix-func-call-param") {
+ RegisterPass(CreateFixFuncCallArgumentsPass());
} else if (pass_name == "convert-to-sampled-image") {
if (pass_args.size() > 0) {
auto descriptor_set_binding_pairs =
@@ -1018,8 +1020,18 @@
MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
}
+Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::InterfaceVariableScalarReplacement>());
+}
+
Optimizer::PassToken CreateRemoveDontInlinePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::RemoveDontInline>());
}
+
+Optimizer::PassToken CreateFixFuncCallArgumentsPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::FixFuncCallArgumentsPass>());
+}
} // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index a12c76b..21354c7 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -37,6 +37,7 @@
#include "source/opt/eliminate_dead_input_components_pass.h"
#include "source/opt/eliminate_dead_members_pass.h"
#include "source/opt/empty_pass.h"
+#include "source/opt/fix_func_call_arguments.h"
#include "source/opt/fix_storage_class.h"
#include "source/opt/flatten_decoration_pass.h"
#include "source/opt/fold_spec_constant_op_and_composite_pass.h"
@@ -48,6 +49,7 @@
#include "source/opt/inst_bindless_check_pass.h"
#include "source/opt/inst_buff_addr_check_pass.h"
#include "source/opt/inst_debug_printf_pass.h"
+#include "source/opt/interface_var_sroa.h"
#include "source/opt/interp_fixup_pass.h"
#include "source/opt/licm_pass.h"
#include "source/opt/local_access_chain_convert_pass.h"
diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp
index 4cadf60..e97593e 100644
--- a/source/opt/replace_desc_array_access_using_var_index.cpp
+++ b/source/opt/replace_desc_array_access_using_var_index.cpp
@@ -95,7 +95,7 @@
CollectRecursiveUsersWithConcreteType(access_chain, &final_users);
for (auto* inst : final_users) {
std::deque<Instruction*> insts_to_be_cloned =
- CollectRequiredImageInsts(inst);
+ CollectRequiredImageAndAccessInsts(inst);
ReplaceNonUniformAccessWithSwitchCase(
inst, access_chain, number_of_elements, insts_to_be_cloned);
}
@@ -121,8 +121,8 @@
}
std::deque<Instruction*>
-ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageInsts(
- Instruction* user_of_image_insts) const {
+ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageAndAccessInsts(
+ Instruction* user) const {
std::unordered_set<uint32_t> seen_inst_ids;
std::queue<Instruction*> work_list;
@@ -131,21 +131,23 @@
if (!seen_inst_ids.insert(*idp).second) return;
Instruction* operand = get_def_use_mgr()->GetDef(*idp);
if (context()->get_instr_block(operand) != nullptr &&
- HasImageOrImagePtrType(operand)) {
+ (HasImageOrImagePtrType(operand) ||
+ operand->opcode() == SpvOpAccessChain ||
+ operand->opcode() == SpvOpInBoundsAccessChain)) {
work_list.push(operand);
}
};
- std::deque<Instruction*> required_image_insts;
- required_image_insts.push_front(user_of_image_insts);
- user_of_image_insts->ForEachInId(decision_to_include_operand);
+ std::deque<Instruction*> required_insts;
+ required_insts.push_front(user);
+ user->ForEachInId(decision_to_include_operand);
while (!work_list.empty()) {
auto* inst_from_work_list = work_list.front();
work_list.pop();
- required_image_insts.push_front(inst_from_work_list);
+ required_insts.push_front(inst_from_work_list);
inst_from_work_list->ForEachInId(decision_to_include_operand);
}
- return required_image_insts;
+ return required_insts;
}
bool ReplaceDescArrayAccessUsingVarIndex::HasImageOrImagePtrType(
diff --git a/source/opt/replace_desc_array_access_using_var_index.h b/source/opt/replace_desc_array_access_using_var_index.h
index 0c97f7e..51817c1 100644
--- a/source/opt/replace_desc_array_access_using_var_index.h
+++ b/source/opt/replace_desc_array_access_using_var_index.h
@@ -76,11 +76,12 @@
void CollectRecursiveUsersWithConcreteType(
Instruction* access_chain, std::vector<Instruction*>* final_users) const;
- // Recursively collects the operands of |user_of_image_insts| (and operands
- // of the operands) whose result types are images/samplers or pointers/array/
- // struct of them and returns them.
- std::deque<Instruction*> CollectRequiredImageInsts(
- Instruction* user_of_image_insts) const;
+ // Recursively collects the operands of |user| (and operands of the operands)
+ // whose result types are images/samplers (or pointers/arrays/ structs of
+ // them) and access chains instructions and returns them. The returned
+ // collection includes |user|.
+ std::deque<Instruction*> CollectRequiredImageAndAccessInsts(
+ Instruction* user) const;
// Returns whether result type of |inst| is an image/sampler/pointer of image
// or sampler or not.
diff --git a/source/val/validate_annotation.cpp b/source/val/validate_annotation.cpp
index bef7ef9..40f2118 100644
--- a/source/val/validate_annotation.cpp
+++ b/source/val/validate_annotation.cpp
@@ -136,8 +136,8 @@
return "PerViewNV";
case SpvDecorationPerTaskNV:
return "PerTaskNV";
- case SpvDecorationPerVertexNV:
- return "PerVertexNV";
+ case SpvDecorationPerVertexKHR:
+ return "PerVertexKHR";
case SpvDecorationNonUniform:
return "NonUniform";
case SpvDecorationRestrictPointer:
@@ -366,6 +366,11 @@
return fail(4670) << "storage class must be Input or Output";
}
break;
+ case SpvDecorationPerVertexKHR:
+ if (sc != SpvStorageClassInput) {
+ return fail(6777) << "storage class must be Input";
+ }
+ break;
default:
break;
}
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 6a2e919..379705a 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -120,7 +120,7 @@
VUIDErrorMax,
} VUIDError;
-const static uint32_t NumVUIDBuiltins = 33;
+const static uint32_t NumVUIDBuiltins = 36;
typedef struct {
SpvBuiltIn builtIn;
@@ -162,6 +162,9 @@
{SpvBuiltInFragSizeEXT, {4220, 4221, 4222}},
{SpvBuiltInFragStencilRefEXT, {4223, 4224, 4225}},
{SpvBuiltInFullyCoveredEXT, {4232, 4233, 4234}},
+ {SpvBuiltInCullMaskKHR, {6735, 6736, 6737}},
+ {SpvBuiltInBaryCoordKHR, {4154, 4155, 4156}},
+ {SpvBuiltInBaryCoordNoPerspKHR, {4160, 4161, 4162}},
// clang-format off
} };
@@ -208,6 +211,7 @@
case SpvBuiltInRayTmaxKHR:
case SpvBuiltInWorldRayDirectionKHR:
case SpvBuiltInWorldRayOriginKHR:
+ case SpvBuiltInCullMaskKHR:
switch (stage) {
case SpvExecutionModelIntersectionKHR:
case SpvExecutionModelAnyHitKHR:
@@ -331,7 +335,9 @@
const Decoration& decoration, const Instruction& inst);
spv_result_t ValidateSMBuiltinsAtDefinition(const Decoration& decoration,
const Instruction& inst);
-
+ // Used for BaryCoord, BaryCoordNoPersp.
+ spv_result_t ValidateFragmentShaderF32Vec3InputAtDefinition(
+ const Decoration& decoration, const Instruction& inst);
// Used for SubgroupEqMask, SubgroupGeMask, SubgroupGtMask, SubgroupLtMask,
// SubgroupLeMask.
spv_result_t ValidateI32Vec4InputAtDefinition(const Decoration& decoration,
@@ -509,6 +515,13 @@
const Decoration& decoration, const Instruction& built_in_inst,
const Instruction& referenced_inst,
const Instruction& referenced_from_inst);
+
+ // Used for BaryCoord, BaryCoordNoPersp.
+ spv_result_t ValidateFragmentShaderF32Vec3InputAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst);
+
// Used for SubgroupId and NumSubgroups.
spv_result_t ValidateComputeI32InputAtReference(
const Decoration& decoration, const Instruction& built_in_inst,
@@ -2788,6 +2801,80 @@
return SPV_SUCCESS;
}
+spv_result_t BuiltInsValidator::ValidateFragmentShaderF32Vec3InputAtDefinition(
+ const Decoration& decoration, const Instruction& inst) {
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ const SpvBuiltIn builtin = SpvBuiltIn(decoration.params()[0]);
+ if (spv_result_t error = ValidateF32Vec(
+ decoration, inst, 3,
+ [this, &inst, builtin](const std::string& message) -> spv_result_t {
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorType);
+ return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+ << _.VkErrorID(vuid) << "According to the "
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+ builtin)
+ << " variable needs to be a 3-component 32-bit float "
+ "vector. "
+ << message;
+ })) {
+ return error;
+ }
+ }
+
+ // Seed at reference checks with this built-in.
+ return ValidateFragmentShaderF32Vec3InputAtReference(decoration, inst, inst,
+ inst);
+}
+
+spv_result_t BuiltInsValidator::ValidateFragmentShaderF32Vec3InputAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst) {
+
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ const SpvBuiltIn builtin = SpvBuiltIn(decoration.params()[0]);
+ const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst);
+ if (storage_class != SpvStorageClassMax &&
+ storage_class != SpvStorageClassInput) {
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorStorageClass);
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << _.VkErrorID(vuid) << spvLogStringForEnv(_.context()->target_env)
+ << " spec allows BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, builtin)
+ << " to be only used for variables with Input storage class. "
+ << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+ referenced_from_inst)
+ << " " << GetStorageClassDesc(referenced_from_inst);
+ }
+
+ for (const SpvExecutionModel execution_model : execution_models_) {
+ if (execution_model != SpvExecutionModelFragment) {
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorExecutionModel);
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << _.VkErrorID(vuid)
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec allows BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, builtin)
+ << " to be used only with Fragment execution model. "
+ << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+ referenced_from_inst, execution_model);
+ }
+ }
+ }
+
+ if (function_id_ == 0) {
+ // Propagate this rule to all dependant ids in the global scope.
+ id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+ &BuiltInsValidator::ValidateFragmentShaderF32Vec3InputAtReference, this,
+ decoration, built_in_inst, referenced_from_inst,
+ std::placeholders::_1));
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtDefinition(
const Decoration& decoration, const Instruction& inst) {
if (spvIsVulkanEnv(_.context()->target_env)) {
@@ -3851,6 +3938,7 @@
case SpvBuiltInInstanceId:
case SpvBuiltInRayGeometryIndexKHR:
case SpvBuiltInIncomingRayFlagsKHR:
+ case SpvBuiltInCullMaskKHR:
// i32 scalar
if (spv_result_t error = ValidateI32(
decoration, inst,
@@ -4027,6 +4115,10 @@
case SpvBuiltInWorkgroupId: {
return ValidateComputeShaderI32Vec3InputAtDefinition(decoration, inst);
}
+ case SpvBuiltInBaryCoordKHR:
+ case SpvBuiltInBaryCoordNoPerspKHR: {
+ return ValidateFragmentShaderF32Vec3InputAtDefinition(decoration, inst);
+ }
case SpvBuiltInHelperInvocation: {
return ValidateHelperInvocationAtDefinition(decoration, inst);
}
@@ -4151,7 +4243,8 @@
case SpvBuiltInObjectToWorldKHR: // alias SpvBuiltInObjectToWorldNV
case SpvBuiltInWorldToObjectKHR: // alias SpvBuiltInWorldToObjectNV
case SpvBuiltInIncomingRayFlagsKHR: // alias SpvBuiltInIncomingRayFlagsNV
- case SpvBuiltInRayGeometryIndexKHR: { // NOT present in NV
+ case SpvBuiltInRayGeometryIndexKHR: // NOT present in NV
+ case SpvBuiltInCullMaskKHR: {
return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
}
case SpvBuiltInWorkDim:
@@ -4182,10 +4275,7 @@
case SpvBuiltInLayerPerViewNV:
case SpvBuiltInMeshViewCountNV:
case SpvBuiltInMeshViewIndicesNV:
- case SpvBuiltInBaryCoordNV:
- case SpvBuiltInBaryCoordNoPerspNV:
case SpvBuiltInCurrentRayTimeNV:
- case SpvBuiltInCullMaskKHR:
// No validation rules (for the moment).
break;
diff --git a/source/val/validate_interfaces.cpp b/source/val/validate_interfaces.cpp
index adf2e47..7f2d648 100644
--- a/source/val/validate_interfaces.cpp
+++ b/source/val/validate_interfaces.cpp
@@ -238,7 +238,7 @@
uint32_t index = 0;
bool has_patch = false;
bool has_per_task_nv = false;
- bool has_per_vertex_nv = false;
+ bool has_per_vertex_khr = false;
for (auto& dec : _.id_decorations(variable->id())) {
if (dec.dec_type() == SpvDecorationLocation) {
if (has_location && dec.params()[0] != location) {
@@ -272,8 +272,20 @@
has_patch = true;
} else if (dec.dec_type() == SpvDecorationPerTaskNV) {
has_per_task_nv = true;
- } else if (dec.dec_type() == SpvDecorationPerVertexNV) {
- has_per_vertex_nv = true;
+ } else if (dec.dec_type() == SpvDecorationPerVertexKHR) {
+ if (!is_fragment) {
+ return _.diag(SPV_ERROR_INVALID_DATA, variable)
+ << _.VkErrorID(6777)
+ << "PerVertexKHR can only be applied to Fragment Execution "
+ "Models";
+ }
+ if (type->opcode() != SpvOpTypeArray &&
+ type->opcode() != SpvOpTypeRuntimeArray) {
+ return _.diag(SPV_ERROR_INVALID_DATA, variable)
+ << _.VkErrorID(6778)
+ << "PerVertexKHR must be declared as arrays";
+ }
+ has_per_vertex_khr = true;
}
}
@@ -298,7 +310,7 @@
}
break;
case SpvExecutionModelFragment:
- if (!is_output && has_per_vertex_nv) {
+ if (!is_output && has_per_vertex_khr) {
is_arrayed = true;
}
break;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 9aa6c63..d9422b2 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1411,6 +1411,18 @@
// Clang format adds spaces between hyphens
// clang-format off
switch (id) {
+ case 4154:
+ return VUID_WRAP(VUID-BaryCoordKHR-BaryCoordKHR-04154);
+ case 4155:
+ return VUID_WRAP(VUID-BaryCoordKHR-BaryCoordKHR-04155);
+ case 4156:
+ return VUID_WRAP(VUID-BaryCoordKHR-BaryCoordKHR-04156);
+ case 4160:
+ return VUID_WRAP(VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04160);
+ case 4161:
+ return VUID_WRAP(VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04161);
+ case 4162:
+ return VUID_WRAP(VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04162);
case 4181:
return VUID_WRAP(VUID-BaseInstance-BaseInstance-04181);
case 4182:
@@ -1443,6 +1455,12 @@
return VUID_WRAP(VUID-CullDistance-CullDistance-04199);
case 4200:
return VUID_WRAP(VUID-CullDistance-CullDistance-04200);
+ case 6735:
+ return VUID_WRAP(VUID-CullMaskKHR-CullMaskKHR-06735); // Execution Model
+ case 6736:
+ return VUID_WRAP(VUID-CullMaskKHR-CullMaskKHR-06736); // input storage
+ case 6737:
+ return VUID_WRAP(VUID-CullMaskKHR-CullMaskKHR-06737); // 32 int scalar
case 4205:
return VUID_WRAP(VUID-DeviceIndex-DeviceIndex-04205);
case 4206:
@@ -1919,6 +1937,10 @@
return VUID_WRAP(VUID-StandaloneSpirv-UniformConstant-06677);
case 6678:
return VUID_WRAP(VUID-StandaloneSpirv-InputAttachmentIndex-06678);
+ case 6777:
+ return VUID_WRAP(VUID-StandaloneSpirv-PerVertexKHR-06777);
+ case 6778:
+ return VUID_WRAP(VUID-StandaloneSpirv-Input-06778);
default:
return ""; // unknown id
}
@@ -1926,4 +1948,4 @@
}
} // namespace val
-} // namespace spvtools
\ No newline at end of file
+} // namespace spvtools
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 6dfb1b7..15966c1 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -45,6 +45,7 @@
eliminate_dead_input_components_test.cpp
eliminate_dead_member_test.cpp
feature_manager_test.cpp
+ fix_func_call_arguments_test.cpp
fix_storage_class_test.cpp
flatten_decoration_test.cpp
fold_spec_const_op_composite_test.cpp
@@ -61,6 +62,7 @@
inst_debug_printf_test.cpp
instruction_list_test.cpp
instruction_test.cpp
+ interface_var_sroa_test.cpp
interp_fixup_test.cpp
ir_builder.cpp
ir_context_test.cpp
@@ -84,7 +86,7 @@
reduce_load_size_test.cpp
redundancy_elimination_test.cpp
remove_dontinline_test.cpp
- remove_unused_interface_variables_test.cpp
+ remove_unused_interface_variables_test.cpp
register_liveness.cpp
relax_float_ops_test.cpp
replace_desc_array_access_using_var_index_test.cpp
@@ -96,7 +98,7 @@
spread_volatile_semantics_test.cpp
strength_reduction_test.cpp
strip_debug_info_test.cpp
- strip_nonsemantic_info_test.cpp
+ strip_nonsemantic_info_test.cpp
struct_cfg_analysis_test.cpp
type_manager_test.cpp
types_test.cpp
diff --git a/test/opt/fix_func_call_arguments_test.cpp b/test/opt/fix_func_call_arguments_test.cpp
new file mode 100644
index 0000000..ecd13a8
--- /dev/null
+++ b/test/opt/fix_func_call_arguments_test.cpp
@@ -0,0 +1,152 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using FixFuncCallArgumentsTest = PassTest<::testing::Test>;
+TEST_F(FixFuncCallArgumentsTest, Simple) {
+ const std::string text = R"(
+;
+; CHECK: [[v0:%\w+]] = OpVariable %_ptr_Function_float Function
+; CHECK: [[v1:%\w+]] = OpVariable %_ptr_Function_float Function
+; CHECK: [[v2:%\w+]] = OpVariable %_ptr_Function_T Function
+; CHECK: [[ac0:%\w+]] = OpAccessChain %_ptr_Function_float %t %int_0
+; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
+; CHECK: [[ld0:%\w+]] = OpLoad %float [[ac0]]
+; CHECK: OpStore [[v1]] [[ld0]]
+; CHECK: [[ld1:%\w+]] = OpLoad %float [[ac1]]
+; CHECK: OpStore [[v0]] [[ld1]]
+; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[v1]] [[v0]]
+; CHECK: [[ld2:%\w+]] = OpLoad %float [[v0]]
+; CHECK: OpStore [[ac1]] [[ld2]]
+; CHECK: [[ld3:%\w+]] = OpLoad %float [[v1]]
+; CHECK: OpStore [[ac0]] [[ld3]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpSource HLSL 630
+OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
+OpName %r1 "r1"
+OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
+OpMemberName %type_ACSBuffer_counter 0 "counter"
+OpName %counter_var_r1 "counter.var.r1"
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %T "T"
+OpMemberName %T 0 "t0"
+OpName %t "t"
+OpName %fn "fn"
+OpName %p0 "p0"
+OpName %p2 "p2"
+OpName %bb_entry_0 "bb.entry"
+OpDecorate %main LinkageAttributes "main" Export
+OpDecorate %r1 DescriptorSet 0
+OpDecorate %r1 Binding 0
+OpDecorate %counter_var_r1 DescriptorSet 0
+OpDecorate %counter_var_r1 Binding 1
+OpDecorate %_runtimearr_float ArrayStride 4
+OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
+OpDecorate %type_RWStructuredBuffer_float BufferBlock
+OpMemberDecorate %type_ACSBuffer_counter 0 Offset 0
+OpDecorate %type_ACSBuffer_counter BufferBlock
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%int_1 = OpConstant %int 1
+%float = OpTypeFloat 32
+%_runtimearr_float = OpTypeRuntimeArray %float
+%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
+%type_ACSBuffer_counter = OpTypeStruct %int
+%_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter
+%15 = OpTypeFunction %int
+%T = OpTypeStruct %float
+%_ptr_Function_T = OpTypePointer Function %T
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+%void = OpTypeVoid
+%27 = OpTypeFunction %void %_ptr_Function_float %_ptr_Function_float
+%r1 = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
+%counter_var_r1 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
+%main = OpFunction %int None %15
+%bb_entry = OpLabel
+%t = OpVariable %_ptr_Function_T Function
+%21 = OpAccessChain %_ptr_Function_float %t %int_0
+%23 = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
+%25 = OpFunctionCall %void %fn %21 %23
+OpReturnValue %int_1
+OpFunctionEnd
+%fn = OpFunction %void DontInline %27
+%p0 = OpFunctionParameter %_ptr_Function_float
+%p2 = OpFunctionParameter %_ptr_Function_float
+%bb_entry_0 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, true);
+}
+
+TEST_F(FixFuncCallArgumentsTest, NotAccessChainInput) {
+ const std::string text = R"(
+;
+; CHECK: [[o:%\w+]] = OpCopyObject %_ptr_Function_float %t
+; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[o]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpSource HLSL 630
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %t "t"
+OpName %fn "fn"
+OpName %p0 "p0"
+OpName %bb_entry_0 "bb.entry"
+OpDecorate %main LinkageAttributes "main" Export
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%4 = OpTypeFunction %int
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%void = OpTypeVoid
+%12 = OpTypeFunction %void %_ptr_Function_float
+%main = OpFunction %int None %4
+%bb_entry = OpLabel
+%t = OpVariable %_ptr_Function_float Function
+%t1 = OpCopyObject %_ptr_Function_float %t
+%10 = OpFunctionCall %void %fn %t1
+OpReturnValue %int_1
+OpFunctionEnd
+%fn = OpFunction %void DontInline %12
+%p0 = OpFunctionParameter %_ptr_Function_float
+%bb_entry_0 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, false);
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
\ No newline at end of file
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 2ca3256..e2240b8 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -147,6 +147,7 @@
%v2double = OpTypeVector %double 2
%v2half = OpTypeVector %half 2
%v2bool = OpTypeVector %bool 2
+%m2x2int = OpTypeMatrix %v2int 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
@@ -218,7 +219,9 @@
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103
+%v4int_undef = OpUndef %v4int
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
+%m2x2int_undef = OpUndef %m2x2int
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
%float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
@@ -6862,7 +6865,7 @@
4, true)
));
-INSTANTIATE_TEST_SUITE_P(CompositeExtractMatchingTest, MatchingInstructionFoldingTest,
+INSTANTIATE_TEST_SUITE_P(CompositeExtractOrInsertMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: Extracting from result of consecutive shuffles of differing
// size.
@@ -7002,7 +7005,145 @@
"%4 = OpCompositeExtract %int %3 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 4, true)
+ 4, true),
+ // Test case 8: Inserting every element of a vector turns into a composite construct.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%3 = OpCompositeInsert %v4int %int_1 %2 1\n" +
+ "%4 = OpCompositeInsert %v4int %int_2 %3 2\n" +
+ "%5 = OpCompositeInsert %v4int %int_3 %4 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 9: Inserting every element of a vector turns into a composite construct in a different order.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%4 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+ "%3 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+ "%5 = OpCompositeInsert %v4int %int_3 %3 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 10: Check multiple inserts to the same position are handled correctly.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %6 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%3 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+ "%4 = OpCompositeInsert %v4int %int_4 %3 1\n" +
+ "%5 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+ "%6 = OpCompositeInsert %v4int %int_3 %5 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 6, true),
+ // Test case 11: The last indexes are 0 and 1, but they have different first indexes. This should not be folded.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_1 %2 1 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, false),
+ // Test case 12: Don't fold when there is a partial insertion.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_4 %2 0 0\n" +
+ "%4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, false),
+ // Test case 13: Insert into a column of a matrix
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+// We keep this insert in the chain. DeadInsertElimPass should remove it.
+ "; CHECK: [[insert:%\\w+]] = OpCompositeInsert [[m2x2]] %100 [[m2x2_undef]] 0 0\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+ "; CHECK: %3 = OpCompositeInsert [[m2x2]] [[construct]] [[insert]] 0\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_1 %2 0 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, true),
+ // Test case 14: Insert all elements of the matrix.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[c0:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+ "; CHECK: [[c1:%\\w+]] = OpCompositeConstruct [[v2]] [[int2]] [[int3]]\n" +
+ "; CHECK: [[matrix:%\\w+]] = OpCompositeConstruct [[m2x2]] [[c0]] [[c1]]\n" +
+ "; CHECK: %5 = OpCopyObject [[m2x2]] [[matrix]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeConstruct %v2int %100 %int_1\n" +
+ "%3 = OpCompositeInsert %m2x2int %2 %m2x2int_undef 0\n" +
+ "%4 = OpCompositeInsert %m2x2int %int_2 %3 1 0\n" +
+ "%5 = OpCompositeInsert %m2x2int %int_3 %4 1 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 15: Replace construct with extract when reconstructing a member
+ // of another object.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK: %5 = OpCompositeExtract [[v2]] [[m2x2_undef]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%3 = OpCompositeExtract %int %m2x2int_undef 1 0\n" +
+ "%4 = OpCompositeExtract %int %m2x2int_undef 1 1\n" +
+ "%5 = OpCompositeConstruct %v2int %3 %4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true)
));
INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,
diff --git a/test/opt/interface_var_sroa_test.cpp b/test/opt/interface_var_sroa_test.cpp
new file mode 100644
index 0000000..7762458
--- /dev/null
+++ b/test/opt/interface_var_sroa_test.cpp
@@ -0,0 +1,410 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+
+#include "gmock/gmock.h"
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using InterfaceVariableScalarReplacementTest = PassTest<::testing::Test>;
+
+TEST_F(InterfaceVariableScalarReplacementTest,
+ ReplaceInterfaceVarsWithScalars) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability Tessellation
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint TessellationControl %func "shader" %x %y %z %w %u %v
+
+; CHECK: OpName [[x:%\w+]] "x"
+; CHECK-NOT: OpName {{%\w+}} "x"
+; CHECK: OpName [[y:%\w+]] "y"
+; CHECK-NOT: OpName {{%\w+}} "y"
+; CHECK: OpName [[z0:%\w+]] "z"
+; CHECK: OpName [[z1:%\w+]] "z"
+; CHECK: OpName [[w0:%\w+]] "w"
+; CHECK: OpName [[w1:%\w+]] "w"
+; CHECK: OpName [[u0:%\w+]] "u"
+; CHECK: OpName [[u1:%\w+]] "u"
+; CHECK: OpName [[v0:%\w+]] "v"
+; CHECK: OpName [[v1:%\w+]] "v"
+; CHECK: OpName [[v2:%\w+]] "v"
+; CHECK: OpName [[v3:%\w+]] "v"
+; CHECK: OpName [[v4:%\w+]] "v"
+; CHECK: OpName [[v5:%\w+]] "v"
+ OpName %x "x"
+ OpName %y "y"
+ OpName %z "z"
+ OpName %w "w"
+ OpName %u "u"
+ OpName %v "v"
+
+; CHECK-DAG: OpDecorate [[x]] Location 2
+; CHECK-DAG: OpDecorate [[y]] Location 0
+; CHECK-DAG: OpDecorate [[z0]] Location 0
+; CHECK-DAG: OpDecorate [[z0]] Component 0
+; CHECK-DAG: OpDecorate [[z1]] Location 1
+; CHECK-DAG: OpDecorate [[z1]] Component 0
+; CHECK-DAG: OpDecorate [[z0]] Patch
+; CHECK-DAG: OpDecorate [[z1]] Patch
+; CHECK-DAG: OpDecorate [[w0]] Location 2
+; CHECK-DAG: OpDecorate [[w0]] Component 0
+; CHECK-DAG: OpDecorate [[w1]] Location 3
+; CHECK-DAG: OpDecorate [[w1]] Component 0
+; CHECK-DAG: OpDecorate [[w0]] Patch
+; CHECK-DAG: OpDecorate [[w1]] Patch
+; CHECK-DAG: OpDecorate [[u0]] Location 3
+; CHECK-DAG: OpDecorate [[u0]] Component 2
+; CHECK-DAG: OpDecorate [[u1]] Location 4
+; CHECK-DAG: OpDecorate [[u1]] Component 2
+; CHECK-DAG: OpDecorate [[v0]] Location 3
+; CHECK-DAG: OpDecorate [[v0]] Component 3
+; CHECK-DAG: OpDecorate [[v1]] Location 4
+; CHECK-DAG: OpDecorate [[v1]] Component 3
+; CHECK-DAG: OpDecorate [[v2]] Location 5
+; CHECK-DAG: OpDecorate [[v2]] Component 3
+; CHECK-DAG: OpDecorate [[v3]] Location 6
+; CHECK-DAG: OpDecorate [[v3]] Component 3
+; CHECK-DAG: OpDecorate [[v4]] Location 7
+; CHECK-DAG: OpDecorate [[v4]] Component 3
+; CHECK-DAG: OpDecorate [[v5]] Location 8
+; CHECK-DAG: OpDecorate [[v5]] Component 3
+ OpDecorate %z Patch
+ OpDecorate %w Patch
+ OpDecorate %z Location 0
+ OpDecorate %x Location 2
+ OpDecorate %v Location 3
+ OpDecorate %v Component 3
+ OpDecorate %y Location 0
+ OpDecorate %w Location 2
+ OpDecorate %u Location 3
+ OpDecorate %u Component 2
+
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %uint_3 = OpConstant %uint 3
+ %uint_4 = OpConstant %uint 4
+%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
+%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
+%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
+%_ptr_Input_uint = OpTypePointer Input %uint
+%_ptr_Output_uint = OpTypePointer Output %uint
+%_arr_arr_uint_uint_2_3 = OpTypeArray %_arr_uint_uint_2 %uint_3
+%_ptr_Input__arr_arr_uint_uint_2_3 = OpTypePointer Input %_arr_arr_uint_uint_2_3
+%_arr_arr_arr_uint_uint_2_3_4 = OpTypeArray %_arr_arr_uint_uint_2_3 %uint_4
+%_ptr_Output__arr_arr_arr_uint_uint_2_3_4 = OpTypePointer Output %_arr_arr_arr_uint_uint_2_3_4
+%_ptr_Output__arr_arr_uint_uint_2_3 = OpTypePointer Output %_arr_arr_uint_uint_2_3
+ %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %u = OpVariable %_ptr_Input__arr_arr_uint_uint_2_3 Input
+ %v = OpVariable %_ptr_Output__arr_arr_arr_uint_uint_2_3_4 Output
+
+; CHECK-DAG: [[x]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
+; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
+; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
+; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
+; CHECK-DAG: [[u0]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
+; CHECK-DAG: [[u1]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
+; CHECK-DAG: [[v0]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+; CHECK-DAG: [[v1]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+; CHECK-DAG: [[v2]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+; CHECK-DAG: [[v3]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+; CHECK-DAG: [[v4]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+; CHECK-DAG: [[v5]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
+
+ %void = OpTypeVoid
+ %void_f = OpTypeFunction %void
+ %func = OpFunction %void None %void_f
+ %label = OpLabel
+
+; CHECK: [[w0_value:%\w+]] = OpLoad %uint [[w0]]
+; CHECK: [[w1_value:%\w+]] = OpLoad %uint [[w1]]
+; CHECK: [[w_value:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[w0_value]] [[w1_value]]
+; CHECK: [[w0:%\w+]] = OpCompositeExtract %uint [[w_value]] 0
+; CHECK: OpStore [[z0]] [[w0]]
+; CHECK: [[w1:%\w+]] = OpCompositeExtract %uint [[w_value]] 1
+; CHECK: OpStore [[z1]] [[w1]]
+ %w_value = OpLoad %_arr_uint_uint_2 %w
+ OpStore %z %w_value
+
+; CHECK: [[u00_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_0
+; CHECK: [[u00:%\w+]] = OpLoad %uint [[u00_ptr]]
+; CHECK: [[u10_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_0
+; CHECK: [[u10:%\w+]] = OpLoad %uint [[u10_ptr]]
+; CHECK: [[u01_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_1
+; CHECK: [[u01:%\w+]] = OpLoad %uint [[u01_ptr]]
+; CHECK: [[u11_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_1
+; CHECK: [[u11:%\w+]] = OpLoad %uint [[u11_ptr]]
+; CHECK: [[u02_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_2
+; CHECK: [[u02:%\w+]] = OpLoad %uint [[u02_ptr]]
+; CHECK: [[u12_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_2
+; CHECK: [[u12:%\w+]] = OpLoad %uint [[u12_ptr]]
+
+; CHECK-DAG: [[u0_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u00]] [[u10]]
+; CHECK-DAG: [[u1_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u01]] [[u11]]
+; CHECK-DAG: [[u2_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u02]] [[u12]]
+
+; CHECK: [[u_val:%\w+]] = OpCompositeConstruct %_arr__arr_uint_uint_2_uint_3 [[u0_val]] [[u1_val]] [[u2_val]]
+
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 0
+; CHECK: OpStore [[ptr]] [[val]]
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 1
+; CHECK: OpStore [[ptr]] [[val]]
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 0
+; CHECK: OpStore [[ptr]] [[val]]
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 1
+; CHECK: OpStore [[ptr]] [[val]]
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 0
+; CHECK: OpStore [[ptr]] [[val]]
+; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_1
+; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 1
+; CHECK: OpStore [[ptr]] [[val]]
+ %v_ptr = OpAccessChain %_ptr_Output__arr_arr_uint_uint_2_3 %v %uint_1
+ %u_val = OpLoad %_arr_arr_uint_uint_2_3 %u
+ OpStore %v_ptr %u_val
+
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
+}
+
+TEST_F(InterfaceVariableScalarReplacementTest,
+ CheckPatchDecorationPreservation) {
+ // Make sure scalars for the variables with the extra arrayness have the extra
+ // arrayness after running the pass while others do not have it.
+ // Only "y" does not have the extra arrayness in the following SPIR-V.
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability Tessellation
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint TessellationEvaluation %func "shader" %x %y %z %w
+ OpDecorate %z Patch
+ OpDecorate %w Patch
+ OpDecorate %z Location 0
+ OpDecorate %x Location 2
+ OpDecorate %y Location 0
+ OpDecorate %w Location 1
+ OpName %x "x"
+ OpName %y "y"
+ OpName %z "z"
+ OpName %w "w"
+
+ ; CHECK: OpName [[y:%\w+]] "y"
+ ; CHECK-NOT: OpName {{%\w+}} "y"
+ ; CHECK-DAG: OpName [[z0:%\w+]] "z"
+ ; CHECK-DAG: OpName [[z1:%\w+]] "z"
+ ; CHECK-DAG: OpName [[w0:%\w+]] "w"
+ ; CHECK-DAG: OpName [[w1:%\w+]] "w"
+ ; CHECK-DAG: OpName [[x0:%\w+]] "x"
+ ; CHECK-DAG: OpName [[x1:%\w+]] "x"
+
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
+%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
+%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
+ %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+
+ ; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ ; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
+ ; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
+ ; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
+ ; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
+ ; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
+ ; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
+
+ %void = OpTypeVoid
+ %void_f = OpTypeFunction %void
+ %func = OpFunction %void None %void_f
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
+}
+
+TEST_F(InterfaceVariableScalarReplacementTest,
+ CheckEntryPointInterfaceOperands) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability Tessellation
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint TessellationEvaluation %tess "tess" %x %y
+ OpEntryPoint Vertex %vert "vert" %w
+ OpDecorate %z Location 0
+ OpDecorate %x Location 2
+ OpDecorate %y Location 0
+ OpDecorate %w Location 1
+ OpName %x "x"
+ OpName %y "y"
+ OpName %z "z"
+ OpName %w "w"
+
+ ; CHECK: OpName [[y:%\w+]] "y"
+ ; CHECK-NOT: OpName {{%\w+}} "y"
+ ; CHECK-DAG: OpName [[x0:%\w+]] "x"
+ ; CHECK-DAG: OpName [[x1:%\w+]] "x"
+ ; CHECK-DAG: OpName [[w0:%\w+]] "w"
+ ; CHECK-DAG: OpName [[w1:%\w+]] "w"
+ ; CHECK-DAG: OpName [[z:%\w+]] "z"
+ ; CHECK-NOT: OpName {{%\w+}} "z"
+
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
+%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
+%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
+ %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+
+ ; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ ; CHECK-DAG: [[z]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ ; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
+ ; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
+ ; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
+ ; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
+
+ %void = OpTypeVoid
+ %void_f = OpTypeFunction %void
+ %tess = OpFunction %void None %void_f
+ %bb0 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %vert = OpFunction %void None %void_f
+ %bb1 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
+}
+
+class InterfaceVarSROAErrorTest : public PassTest<::testing::Test> {
+ public:
+ InterfaceVarSROAErrorTest()
+ : consumer_([this](spv_message_level_t level, const char*,
+ const spv_position_t& position, const char* message) {
+ if (!error_message_.empty()) error_message_ += "\n";
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ error_message_ += "ERROR";
+ break;
+ case SPV_MSG_WARNING:
+ error_message_ += "WARNING";
+ break;
+ case SPV_MSG_INFO:
+ error_message_ += "INFO";
+ break;
+ case SPV_MSG_DEBUG:
+ error_message_ += "DEBUG";
+ break;
+ }
+ error_message_ +=
+ ": " + std::to_string(position.index) + ": " + message;
+ }) {}
+
+ Pass::Status RunPass(const std::string& text) {
+ std::unique_ptr<IRContext> context_ =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text);
+ if (!context_.get()) return Pass::Status::Failure;
+
+ PassManager manager;
+ manager.SetMessageConsumer(consumer_);
+ manager.AddPass<InterfaceVariableScalarReplacement>();
+
+ return manager.Run(context_.get());
+ }
+
+ std::string GetErrorMessage() const { return error_message_; }
+
+ void TearDown() override { error_message_.clear(); }
+
+ private:
+ spvtools::MessageConsumer consumer_;
+ std::string error_message_;
+};
+
+TEST_F(InterfaceVarSROAErrorTest, CheckConflictOfExtraArraynessBetweenEntries) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability Tessellation
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint TessellationControl %tess "tess" %x %y %z
+ OpEntryPoint Vertex %vert "vert" %z %w
+ OpDecorate %z Location 0
+ OpDecorate %x Location 2
+ OpDecorate %y Location 0
+ OpDecorate %w Location 1
+ OpName %x "x"
+ OpName %y "y"
+ OpName %z "z"
+ OpName %w "w"
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
+%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
+%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
+ %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
+ %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
+ %void = OpTypeVoid
+ %void_f = OpTypeFunction %void
+ %tess = OpFunction %void None %void_f
+ %bb0 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %vert = OpFunction %void None %void_f
+ %bb1 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ EXPECT_EQ(RunPass(spirv), Pass::Status::Failure);
+ const char expected_error[] =
+ "ERROR: 0: A variable is arrayed for an entry point but it is not "
+ "arrayed for another entry point\n"
+ " %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output";
+ EXPECT_STREQ(GetErrorMessage().c_str(), expected_error);
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp
index 6fcf23f..2b3231c 100644
--- a/test/opt/local_access_chain_convert_test.cpp
+++ b/test/opt/local_access_chain_convert_test.cpp
@@ -1156,6 +1156,101 @@
SinglePassRunAndMatch<LocalAccessChainConvertPass>(before, true);
}
+TEST_F(LocalAccessChainConvertTest, AccessChainWithLongIndex) {
+ // The access chain take a value that is larger than 32-bit. The index cannot
+ // be encoded in an OpCompositeExtract, so nothing should be done.
+ const std::string before =
+ R"(OpCapability Shader
+OpCapability Int64
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
+OpExecutionMode %2 OriginUpperLeft
+%ulong = OpTypeInt 64 0
+%ulong_8589934592 = OpConstant %ulong 8589934592
+%ulong_8589934591 = OpConstant %ulong 8589934591
+%_arr_ulong_ulong_8589934592 = OpTypeArray %ulong %ulong_8589934592
+%_ptr_Function__arr_ulong_ulong_8589934592 = OpTypePointer Function %_arr_ulong_ulong_8589934592
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%2 = OpFunction %void None %10
+%11 = OpLabel
+%12 = OpVariable %_ptr_Function__arr_ulong_ulong_8589934592 Function
+%13 = OpAccessChain %_ptr_Function_ulong %12 %ulong_8589934591
+%14 = OpLoad %ulong %13
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<LocalAccessChainConvertPass>(before, before, false,
+ true);
+}
+
+TEST_F(LocalAccessChainConvertTest, AccessChainWith32BitIndexInLong) {
+ // The access chain has a value that is 32-bits, but it is stored in a 64-bit
+ // variable. This access change can be converted to an extract.
+ const std::string before =
+ R"(
+; CHECK: OpFunction
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]]
+; CHECK: OpCompositeExtract %ulong [[ld]] 3
+ OpCapability Shader
+ OpCapability Int64
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
+ OpExecutionMode %2 OriginUpperLeft
+ %ulong = OpTypeInt 64 0
+%ulong_8589934592 = OpConstant %ulong 8589934592
+%ulong_3 = OpConstant %ulong 3
+%_arr_ulong_ulong_8589934592 = OpTypeArray %ulong %ulong_8589934592
+%_ptr_Function__arr_ulong_ulong_8589934592 = OpTypePointer Function %_arr_ulong_ulong_8589934592
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %void = OpTypeVoid
+ %10 = OpTypeFunction %void
+ %2 = OpFunction %void None %10
+ %11 = OpLabel
+ %12 = OpVariable %_ptr_Function__arr_ulong_ulong_8589934592 Function
+ %13 = OpAccessChain %_ptr_Function_ulong %12 %ulong_3
+ %14 = OpLoad %ulong %13
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<LocalAccessChainConvertPass>(before, true);
+}
+
+TEST_F(LocalAccessChainConvertTest, AccessChainWithVarIndex) {
+ // The access chain has a value that is not constant, so there should not be
+ // any changes.
+ const std::string before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
+OpExecutionMode %2 OriginUpperLeft
+%uint = OpTypeInt 32 0
+%uint_5 = OpConstant %uint 5
+%_arr_uint_uint_5 = OpTypeArray %uint %uint_5
+%_ptr_Function__arr_uint_uint_5 = OpTypePointer Function %_arr_uint_uint_5
+%_ptr_Function_uint = OpTypePointer Function %uint
+%8 = OpUndef %uint
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%2 = OpFunction %void None %10
+%11 = OpLabel
+%12 = OpVariable %_ptr_Function__arr_uint_uint_5 Function
+%13 = OpAccessChain %_ptr_Function_uint %12 %8
+%14 = OpLoad %uint %13
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<LocalAccessChainConvertPass>(before, before, false,
+ true);
+}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
diff --git a/test/opt/replace_desc_array_access_using_var_index_test.cpp b/test/opt/replace_desc_array_access_using_var_index_test.cpp
index 9900304..9ab9eb1 100644
--- a/test/opt/replace_desc_array_access_using_var_index_test.cpp
+++ b/test/opt/replace_desc_array_access_using_var_index_test.cpp
@@ -406,6 +406,83 @@
SinglePassRunAndMatch<ReplaceDescArrayAccessUsingVarIndex>(text, true);
}
+TEST_F(ReplaceDescArrayAccessUsingVarIndexTest, ReplaceMultipleAccessChains) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "TestFragment" %2
+ OpExecutionMode %1 OriginUpperLeft
+ OpName %11 "type.ConstantBuffer.TestStruct"
+ OpMemberName %11 0 "val1"
+ OpMemberName %11 1 "val2"
+ OpName %3 "TestResources"
+ OpName %13 "type.2d.image"
+ OpName %4 "OutBuffer"
+ OpName %2 "in.var.SV_INSTANCEID"
+ OpName %1 "TestFragment"
+ OpDecorate %2 Flat
+ OpDecorate %2 Location 0
+ OpDecorate %3 DescriptorSet 0
+ OpDecorate %3 Binding 0
+ OpDecorate %4 DescriptorSet 0
+ OpDecorate %4 Binding 1
+ OpMemberDecorate %11 0 Offset 0
+ OpMemberDecorate %11 1 Offset 4
+ OpDecorate %11 Block
+ %9 = OpTypeInt 32 0
+ %10 = OpConstant %9 2
+ %11 = OpTypeStruct %9 %9
+ %8 = OpTypeArray %11 %10
+ %7 = OpTypePointer Uniform %8
+ %13 = OpTypeImage %9 2D 2 0 0 2 R32ui
+ %12 = OpTypePointer UniformConstant %13
+ %14 = OpTypePointer Input %9
+ %15 = OpTypeVoid
+ %16 = OpTypeFunction %15
+ %40 = OpTypeVector %9 2
+ %3 = OpVariable %7 Uniform
+ %4 = OpVariable %12 UniformConstant
+ %2 = OpVariable %14 Input
+ %57 = OpTypePointer Uniform %11
+ %61 = OpTypePointer Uniform %9
+ %62 = OpConstant %9 0
+ %1 = OpFunction %15 None %16
+ %17 = OpLabel
+ %20 = OpLoad %9 %2
+ %47 = OpAccessChain %57 %3 %20
+ %63 = OpAccessChain %61 %47 %62
+ %64 = OpLoad %9 %63
+
+; CHECK: [[null_value:%\w+]] = OpConstantNull %uint
+
+; CHECK: [[var_index:%\w+]] = OpLoad %uint %in_var_SV_INSTANCEID
+; CHECK: OpSelectionMerge [[merge:%\w+]] None
+; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]]
+; CHECK: [[case0]] = OpLabel
+; CHECK: OpAccessChain
+; CHECK: OpAccessChain
+; CHECK: [[result0:%\w+]] = OpLoad
+; CHECK: OpBranch [[merge]]
+; CHECK: [[case1]] = OpLabel
+; CHECK: OpAccessChain
+; CHECK: OpAccessChain
+; CHECK: [[result1:%\w+]] = OpLoad
+; CHECK: OpBranch [[merge]]
+; CHECK: [[default]] = OpLabel
+; CHECK: OpBranch [[merge]]
+; CHECK: [[merge]] = OpLabel
+; CHECK: OpPhi %uint [[result0]] [[case0]] [[result1]] [[case1]] [[null_value]] [[default]]
+
+ %55 = OpCompositeConstruct %40 %20 %20
+ %56 = OpLoad %13 %4
+ OpImageWrite %56 %55 %64 None
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<ReplaceDescArrayAccessUsingVarIndex>(text, true);
+}
+
TEST_F(ReplaceDescArrayAccessUsingVarIndexTest,
ReplaceAccessChainToTextureArrayWithNonUniformIndex) {
const std::string text = R"(
diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp
index d749c5a..b76c163 100644
--- a/test/val/val_builtins_test.cpp
+++ b/test/val/val_builtins_test.cpp
@@ -2372,6 +2372,67 @@
"needs to be a 32-bit int scalar",
"is not an int scalar"))));
+// CullMaskKHR is valid
+// in IS, AH, CH, MS shaders as an input i32 scalar
+INSTANTIATE_TEST_SUITE_P(
+ CullMaskSuccess,
+ ValidateGenericCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(Values(SPV_ENV_VULKAN_1_2), Values("CullMaskKHR"),
+ Values("AnyHitKHR", "ClosestHitKHR", "IntersectionKHR", "MissKHR"),
+ Values("Input"), Values("%u32"),
+ Values("OpCapability RayTracingKHR\nOpCapability RayCullMaskKHR\n"),
+ Values("OpExtension \"SPV_KHR_ray_tracing\"\nOpExtension "
+ "\"SPV_KHR_ray_cull_mask\"\n"),
+ Values(nullptr), Values(TestResult())));
+
+INSTANTIATE_TEST_SUITE_P(
+ CullMaskNotExecutionMode,
+ ValidateGenericCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(Values(SPV_ENV_VULKAN_1_2), Values("CullMaskKHR"),
+ Values("Vertex", "Fragment", "TessellationControl",
+ "TessellationEvaluation", "Geometry", "Fragment",
+ "GLCompute", "RayGenerationKHR", "CallableKHR"),
+ Values("Input"), Values("%u32"),
+ Values("OpCapability RayTracingKHR\nOpCapability RayCullMaskKHR\n"),
+ Values("OpExtension \"SPV_KHR_ray_tracing\"\nOpExtension "
+ "\"SPV_KHR_ray_cull_mask\"\n"),
+ Values("VUID-CullMaskKHR-CullMaskKHR-06735 "
+ "VUID-RayTmaxKHR-RayTmaxKHR-04348 "
+ "VUID-RayTminKHR-RayTminKHR-04351 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA,
+ "Vulkan spec does not allow BuiltIn",
+ "to be used with the execution model"))));
+
+INSTANTIATE_TEST_SUITE_P(
+ ICullMaskNotInput,
+ ValidateGenericCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(Values(SPV_ENV_VULKAN_1_2), Values("CullMaskKHR"),
+ Values("AnyHitKHR", "ClosestHitKHR", "IntersectionKHR", "MissKHR"),
+ Values("Output"), Values("%u32"),
+ Values("OpCapability RayTracingKHR\nOpCapability RayCullMaskKHR\n"),
+ Values("OpExtension \"SPV_KHR_ray_tracing\"\nOpExtension "
+ "\"SPV_KHR_ray_cull_mask\"\n"),
+ Values("VUID-CullMaskKHR-CullMaskKHR-06736 "
+ "VUID-RayTmaxKHR-RayTmaxKHR-04349 "
+ "VUID-RayTminKHR-RayTminKHR-04352 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA, "Vulkan spec allows",
+ "used for variables with Input storage class"))));
+INSTANTIATE_TEST_SUITE_P(
+ CullMaskNotIntScalar,
+ ValidateGenericCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(Values(SPV_ENV_VULKAN_1_2), Values("CullMaskKHR"),
+ Values("AnyHitKHR", "ClosestHitKHR", "IntersectionKHR", "MissKHR"),
+ Values("Input"), Values("%f32", "%u32vec3"),
+ Values("OpCapability RayTracingKHR\nOpCapability RayCullMaskKHR\n"),
+ Values("OpExtension \"SPV_KHR_ray_tracing\"\nOpExtension "
+ "\"SPV_KHR_ray_cull_mask\"\n"),
+ Values("VUID-CullMaskKHR-CullMaskKHR-06737 "
+ "VUID-RayTmaxKHR-RayTmaxKHR-04350 "
+ "VUID-RayTminKHR-RayTminKHR-04353 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA,
+ "needs to be a 32-bit int scalar",
+ "is not an int scalar"))));
+
// RayTmaxKHR, RayTminKHR are all valid
// in IS, AH, CH, MS shaders as input f32 scalars
INSTANTIATE_TEST_SUITE_P(
@@ -4065,6 +4126,71 @@
"According to the Vulkan spec BuiltIn FullyCoveredEXT variable "
"needs to be a bool scalar."))));
+INSTANTIATE_TEST_SUITE_P(
+ BaryCoordNotFragment,
+ ValidateVulkanCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(
+ Values("BaryCoordKHR", "BaryCoordNoPerspKHR"), Values("Vertex"),
+ Values("Input"), Values("%f32vec3"),
+ Values("OpCapability FragmentBarycentricKHR\n"),
+ Values("OpExtension \"SPV_KHR_fragment_shader_barycentric\"\n"),
+ Values("VUID-BaryCoordKHR-BaryCoordKHR-04154 "
+ "VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04160 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA, "Vulkan spec allows BuiltIn",
+ "to be used only with Fragment execution model"))));
+
+INSTANTIATE_TEST_SUITE_P(
+ BaryCoordNotInput,
+ ValidateVulkanCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(Values("BaryCoordKHR", "BaryCoordNoPerspKHR"), Values("Fragment"),
+ Values("Output"), Values("%f32vec3"),
+ Values("OpCapability FragmentBarycentricKHR\n"),
+ Values("OpExtension \"SPV_KHR_fragment_shader_barycentric\"\n"),
+ Values("VUID-BaryCoordKHR-BaryCoordKHR-04155 "
+ "VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04161 "),
+ Values(TestResult(
+ SPV_ERROR_INVALID_DATA, "Vulkan spec allows BuiltIn",
+ "to be only used for variables with Input storage class"))));
+
+INSTANTIATE_TEST_SUITE_P(
+ BaryCoordNotFloatVector,
+ ValidateVulkanCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(
+ Values("BaryCoordKHR", "BaryCoordNoPerspKHR"), Values("Fragment"),
+ Values("Output"), Values("%f32arr3", "%u32vec4"),
+ Values("OpCapability FragmentBarycentricKHR\n"),
+ Values("OpExtension \"SPV_KHR_fragment_shader_barycentric\"\n"),
+ Values("VUID-BaryCoordKHR-BaryCoordKHR-04156 "
+ "VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04162 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA,
+ "needs to be a 3-component 32-bit float vector"))));
+
+INSTANTIATE_TEST_SUITE_P(
+ BaryCoordNotFloatVec3,
+ ValidateVulkanCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(
+ Values("BaryCoordKHR", "BaryCoordNoPerspKHR"), Values("Fragment"),
+ Values("Output"), Values("%f32vec2"),
+ Values("OpCapability FragmentBarycentricKHR\n"),
+ Values("OpExtension \"SPV_KHR_fragment_shader_barycentric\"\n"),
+ Values("VUID-BaryCoordKHR-BaryCoordKHR-04156 "
+ "VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04162 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA,
+ "needs to be a 3-component 32-bit float vector"))));
+
+INSTANTIATE_TEST_SUITE_P(
+ BaryCoordNotF32Vec3,
+ ValidateVulkanCombineBuiltInExecutionModelDataTypeCapabilityExtensionResult,
+ Combine(
+ Values("BaryCoordKHR", "BaryCoordNoPerspKHR"), Values("Fragment"),
+ Values("Output"), Values("%f64vec3"),
+ Values("OpCapability FragmentBarycentricKHR\n"),
+ Values("OpExtension \"SPV_KHR_fragment_shader_barycentric\"\n"),
+ Values("VUID-BaryCoordKHR-BaryCoordKHR-04156 "
+ "VUID-BaryCoordNoPerspKHR-BaryCoordNoPerspKHR-04162 "),
+ Values(TestResult(SPV_ERROR_INVALID_DATA,
+ "needs to be a 3-component 32-bit float vector"))));
+
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp
index 2db44a4..e7ecb61 100644
--- a/test/val/val_decoration_test.cpp
+++ b/test/val/val_decoration_test.cpp
@@ -8220,6 +8220,149 @@
"Offset decorations"));
}
+TEST_F(ValidateDecorations, PerVertexVulkanGood) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability FragmentBarycentricKHR
+ OpExtension "SPV_KHR_fragment_shader_barycentric"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %vertexIDs
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %vertexIDs Location 0
+ OpDecorate %vertexIDs PerVertexKHR
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %uint = OpTypeInt 32 0
+%ptrFloat = OpTypePointer Input %float
+ %uint_3 = OpConstant %uint 3
+%floatArray = OpTypeArray %float %uint_3
+%ptrFloatArray = OpTypePointer Input %floatArray
+ %vertexIDs = OpVariable %ptrFloatArray Input
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %ptrFloat %vertexIDs %int_0
+ %load = OpLoad %float %access
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
+TEST_F(ValidateDecorations, PerVertexVulkanOutput) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability FragmentBarycentricKHR
+ OpExtension "SPV_KHR_fragment_shader_barycentric"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %vertexIDs
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %vertexIDs Location 0
+ OpDecorate %vertexIDs PerVertexKHR
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %uint = OpTypeInt 32 0
+%ptrFloat = OpTypePointer Output %float
+ %uint_3 = OpConstant %uint 3
+%floatArray = OpTypeArray %float %uint_3
+%ptrFloatArray = OpTypePointer Output %floatArray
+ %vertexIDs = OpVariable %ptrFloatArray Output
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %ptrFloat %vertexIDs %int_0
+ %load = OpLoad %float %access
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-PerVertexKHR-06777"));
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("storage class must be Input"));
+}
+
+TEST_F(ValidateDecorations, PerVertexVulkanNonFragment) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability FragmentBarycentricKHR
+ OpExtension "SPV_KHR_fragment_shader_barycentric"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %main "main" %vertexIDs
+ OpDecorate %vertexIDs Location 0
+ OpDecorate %vertexIDs PerVertexKHR
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %uint = OpTypeInt 32 0
+%ptrFloat = OpTypePointer Input %float
+ %uint_3 = OpConstant %uint 3
+%floatArray = OpTypeArray %float %uint_3
+%ptrFloatArray = OpTypePointer Input %floatArray
+ %vertexIDs = OpVariable %ptrFloatArray Input
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %ptrFloat %vertexIDs %int_0
+ %load = OpLoad %float %access
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-PerVertexKHR-06777"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "PerVertexKHR can only be applied to Fragment Execution Models"));
+}
+
+TEST_F(ValidateDecorations, PerVertexVulkanNonArray) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability FragmentBarycentricKHR
+ OpExtension "SPV_KHR_fragment_shader_barycentric"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %vertexIDs
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %vertexIDs Location 0
+ OpDecorate %vertexIDs PerVertexKHR
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %ptrFloat = OpTypePointer Input %float
+ %vertexIDs = OpVariable %ptrFloat Input
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %load = OpLoad %float %vertexIDs
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Input-06778"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("PerVertexKHR must be declared as arrays"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index 0129478..ce2103c 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -157,12 +157,6 @@
another. It will only propagate an array if the source is never
written to, and the only store to the target is the copy.)");
printf(R"(
- --decompose-initialized-variables
- Decomposes initialized variable declarations into a declaration
- followed by a store of the initial value. This is done to work
- around known issues with some Vulkan drivers for initialize
- variables.)");
- printf(R"(
--replace-desc-array-access-using-var-index
Replaces accesses to descriptor arrays based on a variable index
with a switch that has a case for every possible value of the
@@ -237,6 +231,10 @@
loads and stores. Performed only on entry point call tree
functions.)");
printf(R"(
+ --fix-func-call-param
+ fix non memory argument for the function call, replace
+ accesschain pointer argument with a variable.)");
+ printf(R"(
--flatten-decorations
Replace decoration groups with repeated OpDecorate and
OpMemberDecorate instructions.)");
@@ -487,9 +485,6 @@
--strength-reduction
Replaces instructions with equivalent and less expensive ones.)");
printf(R"(
- --strip-atomic-counter-memory
- Removes AtomicCountMemory bit from memory semantics values.)");
- printf(R"(
--strip-debug
Remove all debug instructions.)");
printf(R"(
diff --git a/utils/roll_deps.sh b/utils/roll_deps.sh
index cef8b52..20c061f 100755
--- a/utils/roll_deps.sh
+++ b/utils/roll_deps.sh
@@ -39,6 +39,8 @@
exit 1
fi
+echo "*** Ignore messages about running 'git cl upload' ***"
+
old_head=$(git rev-parse HEAD)
set +e