Squashed 'third_party/SPIRV-Tools/' changes from 9559cdbdf..65e362b7a
65e362b7a AggressiveDCEPass: Set modified to true when appending to to_kill_ (#2825)
d67130cac Replace SwizzleInvocationsAMD extended instruction. (#2823)
ad71c057c Replace SwizzleInvocationsMaskedAMD extended instruction. (#2822)
4ae9b7165 Fix gn check (#2821)
35d98be3b Amd ext to khr (#2811)
5a581e738 spvtools::Optimizer - don't assume original_binary and optimized_binary are aliased (#2799)
73422a0a5 Check feature mgr in context consistency check (#2818)
15fc19d09 Refactor instruction folders (#2815)
1eb89172a Add missing files to BUILD.gn (#2809)
8336d1925 Extend reducer to remove relaxed precision decorations (#2797)
b00ef0d26 Handle Id overflow in private-to-local (#2807)
aef8f92b2 Even more id overflow in sroa (#2806)
c5d1dab99 Add name for variables in desc sroa (#2805)
0cbdc7a2c Remove unimplemented method declaration (#2804)
bc62722b8 Handle overflow in wrap-opkill (#2801)
9cd07272a More handle overflow in sroa (#2800)
06407250a Instrument: Add support for Buffer Device Address extension (#2792)
7b4e5bd5e Update remquo validation to match the OpenCL Extended Instruction Set Specification (#2791)
dac9210dc Use ascii code based characters (#2796)
ff872dc6b Change the way to include header (#2795)
bbd80462f Fix validation of constant matrices (#2794)
60043edfa Replace OpKill With function call. (#2790)
f701237f2 Remove useless semi-colons (#2789)
95386f9e4 Instrument: Fix version 2 output record write for tess eval shaders. (#2782)
22ce39c8e Start SPIRV-Tools v2019.5
d65513e92 Finalize SPIRV-Tools v2019.4
4b64beb1a Add descriptor array scalar replacement (#2742)
c26c2615f Update CHANGES
29af42df1 Add SPV_EXT_physical_storage_buffer to opt whitelists (#2779)
b029d3697 Handle RelaxedPrecision in SROA (#2788)
370375d23 Add -fextra-semi to Clang builds (#2787)
698b56a8f Add 'copy object' transformation (#2766)
4f14b4c8c fuzz: change output extension and fix usage string (#2778)
0b70972a2 Remove extra ';' after member function definition. (#2780)
5ada98d0b Update WebGPU validation rules of OpAtomic*s (#2777)
3726b500b Treat access chain indexes as signed in SROA (#2776)
31590104e Add pass to inject code for robust-buffer-access semantics (#2771)
4a28259cc Update OpMemoryBarriers rules for WebGPU (#2775)
7621034aa Add opt test fixture method SinglePassRunAndFail (#2770)
ac3d13105 Element type is const for analysis::Vector,Matrix,RuntimeArray (#2765)
49797609b Protect against out-of-bounds references when folding OpCompositeExtract (#2774)
7fd2365b0 Don't move debug or decorations when folding (#2772)
7bafeda28 Update OpControlBarriers rules for WebGPU (#2769)
git-subtree-dir: third_party/SPIRV-Tools
git-subtree-split: 65e362b7ae2acb8aa5bd2ad516fb793961e673ee
diff --git a/Android.mk b/Android.mk
index 82d9776..9428116 100644
--- a/Android.mk
+++ b/Android.mk
@@ -75,6 +75,7 @@
SPVTOOLS_OPT_SRC_FILES := \
source/opt/aggressive_dead_code_elim_pass.cpp \
+ source/opt/amd_ext_to_khr.cpp \
source/opt/basic_block.cpp \
source/opt/block_merge_pass.cpp \
source/opt/block_merge_util.cpp \
@@ -95,6 +96,7 @@
source/opt/decompose_initialized_variables_pass.cpp \
source/opt/decoration_manager.cpp \
source/opt/def_use_manager.cpp \
+ source/opt/desc_sroa.cpp \
source/opt/dominator_analysis.cpp \
source/opt/dominator_tree.cpp \
source/opt/eliminate_dead_constant_pass.cpp \
@@ -110,11 +112,13 @@
source/opt/freeze_spec_constant_value_pass.cpp \
source/opt/function.cpp \
source/opt/generate_webgpu_initializers_pass.cpp \
+ source/opt/graphics_robust_access_pass.cpp \
source/opt/if_conversion.cpp \
source/opt/inline_pass.cpp \
source/opt/inline_exhaustive_pass.cpp \
source/opt/inline_opaque_pass.cpp \
source/opt/inst_bindless_check_pass.cpp \
+ source/opt/inst_buff_addr_check_pass.cpp \
source/opt/instruction.cpp \
source/opt/instruction_list.cpp \
source/opt/instrument_pass.cpp \
@@ -169,7 +173,8 @@
source/opt/upgrade_memory_model.cpp \
source/opt/value_number_table.cpp \
source/opt/vector_dce.cpp \
- source/opt/workaround1209.cpp
+ source/opt/workaround1209.cpp \
+ source/opt/wrap_opkill.cpp
# Locations of grammar files.
#
diff --git a/BUILD.gn b/BUILD.gn
index 70772b9..d62aaab 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -451,6 +451,8 @@
sources = [
"source/opt/aggressive_dead_code_elim_pass.cpp",
"source/opt/aggressive_dead_code_elim_pass.h",
+ "source/opt/amd_ext_to_khr.cpp",
+ "source/opt/amd_ext_to_khr.h",
"source/opt/basic_block.cpp",
"source/opt/basic_block.h",
"source/opt/block_merge_pass.cpp",
@@ -491,6 +493,8 @@
"source/opt/decoration_manager.h",
"source/opt/def_use_manager.cpp",
"source/opt/def_use_manager.h",
+ "source/opt/desc_sroa.cpp",
+ "source/opt/desc_sroa.h",
"source/opt/dominator_analysis.cpp",
"source/opt/dominator_analysis.h",
"source/opt/dominator_tree.cpp",
@@ -521,6 +525,8 @@
"source/opt/function.h",
"source/opt/generate_webgpu_initializers_pass.cpp",
"source/opt/generate_webgpu_initializers_pass.h",
+ "source/opt/graphics_robust_access_pass.cpp",
+ "source/opt/graphics_robust_access_pass.h",
"source/opt/if_conversion.cpp",
"source/opt/if_conversion.h",
"source/opt/inline_exhaustive_pass.cpp",
@@ -531,6 +537,8 @@
"source/opt/inline_pass.h",
"source/opt/inst_bindless_check_pass.cpp",
"source/opt/inst_bindless_check_pass.h",
+ "source/opt/inst_buff_addr_check_pass.cpp",
+ "source/opt/inst_buff_addr_check_pass.h",
"source/opt/instruction.cpp",
"source/opt/instruction.h",
"source/opt/instruction_list.cpp",
@@ -646,10 +654,13 @@
"source/opt/vector_dce.h",
"source/opt/workaround1209.cpp",
"source/opt/workaround1209.h",
+ "source/opt/wrap_opkill.cpp",
+ "source/opt/wrap_opkill.h",
]
deps = [
":spvtools",
+ ":spvtools_vendor_tables_spv-amd-shader-ballot",
]
public_deps = [
":spvtools_headers",
@@ -721,6 +732,8 @@
"source/reduce/remove_instruction_reduction_opportunity.h",
"source/reduce/remove_opname_instruction_reduction_opportunity_finder.cpp",
"source/reduce/remove_opname_instruction_reduction_opportunity_finder.h",
+ "source/reduce/remove_relaxed_precision_decoration_opportunity_finder.cpp",
+ "source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h",
"source/reduce/remove_selection_reduction_opportunity.cpp",
"source/reduce/remove_selection_reduction_opportunity.h",
"source/reduce/remove_selection_reduction_opportunity_finder.cpp",
diff --git a/CHANGES b/CHANGES
index 11ecac5..57afc63 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,7 +1,74 @@
Revision history for SPIRV-Tools
-v2019.4-dev 2019-05-15
- - Start v2019.4-dev
+v2019.5-dev 2019-08-08
+ - Start v2019.5-dev
+
+v2019.4 2019-08-08
+ - General:
+ - Memory model support for SPIR-V 1.4
+ - Add new spirv-fuzz tool
+ - Add option for base branch in check_code_format.sh
+ - Removed MarkV and Stats code. (#2576)
+ - Instrument: Add version 2 of record formats (#2630)
+ - Linker: Better type comparison for OpTypeArray and OpTypeForwardPointer (#2580)
+ - Optimizer
+ - Bindless Validation: Instrument descriptor-based loads and stores (#2583)
+ - Better folding for OpSpecConstantOp (#2585, #2614)
+ - Add in individual flags for Vulkan <-> WebGPU passes (#2615)
+ - Handle nested breaks from switches. (#2624)
+ - Optimizer: Handle array type with OpSpecConstantOp length (#2652)
+ - Perform merge return with single return in loop. (#2714)
+ - Add --preserve-bindings and --preserve-spec-constants (#2693)
+ - Remove Common Uniform Elimination Pass (#2731)
+ - Allow ray tracing shaders in inst bindle check pass. (#2733)
+ - Add pass to inject code for robust-buffer-access semantics (#2771)
+ - Treat access chain indexes as signed in SROA (#2776)
+ - Handle RelaxedPrecision in SROA (#2788)
+ - Add descriptor array scalar replacement (#2742)
+ Fixes:
+ - Handle decorations better in some optimizations (#2716)
+ - Change the order branches are simplified in dead branch elim (#2728)
+ - Fix bug in merge return (#2734)
+ - SSA rewriter: Don't use trivial phis (#2757)
+ - Record correct dominators in merge return (#2760)
+ - Process OpDecorateId in ADCE (#2761)
+ - Fix check for unreachable blocks in merge-return (#2762)
+ - Handle out-of-bounds scalar replacements. (#2767)
+ - Don't move debug or decorations when folding (#2772)
+ - Protect against out-of-bounds references when folding OpCompositeExtract (#2774)
+ - Validator
+ - Validate loop merge (#2579)
+ - Validate construct exits (#2459)
+ - Validate OpenCL memory and addressing model environment rules (#2589)
+ - Validate OpenCL environment rules for OpTypeImage (#2606)
+ - Allow breaks to switch merge from nested construct (#2604)
+ - Validate OpenCL environment rules for OpImageWrite (#2619)
+ - Allow arrays of out per-primitive builtins for mesh shaders (#2617)
+ - Validate OpenCL rules for ImageRead and OpImageSampleExplicitLod (#2643)
+ - Add validation for SPV_EXT_fragment_shader_interlock (#2650)
+ - Add builtin validation for SPV_NV_shader_sm_builtins (#2656)
+ - Add validation for Subgroup builtins (#2637)
+ - Validate variable initializer type (#2668)
+ - Disallow stores to UBOs (#2651)A
+ - Validate Volatile memory semantics bit (#2672)
+ - Basic validation for Component decorations (#2679)
+ - Validate that in OpenGL env block variables have Binding (#2685)
+ - Validate usage of 8- and 16-bit types with only storage capabilities (#2704)
+ - Add validation for SPV_EXT_demote_to_helper_invocation (#2707)
+ - Extra small storage validation (#2732)
+ - For Vulkan, disallow structures containing opaque types (#2546)
+ - Validate storage class OpenCL environment rules for atomics (#2750)
+ - Update OpControlBarriers rules for WebGPU (#2769)
+ - Update OpMemoryBarriers rules for WebGPU (#2775)
+ - Update WebGPU validation rules of OpAtomic*s (#2777)
+ Fixes:
+ - Disallow merge targeting block with OpLoopMerge (#2610)
+ - Update vloadn and vstoren validation to match the OpenCL Extended
+ Instruction Set Specification (#2599)
+ - Update memory scope rules for WebGPU (#2725)
+ - Allow LOD ops in compute shaders with derivative group execution modes (#2752)
+ - Reduce
+ Fixes:
v2019.3 2019-05-14
- General:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9f24e38..b95b714 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -235,6 +235,21 @@
add_subdirectory(external)
+# Warning about extra semi-colons.
+#
+# This is not supported on all compilers/versions. so enabling only
+# for clang, since that works for all versions that our bots run.
+#
+# This is intentionally done after adding the external subdirectory,
+# so we don't enforce this flag on our dependencies, some of which do
+# not pass it.
+#
+# If the minimum version of CMake supported is updated to 3.0 or
+# later, then check_cxx_compiler_flag could be used instead.
+if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
+ add_compile_options("-Wextra-semi")
+endif()
+
add_subdirectory(source)
add_subdirectory(tools)
diff --git a/include/spirv-tools/instrument.hpp b/include/spirv-tools/instrument.hpp
index dfd6e35..681d008 100644
--- a/include/spirv-tools/instrument.hpp
+++ b/include/spirv-tools/instrument.hpp
@@ -23,8 +23,9 @@
// communicate with shaders instrumented by passes created by:
//
// CreateInstBindlessCheckPass
+// CreateInstBuffAddrCheckPass
//
-// More detailed documentation of this routine can be found in optimizer.hpp
+// More detailed documentation of these routines can be found in optimizer.hpp
namespace spvtools {
@@ -157,6 +158,12 @@
static const int kInst2BindlessUninitOutUnused = kInst2StageOutCnt + 2;
static const int kInst2BindlessUninitOutCnt = kInst2StageOutCnt + 3;
+// A buffer address unalloc error will output the 64-bit pointer in
+// two 32-bit pieces, lower bits first.
+static const int kInst2BuffAddrUnallocOutDescPtrLo = kInst2StageOutCnt + 1;
+static const int kInst2BuffAddrUnallocOutDescPtrHi = kInst2StageOutCnt + 2;
+static const int kInst2BuffAddrUnallocOutCnt = kInst2StageOutCnt + 3;
+
// DEPRECATED
static const int kInstBindlessOutDescIndex = kInstStageOutCnt + 1;
static const int kInstBindlessOutDescBound = kInstStageOutCnt + 2;
@@ -171,6 +178,7 @@
// These are the possible validation error codes.
static const int kInstErrorBindlessBounds = 0;
static const int kInstErrorBindlessUninit = 1;
+static const int kInstErrorBuffAddrUnallocRef = 2;
// Direct Input Buffer Offsets
//
@@ -187,14 +195,16 @@
// These are the bindings for the different buffers which are
// read or written by the instrumentation passes.
//
-// This is the output buffer written by InstBindlessCheckPass
-// and possibly other future validations.
+// This is the output buffer written by InstBindlessCheckPass,
+// InstBuffAddrCheckPass, and possibly other future validations.
static const int kDebugOutputBindingStream = 0;
-// The binding for the input buffer read by InstBindlessCheckPass and
-// possibly other future validations.
+// The binding for the input buffer read by InstBindlessCheckPass.
static const int kDebugInputBindingBindless = 1;
+// The binding for the input buffer read by InstBuffAddrCheckPass.
+static const int kDebugInputBindingBuffAddr = 2;
+
// Bindless Validation Input Buffer Format
//
// An input buffer for bindless validation consists of a single array of
@@ -216,6 +226,31 @@
// Data[ Data[ s + kDebugInputBindlessOffsetLengths ] + b ]
static const int kDebugInputBindlessOffsetLengths = 1;
+// Buffer Device Address Input Buffer Format
+//
+// An input buffer for buffer device address validation consists of a single
+// array of unsigned 64-bit integers we will call Data[]. This array is
+// formatted as follows:
+//
+// At offset kDebugInputBuffAddrPtrOffset is a list of sorted valid buffer
+// addresses. The list is terminated with the address 0xffffffffffffffff.
+// If 0x0 is not a valid buffer address, this address is inserted at the
+// start of the list.
+//
+static const int kDebugInputBuffAddrPtrOffset = 1;
+//
+// At offset kDebugInputBuffAddrLengthOffset in Data[] is a single uint64 which
+// gives an offset to the start of the buffer length data. More
+// specifically, for a buffer whose pointer is located at input buffer offset
+// i, the length is located at:
+//
+// Data[ i - kDebugInputBuffAddrPtrOffset
+// + Data[ kDebugInputBuffAddrLengthOffset ] ]
+//
+// The length associated with the 0xffffffffffffffff address is zero. If
+// not a valid buffer, the length associated with the 0x0 address is zero.
+static const int kDebugInputBuffAddrLengthOffset = 0;
+
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_INSTRUMENT_HPP_
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 4c668b4..4e54b1a 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -729,6 +729,30 @@
uint32_t desc_set, uint32_t shader_id, bool input_length_enable = false,
bool input_init_enable = false, uint32_t version = 1);
+// Create a pass to instrument physical buffer address checking
+// This pass instruments all physical buffer address references to check that
+// all referenced bytes fall in a valid buffer. If the reference is
+// invalid, a record is written to the debug output buffer (if space allows)
+// and a null value is returned. This pass is designed to support buffer
+// address validation in the Vulkan validation layers.
+//
+// Dead code elimination should be run after this pass as the original,
+// potentially invalid code is not removed and could cause undefined behavior,
+// including crashes. Instruction simplification would likely also be
+// beneficial. It is also generally recommended that this pass (and all
+// instrumentation passes) be run after any legalization and optimization
+// passes. This will give better analysis for the instrumentation and avoid
+// potentially de-optimizing the instrument code, for example, inlining
+// the debug record output function throughout the module.
+//
+// The instrumentation will read and write buffers in debug
+// descriptor set |desc_set|. It will write |shader_id| in each output record
+// to identify the shader module which generated the record.
+// |version| specifies the output buffer record format.
+Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set,
+ uint32_t shader_id,
+ uint32_t version = 2);
+
// Create a pass to upgrade to the VulkanKHR memory model.
// This pass upgrades the Logical GLSL450 memory model to Logical VulkanKHR.
// Additionally, it modifies memory, image, atomic and barrier operations to
@@ -763,6 +787,47 @@
// continue-targets to legalize for WebGPU.
Optimizer::PassToken CreateSplitInvalidUnreachablePass();
+// Creates a graphics robust access pass.
+//
+// This pass injects code to clamp indexed accesses to buffers and internal
+// arrays, providing guarantees satisfying Vulkan's robustBufferAccess rules.
+//
+// TODO(dneto): Clamps coordinates and sample index for pointer calculations
+// into storage images (OpImageTexelPointer). For an cube array image, it
+// assumes the maximum layer count times 6 is at most 0xffffffff.
+//
+// NOTE: This pass will fail with a message if:
+// - The module is not a Shader module.
+// - The module declares VariablePointers, VariablePointersStorageBuffer, or
+// RuntimeDescriptorArrayEXT capabilities.
+// - The module uses an addressing model other than Logical
+// - Access chain indices are wider than 64 bits.
+// - Access chain index for a struct is not an OpConstant integer or is out
+// of range. (The module is already invalid if that is the case.)
+// - TODO(dneto): The OpImageTexelPointer coordinate component is not 32-bits
+// wide.
+Optimizer::PassToken CreateGraphicsRobustAccessPass();
+
+// Create descriptor scalar replacement pass.
+// This pass replaces every array variable |desc| that has a DescriptorSet and
+// Binding decorations with a new variable for each element of the array.
+// Suppose |desc| was bound at binding |b|. Then the variable corresponding to
+// |desc[i]| will have binding |b+i|. The descriptor set will be the same. It
+// is assumed that no other variable already has a binding that will used by one
+// of the new variables. If not, the pass will generate invalid Spir-V. All
+// accesses to |desc| must be OpAccessChain instructions with a literal index
+// for the first index.
+Optimizer::PassToken CreateDescriptorScalarReplacementPass();
+
+// Create a pass to replace all OpKill instruction with a function call to a
+// function that has a single OpKill. This allows more code to be inlined.
+Optimizer::PassToken CreateWrapOpKillPass();
+
+// Replaces the extensions VK_AMD_shader_ballot,VK_AMD_gcn_shader, and
+// VK_AMD_shader_trinary_minmax with equivalent code using core instructions and
+// capabilities.
+Optimizer::PassToken CreateAmdExtToKhrPass();
+
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/enum_set.h b/source/enum_set.h
index e4ef297..2e7046d 100644
--- a/source/enum_set.h
+++ b/source/enum_set.h
@@ -69,6 +69,26 @@
return *this;
}
+ friend bool operator==(const EnumSet& a, const EnumSet& b) {
+ if (a.mask_ != b.mask_) {
+ return false;
+ }
+
+ if (a.overflow_ == nullptr && b.overflow_ == nullptr) {
+ return true;
+ }
+
+ if (a.overflow_ == nullptr || b.overflow_ == nullptr) {
+ return false;
+ }
+
+ return *a.overflow_ == *b.overflow_;
+ }
+
+ friend bool operator!=(const EnumSet& a, const EnumSet& b) {
+ return !(a == b);
+ }
+
// Adds the given enum value to the set. This has no effect if the
// enum value is already in the set.
void Add(EnumType c) { AddWord(ToWord(c)); }
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt
index fbabba1..49ee843 100644
--- a/source/fuzz/CMakeLists.txt
+++ b/source/fuzz/CMakeLists.txt
@@ -26,6 +26,7 @@
)
set(SPIRV_TOOLS_FUZZ_SOURCES
+ data_descriptor.h
fact_manager.h
fuzzer.h
fuzzer_context.h
@@ -52,6 +53,7 @@
transformation_add_type_float.h
transformation_add_type_int.h
transformation_add_type_pointer.h
+ transformation_copy_object.h
transformation_move_block_down.h
transformation_replace_boolean_constant_with_constant_binary.h
transformation_replace_constant_with_uniform.h
@@ -59,6 +61,7 @@
uniform_buffer_element_descriptor.h
${CMAKE_CURRENT_BINARY_DIR}/protobufs/spvtoolsfuzz.pb.h
+ data_descriptor.cpp
fact_manager.cpp
fuzzer.cpp
fuzzer_context.cpp
@@ -84,6 +87,7 @@
transformation_add_type_float.cpp
transformation_add_type_int.cpp
transformation_add_type_pointer.cpp
+ transformation_copy_object.cpp
transformation_move_block_down.cpp
transformation_replace_boolean_constant_with_constant_binary.cpp
transformation_replace_constant_with_uniform.cpp
diff --git a/source/fuzz/data_descriptor.cpp b/source/fuzz/data_descriptor.cpp
new file mode 100644
index 0000000..9cdb2c5
--- /dev/null
+++ b/source/fuzz/data_descriptor.cpp
@@ -0,0 +1,42 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/data_descriptor.h"
+
+#include <algorithm>
+
+namespace spvtools {
+namespace fuzz {
+
+protobufs::DataDescriptor MakeDataDescriptor(uint32_t object,
+ std::vector<uint32_t>&& indices) {
+ protobufs::DataDescriptor result;
+ result.set_object(object);
+ for (auto index : indices) {
+ result.add_index(index);
+ }
+ return result;
+}
+
+bool DataDescriptorEquals::operator()(
+ const protobufs::DataDescriptor* first,
+ const protobufs::DataDescriptor* second) const {
+ return first->object() == second->object() &&
+ first->index().size() == second->index().size() &&
+ std::equal(first->index().begin(), first->index().end(),
+ second->index().begin());
+}
+
+} // namespace fuzz
+} // namespace spvtools
diff --git a/source/fuzz/data_descriptor.h b/source/fuzz/data_descriptor.h
new file mode 100644
index 0000000..731bd21
--- /dev/null
+++ b/source/fuzz/data_descriptor.h
@@ -0,0 +1,39 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_DATA_DESCRIPTOR_H_
+#define SOURCE_FUZZ_DATA_DESCRIPTOR_H_
+
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+
+#include <vector>
+
+namespace spvtools {
+namespace fuzz {
+
+// Factory method to create a data descriptor message from an object id and a
+// list of indices.
+protobufs::DataDescriptor MakeDataDescriptor(uint32_t object,
+ std::vector<uint32_t>&& indices);
+
+// Equality function for data descriptors.
+struct DataDescriptorEquals {
+ bool operator()(const protobufs::DataDescriptor* first,
+ const protobufs::DataDescriptor* second) const;
+};
+
+} // namespace fuzz
+} // namespace spvtools
+
+#endif // #define SOURCE_FUZZ_DATA_DESCRIPTOR_H_
diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp
index 442ff16..61daa64 100644
--- a/source/fuzz/fact_manager.cpp
+++ b/source/fuzz/fact_manager.cpp
@@ -68,6 +68,9 @@
} // namespace
+//=======================
+// Constant uniform facts
+
// The purpose of this struct is to group the fields and data used to represent
// facts about uniform constants.
struct FactManager::ConstantUniformFacts {
@@ -330,10 +333,44 @@
return true;
}
-FactManager::FactManager() {
- uniform_constant_facts_ = MakeUnique<ConstantUniformFacts>();
+// End of uniform constant facts
+//==============================
+
+//==============================
+// Id synonym facts
+
+// The purpose of this struct is to group the fields and data used to represent
+// facts about id synonyms.
+struct FactManager::IdSynonymFacts {
+ // See method in FactManager which delegates to this method.
+ void AddFact(const protobufs::FactIdSynonym& fact);
+
+ // A record of all the synonyms that are available.
+ std::map<uint32_t, std::vector<protobufs::DataDescriptor>> synonyms;
+
+ // The set of keys to the above map; useful if you just want to know which ids
+ // have synonyms.
+ std::set<uint32_t> ids_with_synonyms;
+};
+
+void FactManager::IdSynonymFacts::AddFact(
+ const protobufs::FactIdSynonym& fact) {
+ if (synonyms.count(fact.id()) == 0) {
+ assert(ids_with_synonyms.count(fact.id()) == 0);
+ ids_with_synonyms.insert(fact.id());
+ synonyms[fact.id()] = std::vector<protobufs::DataDescriptor>();
+ }
+ assert(ids_with_synonyms.count(fact.id()) == 1);
+ synonyms[fact.id()].push_back(fact.data_descriptor());
}
+// End of id synonym facts
+//==============================
+
+FactManager::FactManager()
+ : uniform_constant_facts_(MakeUnique<ConstantUniformFacts>()),
+ id_synonym_facts_(MakeUnique<IdSynonymFacts>()) {}
+
FactManager::~FactManager() = default;
void FactManager::AddFacts(const MessageConsumer& message_consumer,
@@ -350,13 +387,17 @@
bool FactManager::AddFact(const spvtools::fuzz::protobufs::Fact& fact,
spvtools::opt::IRContext* context) {
- assert(fact.fact_case() == protobufs::Fact::kConstantUniformFact &&
- "Right now this is the only fact.");
- if (!uniform_constant_facts_->AddFact(fact.constant_uniform_fact(),
- context)) {
- return false;
+ switch (fact.fact_case()) {
+ case protobufs::Fact::kConstantUniformFact:
+ return uniform_constant_facts_->AddFact(fact.constant_uniform_fact(),
+ context);
+ case protobufs::Fact::kIdSynonymFact:
+ id_synonym_facts_->AddFact(fact.id_synonym_fact());
+ return true;
+ default:
+ assert(false && "Unknown fact type.");
+ return false;
}
- return true;
}
std::vector<uint32_t> FactManager::GetConstantsAvailableFromUniformsForType(
@@ -389,5 +430,14 @@
return uniform_constant_facts_->facts_and_type_ids;
}
+const std::set<uint32_t>& FactManager::GetIdsForWhichSynonymsAreKnown() const {
+ return id_synonym_facts_->ids_with_synonyms;
+}
+
+const std::vector<protobufs::DataDescriptor>& FactManager::GetSynonymsForId(
+ uint32_t id) const {
+ return id_synonym_facts_->synonyms.at(id);
+}
+
} // namespace fuzz
} // namespace spvtools
diff --git a/source/fuzz/fact_manager.h b/source/fuzz/fact_manager.h
index cb4ac58..f6ea247 100644
--- a/source/fuzz/fact_manager.h
+++ b/source/fuzz/fact_manager.h
@@ -16,6 +16,7 @@
#define SOURCE_FUZZ_FACT_MANAGER_H_
#include <memory>
+#include <set>
#include <utility>
#include <vector>
@@ -51,13 +52,12 @@
// fact manager.
bool AddFact(const protobufs::Fact& fact, opt::IRContext* context);
- // The fact manager will ultimately be responsible for managing a few distinct
- // categories of facts. In principle there could be different fact managers
- // for each kind of fact, but in practice providing one 'go to' place for
- // facts will be convenient. To keep some separation, the public methods of
- // the fact manager should be grouped according to the kind of fact to which
- // they relate. At present we only have one kind of fact: facts about
- // uniform variables.
+ // The fact manager is responsible for managing a few distinct categories of
+ // facts. In principle there could be different fact managers for each kind
+ // of fact, but in practice providing one 'go to' place for facts is
+ // convenient. To keep some separation, the public methods of the fact
+ // manager should be grouped according to the kind of fact to which they
+ // relate.
//==============================
// Querying facts about uniform constants
@@ -96,6 +96,21 @@
// End of uniform constant facts
//==============================
+ //==============================
+ // Querying facts about id synonyms
+
+ // Returns every id for which a fact of the form "this id is synonymous
+ // with this piece of data" is known.
+ const std::set<uint32_t>& GetIdsForWhichSynonymsAreKnown() const;
+
+ // Requires that at least one synonym for |id| is known, and returns the
+ // sequence of all known synonyms.
+ const std::vector<protobufs::DataDescriptor>& GetSynonymsForId(
+ uint32_t id) const;
+
+ // End of id synonym facts
+ //==============================
+
private:
// For each distinct kind of fact to be managed, we use a separate opaque
// struct type.
@@ -104,6 +119,10 @@
// buffer elements.
std::unique_ptr<ConstantUniformFacts>
uniform_constant_facts_; // Unique pointer to internal data.
+
+ struct IdSynonymFacts; // Opaque struct for holding data about id synonyms.
+ std::unique_ptr<IdSynonymFacts>
+ id_synonym_facts_; // Unique pointer to internal data.
};
} // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 9a05c74..9972e47 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -170,6 +170,41 @@
return false;
}
+opt::BasicBlock::iterator GetIteratorForBaseInstructionAndOffset(
+ opt::BasicBlock* block, const opt::Instruction* base_inst,
+ uint32_t offset) {
+ // The cases where |base_inst| is the block's label, vs. inside the block,
+ // are dealt with separately.
+ if (base_inst == block->GetLabelInst()) {
+ // |base_inst| is the block's label.
+ if (offset == 0) {
+ // We cannot return an iterator to the block's label.
+ return block->end();
+ }
+ // Conceptually, the first instruction in the block is [label + 1].
+ // We thus start from 1 when applying the offset.
+ auto inst_it = block->begin();
+ for (uint32_t i = 1; i < offset && inst_it != block->end(); i++) {
+ ++inst_it;
+ }
+ // This is either the desired instruction, or the end of the block.
+ return inst_it;
+ }
+ // |base_inst| is inside the block.
+ for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
+ if (base_inst == &*inst_it) {
+ // We have found the base instruction; we now apply the offset.
+ for (uint32_t i = 0; i < offset && inst_it != block->end(); i++) {
+ ++inst_it;
+ }
+ // This is either the desired instruction, or the end of the block.
+ return inst_it;
+ }
+ }
+ assert(false && "The base instruction was not found.");
+ return nullptr;
+}
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 15228de..47588b0 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -18,6 +18,8 @@
#include <vector>
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/opt/basic_block.h"
+#include "source/opt/instruction.h"
#include "source/opt/ir_context.h"
namespace spvtools {
@@ -62,6 +64,16 @@
bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
uint32_t maybe_loop_header_id);
+// Requires that |base_inst| is either the label instruction of |block| or an
+// instruction inside |block|.
+//
+// If the block contains a (non-label, non-terminator) instruction |offset|
+// instructions after |base_inst|, an iterator to this instruction is returned.
+//
+// Otherwise |block|->end() is returned.
+opt::BasicBlock::iterator GetIteratorForBaseInstructionAndOffset(
+ opt::BasicBlock* block, const opt::Instruction* base_inst, uint32_t offset);
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index 13d8a05..4e8dcac 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -57,6 +57,22 @@
}
+message DataDescriptor {
+
+ // Represents a data element that can be accessed from an id, by walking the
+ // type hierarchy via a sequence of 0 or more indices.
+ //
+ // Very similar to a UniformBufferElementDescriptor, except that a
+ // DataDescriptor is rooted at the id of a scalar or composite.
+
+ // The object being accessed - a scalar or composite
+ uint32 object = 1;
+
+ // 0 or more indices, used to index into a composite object
+ repeated uint32 index = 2;
+
+}
+
message UniformBufferElementDescriptor {
// Represents a data element inside a uniform buffer. The element is
@@ -97,6 +113,7 @@
oneof fact {
// Order the fact options by numeric id (rather than alphabetically).
FactConstantUniform constant_uniform_fact = 1;
+ FactIdSynonym id_synonym_fact = 2;
}
}
@@ -118,6 +135,22 @@
}
+message FactIdSynonym {
+
+ // Records the fact that the data held in an id is guaranteed to be equal to
+ // the data held in a data descriptor. spirv-fuzz can use this to replace
+ // uses of the id with references to the data described by the data
+ // descriptor.
+
+ // An id
+ uint32 id = 1;
+
+ // A data descriptor guaranteed to hold a value identical to that held by the
+ // id
+ DataDescriptor data_descriptor = 2;
+
+}
+
message TransformationSequence {
repeated Transformation transformation = 1;
}
@@ -138,6 +171,7 @@
TransformationAddTypePointer add_type_pointer = 10;
TransformationReplaceConstantWithUniform replace_constant_with_uniform = 11;
TransformationAddDeadContinue add_dead_continue = 12;
+ TransformationCopyObject copy_object = 13;
// Add additional option using the next available number.
}
}
@@ -262,6 +296,26 @@
}
+message TransformationCopyObject {
+
+ // A transformation that introduces an OpCopyObject instruction to make a
+ // copy of an object.
+
+ // Id of the object to be copied
+ uint32 object = 1;
+
+ // The id of an instruction in a block
+ uint32 base_instruction_id = 2;
+
+ // An offset, such that OpCopyObject instruction should be inserted right
+ // before the instruction |offset| instructions after |base_instruction_id|
+ uint32 offset = 3;
+
+ // A fresh id for the copied object
+ uint32 fresh_id = 4;
+
+}
+
message TransformationMoveBlockDown {
// A transformation that moves a basic block to be one position lower in
@@ -291,6 +345,7 @@
}
message TransformationReplaceBooleanConstantWithConstantBinary {
+
// A transformation to capture replacing a use of a boolean constant with
// binary operation on two constant values
@@ -313,13 +368,14 @@
message TransformationSplitBlock {
- // A transformation that splits a basic block into two basic blocks.
+ // A transformation that splits a basic block into two basic blocks
- // The result id of an instruction.
- uint32 result_id = 1;
+ // The result id of an instruction
+ uint32 base_instruction_id = 1;
- // An offset, such that the block containing |result_id_| should be split
- // right before the instruction |offset_| instructions after |result_id_|.
+ // An offset, such that the block containing |base_instruction_id| should be
+ // split right before the instruction |offset| instructions after
+ // |base_instruction_id|
uint32 offset = 2;
// An id that must not yet be used by the module to which this transformation
diff --git a/source/fuzz/transformation_copy_object.cpp b/source/fuzz/transformation_copy_object.cpp
new file mode 100644
index 0000000..f9ead43
--- /dev/null
+++ b/source/fuzz/transformation_copy_object.cpp
@@ -0,0 +1,158 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_copy_object.h"
+
+#include "source/fuzz/fuzzer_util.h"
+#include "source/opt/instruction.h"
+#include "source/util/make_unique.h"
+
+namespace spvtools {
+namespace fuzz {
+
+TransformationCopyObject::TransformationCopyObject(
+ const protobufs::TransformationCopyObject& message)
+ : message_(message) {}
+
+TransformationCopyObject::TransformationCopyObject(uint32_t object,
+ uint32_t base_instruction_id,
+ uint32_t offset,
+ uint32_t fresh_id) {
+ message_.set_object(object);
+ message_.set_base_instruction_id(base_instruction_id);
+ message_.set_offset(offset);
+ message_.set_fresh_id(fresh_id);
+}
+
+bool TransformationCopyObject::IsApplicable(
+ opt::IRContext* context, const FactManager& /*fact_manager*/) const {
+ if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) {
+ // We require the id for the object copy to be unused.
+ return false;
+ }
+ // The id of the object to be copied must exist
+ auto object_inst = context->get_def_use_mgr()->GetDef(message_.object());
+ if (!object_inst) {
+ return false;
+ }
+ if (!object_inst->type_id()) {
+ // We can only apply OpCopyObject to instructions that have types.
+ return false;
+ }
+ if (!context->get_decoration_mgr()
+ ->GetDecorationsFor(message_.object(), true)
+ .empty()) {
+ // We do not copy objects that have decorations: if the copy is not
+ // decorated analogously, using the original object vs. its copy may not be
+ // equivalent.
+ // TODO(afd): it would be possible to make the copy but not add an id
+ // synonym.
+ return false;
+ }
+
+ auto base_instruction =
+ context->get_def_use_mgr()->GetDef(message_.base_instruction_id());
+ if (!base_instruction) {
+ // The given id to insert after is not defined.
+ return false;
+ }
+
+ auto destination_block = context->get_instr_block(base_instruction);
+ if (!destination_block) {
+ // The given id to insert after is not in a block.
+ return false;
+ }
+
+ auto insert_before = fuzzerutil::GetIteratorForBaseInstructionAndOffset(
+ destination_block, base_instruction, message_.offset());
+
+ if (insert_before == destination_block->end()) {
+ // The offset was inappropriate.
+ return false;
+ }
+ if (insert_before->PreviousNode() &&
+ (insert_before->PreviousNode()->opcode() == SpvOpLoopMerge ||
+ insert_before->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
+ // We cannot insert a copy directly after a merge instruction.
+ return false;
+ }
+ if (insert_before->opcode() == SpvOpVariable) {
+ // We cannot insert a copy directly before a variable; variables in a
+ // function must be contiguous in the entry block.
+ return false;
+ }
+ // We cannot insert a copy directly before OpPhi, because OpPhi instructions
+ // need to be contiguous at the start of a block.
+ if (insert_before->opcode() == SpvOpPhi) {
+ return false;
+ }
+ // |message_object| must be available at the point where we want to add the
+ // copy. It is available if it is at global scope (in which case it has no
+ // block), or if it dominates the point of insertion but is different from the
+ // point of insertion.
+ //
+ // The reason why the object needs to be different from the insertion point is
+ // that the copy will be added *before* this point, and we do not want to
+ // insert it before the object's defining instruction.
+ return !context->get_instr_block(object_inst) ||
+ (object_inst != &*insert_before &&
+ context->GetDominatorAnalysis(destination_block->GetParent())
+ ->Dominates(object_inst, &*insert_before));
+}
+
+void TransformationCopyObject::Apply(opt::IRContext* context,
+ FactManager* fact_manager) const {
+ // - A new instruction,
+ // %|message_.fresh_id| = OpCopyObject %ty %|message_.object|
+ // is added directly before the instruction at |message_.insert_after_id| +
+ // |message_|.offset, where %ty is the type of |message_.object|.
+ // - The fact that |message_.fresh_id| and |message_.object| are synonyms
+ // is added to the fact manager.
+ // The id of the object to be copied must exist
+ auto object_inst = context->get_def_use_mgr()->GetDef(message_.object());
+ assert(object_inst && "The object to be copied must exist.");
+ auto base_instruction =
+ context->get_def_use_mgr()->GetDef(message_.base_instruction_id());
+ assert(base_instruction && "The base instruction must exist.");
+ auto destination_block = context->get_instr_block(base_instruction);
+ assert(destination_block && "The base instruction must be in a block.");
+ auto insert_before = fuzzerutil::GetIteratorForBaseInstructionAndOffset(
+ destination_block, base_instruction, message_.offset());
+ assert(insert_before != destination_block->end() &&
+ "There must be an instruction before which the copy can be inserted.");
+
+ opt::Instruction::OperandList operands = {
+ {SPV_OPERAND_TYPE_ID, {message_.object()}}};
+ insert_before->InsertBefore(MakeUnique<opt::Instruction>(
+ context, SpvOp::SpvOpCopyObject, object_inst->type_id(),
+ message_.fresh_id(), operands));
+
+ fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id());
+ context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone);
+
+ protobufs::Fact fact;
+ fact.mutable_id_synonym_fact()->set_id(message_.object());
+ fact.mutable_id_synonym_fact()->mutable_data_descriptor()->set_object(
+ message_.fresh_id());
+ fact_manager->AddFact(fact, context);
+}
+
+protobufs::Transformation TransformationCopyObject::ToMessage() const {
+ protobufs::Transformation result;
+ *result.mutable_copy_object() = message_;
+ return result;
+}
+
+} // namespace fuzz
+} // namespace spvtools
diff --git a/source/fuzz/transformation_copy_object.h b/source/fuzz/transformation_copy_object.h
new file mode 100644
index 0000000..6ce72df
--- /dev/null
+++ b/source/fuzz/transformation_copy_object.h
@@ -0,0 +1,68 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_TRANSFORMATION_COPY_OBJECT_H_
+#define SOURCE_FUZZ_TRANSFORMATION_COPY_OBJECT_H_
+
+#include "source/fuzz/fact_manager.h"
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace fuzz {
+
+class TransformationCopyObject : public Transformation {
+ public:
+ explicit TransformationCopyObject(
+ const protobufs::TransformationCopyObject& message);
+
+ TransformationCopyObject(uint32_t fresh_id, uint32_t object,
+ uint32_t insert_after_id, uint32_t offset);
+
+ // - |message_.fresh_id| must not be used by the module.
+ // - |message_.object| must be a result id that is a legitimate operand for
+ // OpCopyObject. In particular, it must be the id of an instruction that
+ // has a result type
+ // - |message_.object| must not be the target of any decoration.
+ // TODO(afd): consider copying decorations along with objects.
+ // - |message_.insert_after_id| must be the result id of an instruction
+ // 'base' in some block 'blk'.
+ // - 'blk' must contain an instruction 'inst' located |message_.offset|
+ // instructions after 'base' (if |message_.offset| = 0 then 'inst' =
+ // 'base').
+ // - It must be legal to insert an OpCopyObject instruction directly
+ // before 'inst'.
+ // - |message_object| must be available directly before 'inst'.
+ bool IsApplicable(opt::IRContext* context,
+ const FactManager& fact_manager) const override;
+
+ // - A new instruction,
+ // %|message_.fresh_id| = OpCopyObject %ty %|message_.object|
+ // is added directly before the instruction at |message_.insert_after_id| +
+ // |message_|.offset, where %ty is the type of |message_.object|.
+ // - The fact that |message_.fresh_id| and |message_.object| are synonyms
+ // is added to the fact manager.
+ void Apply(opt::IRContext* context, FactManager* fact_manager) const override;
+
+ protobufs::Transformation ToMessage() const override;
+
+ private:
+ protobufs::TransformationCopyObject message_;
+};
+
+} // namespace fuzz
+} // namespace spvtools
+
+#endif // SOURCE_FUZZ_TRANSFORMATION_COPY_OBJECT_H_
diff --git a/source/fuzz/transformation_split_block.cpp b/source/fuzz/transformation_split_block.cpp
index a8c33de..a2da371 100644
--- a/source/fuzz/transformation_split_block.cpp
+++ b/source/fuzz/transformation_split_block.cpp
@@ -26,147 +26,104 @@
const spvtools::fuzz::protobufs::TransformationSplitBlock& message)
: message_(message) {}
-TransformationSplitBlock::TransformationSplitBlock(uint32_t result_id,
+TransformationSplitBlock::TransformationSplitBlock(uint32_t base_instruction_id,
uint32_t offset,
uint32_t fresh_id) {
- message_.set_result_id(result_id);
+ message_.set_base_instruction_id(base_instruction_id);
message_.set_offset(offset);
message_.set_fresh_id(fresh_id);
}
-std::pair<bool, opt::BasicBlock::iterator>
-TransformationSplitBlock::FindInstToSplitBefore(opt::BasicBlock* block) const {
- // There are three possibilities:
- // (1) the transformation wants to split at some offset from the block's
- // label.
- // (2) the transformation wants to split at some offset from a
- // non-label instruction inside the block.
- // (3) the split assocaiated with this transformation has nothing to do with
- // this block
- if (message_.result_id() == block->id()) {
- // Case (1).
- if (message_.offset() == 0) {
- // The offset is not allowed to be 0: this would mean splitting before the
- // block's label.
- // By returning (true, block->end()), we indicate that we did find the
- // instruction (so that it is not worth searching further for it), but
- // that splitting will not be possible.
- return {true, block->end()};
- }
- // Conceptually, the first instruction in the block is [label + 1].
- // We thus start from 1 when applying the offset.
- auto inst_it = block->begin();
- for (uint32_t i = 1; i < message_.offset() && inst_it != block->end();
- i++) {
- ++inst_it;
- }
- // This is either the desired instruction, or the end of the block.
- return {true, inst_it};
- }
- for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
- if (message_.result_id() == inst_it->result_id()) {
- // Case (2): we have found the base instruction; we now apply the offset.
- for (uint32_t i = 0; i < message_.offset() && inst_it != block->end();
- i++) {
- ++inst_it;
- }
- // This is either the desired instruction, or the end of the block.
- return {true, inst_it};
- }
- }
- // Case (3).
- return {false, block->end()};
-}
-
bool TransformationSplitBlock::IsApplicable(
opt::IRContext* context, const FactManager& /*unused*/) const {
if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) {
// We require the id for the new block to be unused.
return false;
}
- // Consider every block in every function.
- for (auto& function : *context->module()) {
- for (auto& block : function) {
- auto maybe_split_before = FindInstToSplitBefore(&block);
- if (!maybe_split_before.first) {
- continue;
- }
- if (maybe_split_before.second == block.end()) {
- // The base instruction was found, but the offset was inappropriate.
- return false;
- }
- if (block.IsLoopHeader()) {
- // We cannot split a loop header block: back-edges would become invalid.
- return false;
- }
- auto split_before = maybe_split_before.second;
- if (split_before->PreviousNode() &&
- split_before->PreviousNode()->opcode() == SpvOpSelectionMerge) {
- // We cannot split directly after a selection merge: this would separate
- // the merge from its associated branch or switch operation.
- return false;
- }
- if (split_before->opcode() == SpvOpVariable) {
- // We cannot split directly after a variable; variables in a function
- // must be contiguous in the entry block.
- return false;
- }
- if (split_before->opcode() == SpvOpPhi &&
- split_before->NumInOperands() != 2) {
- // We cannot split before an OpPhi unless the OpPhi has exactly one
- // associated incoming edge.
- return false;
- }
- return true;
- }
+ auto base_instruction =
+ context->get_def_use_mgr()->GetDef(message_.base_instruction_id());
+ if (!base_instruction) {
+ // The instruction describing the block we should split does not exist.
+ return false;
}
- return false;
+ auto block_containing_base_instruction =
+ context->get_instr_block(base_instruction);
+ if (!block_containing_base_instruction) {
+ // The instruction describing the block we should split is not contained in
+ // a block.
+ return false;
+ }
+
+ if (block_containing_base_instruction->IsLoopHeader()) {
+ // We cannot split a loop header block: back-edges would become invalid.
+ return false;
+ }
+
+ auto split_before = fuzzerutil::GetIteratorForBaseInstructionAndOffset(
+ block_containing_base_instruction, base_instruction, message_.offset());
+ if (split_before == block_containing_base_instruction->end()) {
+ // The offset was inappropriate.
+ return false;
+ }
+ if (split_before->PreviousNode() &&
+ split_before->PreviousNode()->opcode() == SpvOpSelectionMerge) {
+ // We cannot split directly after a selection merge: this would separate
+ // the merge from its associated branch or switch operation.
+ return false;
+ }
+ if (split_before->opcode() == SpvOpVariable) {
+ // We cannot split directly after a variable; variables in a function
+ // must be contiguous in the entry block.
+ return false;
+ }
+ // We cannot split before an OpPhi unless the OpPhi has exactly one
+ // associated incoming edge.
+ return !(split_before->opcode() == SpvOpPhi &&
+ split_before->NumInOperands() != 2);
}
void TransformationSplitBlock::Apply(opt::IRContext* context,
FactManager* /*unused*/) const {
- for (auto& function : *context->module()) {
- for (auto& block : function) {
- auto maybe_split_before = FindInstToSplitBefore(&block);
- if (!maybe_split_before.first) {
- continue;
- }
- assert(maybe_split_before.second != block.end() &&
- "If the transformation is applicable, we should have an "
- "instruction to split on.");
- // We need to make sure the module's id bound is large enough to add the
- // fresh id.
- fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id());
- // Split the block.
- auto new_bb = block.SplitBasicBlock(context, message_.fresh_id(),
- maybe_split_before.second);
- // The split does not automatically add a branch between the two parts of
- // the original block, so we add one.
- block.AddInstruction(MakeUnique<opt::Instruction>(
+ auto base_instruction =
+ context->get_def_use_mgr()->GetDef(message_.base_instruction_id());
+ assert(base_instruction && "Base instruction must exist");
+ auto block_containing_base_instruction =
+ context->get_instr_block(base_instruction);
+ assert(block_containing_base_instruction &&
+ "Base instruction must be in a block");
+ auto split_before = fuzzerutil::GetIteratorForBaseInstructionAndOffset(
+ block_containing_base_instruction, base_instruction, message_.offset());
+ assert(split_before != block_containing_base_instruction->end() &&
+ "If the transformation is applicable, we should have an "
+ "instruction to split on.");
+ // We need to make sure the module's id bound is large enough to add the
+ // fresh id.
+ fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id());
+ // Split the block.
+ auto new_bb = block_containing_base_instruction->SplitBasicBlock(
+ context, message_.fresh_id(), split_before);
+ // The split does not automatically add a branch between the two parts of
+ // the original block, so we add one.
+ block_containing_base_instruction->AddInstruction(
+ MakeUnique<opt::Instruction>(
context, SpvOpBranch, 0, 0,
std::initializer_list<opt::Operand>{
opt::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
{message_.fresh_id()})}));
- // If we split before OpPhi instructions, we need to update their
- // predecessor operand so that the block they used to be inside is now the
- // predecessor.
- new_bb->ForEachPhiInst([&block](opt::Instruction* phi_inst) {
+ // If we split before OpPhi instructions, we need to update their
+ // predecessor operand so that the block they used to be inside is now the
+ // predecessor.
+ new_bb->ForEachPhiInst(
+ [block_containing_base_instruction](opt::Instruction* phi_inst) {
// The following assertion is a sanity check. It is guaranteed to hold
// if IsApplicable holds.
assert(phi_inst->NumInOperands() == 2 &&
"We can only split a block before an OpPhi if block has exactly "
"one predecessor.");
- phi_inst->SetInOperand(1, {block.id()});
+ phi_inst->SetInOperand(1, {block_containing_base_instruction->id()});
});
- // Invalidate all analyses
- context->InvalidateAnalysesExceptFor(
- opt::IRContext::Analysis::kAnalysisNone);
- return;
- }
- }
- assert(0 &&
- "Should be unreachable: it should have been possible to apply this "
- "transformation.");
+ // Invalidate all analyses
+ context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone);
}
protobufs::Transformation TransformationSplitBlock::ToMessage() const {
diff --git a/source/fuzz/transformation_split_block.h b/source/fuzz/transformation_split_block.h
index ef4aa75..4a7095a 100644
--- a/source/fuzz/transformation_split_block.h
+++ b/source/fuzz/transformation_split_block.h
@@ -28,13 +28,13 @@
explicit TransformationSplitBlock(
const protobufs::TransformationSplitBlock& message);
- TransformationSplitBlock(uint32_t result_id, uint32_t offset,
+ TransformationSplitBlock(uint32_t base_instruction_id, uint32_t offset,
uint32_t fresh_id);
- // - |message_.result_id| must be the result id of an instruction 'base' in
- // some block 'blk'.
+ // - |message_.base_instruction_id| must be the result id of an instruction
+ // 'base' in some block 'blk'.
// - 'blk' must contain an instruction 'inst' located |message_.offset|
- // instructions after 'inst' (if |message_.offset| = 0 then 'inst' =
+ // instructions after 'base' (if |message_.offset| = 0 then 'inst' =
// 'base').
// - Splitting 'blk' at 'inst', so that all instructions from 'inst' onwards
// appear in a new block that 'blk' directly jumps to must be valid.
@@ -52,14 +52,6 @@
protobufs::Transformation ToMessage() const override;
private:
- // Returns:
- // - (true, block->end()) if the relevant instruction is in this block
- // but inapplicable
- // - (true, it) if 'it' is an iterator for the relevant instruction
- // - (false, _) otherwise.
- std::pair<bool, opt::BasicBlock::iterator> FindInstToSplitBefore(
- opt::BasicBlock* block) const;
-
protobufs::TransformationSplitBlock message_;
};
diff --git a/source/fuzz/uniform_buffer_element_descriptor.cpp b/source/fuzz/uniform_buffer_element_descriptor.cpp
index 8c758e4..90fd85e 100644
--- a/source/fuzz/uniform_buffer_element_descriptor.cpp
+++ b/source/fuzz/uniform_buffer_element_descriptor.cpp
@@ -14,7 +14,7 @@
#include "source/fuzz/uniform_buffer_element_descriptor.h"
-#include <source/opt/instruction.h>
+#include <algorithm>
namespace spvtools {
namespace fuzz {
diff --git a/source/fuzz/uniform_buffer_element_descriptor.h b/source/fuzz/uniform_buffer_element_descriptor.h
index 23a16f0..d35de57 100644
--- a/source/fuzz/uniform_buffer_element_descriptor.h
+++ b/source/fuzz/uniform_buffer_element_descriptor.h
@@ -15,7 +15,6 @@
#ifndef SOURCE_FUZZ_UNIFORM_BUFFER_ELEMENT_DESCRIPTOR_H_
#define SOURCE_FUZZ_UNIFORM_BUFFER_ELEMENT_DESCRIPTOR_H_
-#include <algorithm>
#include <vector>
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
@@ -25,8 +24,8 @@
namespace spvtools {
namespace fuzz {
-// Factory method to create a uniform buffer element descriptor message from an
-// id and list of indices.
+// Factory method to create a uniform buffer element descriptor message from
+// descriptor set and binding ids and a list of indices.
protobufs::UniformBufferElementDescriptor MakeUniformBufferElementDescriptor(
uint32_t descriptor_set, uint32_t binding, std::vector<uint32_t>&& indices);
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 3e1280e..2309ca9 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -13,6 +13,7 @@
# limitations under the License.
set(SPIRV_TOOLS_OPT_SOURCES
aggressive_dead_code_elim_pass.h
+ amd_ext_to_khr.h
basic_block.h
block_merge_pass.h
block_merge_util.h
@@ -33,6 +34,7 @@
decompose_initialized_variables_pass.h
decoration_manager.h
def_use_manager.h
+ desc_sroa.h
dominator_analysis.h
dominator_tree.h
eliminate_dead_constant_pass.h
@@ -48,11 +50,13 @@
freeze_spec_constant_value_pass.h
function.h
generate_webgpu_initializers_pass.h
+ graphics_robust_access_pass.h
if_conversion.h
inline_exhaustive_pass.h
inline_opaque_pass.h
inline_pass.h
inst_bindless_check_pass.h
+ inst_buff_addr_check_pass.h
instruction.h
instruction_list.h
instrument_pass.h
@@ -111,8 +115,10 @@
value_number_table.h
vector_dce.h
workaround1209.h
+ wrap_opkill.h
aggressive_dead_code_elim_pass.cpp
+ amd_ext_to_khr.cpp
basic_block.cpp
block_merge_pass.cpp
block_merge_util.cpp
@@ -133,6 +139,7 @@
decompose_initialized_variables_pass.cpp
decoration_manager.cpp
def_use_manager.cpp
+ desc_sroa.cpp
dominator_analysis.cpp
dominator_tree.cpp
eliminate_dead_constant_pass.cpp
@@ -147,12 +154,14 @@
fold_spec_constant_op_and_composite_pass.cpp
freeze_spec_constant_value_pass.cpp
function.cpp
+ graphics_robust_access_pass.cpp
generate_webgpu_initializers_pass.cpp
if_conversion.cpp
inline_exhaustive_pass.cpp
inline_opaque_pass.cpp
inline_pass.cpp
inst_bindless_check_pass.cpp
+ inst_buff_addr_check_pass.cpp
instruction.cpp
instruction_list.cpp
instrument_pass.cpp
@@ -208,6 +217,7 @@
value_number_table.cpp
vector_dce.cpp
workaround1209.cpp
+ wrap_opkill.cpp
)
if(MSVC)
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 11a9574..761ff7c 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -664,6 +664,9 @@
// been marked, it is safe to remove dead global values.
modified |= ProcessGlobalValues();
+ // Sanity check.
+ assert(to_kill_.size() == 0 || modified);
+
// Kill all dead instructions.
for (auto inst : to_kill_) {
context()->KillInst(inst);
@@ -836,7 +839,17 @@
// attributes here.
for (auto& val : get_module()->types_values()) {
if (IsDead(&val)) {
+ // Save forwarded pointer if pointer is live since closure does not mark
+ // this live as it does not have a result id. This is a little too
+ // conservative since it is not known if the structure type that needed
+ // it is still live. TODO(greg-lunarg): Only save if needed.
+ if (val.opcode() == SpvOpTypeForwardPointer) {
+ uint32_t ptr_ty_id = val.GetSingleWordInOperand(0);
+ Instruction* ptr_ty_inst = get_def_use_mgr()->GetDef(ptr_ty_id);
+ if (!IsDead(ptr_ty_inst)) continue;
+ }
to_kill_.push_back(&val);
+ modified = true;
}
}
@@ -918,6 +931,7 @@
"SPV_NV_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_EXT_fragment_invocation_density",
+ "SPV_EXT_physical_storage_buffer",
});
}
diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp
new file mode 100644
index 0000000..1cb5ba5
--- /dev/null
+++ b/source/opt/amd_ext_to_khr.cpp
@@ -0,0 +1,539 @@
+// Copyright (c) 2019 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/amd_ext_to_khr.h"
+
+#include "ir_builder.h"
+#include "source/opt/ir_context.h"
+#include "spv-amd-shader-ballot.insts.inc"
+#include "type_manager.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+enum ExtOpcodes {
+ AmdShaderBallotSwizzleInvocationsAMD = 1,
+ AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
+ AmdShaderBallotWriteInvocationAMD = 3,
+ AmdShaderBallotMbcntAMD = 4
+};
+
+analysis::Type* GetUIntType(IRContext* ctx) {
+ analysis::Integer int_type(32, false);
+ return ctx->get_type_mgr()->GetRegisteredType(&int_type);
+}
+
+// Returns a folding rule that will replace the opcode with |opcode| and add
+// the capabilities required. The folding rule assumes it is folding an
+// OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
+FoldingRule ReplaceGroupNonuniformOperationOpCode(SpvOp new_opcode) {
+ switch (new_opcode) {
+ case SpvOpGroupNonUniformIAdd:
+ case SpvOpGroupNonUniformFAdd:
+ case SpvOpGroupNonUniformUMin:
+ case SpvOpGroupNonUniformSMin:
+ case SpvOpGroupNonUniformFMin:
+ case SpvOpGroupNonUniformUMax:
+ case SpvOpGroupNonUniformSMax:
+ case SpvOpGroupNonUniformFMax:
+ break;
+ default:
+ assert(
+ false &&
+ "Should be replacing with a group non uniform arithmetic operation.");
+ }
+
+ return [new_opcode](IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ switch (inst->opcode()) {
+ case SpvOpGroupIAddNonUniformAMD:
+ case SpvOpGroupFAddNonUniformAMD:
+ case SpvOpGroupUMinNonUniformAMD:
+ case SpvOpGroupSMinNonUniformAMD:
+ case SpvOpGroupFMinNonUniformAMD:
+ case SpvOpGroupUMaxNonUniformAMD:
+ case SpvOpGroupSMaxNonUniformAMD:
+ case SpvOpGroupFMaxNonUniformAMD:
+ break;
+ default:
+ assert(false &&
+ "Should be replacing a group non uniform arithmetic operation.");
+ }
+
+ ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
+ inst->SetOpcode(new_opcode);
+ return true;
+ };
+}
+
+// Returns a folding rule that will replace the SwizzleInvocationsAMD extended
+// instruction in the SPV_AMD_shader_ballot extension.
+//
+// The instruction
+//
+// %offset = OpConstantComposite %v3uint %x %y %z %w
+// %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
+//
+// is replaced with
+//
+// potentially new constants and types
+//
+// clang-format off
+// %uint_max = OpConstant %uint 0xFFFFFFFF
+// %v4uint = OpTypeVector %uint 4
+// %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
+// %null = OpConstantNull %type
+// clang-format on
+//
+// and the following code in the function body
+//
+// clang-format off
+// %id = OpLoad %uint %SubgroupLocalInvocationId
+// %quad_idx = OpBitwiseAnd %uint %id %uint_3
+// %quad_ldr = OpBitwiseXor %uint %id %quad_idx
+// %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
+// %target_inv = OpIAdd %uint %quad_ldr %my_offset
+// %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
+// %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
+// %result = OpSelect %type %is_active %shuffle %null
+// clang-format on
+//
+// Also adding the capabilities and builtins that are needed.
+FoldingRule ReplaceSwizzleInvocations() {
+ return [](IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
+ uint32_t offset_id = inst->GetSingleWordInOperand(3);
+
+ // Get the subgroup invocation id.
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+
+ uint32_t quad_mask = ir_builder.GetUintConstantId(3);
+
+ // This gives the offset in the group of 4 of this invocation.
+ Instruction* quad_idx = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), quad_mask);
+
+ // Get the invocation id of the first invocation in the group of 4.
+ Instruction* quad_ldr = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
+
+ // Get the offset of the target invocation from the offset vector.
+ Instruction* my_offset =
+ ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic,
+ offset_id, quad_idx->result_id());
+
+ // Determine the index of the invocation to read from.
+ Instruction* target_inv = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
+
+ // Do the group operations
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+ const auto* ballot_value_const = const_mgr->GetConstant(
+ type_mgr->GetUIntVectorType(4),
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+ Instruction* ballot_value =
+ const_mgr->GetDefiningInstruction(ballot_value_const);
+ Instruction* is_active = ir_builder.AddNaryOp(
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+ Instruction* shuffle = ir_builder.AddNaryOp(
+ inst->type_id(), SpvOpGroupNonUniformShuffle,
+ {subgroup_scope, data_id, target_inv->result_id()});
+
+ // Create the null constant to use in the select.
+ const auto* null = const_mgr->GetConstant(
+ type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+
+ // Build the select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
+ };
+}
+
+// Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
+// extended instruction in the SPV_AMD_shader_ballot extension.
+//
+// The instruction
+//
+// %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
+// %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
+//
+// is replaced with
+//
+// potentially new constants and types
+//
+// clang-format off
+// %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
+// %uint_max = OpConstant %uint 0xFFFFFFFF
+// %v4uint = OpTypeVector %uint 4
+// %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
+// clang-format on
+//
+// and the following code in the function body
+//
+// clang-format off
+// %id = OpLoad %uint %SubgroupLocalInvocationId
+// %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
+// %and = OpBitwiseAnd %uint %id %and_mask
+// %or = OpBitwiseOr %uint %and %uint_y
+// %target_inv = OpBitwiseXor %uint %or %uint_z
+// %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
+// %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
+// %result = OpSelect %type %is_active %shuffle %uint_0
+// clang-format on
+//
+// Also adding the capabilities and builtins that are needed.
+FoldingRule ReplaceSwizzleInvocationsMasked() {
+ return [](IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+
+ // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ // Get the operands to inst, and the components of the mask
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
+
+ Instruction* mask_inst =
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
+ assert(mask_inst->opcode() == SpvOpConstantComposite &&
+ "The mask is suppose to be a vector constant.");
+ assert(mask_inst->NumInOperands() == 3 &&
+ "The mask is suppose to have 3 components.");
+
+ uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
+ uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
+ uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
+
+ // Get the subgroup invocation id.
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+
+ // Do the bitwise operations.
+ uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
+ Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
+ uint_x, mask_extended);
+ Instruction* and_result = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
+ Instruction* or_result = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
+ Instruction* target_inv = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
+
+ // Do the group operations
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+ const auto* ballot_value_const = const_mgr->GetConstant(
+ type_mgr->GetUIntVectorType(4),
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+ Instruction* ballot_value =
+ const_mgr->GetDefiningInstruction(ballot_value_const);
+ Instruction* is_active = ir_builder.AddNaryOp(
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+ Instruction* shuffle = ir_builder.AddNaryOp(
+ inst->type_id(), SpvOpGroupNonUniformShuffle,
+ {subgroup_scope, data_id, target_inv->result_id()});
+
+ // Create the null constant to use in the select.
+ const auto* null = const_mgr->GetConstant(
+ type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+
+ // Build the select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
+ };
+}
+
+// Returns a folding rule that will replace the WriteInvocationAMD extended
+// instruction in the SPV_AMD_shader_ballot extension.
+//
+// The instruction
+//
+// clang-format off
+// %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
+// clang-format on
+//
+// with
+//
+// %id = OpLoad %uint %SubgroupLocalInvocationId
+// %cmp = OpIEqual %bool %id %invocation_index
+// %result = OpSelect %type %cmp %write_value %input_value
+//
+// Also adding the capabilities and builtins that are needed.
+FoldingRule ReplaceWriteInvocation() {
+ return [](IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ Instruction* t =
+ ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
+ analysis::Bool bool_type;
+ uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
+ Instruction* cmp =
+ ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
+ inst->GetSingleWordInOperand(4));
+
+ // Build a select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
+ new_operands.push_back(inst->GetInOperand(3));
+ new_operands.push_back(inst->GetInOperand(2));
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
+ };
+}
+
+// Returns a folding rule that will replace the MbcntAMD extended instruction in
+// the SPV_AMD_shader_ballot extension.
+//
+// The instruction
+//
+// %result = OpExtInst %uint %1 MbcntAMD %mask
+//
+// with
+//
+// Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
+// AMD's shader compiler expects a 64-bit integer mask.
+//
+// %var = OpLoad %v4uint %SubgroupLtMaskKHR
+// %shuffle = OpVectorShuffle %v2uint %var %var 0 1
+// %cast = OpBitcast %ulong %shuffle
+//
+// Perform the mask and count the bits.
+//
+// %and = OpBitwiseAnd %ulong %cast %mask
+// %result = OpBitCount %uint %and
+//
+// Also adding the capabilities and builtins that are needed.
+FoldingRule ReplaceMbcnt() {
+ return [](IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+
+ uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
+ assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
+ context->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ Instruction* var_inst = def_use_mgr->GetDef(var_id);
+ Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
+ Instruction* var_type =
+ def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
+ assert(var_type->opcode() == SpvOpTypeVector &&
+ "Variable is suppose to be a vector of 4 ints");
+
+ // Get the type for the shuffle.
+ analysis::Vector temp_type(GetUIntType(context), 2);
+ const analysis::Type* shuffle_type =
+ context->get_type_mgr()->GetRegisteredType(&temp_type);
+ uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
+
+ uint32_t mask_id = inst->GetSingleWordInOperand(2);
+ Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
+
+ // Testing with amd's shader compiler shows that a 64-bit mask is expected.
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
+
+ InstructionBuilder ir_builder(
+ context, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
+ Instruction* shuffle = ir_builder.AddVectorShuffle(
+ shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
+ Instruction* bitcast = ir_builder.AddUnaryOp(
+ mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
+ Instruction* t = ir_builder.AddBinaryOp(
+ mask_inst->type_id(), SpvOpBitwiseAnd, bitcast->result_id(), mask_id);
+
+ inst->SetOpcode(SpvOpBitCount);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
+ context->UpdateDefUse(inst);
+ return true;
+ };
+}
+
+class AmdExtFoldingRules : public FoldingRules {
+ public:
+ explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
+
+ protected:
+ virtual void AddFoldingRules() override {
+ rules_[SpvOpGroupIAddNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformIAdd));
+ rules_[SpvOpGroupFAddNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFAdd));
+ rules_[SpvOpGroupUMinNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMin));
+ rules_[SpvOpGroupSMinNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMin));
+ rules_[SpvOpGroupFMinNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMin));
+ rules_[SpvOpGroupUMaxNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMax));
+ rules_[SpvOpGroupSMaxNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMax));
+ rules_[SpvOpGroupFMaxNonUniformAMD].push_back(
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMax));
+
+ uint32_t extension_id =
+ context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
+
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}].push_back(
+ ReplaceSwizzleInvocations());
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
+ .push_back(ReplaceSwizzleInvocationsMasked());
+ ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
+ ReplaceWriteInvocation());
+ ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
+ ReplaceMbcnt());
+ }
+};
+
+class AmdExtConstFoldingRules : public ConstantFoldingRules {
+ public:
+ AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
+
+ protected:
+ virtual void AddFoldingRules() override {}
+};
+
+} // namespace
+
+Pass::Status AmdExtensionToKhrPass::Process() {
+ bool changed = false;
+
+ // Traverse the body of the functions to replace instructions that require
+ // the extensions.
+ InstructionFolder folder(
+ context(),
+ std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
+ MakeUnique<AmdExtConstFoldingRules>(context()));
+ for (Function& func : *get_module()) {
+ func.ForEachInst([&changed, &folder](Instruction* inst) {
+ if (folder.FoldInstruction(inst)) {
+ changed = true;
+ }
+ });
+ }
+
+ // Now that instruction that require the extensions have been removed, we can
+ // remove the extension instructions.
+ std::vector<Instruction*> to_be_killed;
+ for (Instruction& inst : context()->module()->extensions()) {
+ if (inst.opcode() == SpvOpExtension) {
+ if (!strcmp("SPV_AMD_shader_ballot",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0])))) {
+ to_be_killed.push_back(&inst);
+ }
+ }
+ }
+
+ for (Instruction& inst : context()->ext_inst_imports()) {
+ if (inst.opcode() == SpvOpExtInstImport) {
+ if (!strcmp("SPV_AMD_shader_ballot",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0])))) {
+ to_be_killed.push_back(&inst);
+ }
+ }
+ }
+
+ for (Instruction* inst : to_be_killed) {
+ context()->KillInst(inst);
+ changed = true;
+ }
+
+ // The replacements that take place use instructions that are missing before
+ // SPIR-V 1.3. If we changed something, we will have to make sure the version
+ // is at least SPIR-V 1.3 to make sure those instruction can be used.
+ if (changed) {
+ uint32_t version = get_module()->version();
+ if (version < 0x00010300 /*1.3*/) {
+ get_module()->set_version(0x00010300);
+ }
+ }
+ return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/amd_ext_to_khr.h b/source/opt/amd_ext_to_khr.h
new file mode 100644
index 0000000..fd3dab4
--- /dev/null
+++ b/source/opt/amd_ext_to_khr.h
@@ -0,0 +1,51 @@
+// Copyright (c) 2019 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_AMD_EXT_TO_KHR_H_
+#define SOURCE_OPT_AMD_EXT_TO_KHR_H_
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Replaces the extensions VK_AMD_shader_ballot, VK_AMD_gcn_shader, and
+// VK_AMD_shader_trinary_minmax with equivalant code using core instructions and
+// capabilities.
+class AmdExtensionToKhrPass : public Pass {
+ public:
+ const char* name() const override { return "amd-ext-to-khr"; }
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+ IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
+ IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
+ IRContext::kAnalysisScalarEvolution |
+ IRContext::kAnalysisRegisterPressure |
+ IRContext::kAnalysisValueNumberTable |
+ IRContext::kAnalysisStructuredCFG |
+ IRContext::kAnalysisBuiltinVarId |
+ IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisTypes |
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants;
+ }
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_AMD_EXT_TO_KHR_H_
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 10fcde4..06a1a81 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -55,6 +55,9 @@
auto cc = c->AsCompositeConstant();
assert(cc != nullptr);
auto components = cc->GetComponents();
+ // Protect against invalid IR. Refuse to fold if the index is out
+ // of bounds.
+ if (element_index >= components.size()) return nullptr;
c = components[element_index];
}
return c;
@@ -806,9 +809,62 @@
};
}
+ConstantFoldingRule FoldFMix() {
+ return [](IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants)
+ -> const analysis::Constant* {
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ assert(inst->opcode() == SpvOpExtInst &&
+ "Expecting an extended instruction.");
+ assert(inst->GetSingleWordInOperand(0) ==
+ context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
+ "Expecting a GLSLstd450 extended instruction.");
+ assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
+ "Expecting and FMix instruction.");
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return nullptr;
+ }
+
+ // Make sure all FMix operands are constants.
+ for (uint32_t i = 1; i < 4; i++) {
+ if (constants[i] == nullptr) {
+ return nullptr;
+ }
+ }
+
+ const analysis::Constant* one;
+ if (constants[1]->type()->AsFloat()->width() == 32) {
+ one = const_mgr->GetConstant(constants[1]->type(),
+ utils::FloatProxy<float>(1.0f).GetWords());
+ } else {
+ one = const_mgr->GetConstant(constants[1]->type(),
+ utils::FloatProxy<double>(1.0).GetWords());
+ }
+
+ const analysis::Constant* temp1 =
+ FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
+ if (temp1 == nullptr) {
+ return nullptr;
+ }
+
+ const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
+ constants[1]->type(), constants[1], temp1, const_mgr);
+ if (temp2 == nullptr) {
+ return nullptr;
+ }
+ const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
+ constants[2]->type(), constants[2], constants[3], const_mgr);
+ if (temp3 == nullptr) {
+ return nullptr;
+ }
+ return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
+ };
+}
+
} // namespace
-ConstantFoldingRules::ConstantFoldingRules() {
+void ConstantFoldingRules::AddFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
@@ -874,6 +930,14 @@
rules_[SpvOpFNegate].push_back(FoldFNegate());
rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
+
+ // Add rules for GLSLstd450
+ FeatureManager* feature_manager = context_->get_feature_mgr();
+ uint32_t ext_inst_glslstd450_id =
+ feature_manager->GetExtInstImportId_GLSLstd450();
+ if (ext_inst_glslstd450_id != 0) {
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
+ }
}
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/const_folding_rules.h b/source/opt/const_folding_rules.h
index c186579..41ee2aa 100644
--- a/source/opt/const_folding_rules.h
+++ b/source/opt/const_folding_rules.h
@@ -53,24 +53,74 @@
const std::vector<const analysis::Constant*>& constants)>;
class ConstantFoldingRules {
+ protected:
+ // The |Key| and |Value| structs are used to by-pass a "decorated name length
+ // exceeded, name was truncated" warning on VS2013 and VS2015.
+ struct Key {
+ uint32_t instruction_set;
+ uint32_t opcode;
+ };
+
+ friend bool operator<(const Key& a, const Key& b) {
+ if (a.instruction_set < b.instruction_set) {
+ return true;
+ }
+ if (a.instruction_set > b.instruction_set) {
+ return false;
+ }
+ return a.opcode < b.opcode;
+ }
+
+ struct Value {
+ std::vector<ConstantFoldingRule> value;
+ void push_back(ConstantFoldingRule rule) { value.push_back(rule); }
+ };
+
public:
- ConstantFoldingRules();
+ ConstantFoldingRules(IRContext* ctx) : context_(ctx) {}
+ virtual ~ConstantFoldingRules() = default;
// Returns true if there is at least 1 folding rule for |opcode|.
- bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); }
+ bool HasFoldingRule(const Instruction* inst) const {
+ return !GetRulesForInstruction(inst).empty();
+ }
- // Returns an vector of constant folding rules for |opcode|.
- const std::vector<ConstantFoldingRule>& GetRulesForOpcode(
- SpvOp opcode) const {
- auto it = rules_.find(opcode);
- if (it != rules_.end()) {
- return it->second;
+ // Returns true if there is at least 1 folding rule for |inst|.
+ const std::vector<ConstantFoldingRule>& GetRulesForInstruction(
+ const Instruction* inst) const {
+ if (inst->opcode() != SpvOpExtInst) {
+ auto it = rules_.find(inst->opcode());
+ if (it != rules_.end()) {
+ return it->second.value;
+ }
+ } else {
+ uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
+ uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
+ auto it = ext_rules_.find({ext_inst_id, ext_opcode});
+ if (it != ext_rules_.end()) {
+ return it->second.value;
+ }
}
return empty_vector_;
}
+ // Add the folding rules.
+ virtual void AddFoldingRules();
+
+ protected:
+ // |rules[opcode]| is the set of rules that can be applied to instructions
+ // with |opcode| as the opcode.
+ std::unordered_map<uint32_t, Value> rules_;
+
+ // The folding rules for extended instructions.
+ std::map<Key, Value> ext_rules_;
+
private:
- std::unordered_map<uint32_t, std::vector<ConstantFoldingRule>> rules_;
+ // The context that the instruction to be folded will be a part of.
+ IRContext* context_;
+
+ // The empty set of rules to be used as the default return value in
+ // |GetRulesForInstruction|.
std::vector<ConstantFoldingRule> empty_vector_;
};
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index 3c05f9e..5c1468b 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -103,6 +103,45 @@
}
}
+uint64_t Constant::GetZeroExtendedValue() const {
+ const auto* int_type = type()->AsInteger();
+ assert(int_type != nullptr);
+ const auto width = int_type->width();
+ assert(width <= 64);
+
+ uint64_t value = 0;
+ if (const IntConstant* ic = AsIntConstant()) {
+ if (width <= 32) {
+ value = ic->GetU32BitValue();
+ } else {
+ value = ic->GetU64BitValue();
+ }
+ } else {
+ assert(AsNullConstant() && "Must be an integer constant.");
+ }
+ return value;
+}
+
+int64_t Constant::GetSignExtendedValue() const {
+ const auto* int_type = type()->AsInteger();
+ assert(int_type != nullptr);
+ const auto width = int_type->width();
+ assert(width <= 64);
+
+ int64_t value = 0;
+ if (const IntConstant* ic = AsIntConstant()) {
+ if (width <= 32) {
+ // Let the C++ compiler do the sign extension.
+ value = int64_t(ic->GetS32BitValue());
+ } else {
+ value = ic->GetS64BitValue();
+ }
+ } else {
+ assert(AsNullConstant() && "Must be an integer constant.");
+ }
+ return value;
+}
+
ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
// Populate the constant table with values from constant declarations in the
// module. The values of each OpConstant declaration is the identity
@@ -252,7 +291,7 @@
}
}
-const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
+const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
std::vector<uint32_t> literal_words_or_ids;
// Collect the constant defining literals or component ids.
diff --git a/source/opt/constants.h b/source/opt/constants.h
index a8e0fb5..93d0847 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -116,6 +116,14 @@
// Integer type.
int64_t GetS64() const;
+ // Returns the zero-extended representation of an integer constant. Must
+ // be an integral constant of at most 64 bits.
+ uint64_t GetZeroExtendedValue() const;
+
+ // Returns the sign-extended representation of an integer constant. Must
+ // be an integral constant of at most 64 bits.
+ int64_t GetSignExtendedValue() const;
+
// Returns true if the constant is a zero or a composite containing 0s.
virtual bool IsZero() const { return false; }
@@ -514,7 +522,7 @@
// Gets or creates a Constant instance to hold the constant value of the given
// instruction. It returns a pointer to a Constant instance or nullptr if it
// could not create the constant.
- const Constant* GetConstantFromInst(Instruction* inst);
+ const Constant* GetConstantFromInst(const Instruction* inst);
// Gets or creates a constant defining instruction for the given Constant |c|.
// If |c| had already been defined, it returns a pointer to the existing
diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp
index 751786c..00757e3 100644
--- a/source/opt/copy_prop_arrays.cpp
+++ b/source/opt/copy_prop_arrays.cpp
@@ -527,6 +527,9 @@
pointer_type->storage_class());
uint32_t new_pointer_type_id =
context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
+ if (new_pointer_type_id == 0) {
+ return false;
+ }
if (new_pointer_type_id != use->type_id()) {
return CanUpdateUses(use, new_pointer_type_id);
@@ -542,6 +545,9 @@
const analysis::Type* new_type =
type_mgr->GetMemberType(type, access_chain);
uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
+ if (new_type_id == 0) {
+ return false;
+ }
if (new_type_id != use->type_id()) {
return CanUpdateUses(use, new_type_id);
diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp
new file mode 100644
index 0000000..1f25b33
--- /dev/null
+++ b/source/opt/desc_sroa.cpp
@@ -0,0 +1,273 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/desc_sroa.h"
+
+#include "source/util/string_utils.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status DescriptorScalarReplacement::Process() {
+ bool modified = false;
+
+ std::vector<Instruction*> vars_to_kill;
+
+ for (Instruction& var : context()->types_values()) {
+ if (IsCandidate(&var)) {
+ modified = true;
+ if (!ReplaceCandidate(&var)) {
+ return Status::Failure;
+ }
+ vars_to_kill.push_back(&var);
+ }
+ }
+
+ for (Instruction* var : vars_to_kill) {
+ context()->KillInst(var);
+ }
+
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+}
+
+bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
+ if (var->opcode() != SpvOpVariable) {
+ return false;
+ }
+
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst =
+ context()->get_def_use_mgr()->GetDef(ptr_type_id);
+ if (ptr_type_inst->opcode() != SpvOpTypePointer) {
+ return false;
+ }
+
+ uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* var_type_inst =
+ context()->get_def_use_mgr()->GetDef(var_type_id);
+ if (var_type_inst->opcode() != SpvOpTypeArray) {
+ return false;
+ }
+
+ bool has_desc_set_decoration = false;
+ context()->get_decoration_mgr()->ForEachDecoration(
+ var->result_id(), SpvDecorationDescriptorSet,
+ [&has_desc_set_decoration](const Instruction&) {
+ has_desc_set_decoration = true;
+ });
+ if (!has_desc_set_decoration) {
+ return false;
+ }
+
+ bool has_binding_decoration = false;
+ context()->get_decoration_mgr()->ForEachDecoration(
+ var->result_id(), SpvDecorationBinding,
+ [&has_binding_decoration](const Instruction&) {
+ has_binding_decoration = true;
+ });
+ if (!has_binding_decoration) {
+ return false;
+ }
+
+ return true;
+}
+
+bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
+ std::vector<Instruction*> work_list;
+ bool failed = !get_def_use_mgr()->WhileEachUser(
+ var->result_id(), [this, &work_list](Instruction* use) {
+ if (use->opcode() == SpvOpName) {
+ return true;
+ }
+
+ if (use->IsDecoration()) {
+ return true;
+ }
+
+ switch (use->opcode()) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ work_list.push_back(use);
+ return true;
+ default:
+ context()->EmitErrorMessage(
+ "Variable cannot be replaced: invalid instruction", use);
+ return false;
+ }
+ return true;
+ });
+
+ if (failed) {
+ return false;
+ }
+
+ for (Instruction* use : work_list) {
+ if (!ReplaceAccessChain(var, use)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
+ Instruction* use) {
+ if (use->NumInOperands() <= 1) {
+ context()->EmitErrorMessage(
+ "Variable cannot be replaced: invalid instruction", use);
+ return false;
+ }
+
+ uint32_t idx_id = use->GetSingleWordInOperand(1);
+ const analysis::Constant* idx_const =
+ context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
+ if (idx_const == nullptr) {
+ context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
+ use);
+ return false;
+ }
+
+ uint32_t idx = idx_const->GetU32();
+ uint32_t replacement_var = GetReplacementVariable(var, idx);
+
+ if (use->NumInOperands() == 2) {
+ // We are not indexing into the replacement variable. We can replaces the
+ // access chain with the replacement varibale itself.
+ context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
+ context()->KillInst(use);
+ return true;
+ }
+
+ // We need to build a new access chain with the replacement variable as the
+ // base address.
+ Instruction::OperandList new_operands;
+
+ // Same result id and result type.
+ new_operands.emplace_back(use->GetOperand(0));
+ new_operands.emplace_back(use->GetOperand(1));
+
+ // Use the replacement variable as the base address.
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
+
+ // Drop the first index because it is consumed by the replacment, and copy the
+ // rest.
+ for (uint32_t i = 4; i < use->NumOperands(); i++) {
+ new_operands.emplace_back(use->GetOperand(i));
+ }
+
+ use->ReplaceOperands(new_operands);
+ context()->UpdateDefUse(use);
+ return true;
+}
+
+uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
+ uint32_t idx) {
+ auto replacement_vars = replacement_variables_.find(var);
+ if (replacement_vars == replacement_variables_.end()) {
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
+ assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
+ "Variable should be a pointer to an array.");
+ uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
+ assert(arr_type_inst->opcode() == SpvOpTypeArray &&
+ "Variable should be a pointer to an array.");
+
+ uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1);
+ const analysis::Constant* array_len_const =
+ context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
+ assert(array_len_const != nullptr && "Array length must be a constant.");
+ uint32_t array_len = array_len_const->GetU32();
+
+ replacement_vars = replacement_variables_
+ .insert({var, std::vector<uint32_t>(array_len, 0)})
+ .first;
+ }
+
+ if (replacement_vars->second[idx] == 0) {
+ replacement_vars->second[idx] = CreateReplacementVariable(var, idx);
+ }
+
+ return replacement_vars->second[idx];
+}
+
+uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
+ Instruction* var, uint32_t idx) {
+ // The storage class for the new variable is the same as the original.
+ SpvStorageClass storage_class =
+ static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
+
+ // The type for the new variable will be a pointer to type of the elements of
+ // the array.
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
+ assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
+ "Variable should be a pointer to an array.");
+ uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
+ assert(arr_type_inst->opcode() == SpvOpTypeArray &&
+ "Variable should be a pointer to an array.");
+ uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0);
+
+ uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType(
+ element_type_id, storage_class);
+
+ // Create the variable.
+ uint32_t id = TakeNextId();
+ std::unique_ptr<Instruction> variable(
+ new Instruction(context(), SpvOpVariable, ptr_element_type_id, id,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_STORAGE_CLASS,
+ {static_cast<uint32_t>(storage_class)}}}));
+ context()->AddGlobalValue(std::move(variable));
+
+ // Copy all of the decorations to the new variable. The only difference is
+ // the Binding decoration needs to be adjusted.
+ for (auto old_decoration :
+ get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
+ assert(old_decoration->opcode() == SpvOpDecorate);
+ std::unique_ptr<Instruction> new_decoration(
+ old_decoration->Clone(context()));
+ new_decoration->SetInOperand(0, {id});
+
+ uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
+ if (decoration == SpvDecorationBinding) {
+ uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx;
+ new_decoration->SetInOperand(2, {new_binding});
+ }
+ context()->AddAnnotationInst(std::move(new_decoration));
+ }
+
+ // Create a new OpName for the replacement variable.
+ for (auto p : context()->GetNames(var->result_id())) {
+ Instruction* name_inst = p.second;
+ std::string name_str = utils::MakeString(name_inst->GetOperand(1).words);
+ name_str += "[";
+ name_str += utils::ToString(idx);
+ name_str += "]";
+
+ std::unique_ptr<Instruction> new_name(new Instruction(
+ context(), SpvOpName, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {id}},
+ {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}}));
+ Instruction* new_name_inst = new_name.get();
+ context()->AddDebug2Inst(std::move(new_name));
+ get_def_use_mgr()->AnalyzeInstDefUse(new_name_inst);
+ }
+
+ return id;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h
new file mode 100644
index 0000000..a95c6b5
--- /dev/null
+++ b/source/opt/desc_sroa.h
@@ -0,0 +1,84 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_DESC_SROA_H_
+#define SOURCE_OPT_DESC_SROA_H_
+
+#include <cstdio>
+#include <memory>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "source/opt/function.h"
+#include "source/opt/pass.h"
+#include "source/opt/type_manager.h"
+
+namespace spvtools {
+namespace opt {
+
+// Documented in optimizer.hpp
+class DescriptorScalarReplacement : public Pass {
+ public:
+ DescriptorScalarReplacement() {}
+
+ const char* name() const override { return "descriptor-scalar-replacement"; }
+
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+ }
+
+ private:
+ // Returns true if |var| is an OpVariable instruction that represents a
+ // descriptor array. These are the variables that we want to replace.
+ bool IsCandidate(Instruction* var);
+
+ // Replaces all references to |var| by new variables, one for each element of
+ // the array |var|. The binding for the new variables corresponding to
+ // element i will be the binding of |var| plus i. Returns true if successful.
+ bool ReplaceCandidate(Instruction* var);
+
+ // Replaces the base address |var| in the OpAccessChain or
+ // OpInBoundsAccessChain instruction |use| by the variable that the access
+ // chain accesses. The first index in |use| must be an |OpConstant|. Returns
+ // |true| if successful.
+ bool ReplaceAccessChain(Instruction* var, Instruction* use);
+
+ // Returns the id of the variable that will be used to replace the |idx|th
+ // element of |var|. The variable is created if it has not already been
+ // created.
+ uint32_t GetReplacementVariable(Instruction* var, uint32_t idx);
+
+ // Returns the id of a new variable that can be used to replace the |idx|th
+ // element of |var|.
+ uint32_t CreateReplacementVariable(Instruction* var, uint32_t idx);
+
+ // A map from an OpVariable instruction to the set of variables that will be
+ // used to replace it. The entry |replacement_variables_[var][i]| is the id of
+ // a variable that will be used in the place of the the ith element of the
+ // array |var|. If the entry is |0|, then the variable has not been
+ // created yet.
+ std::map<Instruction*, std::vector<uint32_t>> replacement_variables_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_DESC_SROA_H_
diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp
index b7fc16a..63d50b6 100644
--- a/source/opt/feature_manager.cpp
+++ b/source/opt/feature_manager.cpp
@@ -31,12 +31,19 @@
void FeatureManager::AddExtensions(Module* module) {
for (auto ext : module->extensions()) {
- const std::string name =
- reinterpret_cast<const char*>(ext.GetInOperand(0u).words.data());
- Extension extension;
- if (GetExtensionFromString(name.c_str(), &extension)) {
- extensions_.Add(extension);
- }
+ AddExtension(&ext);
+ }
+}
+
+void FeatureManager::AddExtension(Instruction* ext) {
+ assert(ext->opcode() == SpvOpExtension &&
+ "Expecting an extension instruction.");
+
+ const std::string name =
+ reinterpret_cast<const char*>(ext->GetInOperand(0u).words.data());
+ Extension extension;
+ if (GetExtensionFromString(name.c_str(), &extension)) {
+ extensions_.Add(extension);
}
}
@@ -63,5 +70,27 @@
extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450");
}
+bool operator==(const FeatureManager& a, const FeatureManager& b) {
+ // We check that the addresses of the grammars are the same because they
+ // are large objects, and this is faster. It can be changed if needed as a
+ // later time.
+ if (&a.grammar_ != &b.grammar_) {
+ return false;
+ }
+
+ if (a.capabilities_ != b.capabilities_) {
+ return false;
+ }
+
+ if (a.extensions_ != b.extensions_) {
+ return false;
+ }
+
+ if (a.extinst_importid_GLSLstd450_ != b.extinst_importid_GLSLstd450_) {
+ return false;
+ }
+
+ return true;
+}
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h
index 80b2ccc..761a208 100644
--- a/source/opt/feature_manager.h
+++ b/source/opt/feature_manager.h
@@ -45,14 +45,22 @@
return extinst_importid_GLSLstd450_;
}
- private:
- // Analyzes |module| and records enabled extensions.
- void AddExtensions(Module* module);
+ friend bool operator==(const FeatureManager& a, const FeatureManager& b);
+ friend bool operator!=(const FeatureManager& a, const FeatureManager& b) {
+ return !(a == b);
+ }
// Adds the given |capability| and all implied capabilities into the current
// FeatureManager.
void AddCapability(SpvCapability capability);
+ // Add the extension |ext| to the feature manager.
+ void AddExtension(Instruction* ext);
+
+ private:
+ // Analyzes |module| and records enabled extensions.
+ void AddExtensions(Module* module);
+
// Analyzes |module| and records enabled capabilities.
void AddCapabilities(Module* module);
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index 944f438..276e835 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -234,13 +234,12 @@
return true;
}
- SpvOp opcode = inst->opcode();
analysis::ConstantManager* const_manager = context_->get_constant_mgr();
-
std::vector<const analysis::Constant*> constants =
const_manager->GetOperandConstants(inst);
- for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
+ for (const FoldingRule& rule :
+ GetFoldingRules().GetRulesForInstruction(inst)) {
if (rule(context_, inst, constants)) {
return true;
}
@@ -623,7 +622,7 @@
analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
if (!inst->IsFoldableByFoldScalar() &&
- !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
+ !GetConstantFoldingRules().HasFoldingRule(inst)) {
return nullptr;
}
// Collect the values of the constant parameters.
@@ -641,19 +640,16 @@
}
});
- if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
- const analysis::Constant* folded_const = nullptr;
- for (auto rule :
- GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
- folded_const = rule(context_, inst, constants);
- if (folded_const != nullptr) {
- Instruction* const_inst =
- const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
- assert(const_inst->type_id() == inst->type_id());
- // May be a new instruction that needs to be analysed.
- context_->UpdateDefUse(const_inst);
- return const_inst;
- }
+ const analysis::Constant* folded_const = nullptr;
+ for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
+ folded_const = rule(context_, inst, constants);
+ if (folded_const != nullptr) {
+ Instruction* const_inst =
+ const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
+ assert(const_inst->type_id() == inst->type_id());
+ // May be a new instruction that needs to be analysed.
+ context_->UpdateDefUse(const_inst);
+ return const_inst;
}
}
diff --git a/source/opt/fold.h b/source/opt/fold.h
index 0dc7c0e..9e7c470 100644
--- a/source/opt/fold.h
+++ b/source/opt/fold.h
@@ -28,7 +28,23 @@
class InstructionFolder {
public:
- explicit InstructionFolder(IRContext* context) : context_(context) {}
+ explicit InstructionFolder(IRContext* context)
+ : context_(context),
+ const_folding_rules_(new ConstantFoldingRules(context)),
+ folding_rules_(new FoldingRules(context)) {
+ folding_rules_->AddFoldingRules();
+ const_folding_rules_->AddFoldingRules();
+ }
+
+ explicit InstructionFolder(
+ IRContext* context, std::unique_ptr<FoldingRules>&& folding_rules,
+ std::unique_ptr<ConstantFoldingRules>&& constant_folding_rules)
+ : context_(context),
+ const_folding_rules_(std::move(constant_folding_rules)),
+ folding_rules_(std::move(folding_rules)) {
+ folding_rules_->AddFoldingRules();
+ const_folding_rules_->AddFoldingRules();
+ }
// Returns the result of folding a scalar instruction with the given |opcode|
// and |operands|. Each entry in |operands| is a pointer to an
@@ -95,18 +111,18 @@
bool FoldInstruction(Instruction* inst) const;
// Return true if this opcode has a const folding rule associtated with it.
- bool HasConstFoldingRule(SpvOp opcode) const {
- return GetConstantFoldingRules().HasFoldingRule(opcode);
+ bool HasConstFoldingRule(const Instruction* inst) const {
+ return GetConstantFoldingRules().HasFoldingRule(inst);
}
private:
// Returns a reference to the ConstnatFoldingRules instance.
const ConstantFoldingRules& GetConstantFoldingRules() const {
- return const_folding_rules;
+ return *const_folding_rules_;
}
// Returns a reference to the FoldingRules instance.
- const FoldingRules& GetFoldingRules() const { return folding_rules; }
+ const FoldingRules& GetFoldingRules() const { return *folding_rules_; }
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
@@ -159,10 +175,10 @@
IRContext* context_;
// Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|.
- ConstantFoldingRules const_folding_rules;
+ std::unique_ptr<ConstantFoldingRules> const_folding_rules_;
// Folding rules used by |FoldInstruction|.
- FoldingRules folding_rules;
+ std::unique_ptr<FoldingRules> folding_rules_;
};
} // namespace opt
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 18d5149..a125dda 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -2200,7 +2200,7 @@
} // namespace
-FoldingRules::FoldingRules() {
+void FoldingRules::AddFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
@@ -2216,8 +2216,6 @@
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
- rules_[SpvOpExtInst].push_back(RedundantFMix());
-
rules_[SpvOpFAdd].push_back(RedundantFAdd());
rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
@@ -2271,6 +2269,15 @@
rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
+
+ FeatureManager* feature_manager = context_->get_feature_mgr();
+ // Add rules for GLSLstd450
+ uint32_t ext_inst_glslstd450_id =
+ feature_manager->GetExtInstImportId_GLSLstd450();
+ if (ext_inst_glslstd450_id != 0) {
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
+ RedundantFMix());
+ }
}
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/folding_rules.h b/source/opt/folding_rules.h
index 33fdbff..f1a8639 100644
--- a/source/opt/folding_rules.h
+++ b/source/opt/folding_rules.h
@@ -58,19 +58,58 @@
class FoldingRules {
public:
- FoldingRules();
+ using FoldingRuleSet = std::vector<FoldingRule>;
- const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) const {
- auto it = rules_.find(opcode);
- if (it != rules_.end()) {
- return it->second;
+ explicit FoldingRules(IRContext* ctx) : context_(ctx) {}
+ virtual ~FoldingRules() = default;
+
+ const FoldingRuleSet& GetRulesForInstruction(Instruction* inst) const {
+ if (inst->opcode() != SpvOpExtInst) {
+ auto it = rules_.find(inst->opcode());
+ if (it != rules_.end()) {
+ return it->second;
+ }
+ } else {
+ uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
+ uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
+ auto it = ext_rules_.find({ext_inst_id, ext_opcode});
+ if (it != ext_rules_.end()) {
+ return it->second;
+ }
}
return empty_vector_;
}
+ IRContext* context() { return context_; }
+
+ // Adds the folding rules for the object.
+ virtual void AddFoldingRules();
+
+ protected:
+ // The folding rules for core instructions.
+ std::unordered_map<uint32_t, FoldingRuleSet> rules_;
+
+ // The folding rules for extended instructions.
+ struct Key {
+ uint32_t instruction_set;
+ uint32_t opcode;
+ };
+
+ friend bool operator<(const Key& a, const Key& b) {
+ if (a.instruction_set < b.instruction_set) {
+ return true;
+ }
+ if (a.instruction_set > b.instruction_set) {
+ return false;
+ }
+ return a.opcode < b.opcode;
+ }
+
+ std::map<Key, FoldingRuleSet> ext_rules_;
+
private:
- std::unordered_map<uint32_t, std::vector<FoldingRule>> rules_;
- std::vector<FoldingRule> empty_vector_;
+ IRContext* context_;
+ FoldingRuleSet empty_vector_;
};
} // namespace opt
diff --git a/source/opt/graphics_robust_access_pass.cpp b/source/opt/graphics_robust_access_pass.cpp
new file mode 100644
index 0000000..dd60e8c
--- /dev/null
+++ b/source/opt/graphics_robust_access_pass.cpp
@@ -0,0 +1,968 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This pass injects code in a graphics shader to implement guarantees
+// satisfying Vulkan's robustBufferAcces rules. Robust access rules permit
+// an out-of-bounds access to be redirected to an access of the same type
+// (load, store, etc.) but within the same root object.
+//
+// We assume baseline functionality in Vulkan, i.e. the module uses
+// logical addressing mode, without VK_KHR_variable_pointers.
+//
+// - Logical addressing mode implies:
+// - Each root pointer (a pointer that exists other than by the
+// execution of a shader instruction) is the result of an OpVariable.
+//
+// - Instructions that result in pointers are:
+// OpVariable
+// OpAccessChain
+// OpInBoundsAccessChain
+// OpFunctionParameter
+// OpImageTexelPointer
+// OpCopyObject
+//
+// - Instructions that use a pointer are:
+// OpLoad
+// OpStore
+// OpAccessChain
+// OpInBoundsAccessChain
+// OpFunctionCall
+// OpImageTexelPointer
+// OpCopyMemory
+// OpCopyObject
+// all OpAtomic* instructions
+//
+// We classify pointer-users into:
+// - Accesses:
+// - OpLoad
+// - OpStore
+// - OpAtomic*
+// - OpCopyMemory
+//
+// - Address calculations:
+// - OpAccessChain
+// - OpInBoundsAccessChain
+//
+// - Pass-through:
+// - OpFunctionCall
+// - OpFunctionParameter
+// - OpCopyObject
+//
+// The strategy is:
+//
+// - Handle only logical addressing mode. In particular, don't handle a module
+// if it uses one of the variable-pointers capabilities.
+//
+// - Don't handle modules using capability RuntimeDescriptorArrayEXT. So the
+// only runtime arrays are those that are the last member in a
+// Block-decorated struct. This allows us to feasibly/easily compute the
+// length of the runtime array. See below.
+//
+// - The memory locations accessed by OpLoad, OpStore, OpCopyMemory, and
+// OpAtomic* are determined by their pointer parameter or parameters.
+// Pointers are always (correctly) typed and so the address and number of
+// consecutive locations are fully determined by the pointer.
+//
+// - A pointer value orginates as one of few cases:
+//
+// - OpVariable for an interface object or an array of them: image,
+// buffer (UBO or SSBO), sampler, sampled-image, push-constant, input
+// variable, output variable. The execution environment is responsible for
+// allocating the correct amount of storage for these, and for ensuring
+// each resource bound to such a variable is big enough to contain the
+// SPIR-V pointee type of the variable.
+//
+// - OpVariable for a non-interface object. These are variables in
+// Workgroup, Private, and Function storage classes. The compiler ensures
+// the underlying allocation is big enough to store the entire SPIR-V
+// pointee type of the variable.
+//
+// - An OpFunctionParameter. This always maps to a pointer parameter to an
+// OpFunctionCall.
+//
+// - In logical addressing mode, these are severely limited:
+// "Any pointer operand to an OpFunctionCall must be:
+// - a memory object declaration, or
+// - a pointer to an element in an array that is a memory object
+// declaration, where the element type is OpTypeSampler or OpTypeImage"
+//
+// - This has an important simplifying consequence:
+//
+// - When looking for a pointer to the structure containing a runtime
+// array, you begin with a pointer to the runtime array and trace
+// backward in the function. You never have to trace back beyond
+// your function call boundary. So you can't take a partial access
+// chain into an SSBO, then pass that pointer into a function. So
+// we don't resort to using fat pointers to compute array length.
+// We can trace back to a pointer to the containing structure,
+// and use that in an OpArrayLength instruction. (The structure type
+// gives us the member index of the runtime array.)
+//
+// - Otherwise, the pointer type fully encodes the range of valid
+// addresses. In particular, the type of a pointer to an aggregate
+// value fully encodes the range of indices when indexing into
+// that aggregate.
+//
+// - The pointer is the result of an access chain instruction. We clamp
+// indices contributing to address calculations. As noted above, the
+// valid ranges are either bound by the length of a runtime array, or
+// by the type of the base pointer. The length of a runtime array is
+// the result of an OpArrayLength instruction acting on the pointer of
+// the containing structure as noted above.
+//
+// - TODO(dneto): OpImageTexelPointer:
+// - Clamp coordinate to the image size returned by OpImageQuerySize
+// - If multi-sampled, clamp the sample index to the count returned by
+// OpImageQuerySamples.
+// - If not multi-sampled, set the sample index to 0.
+//
+// - Rely on the external validator to check that pointers are only
+// used by the instructions as above.
+//
+// - Handles OpTypeRuntimeArray
+// Track pointer back to original resource (pointer to struct), so we can
+// query the runtime array size.
+//
+
+#include "graphics_robust_access_pass.h"
+
+#include <algorithm>
+#include <cstring>
+#include <functional>
+#include <initializer_list>
+#include <utility>
+
+#include "constants.h"
+#include "def_use_manager.h"
+#include "function.h"
+#include "ir_context.h"
+#include "module.h"
+#include "pass.h"
+#include "source/diagnostic.h"
+#include "source/util/make_unique.h"
+#include "spirv-tools/libspirv.h"
+#include "spirv/unified1/GLSL.std.450.h"
+#include "spirv/unified1/spirv.h"
+#include "type_manager.h"
+#include "types.h"
+
+namespace spvtools {
+namespace opt {
+
+using opt::BasicBlock;
+using opt::Instruction;
+using opt::Operand;
+using spvtools::MakeUnique;
+
+GraphicsRobustAccessPass::GraphicsRobustAccessPass() : module_status_() {}
+
+Pass::Status GraphicsRobustAccessPass::Process() {
+ module_status_ = PerModuleState();
+
+ ProcessCurrentModule();
+
+ auto result = module_status_.failed
+ ? Status::Failure
+ : (module_status_.modified ? Status::SuccessWithChange
+ : Status::SuccessWithoutChange);
+
+ return result;
+}
+
+spvtools::DiagnosticStream GraphicsRobustAccessPass::Fail() {
+ module_status_.failed = true;
+ // We don't really have a position, and we'll ignore the result.
+ return std::move(
+ spvtools::DiagnosticStream({}, consumer(), "", SPV_ERROR_INVALID_BINARY)
+ << name() << ": ");
+}
+
+spv_result_t GraphicsRobustAccessPass::IsCompatibleModule() {
+ auto* feature_mgr = context()->get_feature_mgr();
+ if (!feature_mgr->HasCapability(SpvCapabilityShader))
+ return Fail() << "Can only process Shader modules";
+ if (feature_mgr->HasCapability(SpvCapabilityVariablePointers))
+ return Fail() << "Can't process modules with VariablePointers capability";
+ if (feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer))
+ return Fail() << "Can't process modules with VariablePointersStorageBuffer "
+ "capability";
+ if (feature_mgr->HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
+ // These have a RuntimeArray outside of Block-decorated struct. There
+ // is no way to compute the array length from within SPIR-V.
+ return Fail() << "Can't process modules with RuntimeDescriptorArrayEXT "
+ "capability";
+ }
+
+ {
+ auto* inst = context()->module()->GetMemoryModel();
+ const auto addressing_model = inst->GetSingleWordOperand(0);
+ if (addressing_model != SpvAddressingModelLogical)
+ return Fail() << "Addressing model must be Logical. Found "
+ << inst->PrettyPrint();
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t GraphicsRobustAccessPass::ProcessCurrentModule() {
+ auto err = IsCompatibleModule();
+ if (err != SPV_SUCCESS) return err;
+
+ ProcessFunction fn = [this](opt::Function* f) { return ProcessAFunction(f); };
+ module_status_.modified |= context()->ProcessReachableCallTree(fn);
+
+ // Need something here. It's the price we pay for easier failure paths.
+ return SPV_SUCCESS;
+}
+
+bool GraphicsRobustAccessPass::ProcessAFunction(opt::Function* function) {
+ // Ensure that all pointers computed inside a function are within bounds.
+ // Find the access chains in this block before trying to modify them.
+ std::vector<Instruction*> access_chains;
+ std::vector<Instruction*> image_texel_pointers;
+ for (auto& block : *function) {
+ for (auto& inst : block) {
+ switch (inst.opcode()) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ access_chains.push_back(&inst);
+ break;
+ case SpvOpImageTexelPointer:
+ image_texel_pointers.push_back(&inst);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ for (auto* inst : access_chains) {
+ ClampIndicesForAccessChain(inst);
+ }
+
+ for (auto* inst : image_texel_pointers) {
+ if (SPV_SUCCESS != ClampCoordinateForImageTexelPointer(inst)) break;
+ }
+ return module_status_.modified;
+}
+
+void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
+ Instruction* access_chain) {
+ Instruction& inst = *access_chain;
+
+ auto* constant_mgr = context()->get_constant_mgr();
+ auto* def_use_mgr = context()->get_def_use_mgr();
+ auto* type_mgr = context()->get_type_mgr();
+
+ // Replaces one of the OpAccessChain index operands with a new value.
+ // Updates def-use analysis.
+ auto replace_index = [&inst, def_use_mgr](uint32_t operand_index,
+ Instruction* new_value) {
+ inst.SetOperand(operand_index, {new_value->result_id()});
+ def_use_mgr->AnalyzeInstUse(&inst);
+ };
+
+ // Replaces one of the OpAccesssChain index operands with a clamped value.
+ // Replace the operand at |operand_index| with the value computed from
+ // unsigned_clamp(%old_value, %min_value, %max_value). It also analyzes
+ // the new instruction and records that them module is modified.
+ auto clamp_index = [&inst, this, &replace_index](
+ uint32_t operand_index, Instruction* old_value,
+ Instruction* min_value, Instruction* max_value) {
+ auto* clamp_inst = MakeClampInst(old_value, min_value, max_value, &inst);
+ replace_index(operand_index, clamp_inst);
+ };
+
+ // Ensures the specified index of access chain |inst| has a value that is
+ // at most |count| - 1. If the index is already a constant value less than
+ // |count| then no change is made.
+ auto clamp_to_literal_count = [&inst, this, &constant_mgr, &type_mgr,
+ &replace_index, &clamp_index](
+ uint32_t operand_index, uint64_t count) {
+ Instruction* index_inst =
+ this->GetDef(inst.GetSingleWordOperand(operand_index));
+ const auto* index_type =
+ type_mgr->GetType(index_inst->type_id())->AsInteger();
+ assert(index_type);
+ if (count <= 1) {
+ // Replace the index with 0.
+ replace_index(operand_index, GetValueForType(0, index_type));
+ return;
+ }
+
+ const auto index_width = index_type->width();
+
+ // If the index is a constant then |index_constant| will not be a null
+ // pointer. (If index is an |OpConstantNull| then it |index_constant| will
+ // not be a null pointer.) Since access chain indices must be scalar
+ // integers, this can't be a spec constant.
+ if (auto* index_constant = constant_mgr->GetConstantFromInst(index_inst)) {
+ auto* int_index_constant = index_constant->AsIntConstant();
+ int64_t value = 0;
+ // OpAccessChain indices are treated as signed. So get the signed
+ // constant value here.
+ if (index_width <= 32) {
+ value = int64_t(int_index_constant->GetS32BitValue());
+ } else if (index_width <= 64) {
+ value = int_index_constant->GetS64BitValue();
+ } else {
+ this->Fail() << "Can't handle indices wider than 64 bits, found "
+ "constant index with "
+ << index_type->width() << "bits";
+ return;
+ }
+ if (value < 0) {
+ replace_index(operand_index, GetValueForType(0, index_type));
+ } else if (uint64_t(value) < count) {
+ // Nothing to do.
+ return;
+ } else {
+ // Replace with count - 1.
+ assert(count > 0); // Already took care of this case above.
+ replace_index(operand_index, GetValueForType(count - 1, index_type));
+ }
+ } else {
+ // Generate a clamp instruction.
+
+ // Compute the bit width of a viable type to hold (count-1).
+ const auto maxval = count - 1;
+ const auto* maxval_type = index_type;
+ // Look for a bit width, up to 64 bits wide, to fit maxval.
+ uint32_t maxval_width = index_width;
+ while ((maxval_width < 64) && (0 != (maxval >> maxval_width))) {
+ maxval_width *= 2;
+ }
+ // Widen the index value if necessary
+ if (maxval_width > index_width) {
+ // Find the wider type. We only need this case if a constant (array)
+ // bound is too big. This never requires us to *add* a capability
+ // declaration for Int64 because the existence of the array bound would
+ // already have required that declaration.
+ index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
+ index_inst, &inst);
+ maxval_type = type_mgr->GetType(index_inst->type_id())->AsInteger();
+ }
+ // Finally, clamp the index.
+ clamp_index(operand_index, index_inst, GetValueForType(0, maxval_type),
+ GetValueForType(maxval, maxval_type));
+ }
+ };
+
+ // Ensures the specified index of access chain |inst| has a value that is at
+ // most the value of |count_inst| minus 1, where |count_inst| is treated as an
+ // unsigned integer.
+ auto clamp_to_count = [&inst, this, &constant_mgr, &clamp_to_literal_count,
+ &clamp_index, &type_mgr](uint32_t operand_index,
+ Instruction* count_inst) {
+ Instruction* index_inst =
+ this->GetDef(inst.GetSingleWordOperand(operand_index));
+ const auto* index_type =
+ type_mgr->GetType(index_inst->type_id())->AsInteger();
+ const auto* count_type =
+ type_mgr->GetType(count_inst->type_id())->AsInteger();
+ assert(index_type);
+ if (const auto* count_constant =
+ constant_mgr->GetConstantFromInst(count_inst)) {
+ uint64_t value = 0;
+ const auto width = count_constant->type()->AsInteger()->width();
+ if (width <= 32) {
+ value = count_constant->AsIntConstant()->GetU32BitValue();
+ } else if (width <= 64) {
+ value = count_constant->AsIntConstant()->GetU64BitValue();
+ } else {
+ this->Fail() << "Can't handle indices wider than 64 bits, found "
+ "constant index with "
+ << index_type->width() << "bits";
+ return;
+ }
+ clamp_to_literal_count(operand_index, value);
+ } else {
+ // Widen them to the same width.
+ const auto index_width = index_type->width();
+ const auto count_width = count_type->width();
+ const auto target_width = std::max(index_width, count_width);
+ // UConvert requires the result type to have 0 signedness. So enforce
+ // that here.
+ auto* wider_type = index_width < count_width ? count_type : index_type;
+ if (index_type->width() < target_width) {
+ // Access chain indices are treated as signed integers.
+ index_inst = WidenInteger(true, target_width, index_inst, &inst);
+ } else if (count_type->width() < target_width) {
+ // Assume type sizes are treated as unsigned.
+ count_inst = WidenInteger(false, target_width, count_inst, &inst);
+ }
+ // Compute count - 1.
+ // It doesn't matter if 1 is signed or unsigned.
+ auto* one = GetValueForType(1, wider_type);
+ auto* count_minus_1 = InsertInst(
+ &inst, SpvOpISub, type_mgr->GetId(wider_type), TakeNextId(),
+ {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
+ clamp_index(operand_index, index_inst, GetValueForType(0, wider_type),
+ count_minus_1);
+ }
+ };
+
+ const Instruction* base_inst = GetDef(inst.GetSingleWordInOperand(0));
+ const Instruction* base_type = GetDef(base_inst->type_id());
+ Instruction* pointee_type = GetDef(base_type->GetSingleWordInOperand(1));
+
+ // Walk the indices from earliest to latest, replacing indices with a
+ // clamped value, and updating the pointee_type. The order matters for
+ // the case when we have to compute the length of a runtime array. In
+ // that the algorithm relies on the fact that that the earlier indices
+ // have already been clamped.
+ const uint32_t num_operands = inst.NumOperands();
+ for (uint32_t idx = 3; !module_status_.failed && idx < num_operands; ++idx) {
+ const uint32_t index_id = inst.GetSingleWordOperand(idx);
+ Instruction* index_inst = GetDef(index_id);
+
+ switch (pointee_type->opcode()) {
+ case SpvOpTypeMatrix: // Use column count
+ case SpvOpTypeVector: // Use component count
+ {
+ const uint32_t count = pointee_type->GetSingleWordOperand(2);
+ clamp_to_literal_count(idx, count);
+ pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
+ } break;
+
+ case SpvOpTypeArray: {
+ // The array length can be a spec constant, so go through the general
+ // case.
+ Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
+ clamp_to_count(idx, array_len);
+ pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
+ } break;
+
+ case SpvOpTypeStruct: {
+ // SPIR-V requires the index to be an OpConstant.
+ // We need to know the index literal value so we can compute the next
+ // pointee type.
+ if (index_inst->opcode() != SpvOpConstant ||
+ !constant_mgr->GetConstantFromInst(index_inst)
+ ->type()
+ ->AsInteger()) {
+ Fail() << "Member index into struct is not a constant integer: "
+ << index_inst->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
+ << "\nin access chain: "
+ << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ return;
+ }
+ const auto num_members = pointee_type->NumInOperands();
+ const auto* index_constant =
+ constant_mgr->GetConstantFromInst(index_inst);
+ // Get the sign-extended value, since access index is always treated as
+ // signed.
+ const auto index_value = index_constant->GetSignExtendedValue();
+ if (index_value < 0 || index_value >= num_members) {
+ Fail() << "Member index " << index_value
+ << " is out of bounds for struct type: "
+ << pointee_type->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
+ << "\nin access chain: "
+ << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ return;
+ }
+ pointee_type = GetDef(pointee_type->GetSingleWordInOperand(
+ static_cast<uint32_t>(index_value)));
+ // No need to clamp this index. We just checked that it's valid.
+ } break;
+
+ case SpvOpTypeRuntimeArray: {
+ auto* array_len = MakeRuntimeArrayLengthInst(&inst, idx);
+ if (!array_len) { // We've already signaled an error.
+ return;
+ }
+ clamp_to_count(idx, array_len);
+ pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
+ } break;
+
+ default:
+ Fail() << " Unhandled pointee type for access chain "
+ << pointee_type->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ }
+ }
+}
+
+uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
+ if (module_status_.glsl_insts_id == 0) {
+ // This string serves double-duty as raw data for a string and for a vector
+ // of 32-bit words
+ const char glsl[] = "GLSL.std.450\0\0\0\0";
+ const size_t glsl_str_byte_len = 16;
+ // Use an existing import if we can.
+ for (auto& inst : context()->module()->ext_inst_imports()) {
+ const auto& name_words = inst.GetInOperand(0).words;
+ if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()),
+ glsl, glsl_str_byte_len)) {
+ module_status_.glsl_insts_id = inst.result_id();
+ }
+ }
+ if (module_status_.glsl_insts_id == 0) {
+ // Make a new import instruction.
+ module_status_.glsl_insts_id = TakeNextId();
+ std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t));
+ std::memcpy(words.data(), glsl, glsl_str_byte_len);
+ auto import_inst = MakeUnique<Instruction>(
+ context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id,
+ std::initializer_list<Operand>{
+ Operand{SPV_OPERAND_TYPE_LITERAL_STRING, std::move(words)}});
+ Instruction* inst = import_inst.get();
+ context()->module()->AddExtInstImport(std::move(import_inst));
+ module_status_.modified = true;
+ context()->AnalyzeDefUse(inst);
+ // Reanalyze the feature list, since we added an extended instruction
+ // set improt.
+ context()->get_feature_mgr()->Analyze(context()->module());
+ }
+ }
+ return module_status_.glsl_insts_id;
+}
+
+opt::Instruction* opt::GraphicsRobustAccessPass::GetValueForType(
+ uint64_t value, const analysis::Integer* type) {
+ auto* mgr = context()->get_constant_mgr();
+ assert(type->width() <= 64);
+ std::vector<uint32_t> words;
+ words.push_back(uint32_t(value));
+ if (type->width() > 32) {
+ words.push_back(uint32_t(value >> 32u));
+ }
+ const auto* constant = mgr->GetConstant(type, words);
+ return mgr->GetDefiningInstruction(
+ constant, context()->get_type_mgr()->GetTypeInstruction(type));
+}
+
+opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
+ bool sign_extend, uint32_t bit_width, Instruction* value,
+ Instruction* before_inst) {
+ analysis::Integer unsigned_type_for_query(bit_width, false);
+ auto* type_mgr = context()->get_type_mgr();
+ auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
+ auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
+ auto conversion_id = TakeNextId();
+ auto* conversion = InsertInst(
+ before_inst, (sign_extend ? SpvOpSConvert : SpvOpUConvert), type_id,
+ conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
+ return conversion;
+}
+
+Instruction* GraphicsRobustAccessPass::MakeClampInst(Instruction* x,
+ Instruction* min,
+ Instruction* max,
+ Instruction* where) {
+ // Get IDs of instructions we'll be referencing. Evaluate them before calling
+ // the function so we force a deterministic ordering in case both of them need
+ // to take a new ID.
+ const uint32_t glsl_insts_id = GetGlslInsts();
+ uint32_t clamp_id = TakeNextId();
+ assert(x->type_id() == min->type_id());
+ assert(x->type_id() == max->type_id());
+ auto* clamp_inst = InsertInst(
+ where, SpvOpExtInst, x->type_id(), clamp_id,
+ {
+ {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
+ {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450UClamp}},
+ {SPV_OPERAND_TYPE_ID, {x->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {min->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {max->result_id()}},
+ });
+ return clamp_inst;
+}
+
+Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
+ Instruction* access_chain, uint32_t operand_index) {
+ // The Index parameter to the access chain at |operand_index| is indexing
+ // *into* the runtime-array. To get the number of elements in the runtime
+ // array we need a pointer to the Block-decorated struct that contains the
+ // runtime array. So conceptually we have to go 2 steps backward in the
+ // access chain. The two steps backward might forces us to traverse backward
+ // across multiple dominating instructions.
+ auto* type_mgr = context()->get_type_mgr();
+
+ // How many access chain indices do we have to unwind to find the pointer
+ // to the struct containing the runtime array?
+ uint32_t steps_remaining = 2;
+ // Find or create an instruction computing the pointer to the structure
+ // containing the runtime array.
+ // Walk backward through pointer address calculations until we either get
+ // to exactly the right base pointer, or to an access chain instruction
+ // that we can replicate but truncate to compute the address of the right
+ // struct.
+ Instruction* current_access_chain = access_chain;
+ Instruction* pointer_to_containing_struct = nullptr;
+ while (steps_remaining > 0) {
+ switch (current_access_chain->opcode()) {
+ case SpvOpCopyObject:
+ // Whoops. Walk right through this one.
+ current_access_chain =
+ GetDef(current_access_chain->GetSingleWordInOperand(0));
+ break;
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain: {
+ const int first_index_operand = 3;
+ // How many indices in this access chain contribute to getting us
+ // to an element in the runtime array?
+ const auto num_contributing_indices =
+ current_access_chain == access_chain
+ ? operand_index - (first_index_operand - 1)
+ : current_access_chain->NumInOperands() - 1 /* skip the base */;
+ Instruction* base =
+ GetDef(current_access_chain->GetSingleWordInOperand(0));
+ if (num_contributing_indices == steps_remaining) {
+ // The base pointer points to the structure.
+ pointer_to_containing_struct = base;
+ steps_remaining = 0;
+ break;
+ } else if (num_contributing_indices < steps_remaining) {
+ // Peel off the index and keep going backward.
+ steps_remaining -= num_contributing_indices;
+ current_access_chain = base;
+ } else {
+ // This access chain has more indices than needed. Generate a new
+ // access chain instruction, but truncating the list of indices.
+ const int base_operand = 2;
+ // We'll use the base pointer and the indices up to but not including
+ // the one indexing into the runtime array.
+ Instruction::OperandList ops;
+ // Use the base pointer
+ ops.push_back(current_access_chain->GetOperand(base_operand));
+ const uint32_t num_indices_to_keep =
+ num_contributing_indices - steps_remaining - 1;
+ for (uint32_t i = 0; i <= num_indices_to_keep; i++) {
+ ops.push_back(
+ current_access_chain->GetOperand(first_index_operand + i));
+ }
+ // Compute the type of the result of the new access chain. Start at
+ // the base and walk the indices in a forward direction.
+ auto* constant_mgr = context()->get_constant_mgr();
+ std::vector<uint32_t> indices_for_type;
+ for (uint32_t i = 0; i < ops.size() - 1; i++) {
+ uint32_t index_for_type_calculation = 0;
+ Instruction* index =
+ GetDef(current_access_chain->GetSingleWordOperand(
+ first_index_operand + i));
+ if (auto* index_constant =
+ constant_mgr->GetConstantFromInst(index)) {
+ // We only need 32 bits. For the type calculation, it's sufficient
+ // to take the zero-extended value. It only matters for the struct
+ // case, and struct member indices are unsigned.
+ index_for_type_calculation =
+ uint32_t(index_constant->GetZeroExtendedValue());
+ } else {
+ // Indexing into a variably-sized thing like an array. Use 0.
+ index_for_type_calculation = 0;
+ }
+ indices_for_type.push_back(index_for_type_calculation);
+ }
+ auto* base_ptr_type = type_mgr->GetType(base->type_id())->AsPointer();
+ auto* base_pointee_type = base_ptr_type->pointee_type();
+ auto* new_access_chain_result_pointee_type =
+ type_mgr->GetMemberType(base_pointee_type, indices_for_type);
+ const uint32_t new_access_chain_type_id = type_mgr->FindPointerToType(
+ type_mgr->GetId(new_access_chain_result_pointee_type),
+ base_ptr_type->storage_class());
+
+ // Create the instruction and insert it.
+ const auto new_access_chain_id = TakeNextId();
+ auto* new_access_chain =
+ InsertInst(current_access_chain, current_access_chain->opcode(),
+ new_access_chain_type_id, new_access_chain_id, ops);
+ pointer_to_containing_struct = new_access_chain;
+ steps_remaining = 0;
+ break;
+ }
+ } break;
+ default:
+ Fail() << "Unhandled access chain in logical addressing mode passes "
+ "through "
+ << current_access_chain->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET |
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ return nullptr;
+ }
+ }
+ assert(pointer_to_containing_struct);
+ auto* pointee_type =
+ type_mgr->GetType(pointer_to_containing_struct->type_id())
+ ->AsPointer()
+ ->pointee_type();
+
+ auto* struct_type = pointee_type->AsStruct();
+ const uint32_t member_index_of_runtime_array =
+ uint32_t(struct_type->element_types().size() - 1);
+ // Create the length-of-array instruction before the original access chain,
+ // but after the generation of the pointer to the struct.
+ const auto array_len_id = TakeNextId();
+ analysis::Integer uint_type_for_query(32, false);
+ auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
+ auto* array_len = InsertInst(
+ access_chain, SpvOpArrayLength, type_mgr->GetId(uint_type), array_len_id,
+ {{SPV_OPERAND_TYPE_ID, {pointer_to_containing_struct->result_id()}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index_of_runtime_array}}});
+ return array_len;
+}
+
+spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
+ opt::Instruction* image_texel_pointer) {
+ // TODO(dneto): Write tests for this code.
+ return SPV_SUCCESS;
+
+ // Example:
+ // %texel_ptr = OpImageTexelPointer %texel_ptr_type %image_ptr %coord
+ // %sample
+ //
+ // We want to clamp %coord components between vector-0 and the result
+ // of OpImageQuerySize acting on the underlying image. So insert:
+ // %image = OpLoad %image_type %image_ptr
+ // %query_size = OpImageQuerySize %query_size_type %image
+ //
+ // For a multi-sampled image, %sample is the sample index, and we need
+ // to clamp it between zero and the number of samples in the image.
+ // %sample_count = OpImageQuerySamples %uint %image
+ // %max_sample_index = OpISub %uint %sample_count %uint_1
+ // For non-multi-sampled images, the sample index must be constant zero.
+
+ auto* def_use_mgr = context()->get_def_use_mgr();
+ auto* type_mgr = context()->get_type_mgr();
+ auto* constant_mgr = context()->get_constant_mgr();
+
+ auto* image_ptr = GetDef(image_texel_pointer->GetSingleWordInOperand(0));
+ auto* image_ptr_type = GetDef(image_ptr->type_id());
+ auto image_type_id = image_ptr_type->GetSingleWordInOperand(1);
+ auto* image_type = GetDef(image_type_id);
+ auto* coord = GetDef(image_texel_pointer->GetSingleWordInOperand(1));
+ auto* samples = GetDef(image_texel_pointer->GetSingleWordInOperand(2));
+
+ // We will modify the module, at least by adding image query instructions.
+ module_status_.modified = true;
+
+ // Declare the ImageQuery capability if the module doesn't already have it.
+ auto* feature_mgr = context()->get_feature_mgr();
+ if (!feature_mgr->HasCapability(SpvCapabilityImageQuery)) {
+ auto cap = MakeUnique<Instruction>(
+ context(), SpvOpCapability, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityImageQuery}}});
+ def_use_mgr->AnalyzeInstDefUse(cap.get());
+ context()->AddCapability(std::move(cap));
+ feature_mgr->Analyze(context()->module());
+ }
+
+ // OpImageTexelPointer is used to translate a coordinate and sample index
+ // into an address for use with an atomic operation. That is, it may only
+ // used with what Vulkan calls a "storage image"
+ // (OpTypeImage parameter Sampled=2).
+ // Note: A storage image never has a level-of-detail associated with it.
+
+ // Constraints on the sample id:
+ // - Only 2D images can be multi-sampled: OpTypeImage parameter MS=1
+ // only if Dim=2D.
+ // - Non-multi-sampled images (OpTypeImage parameter MS=0) must use
+ // sample ID to a constant 0.
+
+ // The coordinate is treated as unsigned, and should be clamped against the
+ // image "size", returned by OpImageQuerySize. (Note: OpImageQuerySizeLod
+ // is only usable with a sampled image, i.e. its image type has Sampled=1).
+
+ // Determine the result type for the OpImageQuerySize.
+ // For non-arrayed images:
+ // non-Cube:
+ // - Always the same as the coordinate type
+ // Cube:
+ // - Use all but the last component of the coordinate (which is the face
+ // index from 0 to 5).
+ // For arrayed images (in Vulkan the Dim is 1D, 2D, or Cube):
+ // non-Cube:
+ // - A vector with the components in the coordinate, and one more for
+ // the layer index.
+ // Cube:
+ // - The same as the coordinate type: 3-element integer vector.
+ // - The third component from the size query is the layer count.
+ // - The third component in the texel pointer calculation is
+ // 6 * layer + face, where 0 <= face < 6.
+ // Cube: Use all but the last component of the coordinate (which is the face
+ // index from 0 to 5).
+ const auto dim = SpvDim(image_type->GetSingleWordInOperand(1));
+ const bool arrayed = image_type->GetSingleWordInOperand(3) == 1;
+ const bool multisampled = image_type->GetSingleWordInOperand(4) != 0;
+ const auto query_num_components = [dim, arrayed, this]() -> int {
+ const int arrayness_bonus = arrayed ? 1 : 0;
+ int num_coords = 0;
+ switch (dim) {
+ case SpvDimBuffer:
+ case SpvDim1D:
+ num_coords = 1;
+ break;
+ case SpvDimCube:
+ // For cube, we need bounds for x, y, but not face.
+ case SpvDimRect:
+ case SpvDim2D:
+ num_coords = 2;
+ break;
+ case SpvDim3D:
+ num_coords = 3;
+ break;
+ case SpvDimSubpassData:
+ case SpvDimMax:
+ return Fail() << "Invalid image dimension for OpImageTexelPointer: "
+ << int(dim);
+ break;
+ }
+ return num_coords + arrayness_bonus;
+ }();
+ const auto* coord_component_type = [type_mgr, coord]() {
+ const analysis::Type* coord_type = type_mgr->GetType(coord->type_id());
+ if (auto* vector_type = coord_type->AsVector()) {
+ return vector_type->element_type()->AsInteger();
+ }
+ return coord_type->AsInteger();
+ }();
+ // For now, only handle 32-bit case for coordinates.
+ if (!coord_component_type) {
+ return Fail() << " Coordinates for OpImageTexelPointer are not integral: "
+ << image_texel_pointer->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ }
+ if (coord_component_type->width() != 32) {
+ return Fail() << " Expected OpImageTexelPointer coordinate components to "
+ "be 32-bits wide. They are "
+ << coord_component_type->width() << " bits. "
+ << image_texel_pointer->PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ }
+ const auto* query_size_type =
+ [type_mgr, coord_component_type,
+ query_num_components]() -> const analysis::Type* {
+ if (query_num_components == 1) return coord_component_type;
+ analysis::Vector proposed(coord_component_type, query_num_components);
+ return type_mgr->GetRegisteredType(&proposed);
+ }();
+
+ const uint32_t image_id = TakeNextId();
+ auto* image =
+ InsertInst(image_texel_pointer, SpvOpLoad, image_type_id, image_id,
+ {{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
+
+ const uint32_t query_size_id = TakeNextId();
+ auto* query_size =
+ InsertInst(image_texel_pointer, SpvOpImageQuerySize,
+ type_mgr->GetTypeInstruction(query_size_type), query_size_id,
+ {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
+
+ auto* component_1 = constant_mgr->GetConstant(coord_component_type, {1});
+ const uint32_t component_1_id =
+ constant_mgr->GetDefiningInstruction(component_1)->result_id();
+ auto* component_0 = constant_mgr->GetConstant(coord_component_type, {0});
+ const uint32_t component_0_id =
+ constant_mgr->GetDefiningInstruction(component_0)->result_id();
+
+ // If the image is a cube array, then the last component of the queried
+ // size is the layer count. In the query, we have to accomodate folding
+ // in the face index ranging from 0 through 5. The inclusive upper bound
+ // on the third coordinate therefore is multiplied by 6.
+ auto* query_size_including_faces = query_size;
+ if (arrayed && (dim == SpvDimCube)) {
+ // Multiply the last coordinate by 6.
+ auto* component_6 = constant_mgr->GetConstant(coord_component_type, {6});
+ const uint32_t component_6_id =
+ constant_mgr->GetDefiningInstruction(component_6)->result_id();
+ assert(query_num_components == 3);
+ auto* multiplicand = constant_mgr->GetConstant(
+ query_size_type, {component_1_id, component_1_id, component_6_id});
+ auto* multiplicand_inst =
+ constant_mgr->GetDefiningInstruction(multiplicand);
+ const auto query_size_including_faces_id = TakeNextId();
+ query_size_including_faces = InsertInst(
+ image_texel_pointer, SpvOpIMul,
+ type_mgr->GetTypeInstruction(query_size_type),
+ query_size_including_faces_id,
+ {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
+ {SPV_OPERAND_TYPE_ID, {multiplicand_inst->result_id()}}});
+ }
+
+ // Make a coordinate-type with all 1 components.
+ auto* coordinate_1 =
+ query_num_components == 1
+ ? component_1
+ : constant_mgr->GetConstant(
+ query_size_type,
+ std::vector<uint32_t>(query_num_components, component_1_id));
+ // Make a coordinate-type with all 1 components.
+ auto* coordinate_0 =
+ query_num_components == 0
+ ? component_0
+ : constant_mgr->GetConstant(
+ query_size_type,
+ std::vector<uint32_t>(query_num_components, component_0_id));
+
+ const uint32_t query_max_including_faces_id = TakeNextId();
+ auto* query_max_including_faces = InsertInst(
+ image_texel_pointer, SpvOpISub,
+ type_mgr->GetTypeInstruction(query_size_type),
+ query_max_including_faces_id,
+ {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
+ {SPV_OPERAND_TYPE_ID,
+ {constant_mgr->GetDefiningInstruction(coordinate_1)->result_id()}}});
+
+ // Clamp the coordinate
+ auto* clamp_coord =
+ MakeClampInst(coord, constant_mgr->GetDefiningInstruction(coordinate_0),
+ query_max_including_faces, image_texel_pointer);
+ image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
+
+ // Clamp the sample index
+ if (multisampled) {
+ // Get the sample count via OpImageQuerySamples
+ const auto query_samples_id = TakeNextId();
+ auto* query_samples = InsertInst(
+ image_texel_pointer, SpvOpImageQuerySamples,
+ constant_mgr->GetDefiningInstruction(component_0)->type_id(),
+ query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
+
+ const auto max_samples_id = TakeNextId();
+ auto* max_samples = InsertInst(image_texel_pointer, SpvOpImageQuerySamples,
+ query_samples->type_id(), max_samples_id,
+ {{SPV_OPERAND_TYPE_ID, {query_samples_id}},
+ {SPV_OPERAND_TYPE_ID, {component_1_id}}});
+
+ auto* clamp_samples = MakeClampInst(
+ samples, constant_mgr->GetDefiningInstruction(coordinate_0),
+ max_samples, image_texel_pointer);
+ image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
+
+ } else {
+ // Just replace it with 0. Don't even check what was there before.
+ image_texel_pointer->SetInOperand(2, {component_0_id});
+ }
+
+ def_use_mgr->AnalyzeInstUse(image_texel_pointer);
+
+ return SPV_SUCCESS;
+}
+
+opt::Instruction* GraphicsRobustAccessPass::InsertInst(
+ opt::Instruction* where_inst, SpvOp opcode, uint32_t type_id,
+ uint32_t result_id, const Instruction::OperandList& operands) {
+ module_status_.modified = true;
+ auto* result = where_inst->InsertBefore(
+ MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));
+ context()->get_def_use_mgr()->AnalyzeInstDefUse(result);
+ auto* basic_block = context()->get_instr_block(where_inst);
+ context()->set_instr_block(result, basic_block);
+ return result;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/graphics_robust_access_pass.h b/source/opt/graphics_robust_access_pass.h
new file mode 100644
index 0000000..215cbf1
--- /dev/null
+++ b/source/opt/graphics_robust_access_pass.h
@@ -0,0 +1,142 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_GRAPHICS_ROBUST_ACCESS_PASS_H_
+#define SOURCE_OPT_GRAPHICS_ROBUST_ACCESS_PASS_H_
+
+#include <map>
+#include <unordered_map>
+
+#include "source/diagnostic.h"
+
+#include "constants.h"
+#include "def_use_manager.h"
+#include "instruction.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class GraphicsRobustAccessPass : public Pass {
+ public:
+ GraphicsRobustAccessPass();
+ const char* name() const override { return "graphics-robust-access"; }
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes |
+ IRContext::kAnalysisIdToFuncMapping;
+ }
+
+ private:
+ // Records failure for the current module, and returns a stream
+ // that can be used to provide user error information to the message
+ // consumer.
+ spvtools::DiagnosticStream Fail();
+
+ // Returns SPV_SUCCESS if this pass can correctly process the module.
+ // Otherwise logs a message and returns a failure code.
+ spv_result_t IsCompatibleModule();
+
+ // Transform the current module, if possible. Failure and modification
+ // status is recorded in the |_| member. On failure, error information is
+ // posted to the message consumer. The return value has no significance.
+ spv_result_t ProcessCurrentModule();
+
+ // Process the given function. Updates the state value |_|. Returns true
+ // if the module was modified.
+ bool ProcessAFunction(opt::Function*);
+
+ // Clamps indices in the OpAccessChain or OpInBoundsAccessChain instruction
+ // |access_chain|. Inserts instructions before the given instruction. Updates
+ // analyses and records that the module is modified.
+ void ClampIndicesForAccessChain(Instruction* access_chain);
+
+ // Returns the id of the instruction importing the "GLSL.std.450" extended
+ // instruction set. If it does not yet exist, the import instruction is
+ // created and inserted into the module, and updates |_.modified| and
+ // |_.glsl_insts_id|.
+ uint32_t GetGlslInsts();
+
+ // Returns an instruction which is constant with the given value of the given
+ // type. Ignores any value bits beyond the width of the type.
+ Instruction* GetValueForType(uint64_t value, const analysis::Integer* type);
+
+ // Converts an integer value to an unsigned wider integer type, using either
+ // sign extension or zero extension. The new instruction is inserted
+ // immediately before |before_inst|, and is analyzed for definitions and uses.
+ // Returns the newly inserted instruction. Assumes the |value| is an integer
+ // scalar of a narrower type than |bitwidth| bits.
+ Instruction* WidenInteger(bool sign_extend, uint32_t bitwidth,
+ Instruction* value, Instruction* before_inst);
+
+ // Returns a new instruction that invokes the UClamp GLSL.std.450 extended
+ // instruction with the three given operands. That is, the result of the
+ // instruction is:
+ // - |min| if |x| is unsigned-less than |min|
+ // - |max| if |x| is unsigned-more than |max|
+ // - |x| otherwise.
+ // We assume that |min| is unsigned-less-or-equal to |max|, and that the
+ // operands all have the same scalar integer type. The instruction is
+ // inserted before |where|.
+ opt::Instruction* MakeClampInst(Instruction* x, Instruction* min,
+ Instruction* max, Instruction* where);
+
+ // Returns a new instruction which evaluates to the length the runtime array
+ // referenced by the access chain at the specfied index. The instruction is
+ // inserted before the access chain instruction. Returns a null pointer in
+ // some cases if assumptions are violated (rather than asserting out).
+ opt::Instruction* MakeRuntimeArrayLengthInst(Instruction* access_chain,
+ uint32_t operand_index);
+
+ // Clamps the coordinate for an OpImageTexelPointer so it stays within
+ // the bounds of the size of the image. Updates analyses and records that
+ // the module is modified. Returns a status code to indicate success
+ // or failure. If assumptions are not met, returns an error status code
+ // and emits a diagnostic.
+ spv_result_t ClampCoordinateForImageTexelPointer(opt::Instruction* itp);
+
+ // Gets the instruction that defines the given id.
+ opt::Instruction* GetDef(uint32_t id) {
+ return context()->get_def_use_mgr()->GetDef(id);
+ }
+
+ // Returns a new instruction inserted before |where_inst|, and created from
+ // the remaining arguments. Registers the definitions and uses of the new
+ // instruction and also records its block.
+ opt::Instruction* InsertInst(opt::Instruction* where_inst, SpvOp opcode,
+ uint32_t type_id, uint32_t result_id,
+ const Instruction::OperandList& operands);
+
+ // State required for the current module.
+ struct PerModuleState {
+ // This pass modified the module.
+ bool modified = false;
+ // True if there is an error processing the current module, e.g. if
+ // preconditions are not met.
+ bool failed = false;
+ // The id of the GLSL.std.450 extended instruction set. Zero if it does
+ // not exist.
+ uint32_t glsl_insts_id = 0;
+ } module_status_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_GRAPHICS_ROBUST_ACCESS_PASS_H_
diff --git a/source/opt/inst_buff_addr_check_pass.cpp b/source/opt/inst_buff_addr_check_pass.cpp
new file mode 100644
index 0000000..03221ef
--- /dev/null
+++ b/source/opt/inst_buff_addr_check_pass.cpp
@@ -0,0 +1,427 @@
+// Copyright (c) 2019 The Khronos Group Inc.
+// Copyright (c) 2019 Valve Corporation
+// Copyright (c) 2019 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "inst_buff_addr_check_pass.h"
+
+namespace spvtools {
+namespace opt {
+
+uint32_t InstBuffAddrCheckPass::CloneOriginalReference(
+ Instruction* ref_inst, InstructionBuilder* builder) {
+ // Clone original ref with new result id (if load)
+ assert(
+ (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) &&
+ "unexpected ref");
+ std::unique_ptr<Instruction> new_ref_inst(ref_inst->Clone(context()));
+ uint32_t ref_result_id = ref_inst->result_id();
+ uint32_t new_ref_id = 0;
+ if (ref_result_id != 0) {
+ new_ref_id = TakeNextId();
+ new_ref_inst->SetResultId(new_ref_id);
+ }
+ // Register new reference and add to new block
+ Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
+ uid2offset_[added_inst->unique_id()] = uid2offset_[ref_inst->unique_id()];
+ if (new_ref_id != 0)
+ get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
+ return new_ref_id;
+}
+
+bool InstBuffAddrCheckPass::IsPhysicalBuffAddrReference(Instruction* ref_inst) {
+ if (ref_inst->opcode() != SpvOpLoad && ref_inst->opcode() != SpvOpStore)
+ return false;
+ uint32_t ptr_id = ref_inst->GetSingleWordInOperand(0);
+ analysis::DefUseManager* du_mgr = get_def_use_mgr();
+ Instruction* ptr_inst = du_mgr->GetDef(ptr_id);
+ if (ptr_inst->opcode() != SpvOpAccessChain) return false;
+ uint32_t ptr_ty_id = ptr_inst->type_id();
+ Instruction* ptr_ty_inst = du_mgr->GetDef(ptr_ty_id);
+ if (ptr_ty_inst->GetSingleWordInOperand(0) !=
+ SpvStorageClassPhysicalStorageBufferEXT)
+ return false;
+ return true;
+}
+
+// TODO(greg-lunarg): Refactor with InstBindlessCheckPass::GenCheckCode() ??
+void InstBuffAddrCheckPass::GenCheckCode(
+ uint32_t check_id, uint32_t error_id, uint32_t ref_uptr_id,
+ uint32_t stage_idx, Instruction* ref_inst,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+ BasicBlock* back_blk_ptr = &*new_blocks->back();
+ InstructionBuilder builder(
+ context(), back_blk_ptr,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ // Gen conditional branch on check_id. Valid branch generates original
+ // reference. Invalid generates debug output and zero result (if needed).
+ uint32_t merge_blk_id = TakeNextId();
+ uint32_t valid_blk_id = TakeNextId();
+ uint32_t invalid_blk_id = TakeNextId();
+ std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
+ std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
+ std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
+ (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
+ merge_blk_id, SpvSelectionControlMaskNone);
+ // Gen valid branch
+ std::unique_ptr<BasicBlock> new_blk_ptr(
+ new BasicBlock(std::move(valid_label)));
+ builder.SetInsertPoint(&*new_blk_ptr);
+ uint32_t new_ref_id = CloneOriginalReference(ref_inst, &builder);
+ (void)builder.AddBranch(merge_blk_id);
+ new_blocks->push_back(std::move(new_blk_ptr));
+ // Gen invalid block
+ new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
+ builder.SetInsertPoint(&*new_blk_ptr);
+ // Convert uptr from uint64 to 2 uint32
+ Instruction* lo_uptr_inst =
+ builder.AddUnaryOp(GetUintId(), SpvOpUConvert, ref_uptr_id);
+ Instruction* rshift_uptr_inst =
+ builder.AddBinaryOp(GetUint64Id(), SpvOpShiftRightLogical, ref_uptr_id,
+ builder.GetUintConstantId(32));
+ Instruction* hi_uptr_inst = builder.AddUnaryOp(GetUintId(), SpvOpUConvert,
+ rshift_uptr_inst->result_id());
+ GenDebugStreamWrite(
+ uid2offset_[ref_inst->unique_id()], stage_idx,
+ {error_id, lo_uptr_inst->result_id(), hi_uptr_inst->result_id()},
+ &builder);
+ // Gen zero for invalid reference
+ uint32_t ref_type_id = ref_inst->type_id();
+ (void)builder.AddBranch(merge_blk_id);
+ new_blocks->push_back(std::move(new_blk_ptr));
+ // Gen merge block
+ new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
+ builder.SetInsertPoint(&*new_blk_ptr);
+ // Gen phi of new reference and zero, if necessary, and replace the
+ // result id of the original reference with that of the Phi. Kill original
+ // reference.
+ if (new_ref_id != 0) {
+ Instruction* phi_inst = builder.AddPhi(
+ ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id),
+ invalid_blk_id});
+ context()->ReplaceAllUsesWith(ref_inst->result_id(), phi_inst->result_id());
+ }
+ new_blocks->push_back(std::move(new_blk_ptr));
+ context()->KillInst(ref_inst);
+}
+
+uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
+ Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
+ switch (type_inst->opcode()) {
+ case SpvOpTypeFloat:
+ case SpvOpTypeInt:
+ return type_inst->GetSingleWordInOperand(0) / 8u;
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ return type_inst->GetSingleWordInOperand(1) *
+ GetTypeLength(type_inst->GetSingleWordInOperand(0));
+ case SpvOpTypePointer:
+ assert(type_inst->GetSingleWordInOperand(0) ==
+ SpvStorageClassPhysicalStorageBufferEXT &&
+ "unexpected pointer type");
+ return 8u;
+ default:
+ assert(false && "unexpected buffer reference type");
+ return 0;
+ }
+}
+
+void InstBuffAddrCheckPass::AddParam(uint32_t type_id,
+ std::vector<uint32_t>* param_vec,
+ std::unique_ptr<Function>* input_func) {
+ uint32_t pid = TakeNextId();
+ param_vec->push_back(pid);
+ std::unique_ptr<Instruction> param_inst(new Instruction(
+ get_module()->context(), SpvOpFunctionParameter, type_id, pid, {}));
+ get_def_use_mgr()->AnalyzeInstDefUse(&*param_inst);
+ (*input_func)->AddParameter(std::move(param_inst));
+}
+
+uint32_t InstBuffAddrCheckPass::GetSearchAndTestFuncId() {
+ if (search_test_func_id_ == 0) {
+ // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
+ // which searches input buffer for buffer which most likely contains the
+ // pointer value |ref_ptr| and verifies that the entire reference of
+ // length |len| bytes is contained in the buffer.
+ search_test_func_id_ = TakeNextId();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ std::vector<const analysis::Type*> param_types = {
+ type_mgr->GetType(GetUint64Id()), type_mgr->GetType(GetUintId())};
+ analysis::Function func_ty(type_mgr->GetType(GetBoolId()), param_types);
+ analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
+ std::unique_ptr<Instruction> func_inst(
+ new Instruction(get_module()->context(), SpvOpFunction, GetBoolId(),
+ search_test_func_id_,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {SpvFunctionControlMaskNone}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
+ {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
+ get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
+ std::unique_ptr<Function> input_func =
+ MakeUnique<Function>(std::move(func_inst));
+ std::vector<uint32_t> param_vec;
+ // Add ref_ptr and length parameters
+ AddParam(GetUint64Id(), ¶m_vec, &input_func);
+ AddParam(GetUintId(), ¶m_vec, &input_func);
+ // Empty first block.
+ uint32_t first_blk_id = TakeNextId();
+ std::unique_ptr<Instruction> first_blk_label(NewLabel(first_blk_id));
+ std::unique_ptr<BasicBlock> first_blk_ptr =
+ MakeUnique<BasicBlock>(std::move(first_blk_label));
+ InstructionBuilder builder(
+ context(), &*first_blk_ptr,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ uint32_t hdr_blk_id = TakeNextId();
+ // Branch to search loop header
+ std::unique_ptr<Instruction> hdr_blk_label(NewLabel(hdr_blk_id));
+ (void)builder.AddInstruction(MakeUnique<Instruction>(
+ context(), SpvOpBranch, 0, 0,
+ std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {hdr_blk_id}}}));
+ first_blk_ptr->SetParent(&*input_func);
+ input_func->AddBasicBlock(std::move(first_blk_ptr));
+ // Linear search loop header block
+ // TODO(greg-lunarg): Implement binary search
+ std::unique_ptr<BasicBlock> hdr_blk_ptr =
+ MakeUnique<BasicBlock>(std::move(hdr_blk_label));
+ builder.SetInsertPoint(&*hdr_blk_ptr);
+ // Phi for search index. Starts with 1.
+ uint32_t cont_blk_id = TakeNextId();
+ std::unique_ptr<Instruction> cont_blk_label(NewLabel(cont_blk_id));
+ // Deal with def-use cycle caused by search loop index computation.
+ // Create Add and Phi instructions first, then do Def analysis on Add.
+ // Add Phi and Add instructions and do Use analysis later.
+ uint32_t idx_phi_id = TakeNextId();
+ uint32_t idx_inc_id = TakeNextId();
+ std::unique_ptr<Instruction> idx_inc_inst(new Instruction(
+ context(), SpvOpIAdd, GetUintId(), idx_inc_id,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_phi_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
+ {builder.GetUintConstantId(1u)}}}));
+ std::unique_ptr<Instruction> idx_phi_inst(new Instruction(
+ context(), SpvOpPhi, GetUintId(), idx_phi_id,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
+ {builder.GetUintConstantId(1u)}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {first_blk_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_inc_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
+ get_def_use_mgr()->AnalyzeInstDef(&*idx_inc_inst);
+ // Add (previously created) search index phi
+ (void)builder.AddInstruction(std::move(idx_phi_inst));
+ // LoopMerge
+ uint32_t bound_test_blk_id = TakeNextId();
+ std::unique_ptr<Instruction> bound_test_blk_label(
+ NewLabel(bound_test_blk_id));
+ (void)builder.AddInstruction(MakeUnique<Instruction>(
+ context(), SpvOpLoopMerge, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {bound_test_blk_id}},
+ {SPV_OPERAND_TYPE_ID, {cont_blk_id}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {SpvLoopControlMaskNone}}}));
+ // Branch to continue/work block
+ (void)builder.AddInstruction(MakeUnique<Instruction>(
+ context(), SpvOpBranch, 0, 0,
+ std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
+ hdr_blk_ptr->SetParent(&*input_func);
+ input_func->AddBasicBlock(std::move(hdr_blk_ptr));
+ // Continue/Work Block. Read next buffer pointer and break if greater
+ // than ref_ptr arg.
+ std::unique_ptr<BasicBlock> cont_blk_ptr =
+ MakeUnique<BasicBlock>(std::move(cont_blk_label));
+ builder.SetInsertPoint(&*cont_blk_ptr);
+ // Add (previously created) search index increment now.
+ (void)builder.AddInstruction(std::move(idx_inc_inst));
+ // Load next buffer address from debug input buffer
+ uint32_t ibuf_id = GetInputBufferId();
+ uint32_t ibuf_ptr_id = GetInputBufferPtrId();
+ Instruction* uptr_ac_inst = builder.AddTernaryOp(
+ ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
+ builder.GetUintConstantId(kDebugInputDataOffset), idx_inc_id);
+ uint32_t ibuf_type_id = GetInputBufferTypeId();
+ Instruction* uptr_load_inst =
+ builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, uptr_ac_inst->result_id());
+ // If loaded address greater than ref_ptr arg, break, else branch back to
+ // loop header
+ Instruction* uptr_test_inst =
+ builder.AddBinaryOp(GetBoolId(), SpvOpUGreaterThan,
+ uptr_load_inst->result_id(), param_vec[0]);
+ (void)builder.AddConditionalBranch(uptr_test_inst->result_id(),
+ bound_test_blk_id, hdr_blk_id,
+ kInvalidId, SpvSelectionControlMaskNone);
+ cont_blk_ptr->SetParent(&*input_func);
+ input_func->AddBasicBlock(std::move(cont_blk_ptr));
+ // Bounds test block. Read length of selected buffer and test that
+ // all len arg bytes are in buffer.
+ std::unique_ptr<BasicBlock> bound_test_blk_ptr =
+ MakeUnique<BasicBlock>(std::move(bound_test_blk_label));
+ builder.SetInsertPoint(&*bound_test_blk_ptr);
+ // Decrement index to point to previous/candidate buffer address
+ Instruction* cand_idx_inst = builder.AddBinaryOp(
+ GetUintId(), SpvOpISub, idx_inc_id, builder.GetUintConstantId(1u));
+ // Load candidate buffer address
+ Instruction* cand_ac_inst =
+ builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
+ builder.GetUintConstantId(kDebugInputDataOffset),
+ cand_idx_inst->result_id());
+ Instruction* cand_load_inst =
+ builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, cand_ac_inst->result_id());
+ // Compute offset of ref_ptr from candidate buffer address
+ Instruction* offset_inst = builder.AddBinaryOp(
+ ibuf_type_id, SpvOpISub, param_vec[0], cand_load_inst->result_id());
+ // Convert ref length to uint64
+ Instruction* ref_len_64_inst =
+ builder.AddUnaryOp(ibuf_type_id, SpvOpUConvert, param_vec[1]);
+ // Add ref length to ref offset to compute end of reference
+ Instruction* ref_end_inst =
+ builder.AddBinaryOp(ibuf_type_id, SpvOpIAdd, offset_inst->result_id(),
+ ref_len_64_inst->result_id());
+ // Load starting index of lengths in input buffer and convert to uint32
+ Instruction* len_start_ac_inst =
+ builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
+ builder.GetUintConstantId(kDebugInputDataOffset),
+ builder.GetUintConstantId(0u));
+ Instruction* len_start_load_inst = builder.AddUnaryOp(
+ ibuf_type_id, SpvOpLoad, len_start_ac_inst->result_id());
+ Instruction* len_start_32_inst = builder.AddUnaryOp(
+ GetUintId(), SpvOpUConvert, len_start_load_inst->result_id());
+ // Decrement search index to get candidate buffer length index
+ Instruction* cand_len_idx_inst =
+ builder.AddBinaryOp(GetUintId(), SpvOpISub, cand_idx_inst->result_id(),
+ builder.GetUintConstantId(1u));
+ // Add candidate length index to start index
+ Instruction* len_idx_inst = builder.AddBinaryOp(
+ GetUintId(), SpvOpIAdd, cand_len_idx_inst->result_id(),
+ len_start_32_inst->result_id());
+ // Load candidate buffer length
+ Instruction* len_ac_inst =
+ builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
+ builder.GetUintConstantId(kDebugInputDataOffset),
+ len_idx_inst->result_id());
+ Instruction* len_load_inst =
+ builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, len_ac_inst->result_id());
+ // Test if reference end within candidate buffer length
+ Instruction* len_test_inst = builder.AddBinaryOp(
+ GetBoolId(), SpvOpULessThanEqual, ref_end_inst->result_id(),
+ len_load_inst->result_id());
+ // Return test result
+ (void)builder.AddInstruction(MakeUnique<Instruction>(
+ context(), SpvOpReturnValue, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {len_test_inst->result_id()}}}));
+ // Close block
+ bound_test_blk_ptr->SetParent(&*input_func);
+ input_func->AddBasicBlock(std::move(bound_test_blk_ptr));
+ // Close function and add function to module
+ std::unique_ptr<Instruction> func_end_inst(
+ new Instruction(get_module()->context(), SpvOpFunctionEnd, 0, 0, {}));
+ get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst);
+ input_func->SetFunctionEnd(std::move(func_end_inst));
+ context()->AddFunction(std::move(input_func));
+ }
+ return search_test_func_id_;
+}
+
+uint32_t InstBuffAddrCheckPass::GenSearchAndTest(Instruction* ref_inst,
+ InstructionBuilder* builder,
+ uint32_t* ref_uptr_id) {
+ // Enable Int64 if necessary
+ if (!get_feature_mgr()->HasCapability(SpvCapabilityInt64)) {
+ std::unique_ptr<Instruction> cap_int64_inst(new Instruction(
+ context(), SpvOpCapability, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityInt64}}}));
+ get_def_use_mgr()->AnalyzeInstDefUse(&*cap_int64_inst);
+ context()->AddCapability(std::move(cap_int64_inst));
+ }
+ // Convert reference pointer to uint64
+ uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
+ Instruction* ref_uptr_inst =
+ builder->AddUnaryOp(GetUint64Id(), SpvOpConvertPtrToU, ref_ptr_id);
+ *ref_uptr_id = ref_uptr_inst->result_id();
+ // Compute reference length in bytes
+ analysis::DefUseManager* du_mgr = get_def_use_mgr();
+ Instruction* ref_ptr_inst = du_mgr->GetDef(ref_ptr_id);
+ uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
+ Instruction* ref_ptr_ty_inst = du_mgr->GetDef(ref_ptr_ty_id);
+ uint32_t ref_len = GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
+ uint32_t ref_len_id = builder->GetUintConstantId(ref_len);
+ // Gen call to search and test function
+ const std::vector<uint32_t> args = {GetSearchAndTestFuncId(), *ref_uptr_id,
+ ref_len_id};
+ Instruction* call_inst =
+ builder->AddNaryOp(GetBoolId(), SpvOpFunctionCall, args);
+ uint32_t retval = call_inst->result_id();
+ return retval;
+}
+
+void InstBuffAddrCheckPass::GenBuffAddrCheckCode(
+ BasicBlock::iterator ref_inst_itr,
+ UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+ // Look for reference through indexed descriptor. If found, analyze and
+ // save components. If not, return.
+ Instruction* ref_inst = &*ref_inst_itr;
+ if (!IsPhysicalBuffAddrReference(ref_inst)) return;
+ // Move original block's preceding instructions into first new block
+ std::unique_ptr<BasicBlock> new_blk_ptr;
+ MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
+ InstructionBuilder builder(
+ context(), &*new_blk_ptr,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ new_blocks->push_back(std::move(new_blk_ptr));
+ uint32_t error_id = builder.GetUintConstantId(kInstErrorBuffAddrUnallocRef);
+ // Generate code to do search and test if all bytes of reference
+ // are within a listed buffer. Return reference pointer converted to uint64.
+ uint32_t ref_uptr_id;
+ uint32_t valid_id = GenSearchAndTest(ref_inst, &builder, &ref_uptr_id);
+ // Generate test of search results with true branch
+ // being full reference and false branch being debug output and zero
+ // for the referenced value.
+ GenCheckCode(valid_id, error_id, ref_uptr_id, stage_idx, ref_inst,
+ new_blocks);
+ // Move original block's remaining code into remainder/merge block and add
+ // to new blocks
+ BasicBlock* back_blk_ptr = &*new_blocks->back();
+ MovePostludeCode(ref_block_itr, back_blk_ptr);
+}
+
+void InstBuffAddrCheckPass::InitInstBuffAddrCheck() {
+ // Initialize base class
+ InitializeInstrument();
+ // Initialize class
+ search_test_func_id_ = 0;
+}
+
+Pass::Status InstBuffAddrCheckPass::ProcessImpl() {
+ // Perform bindless bounds check on each entry point function in module
+ InstProcessFunction pfn =
+ [this](BasicBlock::iterator ref_inst_itr,
+ UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+ return GenBuffAddrCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
+ new_blocks);
+ };
+ bool modified = InstProcessEntryPointCallTree(pfn);
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+Pass::Status InstBuffAddrCheckPass::Process() {
+ if (!get_feature_mgr()->HasCapability(
+ SpvCapabilityPhysicalStorageBufferAddressesEXT))
+ return Status::SuccessWithoutChange;
+ InitInstBuffAddrCheck();
+ return ProcessImpl();
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/inst_buff_addr_check_pass.h b/source/opt/inst_buff_addr_check_pass.h
new file mode 100644
index 0000000..9ad3528
--- /dev/null
+++ b/source/opt/inst_buff_addr_check_pass.h
@@ -0,0 +1,133 @@
+// Copyright (c) 2019 The Khronos Group Inc.
+// Copyright (c) 2019 Valve Corporation
+// Copyright (c) 2019 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_INST_BUFFER_ADDRESS_PASS_H_
+#define LIBSPIRV_OPT_INST_BUFFER_ADDRESS_PASS_H_
+
+#include "instrument_pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// This class/pass is designed to support the GPU-assisted validation layer of
+// the Buffer Device Address (BDA) extension in
+// https://github.com/KhronosGroup/Vulkan-ValidationLayers. The internal and
+// external design of this class may change as the layer evolves.
+class InstBuffAddrCheckPass : public InstrumentPass {
+ public:
+ // For test harness only
+ InstBuffAddrCheckPass()
+ : InstrumentPass(7, 23, kInstValidationIdBuffAddr, 1) {}
+ // For all other interfaces
+ InstBuffAddrCheckPass(uint32_t desc_set, uint32_t shader_id, uint32_t version)
+ : InstrumentPass(desc_set, shader_id, kInstValidationIdBuffAddr,
+ version) {}
+
+ ~InstBuffAddrCheckPass() override = default;
+
+ // See optimizer.hpp for pass user documentation.
+ Status Process() override;
+
+ const char* name() const override { return "inst-bindless-check-pass"; }
+
+ private:
+ // Return byte length of type |type_id|. Must be int, float, vector, matrix
+ // or physical pointer.
+ uint32_t GetTypeLength(uint32_t type_id);
+
+ // Add |type_id| param to |input_func| and add id to |param_vec|.
+ void AddParam(uint32_t type_id, std::vector<uint32_t>* param_vec,
+ std::unique_ptr<Function>* input_func);
+
+ // Return id for search and test function. Generate it if not already gen'd.
+ uint32_t GetSearchAndTestFuncId();
+
+ // Generate code into |builder| to do search of the BDA debug input buffer
+ // for the buffer used by |ref_inst| and test that all bytes of reference
+ // are within the buffer. Returns id of boolean value which is true if
+ // search and test is successful, false otherwise.
+ uint32_t GenSearchAndTest(Instruction* ref_inst, InstructionBuilder* builder,
+ uint32_t* ref_uptr_id);
+
+ // This function does checking instrumentation on a single
+ // instruction which references through a physical storage buffer address.
+ // GenBuffAddrCheckCode generates code that checks that all bytes that
+ // are referenced fall within a buffer that was queried via
+ // the Vulkan API call vkGetBufferDeviceAddressEXT().
+ //
+ // The function is designed to be passed to
+ // InstrumentPass::InstProcessEntryPointCallTree(), which applies the
+ // function to each instruction in a module and replaces the instruction
+ // with instrumented code if warranted.
+ //
+ // If |ref_inst_itr| is a physical storage buffer reference, return in
+ // |new_blocks| the result of instrumenting it with validation code within
+ // its block at |ref_block_itr|. The validation code first executes a check
+ // for the specific condition called for. If the check passes, it executes
+ // the remainder of the reference, otherwise writes a record to the debug
+ // output buffer stream including |function_idx, instruction_idx, stage_idx|
+ // and replaces the reference with the null value of the original type. The
+ // block at |ref_block_itr| can just be replaced with the blocks in
+ // |new_blocks|, which will contain at least two blocks. The last block will
+ // comprise all instructions following |ref_inst_itr|,
+ // preceded by a phi instruction if needed.
+ //
+ // This instrumentation function utilizes GenDebugStreamWrite() to write its
+ // error records. The validation-specific part of the error record will
+ // have the format:
+ //
+ // Validation Error Code (=kInstErrorBuffAddr)
+ // Buffer Address (lowest 32 bits)
+ // Buffer Address (highest 32 bits)
+ //
+ void GenBuffAddrCheckCode(
+ BasicBlock::iterator ref_inst_itr,
+ UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+
+ // Return true if |ref_inst| is a physical buffer address reference, false
+ // otherwise.
+ bool IsPhysicalBuffAddrReference(Instruction* ref_inst);
+
+ // Clone original reference |ref_inst| into |builder| and return id of result
+ uint32_t CloneOriginalReference(Instruction* ref_inst,
+ InstructionBuilder* builder);
+
+ // Generate instrumentation code for boolean test result |check_id|,
+ // adding new blocks to |new_blocks|. Generate conditional branch to valid
+ // or invalid reference blocks. Generate valid reference block which does
+ // original reference |ref_inst|. Then generate invalid reference block which
+ // writes debug error output utilizing |ref_inst|, |error_id| and
+ // |stage_idx|. Generate merge block for valid and invalid reference blocks.
+ // Kill original reference.
+ void GenCheckCode(uint32_t check_id, uint32_t error_id, uint32_t length_id,
+ uint32_t stage_idx, Instruction* ref_inst,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+
+ // Initialize state for instrumenting physical buffer address checking
+ void InitInstBuffAddrCheck();
+
+ // Apply GenBuffAddrCheckCode to every instruction in module.
+ Pass::Status ProcessImpl();
+
+ // Id of search and test function, if already gen'd, else zero.
+ uint32_t search_test_func_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_INST_BUFFER_ADDRESS_PASS_H_
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp
index ec73640..49f9142 100644
--- a/source/opt/instruction.cpp
+++ b/source/opt/instruction.cpp
@@ -469,7 +469,7 @@
bool Instruction::IsFoldable() const {
return IsFoldableByFoldScalar() ||
- context()->get_instruction_folder().HasConstFoldingRule(opcode());
+ context()->get_instruction_folder().HasConstFoldingRule(this);
}
bool Instruction::IsFoldableByFoldScalar() const {
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index d507d6c..d1c4ce1 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -399,7 +399,7 @@
inline bool operator<(const Instruction&) const;
// Takes ownership of the instruction owned by |i| and inserts it immediately
- // before |this|. Returns the insterted instruction.
+ // before |this|. Returns the inserted instruction.
Instruction* InsertBefore(std::unique_ptr<Instruction>&& i);
// Takes ownership of the instructions in |list| and inserts them in order
// immediately before |this|. Returns the first inserted instruction.
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp
index 6645a2e..246cdbb 100644
--- a/source/opt/instrument_pass.cpp
+++ b/source/opt/instrument_pass.cpp
@@ -107,7 +107,7 @@
builder->AddBinaryOp(GetUintId(), SpvOpIAdd, base_offset_id,
builder->GetUintConstantId(field_offset));
uint32_t buf_id = GetOutputBufferId();
- uint32_t buf_uint_ptr_id = GetBufferUintPtrId();
+ uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
Instruction* achain_inst =
builder->AddTernaryOp(buf_uint_ptr_id, SpvOpAccessChain, buf_id,
builder->GetUintConstantId(kDebugOutputDataOffset),
@@ -243,10 +243,13 @@
kInstTessEvalOutPrimitiveId, base_offset_id, builder);
uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(SpvBuiltInTessCoord), builder);
+ Instruction* uvec3_cast_inst =
+ builder->AddUnaryOp(GetVec3UintId(), SpvOpBitcast, load_id);
+ uint32_t uvec3_cast_id = uvec3_cast_inst->result_id();
Instruction* u_inst = builder->AddIdLiteralOp(
- GetUintId(), SpvOpCompositeExtract, load_id, 0);
+ GetUintId(), SpvOpCompositeExtract, uvec3_cast_id, 0);
Instruction* v_inst = builder->AddIdLiteralOp(
- GetUintId(), SpvOpCompositeExtract, load_id, 1);
+ GetUintId(), SpvOpCompositeExtract, uvec3_cast_id, 1);
GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordU,
u_inst->result_id(), builder);
GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordV,
@@ -370,19 +373,33 @@
});
}
-// Return id for output buffer uint ptr type
-uint32_t InstrumentPass::GetBufferUintPtrId() {
- if (buffer_uint_ptr_id_ == 0) {
- buffer_uint_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
+uint32_t InstrumentPass::GetOutputBufferPtrId() {
+ if (output_buffer_ptr_id_ == 0) {
+ output_buffer_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
GetUintId(), SpvStorageClassStorageBuffer);
}
- return buffer_uint_ptr_id_;
+ return output_buffer_ptr_id_;
+}
+
+uint32_t InstrumentPass::GetInputBufferTypeId() {
+ return (validation_id_ == kInstValidationIdBuffAddr) ? GetUint64Id()
+ : GetUintId();
+}
+
+uint32_t InstrumentPass::GetInputBufferPtrId() {
+ if (input_buffer_ptr_id_ == 0) {
+ input_buffer_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
+ GetInputBufferTypeId(), SpvStorageClassStorageBuffer);
+ }
+ return input_buffer_ptr_id_;
}
uint32_t InstrumentPass::GetOutputBufferBinding() {
switch (validation_id_) {
case kInstValidationIdBindless:
return kDebugOutputBindingStream;
+ case kInstValidationIdBuffAddr:
+ return kDebugOutputBindingStream;
default:
assert(false && "unexpected validation id");
}
@@ -393,20 +410,24 @@
switch (validation_id_) {
case kInstValidationIdBindless:
return kDebugInputBindingBindless;
+ case kInstValidationIdBuffAddr:
+ return kDebugInputBindingBuffAddr;
default:
assert(false && "unexpected validation id");
}
return 0;
}
-analysis::Type* InstrumentPass::GetUintRuntimeArrayType(
- analysis::DecorationManager* deco_mgr, analysis::TypeManager* type_mgr) {
- if (uint_rarr_ty_ == nullptr) {
- analysis::Integer uint_ty(32, false);
+analysis::Type* InstrumentPass::GetUintXRuntimeArrayType(
+ uint32_t width, analysis::Type** rarr_ty) {
+ if (*rarr_ty == nullptr) {
+ analysis::DecorationManager* deco_mgr = get_decoration_mgr();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::Integer uint_ty(width, false);
analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
analysis::RuntimeArray uint_rarr_ty_tmp(reg_uint_ty);
- uint_rarr_ty_ = type_mgr->GetRegisteredType(&uint_rarr_ty_tmp);
- uint32_t uint_arr_ty_id = type_mgr->GetTypeInstruction(uint_rarr_ty_);
+ *rarr_ty = type_mgr->GetRegisteredType(&uint_rarr_ty_tmp);
+ uint32_t uint_arr_ty_id = type_mgr->GetTypeInstruction(*rarr_ty);
// By the Vulkan spec, a pre-existing RuntimeArray of uint must be part of
// a block, and will therefore be decorated with an ArrayStride. Therefore
// the undecorated type returned here will not be pre-existing and can
@@ -415,23 +436,22 @@
// invalidated after this pass.
assert(context()->get_def_use_mgr()->NumUses(uint_arr_ty_id) == 0 &&
"used RuntimeArray type returned");
- deco_mgr->AddDecorationVal(uint_arr_ty_id, SpvDecorationArrayStride, 4u);
+ deco_mgr->AddDecorationVal(uint_arr_ty_id, SpvDecorationArrayStride,
+ width / 8u);
}
- return uint_rarr_ty_;
+ return *rarr_ty;
+}
+
+analysis::Type* InstrumentPass::GetUintRuntimeArrayType(uint32_t width) {
+ analysis::Type** rarr_ty =
+ (width == 64) ? &uint64_rarr_ty_ : &uint32_rarr_ty_;
+ return GetUintXRuntimeArrayType(width, rarr_ty);
}
void InstrumentPass::AddStorageBufferExt() {
if (storage_buffer_ext_defined_) return;
if (!get_feature_mgr()->HasExtension(kSPV_KHR_storage_buffer_storage_class)) {
- const std::string ext_name("SPV_KHR_storage_buffer_storage_class");
- const auto num_chars = ext_name.size();
- // Compute num words, accommodate the terminating null character.
- const auto num_words = (num_chars + 1 + 3) / 4;
- std::vector<uint32_t> ext_words(num_words, 0u);
- std::memcpy(ext_words.data(), ext_name.data(), num_chars);
- context()->AddExtension(std::unique_ptr<Instruction>(
- new Instruction(context(), SpvOpExtension, 0u, 0u,
- {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
+ context()->AddExtension("SPV_KHR_storage_buffer_storage_class");
}
storage_buffer_ext_defined_ = true;
}
@@ -442,8 +462,7 @@
// If not created yet, create one
analysis::DecorationManager* deco_mgr = get_decoration_mgr();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::Type* reg_uint_rarr_ty =
- GetUintRuntimeArrayType(deco_mgr, type_mgr);
+ analysis::Type* reg_uint_rarr_ty = GetUintRuntimeArrayType(32);
analysis::Integer uint_ty(32, false);
analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
analysis::Struct buf_ty({reg_uint_ty, reg_uint_rarr_ty});
@@ -491,8 +510,8 @@
// If not created yet, create one
analysis::DecorationManager* deco_mgr = get_decoration_mgr();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::Type* reg_uint_rarr_ty =
- GetUintRuntimeArrayType(deco_mgr, type_mgr);
+ uint32_t width = (validation_id_ == kInstValidationIdBuffAddr) ? 64u : 32u;
+ analysis::Type* reg_uint_rarr_ty = GetUintRuntimeArrayType(width);
analysis::Struct buf_ty({reg_uint_rarr_ty});
analysis::Type* reg_buf_ty = type_mgr->GetRegisteredType(&buf_ty);
uint32_t ibufTyId = type_mgr->GetTypeInstruction(reg_buf_ty);
@@ -552,18 +571,36 @@
return uint_id_;
}
-uint32_t InstrumentPass::GetVec4UintId() {
- if (v4uint_id_ == 0) {
+uint32_t InstrumentPass::GetUint64Id() {
+ if (uint64_id_ == 0) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::Integer uint_ty(32, false);
- analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
- analysis::Vector v4uint_ty(reg_uint_ty, 4);
- analysis::Type* reg_v4uint_ty = type_mgr->GetRegisteredType(&v4uint_ty);
- v4uint_id_ = type_mgr->GetTypeInstruction(reg_v4uint_ty);
+ analysis::Integer uint64_ty(64, false);
+ analysis::Type* reg_uint64_ty = type_mgr->GetRegisteredType(&uint64_ty);
+ uint64_id_ = type_mgr->GetTypeInstruction(reg_uint64_ty);
}
+ return uint64_id_;
+}
+
+uint32_t InstrumentPass::GetVecUintId(uint32_t len) {
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::Integer uint_ty(32, false);
+ analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
+ analysis::Vector v_uint_ty(reg_uint_ty, len);
+ analysis::Type* reg_v_uint_ty = type_mgr->GetRegisteredType(&v_uint_ty);
+ uint32_t v_uint_id = type_mgr->GetTypeInstruction(reg_v_uint_ty);
+ return v_uint_id;
+}
+
+uint32_t InstrumentPass::GetVec4UintId() {
+ if (v4uint_id_ == 0) v4uint_id_ = GetVecUintId(4u);
return v4uint_id_;
}
+uint32_t InstrumentPass::GetVec3UintId() {
+ if (v3uint_id_ == 0) v3uint_id_ = GetVecUintId(3u);
+ return v3uint_id_;
+}
+
uint32_t InstrumentPass::GetBoolId() {
if (bool_id_ == 0) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
@@ -631,7 +668,7 @@
(version_ == 1) ? kInstStageOutCnt : kInst2StageOutCnt;
uint32_t obuf_record_sz = val_spec_offset + val_spec_param_cnt;
uint32_t buf_id = GetOutputBufferId();
- uint32_t buf_uint_ptr_id = GetBufferUintPtrId();
+ uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
Instruction* obuf_curr_sz_ac_inst =
builder.AddBinaryOp(buf_uint_ptr_id, SpvOpAccessChain, buf_id,
builder.GetUintConstantId(kDebugOutputSizeOffset));
@@ -702,16 +739,17 @@
uint32_t InstrumentPass::GetDirectReadFunctionId(uint32_t param_cnt) {
uint32_t func_id = param2input_func_id_[param_cnt];
if (func_id != 0) return func_id;
- // Create input function for param_cnt
+ // Create input function for param_cnt.
func_id = TakeNextId();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
std::vector<const analysis::Type*> param_types;
for (uint32_t c = 0; c < param_cnt; ++c)
param_types.push_back(type_mgr->GetType(GetUintId()));
- analysis::Function func_ty(type_mgr->GetType(GetUintId()), param_types);
+ uint32_t ibuf_type_id = GetInputBufferTypeId();
+ analysis::Function func_ty(type_mgr->GetType(ibuf_type_id), param_types);
analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
std::unique_ptr<Instruction> func_inst(new Instruction(
- get_module()->context(), SpvOpFunction, GetUintId(), func_id,
+ get_module()->context(), SpvOpFunction, ibuf_type_id, func_id,
{{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{SpvFunctionControlMaskNone}},
{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
@@ -741,22 +779,27 @@
// loaded value if it exists, and load value from input buffer at new offset.
// Return last loaded value.
uint32_t buf_id = GetInputBufferId();
- uint32_t buf_uint_ptr_id = GetBufferUintPtrId();
+ uint32_t buf_ptr_id = GetInputBufferPtrId();
uint32_t last_value_id = 0;
for (uint32_t p = 0; p < param_cnt; ++p) {
uint32_t offset_id;
if (p == 0) {
offset_id = param_vec[0];
} else {
+ if (ibuf_type_id != GetUintId()) {
+ Instruction* ucvt_inst =
+ builder.AddUnaryOp(GetUintId(), SpvOpUConvert, last_value_id);
+ last_value_id = ucvt_inst->result_id();
+ }
Instruction* offset_inst = builder.AddBinaryOp(
GetUintId(), SpvOpIAdd, last_value_id, param_vec[p]);
offset_id = offset_inst->result_id();
}
Instruction* ac_inst = builder.AddTernaryOp(
- buf_uint_ptr_id, SpvOpAccessChain, buf_id,
+ buf_ptr_id, SpvOpAccessChain, buf_id,
builder.GetUintConstantId(kDebugInputDataOffset), offset_id);
Instruction* load_inst =
- builder.AddUnaryOp(GetUintId(), SpvOpLoad, ac_inst->result_id());
+ builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, ac_inst->result_id());
last_value_id = load_inst->result_id();
}
(void)builder.AddInstruction(MakeUnique<Instruction>(
@@ -883,17 +926,21 @@
void InstrumentPass::InitializeInstrument() {
output_buffer_id_ = 0;
- buffer_uint_ptr_id_ = 0;
+ output_buffer_ptr_id_ = 0;
+ input_buffer_ptr_id_ = 0;
output_func_id_ = 0;
output_func_param_cnt_ = 0;
input_buffer_id_ = 0;
v4float_id_ = 0;
uint_id_ = 0;
+ uint64_id_ = 0;
v4uint_id_ = 0;
+ v3uint_id_ = 0;
bool_id_ = 0;
void_id_ = 0;
storage_buffer_ext_defined_ = false;
- uint_rarr_ty_ = nullptr;
+ uint32_rarr_ty_ = nullptr;
+ uint64_rarr_ty_ = nullptr;
// clear collections
id2function_.clear();
diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h
index d255698..ead3b73 100644
--- a/source/opt/instrument_pass.h
+++ b/source/opt/instrument_pass.h
@@ -60,6 +60,7 @@
// These are used to identify the general validation being done and map to
// its output buffers.
static const uint32_t kInstValidationIdBindless = 0;
+static const uint32_t kInstValidationIdBuffAddr = 1;
class InstrumentPass : public Pass {
using cbb_ptr = const BasicBlock*;
@@ -218,17 +219,29 @@
uint32_t GetUintId();
// Return id for 32-bit unsigned type
+ uint32_t GetUint64Id();
+
+ // Return id for 32-bit unsigned type
uint32_t GetBoolId();
// Return id for void type
uint32_t GetVoidId();
// Return pointer to type for runtime array of uint
- analysis::Type* GetUintRuntimeArrayType(analysis::DecorationManager* deco_mgr,
- analysis::TypeManager* type_mgr);
+ analysis::Type* GetUintXRuntimeArrayType(uint32_t width,
+ analysis::Type** rarr_ty);
+
+ // Return pointer to type for runtime array of uint
+ analysis::Type* GetUintRuntimeArrayType(uint32_t width);
// Return id for buffer uint type
- uint32_t GetBufferUintPtrId();
+ uint32_t GetOutputBufferPtrId();
+
+ // Return id for buffer uint type
+ uint32_t GetInputBufferTypeId();
+
+ // Return id for buffer uint type
+ uint32_t GetInputBufferPtrId();
// Return binding for output buffer for current validation.
uint32_t GetOutputBufferBinding();
@@ -248,9 +261,15 @@
// Return id for v4float type
uint32_t GetVec4FloatId();
+ // Return id for uint vector type of |length|
+ uint32_t GetVecUintId(uint32_t length);
+
// Return id for v4uint type
uint32_t GetVec4UintId();
+ // Return id for v3uint type
+ uint32_t GetVec3UintId();
+
// Return id for output function. Define if it doesn't exist with
// |val_spec_param_cnt| validation-specific uint32 parameters.
uint32_t GetStreamWriteFunctionId(uint32_t stage_idx,
@@ -348,8 +367,11 @@
// id for output buffer variable
uint32_t output_buffer_id_;
- // type id for output buffer element
- uint32_t buffer_uint_ptr_id_;
+ // ptr type id for output buffer element
+ uint32_t output_buffer_ptr_id_;
+
+ // ptr type id for input buffer element
+ uint32_t input_buffer_ptr_id_;
// id for debug output function
uint32_t output_func_id_;
@@ -366,12 +388,18 @@
// id for v4float type
uint32_t v4float_id_;
- // id for v4float type
+ // id for v4uint type
uint32_t v4uint_id_;
+ // id for v3uint type
+ uint32_t v3uint_id_;
+
// id for 32-bit unsigned type
uint32_t uint_id_;
+ // id for 32-bit unsigned type
+ uint32_t uint64_id_;
+
// id for bool type
uint32_t bool_id_;
@@ -385,7 +413,10 @@
bool storage_buffer_ext_defined_;
// runtime array of uint type
- analysis::Type* uint_rarr_ty_;
+ analysis::Type* uint64_rarr_ty_;
+
+ // runtime array of uint type
+ analysis::Type* uint32_rarr_ty_;
// Pre-instrumentation same-block insts
std::unordered_map<uint32_t, Instruction*> same_block_pre_;
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index da74055..a0ca40c 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -465,6 +465,43 @@
return AddInstruction(std::move(new_inst));
}
+ Instruction* AddFunctionCall(uint32_t result_type, uint32_t function,
+ const std::vector<uint32_t>& parameters) {
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {function}});
+ for (uint32_t id : parameters) {
+ operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
+ }
+
+ uint32_t result_id = GetContext()->TakeNextId();
+ if (result_id == 0) {
+ return nullptr;
+ }
+ std::unique_ptr<Instruction> new_inst(new Instruction(
+ GetContext(), SpvOpFunctionCall, result_type, result_id, operands));
+ return AddInstruction(std::move(new_inst));
+ }
+
+ Instruction* AddVectorShuffle(uint32_t result_type, uint32_t vec1,
+ uint32_t vec2,
+ const std::vector<uint32_t>& components) {
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {vec1}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {vec2}});
+ for (uint32_t id : components) {
+ operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {id}});
+ }
+
+ uint32_t result_id = GetContext()->TakeNextId();
+ if (result_id == 0) {
+ return nullptr;
+ }
+
+ std::unique_ptr<Instruction> new_inst(new Instruction(
+ GetContext(), SpvOpVectorShuffle, result_type, result_id, operands));
+ return AddInstruction(std::move(new_inst));
+ }
+
// Inserts the new instruction before the insertion point.
Instruction* AddInstruction(std::unique_ptr<Instruction>&& insn) {
Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));
@@ -512,6 +549,10 @@
// Returns true if the users requested to update |analysis|.
inline bool IsAnalysisUpdateRequested(IRContext::Analysis analysis) const {
+ if (!GetContext()->AreAnalysesValid(analysis)) {
+ // Do not try to update something that is not built.
+ return false;
+ }
return preserved_analyses_ & analysis;
}
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index 20309ca..1c747b7 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -156,14 +156,20 @@
decoration_mgr_->RemoveDecoration(inst);
}
}
-
if (type_mgr_ && IsTypeInst(inst->opcode())) {
type_mgr_->RemoveId(inst->result_id());
}
-
if (constant_mgr_ && IsConstantInst(inst->opcode())) {
constant_mgr_->RemoveId(inst->result_id());
}
+ if (inst->opcode() == SpvOpCapability || inst->opcode() == SpvOpExtension) {
+ // We reset the feature manager, instead of updating it, because it is just
+ // as much work. We would have to remove all capabilities implied by this
+ // capability that are not also implied by the remaining OpCapability
+ // instructions. We could update extensions, but we will see if it is
+ // needed.
+ ResetFeatureManager();
+ }
RemoveFromIdToName(inst);
@@ -190,6 +196,13 @@
}
bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) {
+ return ReplaceAllUsesWithPredicate(
+ before, after, [](Instruction*, uint32_t) { return true; });
+}
+
+bool IRContext::ReplaceAllUsesWithPredicate(
+ uint32_t before, uint32_t after,
+ const std::function<bool(Instruction*, uint32_t)>& predicate) {
if (before == after) return false;
// Ensure that |after| has been registered as def.
@@ -198,8 +211,10 @@
std::vector<std::pair<Instruction*, uint32_t>> uses_to_update;
get_def_use_mgr()->ForEachUse(
- before, [&uses_to_update](Instruction* user, uint32_t index) {
- uses_to_update.emplace_back(user, index);
+ before, [&predicate, &uses_to_update](Instruction* user, uint32_t index) {
+ if (predicate(user, index)) {
+ uses_to_update.emplace_back(user, index);
+ }
});
Instruction* prev = nullptr;
@@ -243,7 +258,6 @@
#ifndef SPIRV_CHECK_CONTEXT
return true;
#endif
-
if (AreAnalysesValid(kAnalysisDefUse)) {
analysis::DefUseManager new_def_use(module());
if (*get_def_use_mgr() != new_def_use) {
@@ -277,6 +291,15 @@
return false;
}
}
+
+ if (feature_mgr_ != nullptr) {
+ FeatureManager current(grammar_);
+ current.Analyze(module());
+
+ if (current != *feature_mgr_) {
+ return false;
+ }
+ }
return true;
}
@@ -678,7 +701,8 @@
case SpvBuiltInVertexIndex:
case SpvBuiltInInstanceIndex:
case SpvBuiltInPrimitiveId:
- case SpvBuiltInInvocationId: {
+ case SpvBuiltInInvocationId:
+ case SpvBuiltInSubgroupLocalInvocationId: {
analysis::Integer uint_ty(32, false);
reg_type = type_mgr->GetRegisteredType(&uint_ty);
break;
@@ -691,6 +715,20 @@
reg_type = type_mgr->GetRegisteredType(&v3uint_ty);
break;
}
+ case SpvBuiltInTessCoord: {
+ analysis::Float float_ty(32);
+ analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty);
+ analysis::Vector v3float_ty(reg_float_ty, 3);
+ reg_type = type_mgr->GetRegisteredType(&v3float_ty);
+ break;
+ }
+ case SpvBuiltInSubgroupLtMask: {
+ analysis::Integer uint_ty(32, false);
+ analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
+ analysis::Vector v4uint_ty(reg_uint_ty, 4);
+ reg_type = type_mgr->GetRegisteredType(&v4uint_ty);
+ break;
+ }
default: {
assert(false && "unhandled builtin");
return 0;
@@ -779,6 +817,42 @@
return modified;
}
+void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
+ if (!consumer()) {
+ return;
+ }
+
+ Instruction* line_inst = inst;
+ while (line_inst != nullptr) { // Stop at the beginning of the basic block.
+ if (!line_inst->dbg_line_insts().empty()) {
+ line_inst = &line_inst->dbg_line_insts().back();
+ if (line_inst->opcode() == SpvOpNoLine) {
+ line_inst = nullptr;
+ }
+ break;
+ }
+ line_inst = line_inst->PreviousNode();
+ }
+
+ uint32_t line_number = 0;
+ uint32_t col_number = 0;
+ char* source = nullptr;
+ if (line_inst != nullptr) {
+ Instruction* file_name =
+ get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0));
+ source = reinterpret_cast<char*>(&file_name->GetInOperand(0).words[0]);
+
+ // Get the line number and column number.
+ line_number = line_inst->GetSingleWordInOperand(1);
+ col_number = line_inst->GetSingleWordInOperand(2);
+ }
+
+ message +=
+ "\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0},
+ message.c_str());
+}
+
// Gets the dominator analysis for function |f|.
DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) {
if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 37c6449..e297fb1 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -190,9 +190,13 @@
// Clears all debug instructions (excluding OpLine & OpNoLine).
inline void debug_clear();
+ // Add |capability| to the module, if it is not already enabled.
+ inline void AddCapability(SpvCapability capability);
+
// Appends a capability instruction to this module.
inline void AddCapability(std::unique_ptr<Instruction>&& c);
// Appends an extension instruction to this module.
+ inline void AddExtension(const std::string& ext_name);
inline void AddExtension(std::unique_ptr<Instruction>&& e);
// Appends an extended instruction set instruction to this module.
inline void AddExtInstImport(std::unique_ptr<Instruction>&& e);
@@ -382,6 +386,15 @@
// |before| and |after| must be registered definitions in the DefUseManager.
bool ReplaceAllUsesWith(uint32_t before, uint32_t after);
+ // Replace all uses of |before| id with |after| id if those uses
+ // (instruction, operand pair) return true for |predicate|. Returns true if
+ // any replacement happens. This method does not kill the definition of the
+ // |before| id. If |after| is the same as |before|, does nothing and return
+ // false.
+ bool ReplaceAllUsesWithPredicate(
+ uint32_t before, uint32_t after,
+ const std::function<bool(Instruction*, uint32_t)>& predicate);
+
// Returns true if all of the analyses that are suppose to be valid are
// actually valid.
bool IsConsistent();
@@ -478,6 +491,8 @@
return feature_mgr_.get();
}
+ void ResetFeatureManager() { feature_mgr_.reset(nullptr); }
+
// Returns the grammar for this context.
const AssemblyGrammar& grammar() const { return grammar_; }
@@ -547,6 +562,10 @@
bool ProcessCallTreeFromRoots(ProcessFunction& pfn,
std::queue<uint32_t>* roots);
+ // Emmits a error message to the message consumer indicating the error
+ // described by |message| occurred in |inst|.
+ void EmitErrorMessage(std::string message, Instruction* inst);
+
private:
// Builds the def-use manager from scratch, even if it was already valid.
void BuildDefUseManager() {
@@ -910,15 +929,45 @@
void IRContext::debug_clear() { module_->debug_clear(); }
+void IRContext::AddCapability(SpvCapability capability) {
+ if (!get_feature_mgr()->HasCapability(capability)) {
+ std::unique_ptr<Instruction> capability_inst(new Instruction(
+ this, SpvOpCapability, 0, 0,
+ {{SPV_OPERAND_TYPE_CAPABILITY, {static_cast<uint32_t>(capability)}}}));
+ AddCapability(std::move(capability_inst));
+ }
+}
+
void IRContext::AddCapability(std::unique_ptr<Instruction>&& c) {
AddCombinatorsForCapability(c->GetSingleWordInOperand(0));
+ if (feature_mgr_ != nullptr) {
+ feature_mgr_->AddCapability(
+ static_cast<SpvCapability>(c->GetSingleWordInOperand(0)));
+ }
+ if (AreAnalysesValid(kAnalysisDefUse)) {
+ get_def_use_mgr()->AnalyzeInstDefUse(c.get());
+ }
module()->AddCapability(std::move(c));
}
+void IRContext::AddExtension(const std::string& ext_name) {
+ const auto num_chars = ext_name.size();
+ // Compute num words, accommodate the terminating null character.
+ const auto num_words = (num_chars + 1 + 3) / 4;
+ std::vector<uint32_t> ext_words(num_words, 0u);
+ std::memcpy(ext_words.data(), ext_name.data(), num_chars);
+ AddExtension(std::unique_ptr<Instruction>(
+ new Instruction(this, SpvOpExtension, 0u, 0u,
+ {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
+}
+
void IRContext::AddExtension(std::unique_ptr<Instruction>&& e) {
if (AreAnalysesValid(kAnalysisDefUse)) {
get_def_use_mgr()->AnalyzeInstDefUse(e.get());
}
+ if (feature_mgr_ != nullptr) {
+ feature_mgr_->AddExtension(&*e);
+ }
module()->AddExtension(std::move(e));
}
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index cc1b837..aebbd00 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -256,6 +256,7 @@
"SPV_NV_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_EXT_fragment_invocation_density",
+ "SPV_EXT_physical_storage_buffer",
});
}
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index f47777e..d6beeab 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -119,6 +119,7 @@
"SPV_NV_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_EXT_fragment_invocation_density",
+ "SPV_EXT_physical_storage_buffer",
});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
diff --git a/source/opt/local_ssa_elim_pass.cpp b/source/opt/local_ssa_elim_pass.cpp
index 299bbe0..c3f4ab6 100644
--- a/source/opt/local_ssa_elim_pass.cpp
+++ b/source/opt/local_ssa_elim_pass.cpp
@@ -105,6 +105,7 @@
"SPV_NV_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_EXT_fragment_invocation_density",
+ "SPV_EXT_physical_storage_buffer",
});
}
diff --git a/source/opt/module.h b/source/opt/module.h
index ede0bbb..cf7c274 100644
--- a/source/opt/module.h
+++ b/source/opt/module.h
@@ -133,6 +133,8 @@
inline uint32_t version() const { return header_.version; }
+ inline void set_version(uint32_t v) { header_.version = v; }
+
// Iterators for capabilities instructions contained in this module.
inline inst_iterator capability_begin();
inline inst_iterator capability_end();
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 5c1e6ca..635b075 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -14,6 +14,7 @@
#include "spirv-tools/optimizer.hpp"
+#include <cassert>
#include <memory>
#include <string>
#include <unordered_map>
@@ -22,6 +23,7 @@
#include <source/spirv_optimizer_options.h>
#include "source/opt/build_module.h"
+#include "source/opt/graphics_robust_access_pass.h"
#include "source/opt/log.h"
#include "source/opt/pass_manager.h"
#include "source/opt/passes.h"
@@ -106,8 +108,10 @@
// or enable more copy propagation.
Optimizer& Optimizer::RegisterLegalizationPasses() {
return
- // Remove unreachable block so that merge return works.
- RegisterPass(CreateDeadBranchElimPass())
+ // Wrap OpKill instructions so all other code can be inlined.
+ RegisterPass(CreateWrapOpKillPass())
+ // Remove unreachable block so that merge return works.
+ .RegisterPass(CreateDeadBranchElimPass())
// Merge the returns so we can inline.
.RegisterPass(CreateMergeReturnPass())
// Make sure uses and definitions are in the same function.
@@ -153,7 +157,8 @@
}
Optimizer& Optimizer::RegisterPerformancePasses() {
- return RegisterPass(CreateDeadBranchElimPass())
+ return RegisterPass(CreateWrapOpKillPass())
+ .RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateAggressiveDCEPass())
@@ -189,7 +194,8 @@
}
Optimizer& Optimizer::RegisterSizePasses() {
- return RegisterPass(CreateDeadBranchElimPass())
+ return RegisterPass(CreateWrapOpKillPass())
+ .RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateAggressiveDCEPass())
@@ -314,6 +320,8 @@
RegisterPass(CreateCombineAccessChainsPass());
} else if (pass_name == "convert-local-access-chains") {
RegisterPass(CreateLocalAccessChainConvertPass());
+ } else if (pass_name == "descriptor-scalar-replacement") {
+ RegisterPass(CreateDescriptorScalarReplacementPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
RegisterPass(CreateAggressiveDCEPass());
} else if (pass_name == "propagate-line-info") {
@@ -393,11 +401,20 @@
} else if (pass_name == "replace-invalid-opcode") {
RegisterPass(CreateReplaceInvalidOpcodePass());
} else if (pass_name == "inst-bindless-check") {
- RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true, 1));
+ RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false, 2));
RegisterPass(CreateSimplificationPass());
RegisterPass(CreateDeadBranchElimPass());
RegisterPass(CreateBlockMergePass());
RegisterPass(CreateAggressiveDCEPass());
+ } else if (pass_name == "inst-desc-idx-check") {
+ RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true, 2));
+ RegisterPass(CreateSimplificationPass());
+ RegisterPass(CreateDeadBranchElimPass());
+ RegisterPass(CreateBlockMergePass());
+ RegisterPass(CreateAggressiveDCEPass());
+ } else if (pass_name == "inst-buff-addr-check") {
+ RegisterPass(CreateInstBuffAddrCheckPass(7, 23, 2));
+ RegisterPass(CreateAggressiveDCEPass());
} else if (pass_name == "simplify-instructions") {
RegisterPass(CreateSimplificationPass());
} else if (pass_name == "ssa-rewrite") {
@@ -472,6 +489,12 @@
RegisterPass(CreateLegalizeVectorShufflePass());
} else if (pass_name == "decompose-initialized-variables") {
RegisterPass(CreateDecomposeInitializedVariablesPass());
+ } else if (pass_name == "graphics-robust-access") {
+ RegisterPass(CreateGraphicsRobustAccessPass());
+ } else if (pass_name == "wrap-opkill") {
+ RegisterPass(CreateWrapOpKillPass());
+ } else if (pass_name == "amd-ext-to-khr") {
+ RegisterPass(CreateAmdExtToKhrPass());
} else {
Errorf(consumer(), nullptr, {},
"Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -529,26 +552,25 @@
impl_->pass_manager.SetTargetEnv(impl_->target_env);
auto status = impl_->pass_manager.Run(context.get());
- bool binary_changed = false;
- if (status == opt::Pass::Status::SuccessWithChange) {
- binary_changed = true;
- } else if (status == opt::Pass::Status::SuccessWithoutChange) {
- if (optimized_binary->size() != original_binary_size ||
- (memcmp(optimized_binary->data(), original_binary,
- original_binary_size) != 0)) {
- binary_changed = true;
- Log(consumer(), SPV_MSG_WARNING, nullptr, {},
- "Binary unexpectedly changed despite optimizer saying there was no "
- "change");
- }
+ if (status == opt::Pass::Status::Failure) {
+ return false;
}
- if (binary_changed) {
- optimized_binary->clear();
- context->module()->ToBinary(optimized_binary, /* skip_nop = */ true);
- }
+ optimized_binary->clear();
+ context->module()->ToBinary(optimized_binary, /* skip_nop = */ true);
- return status != opt::Pass::Status::Failure;
+#ifndef NDEBUG
+ if (status == opt::Pass::Status::SuccessWithoutChange) {
+ auto changed = optimized_binary->size() != original_binary_size ||
+ memcmp(optimized_binary->data(), original_binary,
+ original_binary_size) != 0;
+ assert(!changed &&
+ "Binary unexpectedly changed despite optimizer saying there was no "
+ "change");
+ }
+#endif // !NDEBUG
+
+ return true;
}
Optimizer& Optimizer::SetPrintAll(std::ostream* out) {
@@ -848,6 +870,13 @@
input_init_enable, version));
}
+Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set,
+ uint32_t shader_id,
+ uint32_t version) {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::InstBuffAddrCheckPass>(desc_set, shader_id, version));
+}
+
Optimizer::PassToken CreateCodeSinkingPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::CodeSinkingPass>());
@@ -878,4 +907,23 @@
MakeUnique<opt::SplitInvalidUnreachablePass>());
}
+Optimizer::PassToken CreateGraphicsRobustAccessPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::GraphicsRobustAccessPass>());
+}
+
+Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::DescriptorScalarReplacement>());
+}
+
+Optimizer::PassToken CreateWrapOpKillPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::WrapOpKill>());
+}
+
+Optimizer::PassToken CreateAmdExtToKhrPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::AmdExtensionToKhrPass>());
+}
+
} // namespace spvtools
diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp
index a783f4f..f9e4a5d 100644
--- a/source/opt/pass.cpp
+++ b/source/opt/pass.cpp
@@ -43,7 +43,8 @@
if (status == Status::SuccessWithChange) {
ctx->InvalidateAnalysesExceptFor(GetPreservedAnalyses());
}
- assert(ctx->IsConsistent());
+ assert((status == Status::Failure || ctx->IsConsistent()) &&
+ "An analysis in the context is out of date.");
return status;
}
diff --git a/source/opt/pass.h b/source/opt/pass.h
index 0667c3d..686e9fc 100644
--- a/source/opt/pass.h
+++ b/source/opt/pass.h
@@ -26,6 +26,7 @@
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "spirv-tools/libspirv.hpp"
+#include "types.h"
namespace spvtools {
namespace opt {
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 0a348e4..d53af8f 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -18,6 +18,7 @@
// A single header to include all passes.
#include "source/opt/aggressive_dead_code_elim_pass.h"
+#include "source/opt/amd_ext_to_khr.h"
#include "source/opt/block_merge_pass.h"
#include "source/opt/ccp_pass.h"
#include "source/opt/cfg_cleanup_pass.h"
@@ -29,6 +30,7 @@
#include "source/opt/dead_insert_elim_pass.h"
#include "source/opt/dead_variable_elimination.h"
#include "source/opt/decompose_initialized_variables_pass.h"
+#include "source/opt/desc_sroa.h"
#include "source/opt/eliminate_dead_constant_pass.h"
#include "source/opt/eliminate_dead_functions_pass.h"
#include "source/opt/eliminate_dead_members_pass.h"
@@ -37,10 +39,12 @@
#include "source/opt/fold_spec_constant_op_and_composite_pass.h"
#include "source/opt/freeze_spec_constant_value_pass.h"
#include "source/opt/generate_webgpu_initializers_pass.h"
+#include "source/opt/graphics_robust_access_pass.h"
#include "source/opt/if_conversion.h"
#include "source/opt/inline_exhaustive_pass.h"
#include "source/opt/inline_opaque_pass.h"
#include "source/opt/inst_bindless_check_pass.h"
+#include "source/opt/inst_buff_addr_check_pass.h"
#include "source/opt/legalize_vector_shuffle_pass.h"
#include "source/opt/licm_pass.h"
#include "source/opt/local_access_chain_convert_pass.h"
@@ -74,5 +78,6 @@
#include "source/opt/upgrade_memory_model.h"
#include "source/opt/vector_dce.h"
#include "source/opt/workaround1209.h"
+#include "source/opt/wrap_opkill.h"
#endif // SOURCE_OPT_PASSES_H_
diff --git a/source/opt/private_to_local_pass.cpp b/source/opt/private_to_local_pass.cpp
index d41d8f2..6df690d 100644
--- a/source/opt/private_to_local_pass.cpp
+++ b/source/opt/private_to_local_pass.cpp
@@ -58,7 +58,9 @@
modified = !variables_to_move.empty();
for (auto p : variables_to_move) {
- MoveVariable(p.first, p.second);
+ if (!MoveVariable(p.first, p.second)) {
+ return Status::Failure;
+ }
localized_variables.insert(p.first->result_id());
}
@@ -112,7 +114,7 @@
return target_function;
} // namespace opt
-void PrivateToLocalPass::MoveVariable(Instruction* variable,
+bool PrivateToLocalPass::MoveVariable(Instruction* variable,
Function* function) {
// The variable needs to be removed from the global section, and placed in the
// header of the function. First step remove from the global list.
@@ -125,6 +127,9 @@
// Update the type as well.
uint32_t new_type_id = GetNewType(variable->type_id());
+ if (new_type_id == 0) {
+ return false;
+ }
variable->SetResultType(new_type_id);
// Place the variable at the start of the first basic block.
@@ -133,7 +138,7 @@
function->begin()->begin()->InsertBefore(move(var));
// Update uses where the type may have changed.
- UpdateUses(variable->result_id());
+ return UpdateUses(variable->result_id());
}
uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
@@ -143,7 +148,9 @@
old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
uint32_t new_type_id =
type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction);
- context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
+ if (new_type_id != 0) {
+ context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
+ }
return new_type_id;
}
@@ -168,7 +175,7 @@
}
}
-void PrivateToLocalPass::UpdateUse(Instruction* inst) {
+bool PrivateToLocalPass::UpdateUse(Instruction* inst) {
// The cases in this switch have to match the cases in |IsValidUse|. If we
// don't think it is valid, the optimization will not view the variable as a
// candidate, and therefore the use will not be updated.
@@ -179,14 +186,20 @@
// The type is fine because it is the type pointed to, and that does not
// change.
break;
- case SpvOpAccessChain:
+ case SpvOpAccessChain: {
context()->ForgetUses(inst);
- inst->SetResultType(GetNewType(inst->type_id()));
+ uint32_t new_type_id = GetNewType(inst->type_id());
+ if (new_type_id == 0) {
+ return false;
+ }
+ inst->SetResultType(new_type_id);
context()->AnalyzeUses(inst);
// Update uses where the type may have changed.
- UpdateUses(inst->result_id());
- break;
+ if (!UpdateUses(inst->result_id())) {
+ return false;
+ }
+ } break;
case SpvOpName:
case SpvOpEntryPoint: // entry points will be updated separately.
break;
@@ -195,15 +208,20 @@
"Do not know how to update the type for this instruction.");
break;
}
+ return true;
}
-void PrivateToLocalPass::UpdateUses(uint32_t id) {
+
+bool PrivateToLocalPass::UpdateUses(uint32_t id) {
std::vector<Instruction*> uses;
context()->get_def_use_mgr()->ForEachUser(
id, [&uses](Instruction* use) { uses.push_back(use); });
for (Instruction* use : uses) {
- UpdateUse(use);
+ if (!UpdateUse(use)) {
+ return false;
+ }
}
+ return true;
}
} // namespace opt
diff --git a/source/opt/private_to_local_pass.h b/source/opt/private_to_local_pass.h
index 4678530..3f9135c0 100644
--- a/source/opt/private_to_local_pass.h
+++ b/source/opt/private_to_local_pass.h
@@ -41,8 +41,8 @@
private:
// Moves |variable| from the private storage class to the function storage
- // class of |function|.
- void MoveVariable(Instruction* variable, Function* function);
+ // class of |function|. Returns false if the variable could not be moved.
+ bool MoveVariable(Instruction* variable, Function* function);
// |inst| is an instruction declaring a varible. If that variable is
// referenced in a single function and all of uses are valid as defined by
@@ -58,13 +58,13 @@
// Given the result id of a pointer type, |old_type_id|, this function
// returns the id of a the same pointer type except the storage class has
// been changed to function. If the type does not already exist, it will be
- // created.
+ // created. Returns 0 if the new type could not be found or generated.
uint32_t GetNewType(uint32_t old_type_id);
// Updates |inst|, and any instruction dependent on |inst|, to reflect the
// change of the base pointer now pointing to the function storage class.
- void UpdateUse(Instruction* inst);
- void UpdateUses(uint32_t id);
+ bool UpdateUse(Instruction* inst);
+ bool UpdateUses(uint32_t id);
};
} // namespace opt
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 9ae1ae8..d748e7f 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -78,36 +78,47 @@
}
std::vector<Instruction*> dead;
- if (get_def_use_mgr()->WhileEachUser(
- inst, [this, &replacements, &dead](Instruction* user) {
- if (!IsAnnotationInst(user->opcode())) {
- switch (user->opcode()) {
- case SpvOpLoad:
- ReplaceWholeLoad(user, replacements);
- dead.push_back(user);
- break;
- case SpvOpStore:
- ReplaceWholeStore(user, replacements);
- dead.push_back(user);
- break;
- case SpvOpAccessChain:
- case SpvOpInBoundsAccessChain:
- if (ReplaceAccessChain(user, replacements))
- dead.push_back(user);
- else
- return false;
- break;
- case SpvOpName:
- case SpvOpMemberName:
- break;
- default:
- assert(false && "Unexpected opcode");
- break;
+ bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
+ inst, [this, &replacements, &dead](Instruction* user) {
+ if (!IsAnnotationInst(user->opcode())) {
+ switch (user->opcode()) {
+ case SpvOpLoad:
+ if (ReplaceWholeLoad(user, replacements)) {
+ dead.push_back(user);
+ } else {
+ return false;
}
- }
- return true;
- }))
+ break;
+ case SpvOpStore:
+ if (ReplaceWholeStore(user, replacements)) {
+ dead.push_back(user);
+ } else {
+ return false;
+ }
+ break;
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ if (ReplaceAccessChain(user, replacements))
+ dead.push_back(user);
+ else
+ return false;
+ break;
+ case SpvOpName:
+ case SpvOpMemberName:
+ break;
+ default:
+ assert(false && "Unexpected opcode");
+ break;
+ }
+ }
+ return true;
+ });
+
+ if (replaced_all_uses) {
dead.push_back(inst);
+ } else {
+ return Status::Failure;
+ }
// If there are no dead instructions to clean up, return with no changes.
if (dead.empty()) return Status::SuccessWithoutChange;
@@ -133,7 +144,7 @@
return Status::SuccessWithChange;
}
-void ScalarReplacementPass::ReplaceWholeLoad(
+bool ScalarReplacementPass::ReplaceWholeLoad(
Instruction* load, const std::vector<Instruction*>& replacements) {
// Replaces the load of the entire composite with a load from each replacement
// variable followed by a composite construction.
@@ -150,6 +161,9 @@
Instruction* type = GetStorageType(var);
uint32_t loadId = TakeNextId();
+ if (loadId == 0) {
+ return false;
+ }
std::unique_ptr<Instruction> newLoad(
new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
std::initializer_list<Operand>{
@@ -168,6 +182,9 @@
// Construct a new composite.
uint32_t compositeId = TakeNextId();
+ if (compositeId == 0) {
+ return false;
+ }
where = load;
std::unique_ptr<Instruction> compositeConstruct(new Instruction(
context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
@@ -180,9 +197,10 @@
get_def_use_mgr()->AnalyzeInstDefUse(&*where);
context()->set_instr_block(&*where, block);
context()->ReplaceAllUsesWith(load->result_id(), compositeId);
+ return true;
}
-void ScalarReplacementPass::ReplaceWholeStore(
+bool ScalarReplacementPass::ReplaceWholeStore(
Instruction* store, const std::vector<Instruction*>& replacements) {
// Replaces a store to the whole composite with a series of extract and stores
// to each element.
@@ -199,6 +217,9 @@
Instruction* type = GetStorageType(var);
uint32_t extractId = TakeNextId();
+ if (extractId == 0) {
+ return false;
+ }
std::unique_ptr<Instruction> extract(new Instruction(
context(), SpvOpCompositeExtract, type->result_id(), extractId,
std::initializer_list<Operand>{
@@ -224,6 +245,7 @@
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
context()->set_instr_block(&*iter, block);
}
+ return true;
}
bool ScalarReplacementPass::ReplaceAccessChain(
@@ -232,8 +254,12 @@
// indexes) or a direct use of the replacement variable.
uint32_t indexId = chain->GetSingleWordInOperand(1u);
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
- uint64_t indexValue = GetConstantInteger(index);
- if (indexValue >= replacements.size()) {
+ int64_t indexValue = context()
+ ->get_constant_mgr()
+ ->GetConstantFromInst(index)
+ ->GetSignExtendedValue();
+ if (indexValue < 0 ||
+ indexValue >= static_cast<int64_t>(replacements.size())) {
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
// indexing is 0-based, so we should also reject index == size-of-array.
return false;
@@ -243,6 +269,9 @@
// Replace input access chain with another access chain.
BasicBlock::iterator chainIter(chain);
uint32_t replacementId = TakeNextId();
+ if (replacementId == 0) {
+ return false;
+ }
std::unique_ptr<Instruction> replacementChain(new Instruction(
context(), chain->opcode(), chain->type_id(), replacementId,
std::initializer_list<Operand>{
@@ -269,7 +298,7 @@
Instruction* inst, std::vector<Instruction*>* replacements) {
Instruction* type = GetStorageType(inst);
- std::unique_ptr<std::unordered_set<uint64_t>> components_used =
+ std::unique_ptr<std::unordered_set<int64_t>> components_used =
GetUsedComponents(inst);
uint32_t elem = 0;
@@ -325,6 +354,10 @@
if (decoration == SpvDecorationInvariant ||
decoration == SpvDecorationRestrict) {
for (auto var : *replacements) {
+ if (var == nullptr) {
+ continue;
+ }
+
std::unique_ptr<Instruction> annotation(
new Instruction(context(), SpvOpDecorate, 0, 0,
std::initializer_list<Operand>{
@@ -346,6 +379,11 @@
std::vector<Instruction*>* replacements) {
uint32_t ptrId = GetOrCreatePointerType(typeId);
uint32_t id = TakeNextId();
+
+ if (id == 0) {
+ replacements->push_back(nullptr);
+ }
+
std::unique_ptr<Instruction> variable(new Instruction(
context(), SpvOpVariable, ptrId, id,
std::initializer_list<Operand>{
@@ -360,6 +398,35 @@
get_def_use_mgr()->AnalyzeInstDefUse(inst);
context()->set_instr_block(inst, block);
+ // Copy decorations from the member to the new variable.
+ Instruction* typeInst = GetStorageType(varInst);
+ for (auto dec_inst :
+ get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
+ uint32_t decoration;
+ if (dec_inst->opcode() != SpvOpMemberDecorate) {
+ continue;
+ }
+
+ if (dec_inst->GetSingleWordInOperand(1) != index) {
+ continue;
+ }
+
+ decoration = dec_inst->GetSingleWordInOperand(2u);
+ switch (decoration) {
+ case SpvDecorationRelaxedPrecision: {
+ std::unique_ptr<Instruction> new_dec_inst(
+ new Instruction(context(), SpvOpDecorate, 0, 0, {}));
+ new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
+ for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
+ new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
+ }
+ context()->AddAnnotationInst(std::move(new_dec_inst));
+ } break;
+ default:
+ break;
+ }
+ }
+
replacements->push_back(inst);
}
@@ -467,35 +534,15 @@
}
}
-uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
- assert(op.words.size() <= 2);
- uint64_t len = 0;
- for (uint32_t i = 0; i != op.words.size(); ++i) {
- len |= (op.words[i] << (32 * i));
- }
- return len;
-}
-
-uint64_t ScalarReplacementPass::GetConstantInteger(
- const Instruction* constant) const {
- assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
- SpvOpTypeInt);
- assert(constant->opcode() == SpvOpConstant ||
- constant->opcode() == SpvOpConstantNull);
- if (constant->opcode() == SpvOpConstantNull) {
- return 0;
- }
-
- const Operand& op = constant->GetInOperand(0u);
- return GetIntegerLiteral(op);
-}
-
uint64_t ScalarReplacementPass::GetArrayLength(
const Instruction* arrayType) const {
assert(arrayType->opcode() == SpvOpTypeArray);
const Instruction* length =
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
- return GetConstantInteger(length);
+ return context()
+ ->get_constant_mgr()
+ ->GetConstantFromInst(length)
+ ->GetZeroExtendedValue();
}
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
@@ -531,25 +578,42 @@
assert(varInst->opcode() == SpvOpVariable);
// Can only replace function scope variables.
- if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction)
+ if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
return false;
+ }
- if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id())))
+ if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
return false;
+ }
const Instruction* typeInst = GetStorageType(varInst);
- return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst);
+ if (!CheckType(typeInst)) {
+ return false;
+ }
+
+ if (!CheckAnnotations(varInst)) {
+ return false;
+ }
+
+ if (!CheckUses(varInst)) {
+ return false;
+ }
+
+ return true;
}
bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
- if (!CheckTypeAnnotations(typeInst)) return false;
+ if (!CheckTypeAnnotations(typeInst)) {
+ return false;
+ }
switch (typeInst->opcode()) {
case SpvOpTypeStruct:
// Don't bother with empty structs or very large structs.
if (typeInst->NumInOperands() == 0 ||
- IsLargerThanSizeLimit(typeInst->NumInOperands()))
+ IsLargerThanSizeLimit(typeInst->NumInOperands())) {
return false;
+ }
return true;
case SpvOpTypeArray:
if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
@@ -598,6 +662,7 @@
case SpvDecorationAlignment:
case SpvDecorationAlignmentId:
case SpvDecorationMaxByteOffset:
+ case SpvDecorationRelaxedPrecision:
break;
default:
return false;
@@ -640,44 +705,51 @@
bool ScalarReplacementPass::CheckUses(const Instruction* inst,
VariableStats* stats) const {
+ uint64_t max_legal_index = GetMaxLegalIndex(inst);
+
bool ok = true;
- get_def_use_mgr()->ForEachUse(
- inst, [this, stats, &ok](const Instruction* user, uint32_t index) {
- // Annotations are check as a group separately.
- if (!IsAnnotationInst(user->opcode())) {
- switch (user->opcode()) {
- case SpvOpAccessChain:
- case SpvOpInBoundsAccessChain:
- if (index == 2u && user->NumInOperands() > 1) {
- uint32_t id = user->GetSingleWordInOperand(1u);
- const Instruction* opInst = get_def_use_mgr()->GetDef(id);
- if (!IsCompileTimeConstantInst(opInst->opcode())) {
- ok = false;
- } else {
- if (!CheckUsesRelaxed(user)) ok = false;
- }
- stats->num_partial_accesses++;
- } else {
- ok = false;
- }
- break;
- case SpvOpLoad:
- if (!CheckLoad(user, index)) ok = false;
- stats->num_full_accesses++;
- break;
- case SpvOpStore:
- if (!CheckStore(user, index)) ok = false;
- stats->num_full_accesses++;
- break;
- case SpvOpName:
- case SpvOpMemberName:
- break;
- default:
+ get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
+ const Instruction* user,
+ uint32_t index) {
+ // Annotations are check as a group separately.
+ if (!IsAnnotationInst(user->opcode())) {
+ switch (user->opcode()) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ if (index == 2u && user->NumInOperands() > 1) {
+ uint32_t id = user->GetSingleWordInOperand(1u);
+ const Instruction* opInst = get_def_use_mgr()->GetDef(id);
+ const auto* constant =
+ context()->get_constant_mgr()->GetConstantFromInst(opInst);
+ if (!constant) {
ok = false;
- break;
+ } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
+ ok = false;
+ } else {
+ if (!CheckUsesRelaxed(user)) ok = false;
+ }
+ stats->num_partial_accesses++;
+ } else {
+ ok = false;
}
- }
- });
+ break;
+ case SpvOpLoad:
+ if (!CheckLoad(user, index)) ok = false;
+ stats->num_full_accesses++;
+ break;
+ case SpvOpStore:
+ if (!CheckStore(user, index)) ok = false;
+ stats->num_full_accesses++;
+ break;
+ case SpvOpName:
+ case SpvOpMemberName:
+ break;
+ default:
+ ok = false;
+ break;
+ }
+ }
+ });
return ok;
}
@@ -734,10 +806,10 @@
return length > max_num_elements_;
}
-std::unique_ptr<std::unordered_set<uint64_t>>
+std::unique_ptr<std::unordered_set<int64_t>>
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
- std::unique_ptr<std::unordered_set<uint64_t>> result(
- new std::unordered_set<uint64_t>());
+ std::unique_ptr<std::unordered_set<int64_t>> result(
+ new std::unordered_set<int64_t>());
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
@@ -775,18 +847,8 @@
const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(index_id);
if (index_const) {
- const analysis::Integer* index_type =
- index_const->type()->AsInteger();
- assert(index_type);
- if (index_type->width() == 32) {
- result->insert(index_const->GetU32());
- return true;
- } else if (index_type->width() == 64) {
- result->insert(index_const->GetU64());
- return true;
- }
- result.reset(nullptr);
- return false;
+ result->insert(index_const->GetSignExtendedValue());
+ return true;
} else {
// Could be any element. Assuming all are used.
result.reset(nullptr);
@@ -817,5 +879,24 @@
return null_inst;
}
+uint64_t ScalarReplacementPass::GetMaxLegalIndex(
+ const Instruction* var_inst) const {
+ assert(var_inst->opcode() == SpvOpVariable &&
+ "|var_inst| must be a variable instruction.");
+ Instruction* type = GetStorageType(var_inst);
+ switch (type->opcode()) {
+ case SpvOpTypeStruct:
+ return type->NumInOperands();
+ case SpvOpTypeArray:
+ return GetArrayLength(type);
+ case SpvOpTypeMatrix:
+ case SpvOpTypeVector:
+ return GetNumElements(type);
+ default:
+ return 0;
+ }
+ return 0;
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index 3a17045..e20f1f1 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -143,7 +143,8 @@
bool CheckStore(const Instruction* inst, uint32_t index) const;
// Creates a variable of type |typeId| from the |index|'th element of
- // |varInst|. The new variable is added to |replacements|.
+ // |varInst|. The new variable is added to |replacements|. If the variable
+ // could not be created, then |nullptr| is appended to |replacements|.
void CreateVariable(uint32_t typeId, Instruction* varInst, uint32_t index,
std::vector<Instruction*>* replacements);
@@ -158,14 +159,6 @@
bool CreateReplacementVariables(Instruction* inst,
std::vector<Instruction*>* replacements);
- // Returns the value of an OpConstant of integer type.
- //
- // |constant| must use two or fewer words to generate the value.
- uint64_t GetConstantInteger(const Instruction* constant) const;
-
- // Returns the integer literal for |op|.
- uint64_t GetIntegerLiteral(const Operand& op) const;
-
// Returns the array length for |arrayInst|.
uint64_t GetArrayLength(const Instruction* arrayInst) const;
@@ -195,28 +188,28 @@
// Generates a load for each replacement variable and then creates a new
// composite by combining all of the loads.
//
- // |load| must be a load.
- void ReplaceWholeLoad(Instruction* load,
+ // |load| must be a load. Returns true if successful.
+ bool ReplaceWholeLoad(Instruction* load,
const std::vector<Instruction*>& replacements);
// Replaces the store to the entire composite.
//
// Generates a composite extract and store for each element in the scalarized
- // variable from the original store data input.
- void ReplaceWholeStore(Instruction* store,
+ // variable from the original store data input. Returns true if successful.
+ bool ReplaceWholeStore(Instruction* store,
const std::vector<Instruction*>& replacements);
// Replaces an access chain to the composite variable with either a direct use
// of the appropriate replacement variable or another access chain with the
- // replacement variable as the base and one fewer indexes. Returns false if
- // the chain has an out of bounds access.
+ // replacement variable as the base and one fewer indexes. Returns true if
+ // successful.
bool ReplaceAccessChain(Instruction* chain,
const std::vector<Instruction*>& replacements);
// Returns a set containing the which components of the result of |inst| are
// potentially used. If the return value is |nullptr|, then every components
// is possibly used.
- std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
+ std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst);
// Returns an instruction defining a null constant with type |type_id|. If
@@ -230,10 +223,16 @@
// Maps type id to OpConstantNull for that type.
std::unordered_map<uint32_t, uint32_t> type_to_null_;
+ // Returns the number of elements in the variable |var_inst|.
+ uint64_t GetMaxLegalIndex(const Instruction* var_inst) const;
+
+ // Returns true if |length| is larger than limit on the size of the variable
+ // that we will be willing to split.
+ bool IsLargerThanSizeLimit(uint64_t length) const;
+
// Limit on the number of members in an object that will be replaced.
// 0 means there is no limit.
uint32_t max_num_elements_;
- bool IsLargerThanSizeLimit(uint64_t length) const;
char name_[55];
};
diff --git a/source/opt/simplification_pass.cpp b/source/opt/simplification_pass.cpp
index 6ea4566..7b0887c 100644
--- a/source/opt/simplification_pass.cpp
+++ b/source/opt/simplification_pass.cpp
@@ -49,7 +49,7 @@
cfg()->ForEachBlockInReversePostOrder(
function->entry().get(),
[&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill,
- folder, this](BasicBlock* bb) {
+ &folder, this](BasicBlock* bb) {
for (Instruction* inst = &*bb->begin(); inst; inst = inst->NextNode()) {
if (inst->opcode() == SpvOpPhi) {
process_phis.insert(inst);
@@ -71,8 +71,16 @@
}
});
if (inst->opcode() == SpvOpCopyObject) {
- context()->ReplaceAllUsesWith(inst->result_id(),
- inst->GetSingleWordInOperand(0));
+ context()->ReplaceAllUsesWithPredicate(
+ inst->result_id(), inst->GetSingleWordInOperand(0),
+ [](Instruction* user, uint32_t) {
+ const auto opcode = user->opcode();
+ if (!spvOpcodeIsDebug(opcode) &&
+ !spvOpcodeIsDecoration(opcode)) {
+ return true;
+ }
+ return false;
+ });
inst_to_kill.insert(inst);
in_work_list.insert(inst);
} else if (inst->opcode() == SpvOpNop) {
@@ -107,8 +115,15 @@
});
if (inst->opcode() == SpvOpCopyObject) {
- context()->ReplaceAllUsesWith(inst->result_id(),
- inst->GetSingleWordInOperand(0));
+ context()->ReplaceAllUsesWithPredicate(
+ inst->result_id(), inst->GetSingleWordInOperand(0),
+ [](Instruction* user, uint32_t) {
+ const auto opcode = user->opcode();
+ if (!spvOpcodeIsDebug(opcode) && !spvOpcodeIsDecoration(opcode)) {
+ return true;
+ }
+ return false;
+ });
inst_to_kill.insert(inst);
in_work_list.insert(inst);
} else if (inst->opcode() == SpvOpNop) {
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 1c27b16..d349481 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -213,6 +213,10 @@
std::unique_ptr<Instruction> typeInst;
// TODO(1841): Handle id overflow.
id = context()->TakeNextId();
+ if (id == 0) {
+ return 0;
+ }
+
RegisterType(id, *type);
switch (type->kind()) {
#define DefineParameterlessCase(kind) \
@@ -247,6 +251,9 @@
break;
case Type::kVector: {
uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst =
MakeUnique<Instruction>(context(), SpvOpTypeVector, 0, id,
std::initializer_list<Operand>{
@@ -257,6 +264,9 @@
}
case Type::kMatrix: {
uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst =
MakeUnique<Instruction>(context(), SpvOpTypeMatrix, 0, id,
std::initializer_list<Operand>{
@@ -268,6 +278,9 @@
case Type::kImage: {
const Image* image = type->AsImage();
uint32_t subtype = GetTypeInstruction(image->sampled_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeImage, 0, id,
std::initializer_list<Operand>{
@@ -289,6 +302,9 @@
case Type::kSampledImage: {
uint32_t subtype =
GetTypeInstruction(type->AsSampledImage()->image_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeSampledImage, 0, id,
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
@@ -296,6 +312,9 @@
}
case Type::kArray: {
uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeArray, 0, id,
std::initializer_list<Operand>{
@@ -306,6 +325,9 @@
case Type::kRuntimeArray: {
uint32_t subtype =
GetTypeInstruction(type->AsRuntimeArray()->element_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeRuntimeArray, 0, id,
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
@@ -315,7 +337,11 @@
std::vector<Operand> ops;
const Struct* structTy = type->AsStruct();
for (auto ty : structTy->element_types()) {
- ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+ uint32_t member_type_id = GetTypeInstruction(ty);
+ if (member_type_id == 0) {
+ return 0;
+ }
+ ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {member_type_id}));
}
typeInst =
MakeUnique<Instruction>(context(), SpvOpTypeStruct, 0, id, ops);
@@ -337,6 +363,9 @@
case Type::kPointer: {
const Pointer* pointer = type->AsPointer();
uint32_t subtype = GetTypeInstruction(pointer->pointee_type());
+ if (subtype == 0) {
+ return 0;
+ }
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypePointer, 0, id,
std::initializer_list<Operand>{
@@ -348,10 +377,17 @@
case Type::kFunction: {
std::vector<Operand> ops;
const Function* function = type->AsFunction();
- ops.push_back(Operand(SPV_OPERAND_TYPE_ID,
- {GetTypeInstruction(function->return_type())}));
+ uint32_t return_type_id = GetTypeInstruction(function->return_type());
+ if (return_type_id == 0) {
+ return 0;
+ }
+ ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {return_type_id}));
for (auto ty : function->param_types()) {
- ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+ uint32_t paramater_type_id = GetTypeInstruction(ty);
+ if (paramater_type_id == 0) {
+ return 0;
+ }
+ ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {paramater_type_id}));
}
typeInst =
MakeUnique<Instruction>(context(), SpvOpTypeFunction, 0, id, ops);
@@ -594,6 +630,9 @@
Type* TypeManager::GetRegisteredType(const Type* type) {
uint32_t id = GetTypeInstruction(type);
+ if (id == 0) {
+ return nullptr;
+ }
return GetType(id);
}
diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h
index ecc7858..8fcf8aa 100644
--- a/source/opt/type_manager.h
+++ b/source/opt/type_manager.h
@@ -101,7 +101,8 @@
std::pair<Type*, std::unique_ptr<Pointer>> GetTypeAndPointerType(
uint32_t id, SpvStorageClass sc) const;
- // Returns an id for a declaration representing |type|.
+ // Returns an id for a declaration representing |type|. Returns 0 if the type
+ // does not exists, and could not be generated.
//
// If |type| is registered, then the registered id is returned. Otherwise,
// this function recursively adds type and annotation instructions as
@@ -109,7 +110,8 @@
uint32_t GetTypeInstruction(const Type* type);
// Find pointer to type and storage in module, return its resultId. If it is
- // not found, a new type is created, and its id is returned.
+ // not found, a new type is created, and its id is returned. Returns 0 if the
+ // type could not be created.
uint32_t FindPointerToType(uint32_t type_id, SpvStorageClass storage_class);
// Registers |id| to |type|.
@@ -118,6 +120,7 @@
// unchanged.
void RegisterType(uint32_t id, const Type& type);
+ // Return the registered type object that is the same as |type|.
Type* GetRegisteredType(const Type* type);
// Removes knowledge of |id| from the manager.
@@ -136,6 +139,61 @@
const Type* GetMemberType(const Type* parent_type,
const std::vector<uint32_t>& access_chain);
+ Type* GetUIntType() {
+ Integer int_type(32, false);
+ return GetRegisteredType(&int_type);
+ }
+
+ uint32_t GetUIntTypeId() { return GetTypeInstruction(GetUIntType()); }
+
+ Type* GetSIntType() {
+ Integer int_type(32, true);
+ return GetRegisteredType(&int_type);
+ }
+
+ uint32_t GetSIntTypeId() { return GetTypeInstruction(GetSIntType()); }
+
+ Type* GetFloatType() {
+ Float float_type(32);
+ return GetRegisteredType(&float_type);
+ }
+
+ uint32_t GetFloatTypeId() { return GetTypeInstruction(GetFloatType()); }
+
+ Type* GetUIntVectorType(uint32_t size) {
+ Vector vec_type(GetUIntType(), size);
+ return GetRegisteredType(&vec_type);
+ }
+
+ uint32_t GetUIntVectorTypeId(uint32_t size) {
+ return GetTypeInstruction(GetUIntVectorType(size));
+ }
+
+ Type* GetSIntVectorType(uint32_t size) {
+ Vector vec_type(GetSIntType(), size);
+ return GetRegisteredType(&vec_type);
+ }
+
+ uint32_t GetSIntVectorTypeId(uint32_t size) {
+ return GetTypeInstruction(GetSIntVectorType(size));
+ }
+
+ Type* GetFloatVectorType(uint32_t size) {
+ Vector vec_type(GetFloatType(), size);
+ return GetRegisteredType(&vec_type);
+ }
+
+ uint32_t GetFloatVectorTypeId(uint32_t size) {
+ return GetTypeInstruction(GetFloatVectorType(size));
+ }
+
+ Type* GetBoolType() {
+ Bool bool_type;
+ return GetRegisteredType(&bool_type);
+ }
+
+ uint32_t GetBoolTypeId() { return GetTypeInstruction(GetBoolType()); }
+
private:
using TypeToIdMap = std::unordered_map<const Type*, uint32_t, HashTypePointer,
CompareTypePointers>;
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 3717fd1..4f7150f 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -274,7 +274,7 @@
words->push_back(width_);
}
-Vector::Vector(Type* type, uint32_t count)
+Vector::Vector(const Type* type, uint32_t count)
: Type(kVector), element_type_(type), count_(count) {
assert(type->AsBool() || type->AsInteger() || type->AsFloat());
}
@@ -299,7 +299,7 @@
words->push_back(count_);
}
-Matrix::Matrix(Type* type, uint32_t count)
+Matrix::Matrix(const Type* type, uint32_t count)
: Type(kMatrix), element_type_(type), count_(count) {
assert(type->AsVector());
}
@@ -426,7 +426,7 @@
void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
-RuntimeArray::RuntimeArray(Type* type)
+RuntimeArray::RuntimeArray(const Type* type)
: Type(kRuntimeArray), element_type_(type) {
assert(!type->AsVoid());
}
@@ -571,10 +571,10 @@
void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; }
-Function::Function(Type* ret_type, const std::vector<const Type*>& params)
+Function::Function(const Type* ret_type, const std::vector<const Type*>& params)
: Type(kFunction), return_type_(ret_type), param_types_(params) {}
-Function::Function(Type* ret_type, std::vector<const Type*>& params)
+Function::Function(const Type* ret_type, std::vector<const Type*>& params)
: Type(kFunction), return_type_(ret_type), param_types_(params) {}
bool Function::IsSameImpl(const Type* that, IsSameCache* seen) const {
diff --git a/source/opt/types.h b/source/opt/types.h
index c997b1f..57920df 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -258,7 +258,7 @@
class Vector : public Type {
public:
- Vector(Type* element_type, uint32_t count);
+ Vector(const Type* element_type, uint32_t count);
Vector(const Vector&) = default;
std::string str() const override;
@@ -280,7 +280,7 @@
class Matrix : public Type {
public:
- Matrix(Type* element_type, uint32_t count);
+ Matrix(const Type* element_type, uint32_t count);
Matrix(const Matrix&) = default;
std::string str() const override;
@@ -407,7 +407,7 @@
class RuntimeArray : public Type {
public:
- RuntimeArray(Type* element_type);
+ RuntimeArray(const Type* element_type);
RuntimeArray(const RuntimeArray&) = default;
std::string str() const override;
@@ -520,8 +520,8 @@
class Function : public Type {
public:
- Function(Type* ret_type, const std::vector<const Type*>& params);
- Function(Type* ret_type, std::vector<const Type*>& params);
+ Function(const Type* ret_type, const std::vector<const Type*>& params);
+ Function(const Type* ret_type, std::vector<const Type*>& params);
Function(const Function&) = default;
std::string str() const override;
diff --git a/source/opt/upgrade_memory_model.cpp b/source/opt/upgrade_memory_model.cpp
index ef9f620..f3bee9e 100644
--- a/source/opt/upgrade_memory_model.cpp
+++ b/source/opt/upgrade_memory_model.cpp
@@ -53,7 +53,7 @@
// 2. Add the OpCapability.
// 3. Modify the memory model.
Instruction* memory_model = get_module()->GetMemoryModel();
- get_module()->AddCapability(MakeUnique<Instruction>(
+ context()->AddCapability(MakeUnique<Instruction>(
context(), SpvOpCapability, 0, 0,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}}));
@@ -61,7 +61,7 @@
std::vector<uint32_t> words(extension.size() / 4 + 1, 0);
char* dst = reinterpret_cast<char*>(words.data());
strncpy(dst, extension.c_str(), extension.size());
- get_module()->AddExtension(
+ context()->AddExtension(
MakeUnique<Instruction>(context(), SpvOpExtension, 0, 0,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_LITERAL_STRING, words}}));
diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp
new file mode 100644
index 0000000..d10cdd2
--- /dev/null
+++ b/source/opt/wrap_opkill.cpp
@@ -0,0 +1,146 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/wrap_opkill.h"
+
+#include "ir_builder.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status WrapOpKill::Process() {
+ bool modified = false;
+
+ for (auto& func : *get_module()) {
+ bool successful = func.WhileEachInst([this, &modified](Instruction* inst) {
+ if (inst->opcode() == SpvOpKill) {
+ modified = true;
+ if (!ReplaceWithFunctionCall(inst)) {
+ return false;
+ }
+ }
+ return true;
+ });
+
+ if (!successful) {
+ return Status::Failure;
+ }
+ }
+
+ if (opkill_function_ != nullptr) {
+ assert(modified &&
+ "The function should only be generated if something was modified.");
+ context()->AddFunction(std::move(opkill_function_));
+ }
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+}
+
+bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
+ assert(inst->opcode() == SpvOpKill &&
+ "|inst| must be an OpKill instruction.");
+ InstructionBuilder ir_builder(
+ context(), inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ uint32_t func_id = GetOpKillFuncId();
+ if (func_id == 0) {
+ return false;
+ }
+ if (ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}) == nullptr) {
+ return false;
+ }
+ ir_builder.AddUnreachable();
+ context()->KillInst(inst);
+ return true;
+}
+
+uint32_t WrapOpKill::GetVoidTypeId() {
+ if (void_type_id_ != 0) {
+ return void_type_id_;
+ }
+
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::Void void_type;
+ void_type_id_ = type_mgr->GetTypeInstruction(&void_type);
+ return void_type_id_;
+}
+
+uint32_t WrapOpKill::GetVoidFunctionTypeId() {
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::Void void_type;
+ const analysis::Type* registered_void_type =
+ type_mgr->GetRegisteredType(&void_type);
+
+ analysis::Function func_type(registered_void_type, {});
+ return type_mgr->GetTypeInstruction(&func_type);
+}
+
+uint32_t WrapOpKill::GetOpKillFuncId() {
+ if (opkill_function_ != nullptr) {
+ return opkill_function_->result_id();
+ }
+
+ uint32_t opkill_func_id = TakeNextId();
+ if (opkill_func_id == 0) {
+ return 0;
+ }
+
+ // Generate the function start instruction
+ std::unique_ptr<Instruction> func_start(new Instruction(
+ context(), SpvOpFunction, GetVoidTypeId(), opkill_func_id, {}));
+ func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}});
+ func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}});
+ opkill_function_.reset(new Function(std::move(func_start)));
+
+ // Generate the function end instruction
+ std::unique_ptr<Instruction> func_end(
+ new Instruction(context(), SpvOpFunctionEnd, 0, 0, {}));
+ opkill_function_->SetFunctionEnd(std::move(func_end));
+
+ // Create the one basic block for the function.
+ uint32_t lab_id = TakeNextId();
+ if (lab_id == 0) {
+ return 0;
+ }
+ std::unique_ptr<Instruction> label_inst(
+ new Instruction(context(), SpvOpLabel, 0, lab_id, {}));
+ std::unique_ptr<BasicBlock> bb(new BasicBlock(std::move(label_inst)));
+
+ // Add the OpKill to the basic block
+ std::unique_ptr<Instruction> kill_inst(
+ new Instruction(context(), SpvOpKill, 0, 0, {}));
+ bb->AddInstruction(std::move(kill_inst));
+
+ // Add the bb to the function
+ opkill_function_->AddBasicBlock(std::move(bb));
+
+ // Add the function to the module.
+ if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) {
+ opkill_function_->ForEachInst(
+ [this](Instruction* inst) { context()->AnalyzeDefUse(inst); });
+ }
+
+ if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) {
+ for (BasicBlock& basic_block : *opkill_function_) {
+ context()->set_instr_block(basic_block.GetLabelInst(), &basic_block);
+ for (Instruction& inst : basic_block) {
+ context()->set_instr_block(&inst, &basic_block);
+ }
+ }
+ }
+
+ return opkill_function_->result_id();
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/wrap_opkill.h b/source/opt/wrap_opkill.h
new file mode 100644
index 0000000..8b03281
--- /dev/null
+++ b/source/opt/wrap_opkill.h
@@ -0,0 +1,72 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_WRAP_OPKILL_H_
+#define SOURCE_OPT_WRAP_OPKILL_H_
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Documented in optimizer.hpp
+class WrapOpKill : public Pass {
+ public:
+ WrapOpKill() : void_type_id_(0) {}
+
+ const char* name() const override { return "wrap-opkill"; }
+
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+ IRContext::kAnalysisNameMap | IRContext::kAnalysisBuiltinVarId |
+ IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants |
+ IRContext::kAnalysisTypes;
+ }
+
+ private:
+ // Replaces the OpKill instruction |inst| with a function call to a function
+ // that contains a single instruction, which is OpKill. An OpUnreachable
+ // instruction will be placed after the function call. Return true if
+ // successful.
+ bool ReplaceWithFunctionCall(Instruction* inst);
+
+ // Returns the id of the void type.
+ uint32_t GetVoidTypeId();
+
+ // Returns the id of the function type for a void function with no parameters.
+ uint32_t GetVoidFunctionTypeId();
+
+ // Return the id of a function that has return type void, has no parameters,
+ // and contains a single instruction, which is an OpKill. Returns 0 if the
+ // function could not be generated.
+ uint32_t GetOpKillFuncId();
+
+ // The id of the void type. If its value is 0, then the void type has not
+ // been found or created yet.
+ uint32_t void_type_id_;
+
+ // The function that is a single instruction, which is an OpKill. The
+ // function has a void return type and takes no parameters. If the function is
+ // |nullptr|, then the function has not been generated.
+ std::unique_ptr<Function> opkill_function_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_WRAP_OPKILL_H_
diff --git a/source/reduce/CMakeLists.txt b/source/reduce/CMakeLists.txt
index def4d21..7651e86 100644
--- a/source/reduce/CMakeLists.txt
+++ b/source/reduce/CMakeLists.txt
@@ -30,6 +30,7 @@
remove_function_reduction_opportunity.h
remove_function_reduction_opportunity_finder.h
remove_opname_instruction_reduction_opportunity_finder.h
+ remove_relaxed_precision_decoration_opportunity_finder.h
remove_selection_reduction_opportunity.h
remove_selection_reduction_opportunity_finder.h
remove_unreferenced_instruction_reduction_opportunity_finder.h
@@ -56,6 +57,7 @@
remove_function_reduction_opportunity.cpp
remove_function_reduction_opportunity_finder.cpp
remove_instruction_reduction_opportunity.cpp
+ remove_relaxed_precision_decoration_opportunity_finder.cpp
remove_selection_reduction_opportunity.cpp
remove_selection_reduction_opportunity_finder.cpp
remove_unreferenced_instruction_reduction_opportunity_finder.cpp
diff --git a/source/reduce/reducer.cpp b/source/reduce/reducer.cpp
index a677be3..ebb5d47 100644
--- a/source/reduce/reducer.cpp
+++ b/source/reduce/reducer.cpp
@@ -25,6 +25,7 @@
#include "source/reduce/remove_block_reduction_opportunity_finder.h"
#include "source/reduce/remove_function_reduction_opportunity_finder.h"
#include "source/reduce/remove_opname_instruction_reduction_opportunity_finder.h"
+#include "source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h"
#include "source/reduce/remove_selection_reduction_opportunity_finder.h"
#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h"
#include "source/reduce/simple_conditional_branch_to_branch_opportunity_finder.h"
@@ -175,6 +176,8 @@
void Reducer::AddDefaultReductionPasses() {
AddReductionPass(spvtools::MakeUnique<
RemoveOpNameInstructionReductionOpportunityFinder>());
+ AddReductionPass(spvtools::MakeUnique<
+ RemoveRelaxedPrecisionDecorationOpportunityFinder>());
AddReductionPass(
spvtools::MakeUnique<OperandToUndefReductionOpportunityFinder>());
AddReductionPass(
diff --git a/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.cpp b/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.cpp
new file mode 100644
index 0000000..352cefb
--- /dev/null
+++ b/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.cpp
@@ -0,0 +1,49 @@
+// Copyright (c) 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h"
+
+#include "source/reduce/remove_instruction_reduction_opportunity.h"
+
+namespace spvtools {
+namespace reduce {
+
+std::vector<std::unique_ptr<ReductionOpportunity>>
+RemoveRelaxedPrecisionDecorationOpportunityFinder::GetAvailableOpportunities(
+ opt::IRContext* context) const {
+ std::vector<std::unique_ptr<ReductionOpportunity>> result;
+
+ // Consider all annotation instructions
+ for (auto& inst : context->module()->annotations()) {
+ // We are interested in removing instructions of the form:
+ // SpvOpDecorate %id RelaxedPrecision
+ // and
+ // SpvOpMemberDecorate %id member RelaxedPrecision
+ if ((inst.opcode() == SpvOpDecorate &&
+ inst.GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision) ||
+ (inst.opcode() == SpvOpMemberDecorate &&
+ inst.GetSingleWordInOperand(2) == SpvDecorationRelaxedPrecision)) {
+ result.push_back(
+ MakeUnique<RemoveInstructionReductionOpportunity>(&inst));
+ }
+ }
+ return result;
+}
+
+std::string RemoveRelaxedPrecisionDecorationOpportunityFinder::GetName() const {
+ return "RemoveRelaxedPrecisionDecorationOpportunityFinder";
+}
+
+} // namespace reduce
+} // namespace spvtools
diff --git a/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h b/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h
new file mode 100644
index 0000000..673049c
--- /dev/null
+++ b/source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h
@@ -0,0 +1,36 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_REDUCE_REMOVE_RELAXED_PRECISION_OPPORTUNITY_FINDER_H_
+#define SOURCE_REDUCE_REMOVE_RELAXED_PRECISION_OPPORTUNITY_FINDER_H_
+
+#include "source/reduce/reduction_opportunity_finder.h"
+
+namespace spvtools {
+namespace reduce {
+
+// A finder for opportunities to remove relaxed precision decorations.
+class RemoveRelaxedPrecisionDecorationOpportunityFinder
+ : public ReductionOpportunityFinder {
+ public:
+ std::vector<std::unique_ptr<ReductionOpportunity>> GetAvailableOpportunities(
+ opt::IRContext* context) const override;
+
+ std::string GetName() const override;
+};
+
+} // namespace reduce
+} // namespace spvtools
+
+#endif // SOURCE_REDUCE_REMOVE_RELAXED_PRECISION_OPPORTUNITY_FINDER_H_
diff --git a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp b/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp
index 8f32435..dabee50 100644
--- a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp
+++ b/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp
@@ -21,11 +21,9 @@
namespace spvtools {
namespace reduce {
-using opt::IRContext;
-
std::vector<std::unique_ptr<ReductionOpportunity>>
RemoveUnreferencedInstructionReductionOpportunityFinder::
- GetAvailableOpportunities(IRContext* context) const {
+ GetAvailableOpportunities(opt::IRContext* context) const {
std::vector<std::unique_ptr<ReductionOpportunity>> result;
for (auto& function : *context->module()) {
diff --git a/source/util/string_utils.h b/source/util/string_utils.h
index f1cd179..4282aa9 100644
--- a/source/util/string_utils.h
+++ b/source/util/string_utils.h
@@ -15,8 +15,10 @@
#ifndef SOURCE_UTIL_STRING_UTILS_H_
#define SOURCE_UTIL_STRING_UTILS_H_
+#include <assert.h>
#include <sstream>
#include <string>
+#include <vector>
#include "source/util/string_utils.h"
@@ -42,6 +44,48 @@
// string will be empty.
std::pair<std::string, std::string> SplitFlagArgs(const std::string& flag);
+// Encodes a string as a sequence of words, using the SPIR-V encoding.
+inline std::vector<uint32_t> MakeVector(std::string input) {
+ std::vector<uint32_t> result;
+ uint32_t word = 0;
+ size_t num_bytes = input.size();
+ // SPIR-V strings are null-terminated. The byte_index == num_bytes
+ // case is used to push the terminating null byte.
+ for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) {
+ const auto new_byte =
+ (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0));
+ word |= (new_byte << (8 * (byte_index % sizeof(uint32_t))));
+ if (3 == (byte_index % sizeof(uint32_t))) {
+ result.push_back(word);
+ word = 0;
+ }
+ }
+ // Emit a trailing partial word.
+ if ((num_bytes + 1) % sizeof(uint32_t)) {
+ result.push_back(word);
+ }
+ return result;
+}
+
+// Decode a string from a sequence of words, using the SPIR-V encoding.
+template <class VectorType>
+inline std::string MakeString(const VectorType& words) {
+ std::string result;
+
+ for (uint32_t word : words) {
+ for (int byte_index = 0; byte_index < 4; byte_index++) {
+ uint32_t extracted_word = (word >> (8 * byte_index)) & 0xFF;
+ char c = static_cast<char>(extracted_word);
+ if (c == 0) {
+ return result;
+ }
+ result += c;
+ }
+ }
+ assert(false && "Did not find terminating null for the string.");
+ return result;
+} // namespace utils
+
} // namespace utils
} // namespace spvtools
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index 04544aa..565518b 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -116,15 +116,13 @@
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
- !(SpvOpConstantComposite == constituent->opcode() ||
- SpvOpSpecConstantComposite == constituent->opcode() ||
- SpvOpUndef == constituent->opcode())) {
+ !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
// The message says "... or undef" because the spec does not say
// undef is a constant.
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
- << "' is not a constant composite or undef.";
+ << "' is not a constant or undef.";
}
const auto vector = _.FindDef(constituent->type_id());
if (!vector) {
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index ec769db..1a64605 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -923,7 +923,58 @@
case OpenCLLIB::Fract:
case OpenCLLIB::Modf:
- case OpenCLLIB::Sincos:
+ case OpenCLLIB::Sincos: {
+ if (!_.IsFloatScalarOrVectorType(result_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected Result Type to be a float scalar or vector type";
+ }
+
+ const uint32_t num_components = _.GetDimension(result_type);
+ if (num_components > 4 && num_components != 8 && num_components != 16) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected Result Type to be a scalar or a vector with 2, "
+ "3, 4, 8 or 16 components";
+ }
+
+ const uint32_t x_type = _.GetOperandTypeId(inst, 4);
+ if (result_type != x_type) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected type of operand X to be equal to Result Type";
+ }
+
+ const uint32_t p_type = _.GetOperandTypeId(inst, 5);
+ uint32_t p_storage_class = 0;
+ uint32_t p_data_type = 0;
+ if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected the last operand to be a pointer";
+ }
+
+ if (p_storage_class != SpvStorageClassGeneric &&
+ p_storage_class != SpvStorageClassCrossWorkgroup &&
+ p_storage_class != SpvStorageClassWorkgroup &&
+ p_storage_class != SpvStorageClassFunction) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected storage class of the pointer to be Generic, "
+ "CrossWorkgroup, Workgroup or Function";
+ }
+
+ if (result_type != p_data_type) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << ext_inst_name() << ": "
+ << "expected data type of the pointer to be equal to Result "
+ "Type";
+ }
+ break;
+ }
+
+ case OpenCLLIB::Frexp:
+ case OpenCLLIB::Lgamma_r:
case OpenCLLIB::Remquo: {
if (!_.IsFloatScalarOrVectorType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -975,57 +1026,6 @@
"CrossWorkgroup, Workgroup or Function";
}
- if (result_type != p_data_type) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected data type of the pointer to be equal to Result "
- "Type";
- }
- break;
- }
-
- case OpenCLLIB::Frexp:
- case OpenCLLIB::Lgamma_r: {
- if (!_.IsFloatScalarOrVectorType(result_type)) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected Result Type to be a float scalar or vector type";
- }
-
- const uint32_t num_components = _.GetDimension(result_type);
- if (num_components > 4 && num_components != 8 && num_components != 16) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected Result Type to be a scalar or a vector with 2, "
- "3, 4, 8 or 16 components";
- }
-
- const uint32_t x_type = _.GetOperandTypeId(inst, 4);
- if (result_type != x_type) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected type of operand X to be equal to Result Type";
- }
-
- const uint32_t p_type = _.GetOperandTypeId(inst, 5);
- uint32_t p_storage_class = 0;
- uint32_t p_data_type = 0;
- if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected the last operand to be a pointer";
- }
-
- if (p_storage_class != SpvStorageClassGeneric &&
- p_storage_class != SpvStorageClassCrossWorkgroup &&
- p_storage_class != SpvStorageClassWorkgroup &&
- p_storage_class != SpvStorageClassFunction) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << ext_inst_name() << ": "
- << "expected storage class of the pointer to be Generic, "
- "CrossWorkgroup, Workgroup or Function";
- }
-
if (!_.IsIntScalarOrVectorType(p_data_type) ||
_.GetBitWidth(p_data_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
diff --git a/source/val/validate_memory_semantics.cpp b/source/val/validate_memory_semantics.cpp
index 0088cdd..4c582f0 100644
--- a/source/val/validate_memory_semantics.cpp
+++ b/source/val/validate_memory_semantics.cpp
@@ -57,36 +57,51 @@
}
if (spvIsWebGPUEnv(_.context()->target_env)) {
- uint32_t valid_bits = SpvMemorySemanticsUniformMemoryMask |
- SpvMemorySemanticsWorkgroupMemoryMask |
- SpvMemorySemanticsImageMemoryMask |
- SpvMemorySemanticsOutputMemoryKHRMask |
- SpvMemorySemanticsMakeAvailableKHRMask |
- SpvMemorySemanticsMakeVisibleKHRMask;
- if (!spvOpcodeIsAtomicOp(inst->opcode())) {
- valid_bits |= SpvMemorySemanticsAcquireReleaseMask;
- }
+ uint32_t valid_bits;
+ switch (inst->opcode()) {
+ case SpvOpControlBarrier:
+ if (!(value & SpvMemorySemanticsAcquireReleaseMask)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU, AcquireRelease must be set for Memory "
+ "Semantics of OpControlBarrier.";
+ }
- if (value & ~valid_bits) {
- if (spvOpcodeIsAtomicOp(inst->opcode())) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "WebGPU spec disallows, for OpAtomic*, any bit masks in "
- "Memory Semantics that are not UniformMemory, "
- "WorkgroupMemory, ImageMemory, or OutputMemoryKHR";
- } else {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "WebGPU spec disallows any bit masks in Memory Semantics "
- "that are not AcquireRelease, UniformMemory, "
- "WorkgroupMemory, ImageMemory, OutputMemoryKHR, "
- "MakeAvailableKHR, or MakeVisibleKHR";
- }
- }
+ if (!(value & SpvMemorySemanticsWorkgroupMemoryMask)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU, WorkgroupMemory must be set for Memory "
+ "Semantics of OpControlBarrier.";
+ }
- if (!spvOpcodeIsAtomicOp(inst->opcode()) &&
- !(value & SpvMemorySemanticsAcquireReleaseMask)) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "WebGPU spec requires AcquireRelease to set in Memory "
- "Semantics.";
+ valid_bits = SpvMemorySemanticsAcquireReleaseMask |
+ SpvMemorySemanticsWorkgroupMemoryMask;
+ if (value & ~valid_bits) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU only WorkgroupMemory and AcquireRelease may be "
+ "set for Memory Semantics of OpControlBarrier.";
+ }
+ break;
+ case SpvOpMemoryBarrier:
+ if (!(value & SpvMemorySemanticsImageMemoryMask)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU, ImageMemory must be set for Memory Semantics "
+ "of OpMemoryBarrier.";
+ }
+ valid_bits = SpvMemorySemanticsImageMemoryMask;
+ if (value & ~valid_bits) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU only ImageMemory may be set for Memory "
+ "Semantics of OpMemoryBarrier.";
+ }
+ break;
+ default:
+ if (spvOpcodeIsAtomicOp(inst->opcode())) {
+ if (value != 0) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU Memory no bits may be set for Memory "
+ "Semantics of OpAtomic* instructions.";
+ }
+ }
+ break;
}
}
diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp
index c607984..70a65f7 100644
--- a/source/val/validate_scopes.cpp
+++ b/source/val/validate_scopes.cpp
@@ -122,7 +122,6 @@
// WebGPU Specific rules
if (spvIsWebGPUEnv(_.context()->target_env)) {
- // Scope for execution must be limited to Workgroup or Subgroup
if (value != SpvScopeWorkgroup) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
@@ -229,12 +228,41 @@
// WebGPU specific rules
if (spvIsWebGPUEnv(_.context()->target_env)) {
- if (value != SpvScopeWorkgroup && value != SpvScopeInvocation &&
- value != SpvScopeQueueFamilyKHR) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << spvOpcodeString(opcode)
- << ": in WebGPU environment Memory Scope is limited to "
- << "Workgroup, Invocation, and QueueFamilyKHR";
+ switch (inst->opcode()) {
+ case SpvOpControlBarrier:
+ if (value != SpvScopeWorkgroup) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": in WebGPU environment Memory Scope is limited to "
+ << "Workgroup for OpControlBarrier";
+ }
+ break;
+ case SpvOpMemoryBarrier:
+ if (value != SpvScopeWorkgroup) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": in WebGPU environment Memory Scope is limited to "
+ << "Workgroup for OpMemoryBarrier";
+ }
+ break;
+ default:
+ if (spvOpcodeIsAtomicOp(inst->opcode())) {
+ if (value != SpvScopeQueueFamilyKHR) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": in WebGPU environment Memory Scope is limited to "
+ << "QueueFamilyKHR for OpAtomic* operations";
+ }
+ }
+
+ if (value != SpvScopeWorkgroup && value != SpvScopeInvocation &&
+ value != SpvScopeQueueFamilyKHR) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": in WebGPU environment Memory Scope is limited to "
+ << "Workgroup, Invocation, and QueueFamilyKHR";
+ }
+ break;
}
}
diff --git a/test/assembly_context_test.cpp b/test/assembly_context_test.cpp
index ee0bb24..c8aa06b 100644
--- a/test/assembly_context_test.cpp
+++ b/test/assembly_context_test.cpp
@@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "source/instruction.h"
+#include "source/util/string_utils.h"
#include "test/unit_spirv.h"
namespace spvtools {
@@ -40,9 +41,8 @@
ASSERT_EQ(SPV_SUCCESS,
context.binaryEncodeString(GetParam().str.c_str(), &inst));
// We already trust MakeVector
- EXPECT_THAT(inst.words,
- Eq(Concatenate({GetParam().initial_contents,
- spvtest::MakeVector(GetParam().str)})));
+ EXPECT_THAT(inst.words, Eq(Concatenate({GetParam().initial_contents,
+ utils::MakeVector(GetParam().str)})));
}
// clang-format off
diff --git a/test/binary_parse_test.cpp b/test/binary_parse_test.cpp
index b966102..54664fc 100644
--- a/test/binary_parse_test.cpp
+++ b/test/binary_parse_test.cpp
@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/table.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -39,7 +40,7 @@
using ::spvtest::Concatenate;
using ::spvtest::MakeInstruction;
-using ::spvtest::MakeVector;
+using utils::MakeVector;
using ::spvtest::ScopedContext;
using ::testing::_;
using ::testing::AnyOf;
diff --git a/test/comment_test.cpp b/test/comment_test.cpp
index f46b72a..49f8df6 100644
--- a/test/comment_test.cpp
+++ b/test/comment_test.cpp
@@ -15,6 +15,7 @@
#include <string>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -23,7 +24,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using testing::Eq;
diff --git a/test/ext_inst.debuginfo_test.cpp b/test/ext_inst.debuginfo_test.cpp
index ec012e0..9090c24 100644
--- a/test/ext_inst.debuginfo_test.cpp
+++ b/test/ext_inst.debuginfo_test.cpp
@@ -17,6 +17,7 @@
#include "DebugInfo.h"
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -31,7 +32,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using testing::Eq;
struct InstructionCase {
diff --git a/test/ext_inst.opencl_test.cpp b/test/ext_inst.opencl_test.cpp
index 7dd903e..7547d92 100644
--- a/test/ext_inst.opencl_test.cpp
+++ b/test/ext_inst.opencl_test.cpp
@@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_opencl_std_header.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -25,7 +26,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using testing::Eq;
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt
index c0e2925..6a101dd 100644
--- a/test/fuzz/CMakeLists.txt
+++ b/test/fuzz/CMakeLists.txt
@@ -30,6 +30,7 @@
transformation_add_type_float_test.cpp
transformation_add_type_int_test.cpp
transformation_add_type_pointer_test.cpp
+ transformation_copy_object_test.cpp
transformation_move_block_down_test.cpp
transformation_replace_boolean_constant_with_constant_binary_test.cpp
transformation_replace_constant_with_uniform_test.cpp
diff --git a/test/fuzz/transformation_copy_object_test.cpp b/test/fuzz/transformation_copy_object_test.cpp
new file mode 100644
index 0000000..0c214c8
--- /dev/null
+++ b/test/fuzz/transformation_copy_object_test.cpp
@@ -0,0 +1,539 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_copy_object.h"
+#include "source/fuzz/data_descriptor.h"
+#include "test/fuzz/fuzz_test_util.h"
+
+namespace spvtools {
+namespace fuzz {
+namespace {
+
+TEST(TransformationCopyObjectTest, CopyBooleanConstants) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ %2 = OpTypeVoid
+ %6 = OpTypeBool
+ %7 = OpConstantTrue %6
+ %8 = OpConstantFalse %6
+ %3 = OpTypeFunction %2
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+
+ ASSERT_EQ(0, fact_manager.GetIdsForWhichSynonymsAreKnown().size());
+
+ TransformationCopyObject copy_true(7, 5, 1, 100);
+ ASSERT_TRUE(copy_true.IsApplicable(context.get(), fact_manager));
+ copy_true.Apply(context.get(), &fact_manager);
+
+ const std::set<uint32_t>& ids_for_which_synonyms_are_known =
+ fact_manager.GetIdsForWhichSynonymsAreKnown();
+ ASSERT_EQ(1, ids_for_which_synonyms_are_known.size());
+ ASSERT_TRUE(ids_for_which_synonyms_are_known.find(7) !=
+ ids_for_which_synonyms_are_known.end());
+ ASSERT_EQ(1, fact_manager.GetSynonymsForId(7).size());
+ protobufs::DataDescriptor descriptor_100 = MakeDataDescriptor(100, {});
+ ASSERT_TRUE(DataDescriptorEquals()(&descriptor_100,
+ &fact_manager.GetSynonymsForId(7)[0]));
+
+ TransformationCopyObject copy_false(8, 100, 1, 101);
+ ASSERT_TRUE(copy_false.IsApplicable(context.get(), fact_manager));
+ copy_false.Apply(context.get(), &fact_manager);
+ ASSERT_EQ(2, ids_for_which_synonyms_are_known.size());
+ ASSERT_TRUE(ids_for_which_synonyms_are_known.find(8) !=
+ ids_for_which_synonyms_are_known.end());
+ ASSERT_EQ(1, fact_manager.GetSynonymsForId(8).size());
+ protobufs::DataDescriptor descriptor_101 = MakeDataDescriptor(101, {});
+ ASSERT_TRUE(DataDescriptorEquals()(&descriptor_101,
+ &fact_manager.GetSynonymsForId(8)[0]));
+
+ TransformationCopyObject copy_false_again(101, 5, 3, 102);
+ ASSERT_TRUE(copy_false_again.IsApplicable(context.get(), fact_manager));
+ copy_false_again.Apply(context.get(), &fact_manager);
+ ASSERT_EQ(3, ids_for_which_synonyms_are_known.size());
+ ASSERT_TRUE(ids_for_which_synonyms_are_known.find(101) !=
+ ids_for_which_synonyms_are_known.end());
+ ASSERT_EQ(1, fact_manager.GetSynonymsForId(101).size());
+ protobufs::DataDescriptor descriptor_102 = MakeDataDescriptor(102, {});
+ ASSERT_TRUE(DataDescriptorEquals()(&descriptor_102,
+ &fact_manager.GetSynonymsForId(101)[0]));
+
+ TransformationCopyObject copy_true_again(7, 102, 1, 103);
+ ASSERT_TRUE(copy_true_again.IsApplicable(context.get(), fact_manager));
+ copy_true_again.Apply(context.get(), &fact_manager);
+ // This does re-uses an id for which synonyms are already known, so the count
+ // of such ids does not change.
+ ASSERT_EQ(3, ids_for_which_synonyms_are_known.size());
+ ASSERT_TRUE(ids_for_which_synonyms_are_known.find(7) !=
+ ids_for_which_synonyms_are_known.end());
+ ASSERT_EQ(2, fact_manager.GetSynonymsForId(7).size());
+ protobufs::DataDescriptor descriptor_103 = MakeDataDescriptor(103, {});
+ ASSERT_TRUE(DataDescriptorEquals()(&descriptor_103,
+ &fact_manager.GetSynonymsForId(7)[0]) ||
+ DataDescriptorEquals()(&descriptor_103,
+ &fact_manager.GetSynonymsForId(7)[1]));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ %2 = OpTypeVoid
+ %6 = OpTypeBool
+ %7 = OpConstantTrue %6
+ %8 = OpConstantFalse %6
+ %3 = OpTypeFunction %2
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %100 = OpCopyObject %6 %7
+ %101 = OpCopyObject %6 %8
+ %102 = OpCopyObject %6 %101
+ %103 = OpCopyObject %6 %7
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationCopyObjectTest, CheckIllegalCases) {
+ // The following SPIR-V comes from this GLSL, pushed through spirv-opt
+ // and then doctored a bit.
+ //
+ // #version 310 es
+ //
+ // precision highp float;
+ //
+ // struct S {
+ // int a;
+ // float b;
+ // };
+ //
+ // layout(set = 0, binding = 2) uniform block {
+ // S s;
+ // lowp float f;
+ // int ii;
+ // } ubuf;
+ //
+ // layout(location = 0) out vec4 color;
+ //
+ // void main() {
+ // float c = 0.0;
+ // lowp float d = 0.0;
+ // S localS = ubuf.s;
+ // for (int i = 0; i < ubuf.s.a; i++) {
+ // switch (ubuf.ii) {
+ // case 0:
+ // c += 0.1;
+ // d += 0.2;
+ // case 1:
+ // c += 0.1;
+ // if (c > d) {
+ // d += 0.2;
+ // } else {
+ // d += c;
+ // }
+ // break;
+ // default:
+ // i += 1;
+ // localS.b += d;
+ // }
+ // }
+ // color = vec4(c, d, localS.b, 1.0);
+ // }
+
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %80
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %12 "S"
+ OpMemberName %12 0 "a"
+ OpMemberName %12 1 "b"
+ OpName %15 "S"
+ OpMemberName %15 0 "a"
+ OpMemberName %15 1 "b"
+ OpName %16 "block"
+ OpMemberName %16 0 "s"
+ OpMemberName %16 1 "f"
+ OpMemberName %16 2 "ii"
+ OpName %18 "ubuf"
+ OpName %80 "color"
+ OpMemberDecorate %12 0 RelaxedPrecision
+ OpMemberDecorate %15 0 RelaxedPrecision
+ OpMemberDecorate %15 0 Offset 0
+ OpMemberDecorate %15 1 Offset 4
+ OpMemberDecorate %16 0 Offset 0
+ OpMemberDecorate %16 1 RelaxedPrecision
+ OpMemberDecorate %16 1 Offset 16
+ OpMemberDecorate %16 2 RelaxedPrecision
+ OpMemberDecorate %16 2 Offset 20
+ OpDecorate %16 Block
+ OpDecorate %18 DescriptorSet 0
+ OpDecorate %18 Binding 2
+ OpDecorate %38 RelaxedPrecision
+ OpDecorate %43 RelaxedPrecision
+ OpDecorate %53 RelaxedPrecision
+ OpDecorate %62 RelaxedPrecision
+ OpDecorate %69 RelaxedPrecision
+ OpDecorate %77 RelaxedPrecision
+ OpDecorate %80 Location 0
+ OpDecorate %101 RelaxedPrecision
+ OpDecorate %102 RelaxedPrecision
+ OpDecorate %96 RelaxedPrecision
+ OpDecorate %108 RelaxedPrecision
+ OpDecorate %107 RelaxedPrecision
+ OpDecorate %98 RelaxedPrecision
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %9 = OpConstant %6 0
+ %11 = OpTypeInt 32 1
+ %12 = OpTypeStruct %11 %6
+ %15 = OpTypeStruct %11 %6
+ %16 = OpTypeStruct %15 %6 %11
+ %17 = OpTypePointer Uniform %16
+ %18 = OpVariable %17 Uniform
+ %19 = OpConstant %11 0
+ %20 = OpTypePointer Uniform %15
+ %27 = OpConstant %11 1
+ %36 = OpTypePointer Uniform %11
+ %39 = OpTypeBool
+ %41 = OpConstant %11 2
+ %48 = OpConstant %6 0.100000001
+ %51 = OpConstant %6 0.200000003
+ %78 = OpTypeVector %6 4
+ %79 = OpTypePointer Output %78
+ %80 = OpVariable %79 Output
+ %85 = OpConstant %6 1
+ %95 = OpUndef %12
+ %112 = OpTypePointer Uniform %6
+ %113 = OpTypeInt 32 0
+ %114 = OpConstant %113 1
+ %179 = OpTypePointer Function %39
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %180 = OpVariable %179 Function
+ %181 = OpVariable %179 Function
+ %182 = OpVariable %179 Function
+ %21 = OpAccessChain %20 %18 %19
+ %115 = OpAccessChain %112 %21 %114
+ %116 = OpLoad %6 %115
+ %90 = OpCompositeInsert %12 %116 %95 1
+ OpBranch %30
+ %30 = OpLabel
+ %99 = OpPhi %12 %90 %5 %109 %47
+ %98 = OpPhi %6 %9 %5 %107 %47
+ %97 = OpPhi %6 %9 %5 %105 %47
+ %96 = OpPhi %11 %19 %5 %77 %47
+ %37 = OpAccessChain %36 %18 %19 %19
+ %38 = OpLoad %11 %37
+ %40 = OpSLessThan %39 %96 %38
+ OpLoopMerge %32 %47 None
+ OpBranchConditional %40 %31 %32
+ %31 = OpLabel
+ %42 = OpAccessChain %36 %18 %41
+ %43 = OpLoad %11 %42
+ OpSelectionMerge %47 None
+ OpSwitch %43 %46 0 %44 1 %45
+ %46 = OpLabel
+ %69 = OpIAdd %11 %96 %27
+ %72 = OpCompositeExtract %6 %99 1
+ %73 = OpFAdd %6 %72 %98
+ %93 = OpCompositeInsert %12 %73 %99 1
+ OpBranch %47
+ %44 = OpLabel
+ %50 = OpFAdd %6 %97 %48
+ %53 = OpFAdd %6 %98 %51
+ OpBranch %45
+ %45 = OpLabel
+ %101 = OpPhi %6 %98 %31 %53 %44
+ %100 = OpPhi %6 %97 %31 %50 %44
+ %55 = OpFAdd %6 %100 %48
+ %58 = OpFOrdGreaterThan %39 %55 %101
+ OpSelectionMerge %60 None
+ OpBranchConditional %58 %59 %63
+ %59 = OpLabel
+ %62 = OpFAdd %6 %101 %51
+ OpBranch %60
+ %63 = OpLabel
+ %66 = OpFAdd %6 %101 %55
+ OpBranch %60
+ %60 = OpLabel
+ %108 = OpPhi %6 %62 %59 %66 %63
+ OpBranch %47
+ %47 = OpLabel
+ %109 = OpPhi %12 %93 %46 %99 %60
+ %107 = OpPhi %6 %98 %46 %108 %60
+ %105 = OpPhi %6 %97 %46 %55 %60
+ %102 = OpPhi %11 %69 %46 %96 %60
+ %77 = OpIAdd %11 %102 %27
+ OpBranch %30
+ %32 = OpLabel
+ %84 = OpCompositeExtract %6 %99 1
+ %86 = OpCompositeConstruct %78 %97 %98 %84 %85
+ OpStore %80 %86
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+
+ // Inapplicable because %18 is decorated.
+ ASSERT_FALSE(TransformationCopyObject(18, 21, 0, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because %77 is decorated.
+ ASSERT_FALSE(TransformationCopyObject(17, 17, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because %80 is decorated.
+ ASSERT_FALSE(TransformationCopyObject(80, 77, 0, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because %84 is not available at the requested point
+ ASSERT_FALSE(TransformationCopyObject(84, 32, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Fine because %84 is available at the requested point
+ ASSERT_TRUE(TransformationCopyObject(84, 32, 2, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because id %9 is already in use
+ ASSERT_FALSE(TransformationCopyObject(84, 32, 2, 9)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the requested point is not in a block
+ ASSERT_FALSE(TransformationCopyObject(84, 86, 3, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because %9 is not in a function
+ ASSERT_FALSE(TransformationCopyObject(9, 9, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because %9 is not in a function
+ ASSERT_FALSE(TransformationCopyObject(9, 9, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the insert point is right before, or inside, a chunk
+ // of OpPhis
+ ASSERT_FALSE(TransformationCopyObject(9, 30, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+ ASSERT_FALSE(TransformationCopyObject(9, 99, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // OK, because the insert point is just after a chunk of OpPhis.
+ ASSERT_TRUE(TransformationCopyObject(9, 96, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the insert point is right after an OpSelectionMerge
+ ASSERT_FALSE(TransformationCopyObject(9, 58, 2, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // OK, because the insert point is right before the OpSelectionMerge
+ ASSERT_TRUE(TransformationCopyObject(9, 58, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the insert point is right after an OpSelectionMerge
+ ASSERT_FALSE(TransformationCopyObject(9, 43, 2, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // OK, because the insert point is right before the OpSelectionMerge
+ ASSERT_TRUE(TransformationCopyObject(9, 43, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the insert point is right after an OpLoopMerge
+ ASSERT_FALSE(TransformationCopyObject(9, 40, 2, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // OK, because the insert point is right before the OpLoopMerge
+ ASSERT_TRUE(TransformationCopyObject(9, 40, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because id %300 does not exist
+ ASSERT_FALSE(TransformationCopyObject(300, 40, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // Inapplicable because the following instruction is OpVariable
+ ASSERT_FALSE(TransformationCopyObject(9, 180, 0, 200)
+ .IsApplicable(context.get(), fact_manager));
+ ASSERT_FALSE(TransformationCopyObject(9, 181, 0, 200)
+ .IsApplicable(context.get(), fact_manager));
+ ASSERT_FALSE(TransformationCopyObject(9, 182, 0, 200)
+ .IsApplicable(context.get(), fact_manager));
+
+ // OK, because this is just past the group of OpVariable instructions.
+ ASSERT_TRUE(TransformationCopyObject(9, 182, 1, 200)
+ .IsApplicable(context.get(), fact_manager));
+}
+
+TEST(TransformationCopyObjectTest, MiscellaneousCopies) {
+ // The following SPIR-V comes from this GLSL:
+ //
+ // #version 310 es
+ //
+ // precision highp float;
+ //
+ // float g;
+ //
+ // vec4 h;
+ //
+ // void main() {
+ // int a;
+ // int b;
+ // b = int(g);
+ // h.x = float(a);
+ // }
+
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %8 "b"
+ OpName %11 "g"
+ OpName %16 "h"
+ OpName %17 "a"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpTypeFloat 32
+ %10 = OpTypePointer Private %9
+ %11 = OpVariable %10 Private
+ %14 = OpTypeVector %9 4
+ %15 = OpTypePointer Private %14
+ %16 = OpVariable %15 Private
+ %20 = OpTypeInt 32 0
+ %21 = OpConstant %20 0
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %8 = OpVariable %7 Function
+ %17 = OpVariable %7 Function
+ %12 = OpLoad %9 %11
+ %13 = OpConvertFToS %6 %12
+ OpStore %8 %13
+ %18 = OpLoad %6 %17
+ %19 = OpConvertSToF %9 %18
+ %22 = OpAccessChain %10 %16 %21
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+
+ std::vector<TransformationCopyObject> transformations = {
+ TransformationCopyObject(19, 22, 1, 100),
+ TransformationCopyObject(22, 22, 1, 101),
+ TransformationCopyObject(12, 22, 1, 102),
+ TransformationCopyObject(11, 22, 1, 103),
+ TransformationCopyObject(16, 22, 1, 104),
+ TransformationCopyObject(8, 22, 1, 105),
+ TransformationCopyObject(17, 22, 1, 106)};
+
+ for (auto& transformation : transformations) {
+ ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager));
+ transformation.Apply(context.get(), &fact_manager);
+ }
+
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %8 "b"
+ OpName %11 "g"
+ OpName %16 "h"
+ OpName %17 "a"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpTypeFloat 32
+ %10 = OpTypePointer Private %9
+ %11 = OpVariable %10 Private
+ %14 = OpTypeVector %9 4
+ %15 = OpTypePointer Private %14
+ %16 = OpVariable %15 Private
+ %20 = OpTypeInt 32 0
+ %21 = OpConstant %20 0
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %8 = OpVariable %7 Function
+ %17 = OpVariable %7 Function
+ %12 = OpLoad %9 %11
+ %13 = OpConvertFToS %6 %12
+ OpStore %8 %13
+ %18 = OpLoad %6 %17
+ %19 = OpConvertSToF %9 %18
+ %22 = OpAccessChain %10 %16 %21
+ %106 = OpCopyObject %7 %17
+ %105 = OpCopyObject %7 %8
+ %104 = OpCopyObject %15 %16
+ %103 = OpCopyObject %10 %11
+ %102 = OpCopyObject %9 %12
+ %101 = OpCopyObject %10 %22
+ %100 = OpCopyObject %9 %19
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+} // namespace
+} // namespace fuzz
+} // namespace spvtools
diff --git a/test/link/type_match_test.cpp b/test/link/type_match_test.cpp
index c12ffb3..dae70c1 100644
--- a/test/link/type_match_test.cpp
+++ b/test/link/type_match_test.cpp
@@ -84,44 +84,46 @@
MatchF(T##Of##A##Of##B, \
MatchPart1(B, b) MatchPart2(A, a, b) MatchPart2(T, type, a))
+// clang-format off
// Basic types
-Match1(Int);
-Match1(Float);
-Match1(Opaque);
-Match1(Sampler);
-Match1(Event);
-Match1(DeviceEvent);
-Match1(ReserveId);
-Match1(Queue);
-Match1(Pipe);
-Match1(PipeStorage);
-Match1(NamedBarrier);
+Match1(Int)
+Match1(Float)
+Match1(Opaque)
+Match1(Sampler)
+Match1(Event)
+Match1(DeviceEvent)
+Match1(ReserveId)
+Match1(Queue)
+Match1(Pipe)
+Match1(PipeStorage)
+Match1(NamedBarrier)
// Simpler (restricted) compound types
-Match2(Vector, Float);
-Match3(Matrix, Vector, Float);
-Match2(Image, Float);
+Match2(Vector, Float)
+Match3(Matrix, Vector, Float)
+Match2(Image, Float)
// Unrestricted compound types
#define MatchCompounds1(A) \
- Match2(RuntimeArray, A); \
- Match2(Struct, A); \
- Match2(Pointer, A); \
- Match2(Function, A); \
- Match2(Array, A);
+ Match2(RuntimeArray, A) \
+ Match2(Struct, A) \
+ Match2(Pointer, A) \
+ Match2(Function, A) \
+ Match2(Array, A)
#define MatchCompounds2(A, B) \
- Match3(RuntimeArray, A, B); \
- Match3(Struct, A, B); \
- Match3(Pointer, A, B); \
- Match3(Function, A, B); \
- Match3(Array, A, B);
+ Match3(RuntimeArray, A, B) \
+ Match3(Struct, A, B) \
+ Match3(Pointer, A, B) \
+ Match3(Function, A, B) \
+ Match3(Array, A, B)
-MatchCompounds1(Float);
-MatchCompounds2(Array, Float);
-MatchCompounds2(RuntimeArray, Float);
-MatchCompounds2(Struct, Float);
-MatchCompounds2(Pointer, Float);
-MatchCompounds2(Function, Float);
+MatchCompounds1(Float)
+MatchCompounds2(Array, Float)
+MatchCompounds2(RuntimeArray, Float)
+MatchCompounds2(Struct, Float)
+MatchCompounds2(Pointer, Float)
+MatchCompounds2(Function, Float)
+// clang-format on
// ForwardPointer tests, which don't fit into the previous mold
#define MatchFpF(N, CODE) \
@@ -134,11 +136,13 @@
#define MatchFp2(T, A) \
MatchFpF(ForwardPointerOf##T, MatchPart1(A, a) MatchPart2(T, realtype, a))
-MatchFp1(Float);
-MatchFp2(Array, Float);
-MatchFp2(RuntimeArray, Float);
-MatchFp2(Struct, Float);
-MatchFp2(Function, Float);
+ // clang-format off
+MatchFp1(Float)
+MatchFp2(Array, Float)
+MatchFp2(RuntimeArray, Float)
+MatchFp2(Struct, Float)
+MatchFp2(Function, Float)
+// clang-format on
} // namespace
} // namespace spvtools
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 246c116..47ce41f 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -17,6 +17,7 @@
add_spvtools_unittest(TARGET opt
SRCS aggressive_dead_code_elim_test.cpp
+ amd_ext_to_khr.cpp
assembly_builder_test.cpp
block_merge_test.cpp
ccp_test.cpp
@@ -25,6 +26,7 @@
code_sink_test.cpp
combine_access_chains_test.cpp
compact_ids_test.cpp
+ constants_test.cpp
constant_manager_test.cpp
copy_prop_array_test.cpp
dead_branch_elim_test.cpp
@@ -33,6 +35,7 @@
decompose_initialized_variables_test.cpp
decoration_manager_test.cpp
def_use_test.cpp
+ desc_sroa_test.cpp
eliminate_dead_const_test.cpp
eliminate_dead_functions_test.cpp
eliminate_dead_member_test.cpp
@@ -44,11 +47,13 @@
freeze_spec_const_test.cpp
function_test.cpp
generate_webgpu_initializers_test.cpp
+ graphics_robust_access_test.cpp
if_conversion_test.cpp
inline_opaque_test.cpp
inline_test.cpp
insert_extract_elim_test.cpp
inst_bindless_check_test.cpp
+ inst_buff_addr_check_test.cpp
instruction_list_test.cpp
instruction_test.cpp
ir_builder.cpp
@@ -94,6 +99,7 @@
value_table_test.cpp
vector_dce_test.cpp
workaround1209_test.cpp
+ wrap_opkill_test.cpp
LIBS SPIRV-Tools-opt
PCH_FILE pch_test_opt
)
diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp
index b4ab10d..3a7fc27 100644
--- a/test/opt/aggressive_dead_code_elim_test.cpp
+++ b/test/opt/aggressive_dead_code_elim_test.cpp
@@ -6616,6 +6616,152 @@
SinglePassRunAndCheck<AggressiveDCEPass>(spirv, spirv, true);
}
+TEST_F(AggressiveDCETest, NoEliminateForwardPointer) {
+ // clang-format off
+ //
+ // #version 450
+ // #extension GL_EXT_buffer_reference : enable
+ //
+ // // forward reference
+ // layout(buffer_reference) buffer blockType;
+ //
+ // layout(buffer_reference, std430, buffer_reference_align = 16) buffer blockType {
+ // int x;
+ // blockType next;
+ // };
+ //
+ // layout(std430) buffer rootBlock {
+ // blockType root;
+ // } r;
+ //
+ // void main()
+ // {
+ // blockType b = r.root;
+ // b = b.next;
+ // b.x = 531;
+ // }
+ //
+ // clang-format on
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+OpCapability PhysicalStorageBufferAddressesEXT
+OpExtension "SPV_EXT_physical_storage_buffer"
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_buffer_reference"
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %blockType "blockType"
+OpMemberName %blockType 0 "x"
+OpMemberName %blockType 1 "next"
+OpName %b "b"
+OpName %rootBlock "rootBlock"
+OpMemberName %rootBlock 0 "root"
+OpName %r "r"
+OpMemberDecorate %blockType 0 Offset 0
+OpMemberDecorate %blockType 1 Offset 8
+OpDecorate %blockType Block
+OpDecorate %b AliasedPointerEXT
+OpMemberDecorate %rootBlock 0 Offset 0
+OpDecorate %rootBlock Block
+OpDecorate %r DescriptorSet 0
+OpDecorate %r Binding 0
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %blockType "blockType"
+OpMemberName %blockType 0 "x"
+OpMemberName %blockType 1 "next"
+OpName %rootBlock "rootBlock"
+OpMemberName %rootBlock 0 "root"
+OpName %r "r"
+OpMemberDecorate %blockType 0 Offset 0
+OpMemberDecorate %blockType 1 Offset 8
+OpDecorate %blockType Block
+OpMemberDecorate %rootBlock 0 Offset 0
+OpDecorate %rootBlock Block
+OpDecorate %r DescriptorSet 0
+OpDecorate %r Binding 0
+)";
+
+ const std::string predefs2_before =
+ R"(%void = OpTypeVoid
+%3 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_blockType PhysicalStorageBufferEXT
+%int = OpTypeInt 32 1
+%blockType = OpTypeStruct %int %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %blockType
+%_ptr_Function__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer Function %_ptr_PhysicalStorageBufferEXT_blockType
+%rootBlock = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_StorageBuffer_rootBlock = OpTypePointer StorageBuffer %rootBlock
+%r = OpVariable %_ptr_StorageBuffer_rootBlock StorageBuffer
+%int_0 = OpConstant %int 0
+%_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBufferEXT_blockType
+%int_1 = OpConstant %int 1
+%_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %_ptr_PhysicalStorageBufferEXT_blockType
+%int_531 = OpConstant %int 531
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+)";
+
+ const std::string predefs2_after =
+ R"(%void = OpTypeVoid
+%8 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_blockType PhysicalStorageBufferEXT
+%int = OpTypeInt 32 1
+%blockType = OpTypeStruct %int %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %blockType
+%rootBlock = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_StorageBuffer_rootBlock = OpTypePointer StorageBuffer %rootBlock
+%r = OpVariable %_ptr_StorageBuffer_rootBlock StorageBuffer
+%int_0 = OpConstant %int 0
+%_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBufferEXT_blockType
+%int_1 = OpConstant %int 1
+%_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %_ptr_PhysicalStorageBufferEXT_blockType
+%int_531 = OpConstant %int 531
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %3
+%5 = OpLabel
+%b = OpVariable %_ptr_Function__ptr_PhysicalStorageBufferEXT_blockType Function
+%16 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType %r %int_0
+%17 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %16
+%21 = OpAccessChain %_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType %17 %int_1
+%22 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %21 Aligned 8
+OpStore %b %22
+%26 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %22 %int_0
+OpStore %26 %int_531 Aligned 16
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %8
+%19 = OpLabel
+%20 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType %r %int_0
+%21 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %20
+%22 = OpAccessChain %_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType %21 %int_1
+%23 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %22 Aligned 8
+%24 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %23 %int_0
+OpStore %24 %int_531 Aligned 16
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<AggressiveDCEPass>(
+ predefs1 + names_before + predefs2_before + func_before,
+ predefs1 + names_after + predefs2_after + func_after, true, true);
+}
+
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// Check that logical addressing required
diff --git a/test/opt/amd_ext_to_khr.cpp b/test/opt/amd_ext_to_khr.cpp
new file mode 100644
index 0000000..7a6d4b4
--- /dev/null
+++ b/test/opt/amd_ext_to_khr.cpp
@@ -0,0 +1,338 @@
+// Copyright (c) 2019 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "gmock/gmock.h"
+
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using AmdExtToKhrTest = PassTest<::testing::Test>;
+
+using ::testing::HasSubstr;
+
+std::string GetTest(std::string op_code, std::string new_op_code) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[undef:%\w+]] = OpUndef %uint
+; CHECK-NEXT: )" + new_op_code +
+ R"( %uint %uint_3 Reduce [[undef]]
+ OpCapability Shader
+ OpCapability Groups
+ OpExtension "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %uint
+ %8 = )" + op_code +
+ R"( %uint %uint_3 Reduce %7
+ OpReturn
+ OpFunctionEnd
+
+)";
+ return text;
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceGroupIAddNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupIAddNonUniformAMD", "OpGroupNonUniformIAdd");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupFAddNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupFAddNonUniformAMD", "OpGroupNonUniformFAdd");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupUMinNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupUMinNonUniformAMD", "OpGroupNonUniformUMin");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupSMinNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupSMinNonUniformAMD", "OpGroupNonUniformSMin");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupFMinNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupFMinNonUniformAMD", "OpGroupNonUniformFMin");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupUMaxNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupUMaxNonUniformAMD", "OpGroupNonUniformUMax");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupSMaxNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupSMaxNonUniformAMD", "OpGroupNonUniformSMax");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceGroupFMaxNonUniformAMD) {
+ std::string text =
+ GetTest("OpGroupFMaxNonUniformAMD", "OpGroupNonUniformFMax");
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceMbcntAMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_ballot"
+; CHECK: OpDecorate [[var:%\w+]] BuiltIn SubgroupLtMask
+; CHECK: [[var]] = OpVariable %_ptr_Input_v4uint Input
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[ld:%\w+]] = OpLoad %v4uint [[var]]
+; CHECK-NEXT: [[shuffle:%\w+]] = OpVectorShuffle %v2uint [[ld]] [[ld]] 0 1
+; CHECK-NEXT: [[bitcast:%\w+]] = OpBitcast %ulong [[shuffle]]
+; CHECK-NEXT: [[and:%\w+]] = OpBitwiseAnd %ulong [[bitcast]] %ulong_0
+; CHECK-NEXT: [[result:%\w+]] = OpBitCount %uint [[and]]
+ OpCapability Shader
+ OpCapability Int64
+ OpExtension "SPV_AMD_shader_ballot"
+ %1 = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "func"
+ OpExecutionMode %2 OriginUpperLeft
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %ulong = OpTypeInt 64 0
+ %ulong_0 = OpConstant %ulong 0
+ %2 = OpFunction %void None %4
+ %8 = OpLabel
+ %9 = OpExtInst %uint %1 MbcntAMD %ulong_0
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSwizzleInvocationsAMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_ballot"
+; CHECK: OpDecorate [[var:%\w+]] BuiltIn SubgroupLocalInvocationId
+; CHECK: [[subgroup:%\w+]] = OpConstant %uint 3
+; CHECK: [[offset:%\w+]] = OpConstantComposite %v4uint
+; CHECK: [[var]] = OpVariable %_ptr_Input_uint Input
+; CHECK: [[uint_max:%\w+]] = OpConstant %uint 4294967295
+; CHECK: [[ballot_value:%\w+]] = OpConstantComposite %v4uint [[uint_max]] [[uint_max]] [[uint_max]] [[uint_max]]
+; CHECK: [[null:%\w+]] = OpConstantNull [[type:%\w+]]
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[data:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[id:%\w+]] = OpLoad %uint [[var]]
+; CHECK-NEXT: [[quad_idx:%\w+]] = OpBitwiseAnd %uint [[id]] %uint_3
+; CHECK-NEXT: [[quad_ldr:%\w+]] = OpBitwiseXor %uint [[id]] [[quad_idx]]
+; CHECK-NEXT: [[my_offset:%\w+]] = OpVectorExtractDynamic %uint [[offset]] [[quad_idx]]
+; CHECK-NEXT: [[target_inv:%\w+]] = OpIAdd %uint [[quad_ldr]] [[my_offset]]
+; CHECK-NEXT: [[is_active:%\w+]] = OpGroupNonUniformBallotBitExtract %bool [[subgroup]] [[ballot_value]] [[target_inv]]
+; CHECK-NEXT: [[shuffle:%\w+]] = OpGroupNonUniformShuffle [[type]] [[subgroup]] [[data]] [[target_inv]]
+; CHECK-NEXT: [[result:%\w+]] = OpSelect [[type]] [[is_active]] [[shuffle]] [[null]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_ballot"
+ %ext = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_x = OpConstant %uint 1
+ %uint_y = OpConstant %uint 2
+ %uint_z = OpConstant %uint 3
+ %uint_w = OpConstant %uint 0
+ %v4uint = OpTypeVector %uint 4
+ %offset = OpConstantComposite %v4uint %uint_x %uint_y %uint_z %uint_x
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %data = OpUndef %uint
+ %9 = OpExtInst %uint %ext SwizzleInvocationsAMD %data %offset
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceSwizzleInvocationsMaskedAMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_ballot"
+; CHECK: OpDecorate [[var:%\w+]] BuiltIn SubgroupLocalInvocationId
+; CHECK: [[x:%\w+]] = OpConstant %uint 19
+; CHECK: [[y:%\w+]] = OpConstant %uint 12
+; CHECK: [[z:%\w+]] = OpConstant %uint 16
+; CHECK: [[var]] = OpVariable %_ptr_Input_uint Input
+; CHECK: [[mask_extend:%\w+]] = OpConstant %uint 4294967264
+; CHECK: [[uint_max:%\w+]] = OpConstant %uint 4294967295
+; CHECK: [[subgroup:%\w+]] = OpConstant %uint 3
+; CHECK: [[ballot_value:%\w+]] = OpConstantComposite %v4uint [[uint_max]] [[uint_max]] [[uint_max]] [[uint_max]]
+; CHECK: [[null:%\w+]] = OpConstantNull [[type:%\w+]]
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[data:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[id:%\w+]] = OpLoad %uint [[var]]
+; CHECK-NEXT: [[and_mask:%\w+]] = OpBitwiseOr %uint [[x]] [[mask_extend]]
+; CHECK-NEXT: [[and:%\w+]] = OpBitwiseAnd %uint [[id]] [[and_mask]]
+; CHECK-NEXT: [[or:%\w+]] = OpBitwiseOr %uint [[and]] [[y]]
+; CHECK-NEXT: [[target_inv:%\w+]] = OpBitwiseXor %uint [[or]] [[z]]
+; CHECK-NEXT: [[is_active:%\w+]] = OpGroupNonUniformBallotBitExtract %bool [[subgroup]] [[ballot_value]] [[target_inv]]
+; CHECK-NEXT: [[shuffle:%\w+]] = OpGroupNonUniformShuffle [[type]] [[subgroup]] [[data]] [[target_inv]]
+; CHECK-NEXT: [[result:%\w+]] = OpSelect [[type]] [[is_active]] [[shuffle]] [[null]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_ballot"
+ %ext = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_x = OpConstant %uint 19
+ %uint_y = OpConstant %uint 12
+ %uint_z = OpConstant %uint 16
+ %v3uint = OpTypeVector %uint 3
+ %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %data = OpUndef %uint
+ %9 = OpExtInst %uint %ext SwizzleInvocationsMaskedAMD %data %mask
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+TEST_F(AmdExtToKhrTest, ReplaceWriteInvocationAMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_ballot"
+; CHECK: OpDecorate [[var:%\w+]] BuiltIn SubgroupLocalInvocationId
+; CHECK: [[var]] = OpVariable %_ptr_Input_uint Input
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[input_val:%\w+]] = OpUndef %uint
+; CHECK-NEXT: [[write_val:%\w+]] = OpUndef %uint
+; CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[var]]
+; CHECK-NEXT: [[cmp:%\w+]] = OpIEqual %bool [[ld]] %uint_3
+; CHECK-NEXT: [[result:%\w+]] = OpSelect %uint [[cmp]] [[write_val]] [[input_val]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_ballot"
+ %ext = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %uint
+ %8 = OpUndef %uint
+ %9 = OpExtInst %uint %ext WriteInvocationAMD %7 %8 %uint_3
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, SetVersion) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpCapability Int64
+ OpExtension "SPV_AMD_shader_ballot"
+ %1 = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "func"
+ OpExecutionMode %2 OriginUpperLeft
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %ulong = OpTypeInt 64 0
+ %ulong_0 = OpConstant %ulong 0
+ %2 = OpFunction %void None %4
+ %8 = OpLabel
+ %9 = OpExtInst %uint %1 MbcntAMD %ulong_0
+ OpReturn
+ OpFunctionEnd
+)";
+
+ // Set the version to 1.1 and make sure it is upgraded to 1.3.
+ SetTargetEnv(SPV_ENV_UNIVERSAL_1_1);
+ SetDisassembleOptions(0);
+ auto result = SinglePassRunAndDisassemble<AmdExtensionToKhrPass>(
+ text, /* skip_nop = */ true, /* skip_validation = */ false);
+
+ EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
+ const std::string& output = std::get<0>(result);
+ EXPECT_THAT(output, HasSubstr("Version: 1.3"));
+}
+
+TEST_F(AmdExtToKhrTest, SetVersion1) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpCapability Int64
+ OpExtension "SPV_AMD_shader_ballot"
+ %1 = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "func"
+ OpExecutionMode %2 OriginUpperLeft
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %ulong = OpTypeInt 64 0
+ %ulong_0 = OpConstant %ulong 0
+ %2 = OpFunction %void None %4
+ %8 = OpLabel
+ %9 = OpExtInst %uint %1 MbcntAMD %ulong_0
+ OpReturn
+ OpFunctionEnd
+)";
+
+ // Set the version to 1.4 and make sure it is stays the same.
+ SetTargetEnv(SPV_ENV_UNIVERSAL_1_4);
+ SetDisassembleOptions(0);
+ auto result = SinglePassRunAndDisassemble<AmdExtensionToKhrPass>(
+ text, /* skip_nop = */ true, /* skip_validation = */ false);
+
+ EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
+ const std::string& output = std::get<0>(result);
+ EXPECT_THAT(output, HasSubstr("Version: 1.4"));
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/constants_test.cpp b/test/opt/constants_test.cpp
new file mode 100644
index 0000000..55c92a5
--- /dev/null
+++ b/test/opt/constants_test.cpp
@@ -0,0 +1,167 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/constants.h"
+
+#include <gtest/gtest-param-test.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "source/opt/types.h"
+
+namespace spvtools {
+namespace opt {
+namespace analysis {
+namespace {
+
+using ConstantTest = ::testing::Test;
+using ::testing::ValuesIn;
+
+template <typename T>
+struct GetExtendedValueCase {
+ bool is_signed;
+ int width;
+ std::vector<uint32_t> words;
+ T expected_value;
+};
+
+using GetSignExtendedValueCase = GetExtendedValueCase<int64_t>;
+using GetZeroExtendedValueCase = GetExtendedValueCase<uint64_t>;
+
+using GetSignExtendedValueTest =
+ ::testing::TestWithParam<GetSignExtendedValueCase>;
+using GetZeroExtendedValueTest =
+ ::testing::TestWithParam<GetZeroExtendedValueCase>;
+
+TEST_P(GetSignExtendedValueTest, Case) {
+ Integer type(GetParam().width, GetParam().is_signed);
+ IntConstant value(&type, GetParam().words);
+
+ EXPECT_EQ(GetParam().expected_value, value.GetSignExtendedValue());
+}
+
+TEST_P(GetZeroExtendedValueTest, Case) {
+ Integer type(GetParam().width, GetParam().is_signed);
+ IntConstant value(&type, GetParam().words);
+
+ EXPECT_EQ(GetParam().expected_value, value.GetZeroExtendedValue());
+}
+
+const uint32_t k32ones = ~uint32_t(0);
+const uint64_t k64ones = ~uint64_t(0);
+const int64_t kSBillion = 1000 * 1000 * 1000;
+const uint64_t kUBillion = 1000 * 1000 * 1000;
+
+INSTANTIATE_TEST_SUITE_P(AtMost32Bits, GetSignExtendedValueTest,
+ ValuesIn(std::vector<GetSignExtendedValueCase>{
+ // 4 bits
+ {false, 4, {0}, 0},
+ {false, 4, {7}, 7},
+ {false, 4, {15}, 15},
+ {true, 4, {0}, 0},
+ {true, 4, {7}, 7},
+ {true, 4, {0xfffffff8}, -8},
+ {true, 4, {k32ones}, -1},
+ // 16 bits
+ {false, 16, {0}, 0},
+ {false, 16, {32767}, 32767},
+ {false, 16, {32768}, 32768},
+ {false, 16, {65000}, 65000},
+ {true, 16, {0}, 0},
+ {true, 16, {32767}, 32767},
+ {true, 16, {0xfffffff8}, -8},
+ {true, 16, {k32ones}, -1},
+ // 32 bits
+ {false, 32, {0}, 0},
+ {false, 32, {1000000}, 1000000},
+ {true, 32, {0xfffffff8}, -8},
+ {true, 32, {k32ones}, -1},
+ }));
+
+INSTANTIATE_TEST_SUITE_P(AtMost64Bits, GetSignExtendedValueTest,
+ ValuesIn(std::vector<GetSignExtendedValueCase>{
+ // 48 bits
+ {false, 48, {0, 0}, 0},
+ {false, 48, {5, 0}, 5},
+ {false, 48, {0xfffffff8, k32ones}, -8},
+ {false, 48, {k32ones, k32ones}, -1},
+ {false, 48, {0xdcd65000, 1}, 8 * kSBillion},
+ {true, 48, {0xfffffff8, k32ones}, -8},
+ {true, 48, {k32ones, k32ones}, -1},
+ {true, 48, {0xdcd65000, 1}, 8 * kSBillion},
+
+ // 64 bits
+ {false, 64, {12, 0}, 12},
+ {false, 64, {0xdcd65000, 1}, 8 * kSBillion},
+ {false, 48, {0xfffffff8, k32ones}, -8},
+ {false, 64, {k32ones, k32ones}, -1},
+ {true, 64, {12, 0}, 12},
+ {true, 64, {0xdcd65000, 1}, 8 * kSBillion},
+ {true, 48, {0xfffffff8, k32ones}, -8},
+ {true, 64, {k32ones, k32ones}, -1},
+ }));
+
+INSTANTIATE_TEST_SUITE_P(AtMost32Bits, GetZeroExtendedValueTest,
+ ValuesIn(std::vector<GetZeroExtendedValueCase>{
+ // 4 bits
+ {false, 4, {0}, 0},
+ {false, 4, {7}, 7},
+ {false, 4, {15}, 15},
+ {true, 4, {0}, 0},
+ {true, 4, {7}, 7},
+ {true, 4, {0xfffffff8}, 0xfffffff8},
+ {true, 4, {k32ones}, k32ones},
+ // 16 bits
+ {false, 16, {0}, 0},
+ {false, 16, {32767}, 32767},
+ {false, 16, {32768}, 32768},
+ {false, 16, {65000}, 65000},
+ {true, 16, {0}, 0},
+ {true, 16, {32767}, 32767},
+ {true, 16, {0xfffffff8}, 0xfffffff8},
+ {true, 16, {k32ones}, k32ones},
+ // 32 bits
+ {false, 32, {0}, 0},
+ {false, 32, {1000000}, 1000000},
+ {true, 32, {0xfffffff8}, 0xfffffff8},
+ {true, 32, {k32ones}, k32ones},
+ }));
+
+INSTANTIATE_TEST_SUITE_P(AtMost64Bits, GetZeroExtendedValueTest,
+ ValuesIn(std::vector<GetZeroExtendedValueCase>{
+ // 48 bits
+ {false, 48, {0, 0}, 0},
+ {false, 48, {5, 0}, 5},
+ {false, 48, {0xfffffff8, k32ones}, uint64_t(-8)},
+ {false, 48, {k32ones, k32ones}, uint64_t(-1)},
+ {false, 48, {0xdcd65000, 1}, 8 * kUBillion},
+ {true, 48, {0xfffffff8, k32ones}, uint64_t(-8)},
+ {true, 48, {k32ones, k32ones}, uint64_t(-1)},
+ {true, 48, {0xdcd65000, 1}, 8 * kUBillion},
+
+ // 64 bits
+ {false, 64, {12, 0}, 12},
+ {false, 64, {0xdcd65000, 1}, 8 * kUBillion},
+ {false, 48, {0xfffffff8, k32ones}, uint64_t(-8)},
+ {false, 64, {k32ones, k32ones}, k64ones},
+ {true, 64, {12, 0}, 12},
+ {true, 64, {0xdcd65000, 1}, 8 * kUBillion},
+ {true, 48, {0xfffffff8, k32ones}, uint64_t(-8)},
+ {true, 64, {k32ones, k32ones}, k64ones},
+ }));
+
+} // namespace
+} // namespace analysis
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/decoration_manager_test.cpp b/test/opt/decoration_manager_test.cpp
index 3eb3ef5..fcfbff0 100644
--- a/test/opt/decoration_manager_test.cpp
+++ b/test/opt/decoration_manager_test.cpp
@@ -22,6 +22,7 @@
#include "source/opt/decoration_manager.h"
#include "source/opt/ir_context.h"
#include "source/spirv_constant.h"
+#include "source/util/string_utils.h"
#include "test/unit_spirv.h"
namespace spvtools {
@@ -29,7 +30,7 @@
namespace analysis {
namespace {
-using spvtest::MakeVector;
+using utils::MakeVector;
class DecorationManagerTest : public ::testing::Test {
public:
diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp
new file mode 100644
index 0000000..11074c3
--- /dev/null
+++ b/test/opt/desc_sroa_test.cpp
@@ -0,0 +1,270 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using DescriptorScalarReplacementTest = PassTest<::testing::Test>;
+
+TEST_F(DescriptorScalarReplacementTest, ExpandTexture) {
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 0
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 1
+; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var3]] Binding 2
+; CHECK: OpDecorate [[var4:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var4]] Binding 3
+; CHECK: OpDecorate [[var5:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var5]] Binding 4
+; CHECK: [[image_type:%\w+]] = OpTypeImage
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[image_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var4]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var5]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: OpLoad [[image_type]] [[var1]]
+; CHECK: OpLoad [[image_type]] [[var2]]
+; CHECK: OpLoad [[image_type]] [[var3]]
+; CHECK: OpLoad [[image_type]] [[var4]]
+; CHECK: OpLoad [[image_type]] [[var5]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %MyTextures DescriptorSet 0
+ OpDecorate %MyTextures Binding 0
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %int_3 = OpConstant %int 3
+ %int_4 = OpConstant %int 4
+ %uint = OpTypeInt 32 0
+ %uint_5 = OpConstant %uint 5
+ %float = OpTypeFloat 32
+%type_2d_image = OpTypeImage %float 2D 2 0 0 1 Unknown
+%_arr_type_2d_image_uint_5 = OpTypeArray %type_2d_image %uint_5
+%_ptr_UniformConstant__arr_type_2d_image_uint_5 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_5
+ %v2float = OpTypeVector %float 2
+ %void = OpTypeVoid
+ %26 = OpTypeFunction %void
+%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
+ %MyTextures = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_5 UniformConstant
+ %main = OpFunction %void None %26
+ %28 = OpLabel
+ %29 = OpUndef %v2float
+ %30 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_0
+ %31 = OpLoad %type_2d_image %30
+ %35 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_1
+ %36 = OpLoad %type_2d_image %35
+ %40 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_2
+ %41 = OpLoad %type_2d_image %40
+ %45 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_3
+ %46 = OpLoad %type_2d_image %45
+ %50 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_4
+ %51 = OpLoad %type_2d_image %50
+ OpReturn
+ OpFunctionEnd
+
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+TEST_F(DescriptorScalarReplacementTest, ExpandSampler) {
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 1
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 2
+; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var3]] Binding 3
+; CHECK: [[sampler_type:%\w+]] = OpTypeSampler
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[sampler_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: OpLoad [[sampler_type]] [[var1]]
+; CHECK: OpLoad [[sampler_type]] [[var2]]
+; CHECK: OpLoad [[sampler_type]] [[var3]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %MySampler DescriptorSet 0
+ OpDecorate %MySampler Binding 1
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+%type_sampler = OpTypeSampler
+%_arr_type_sampler_uint_3 = OpTypeArray %type_sampler %uint_3
+%_ptr_UniformConstant__arr_type_sampler_uint_3 = OpTypePointer UniformConstant %_arr_type_sampler_uint_3
+ %void = OpTypeVoid
+ %26 = OpTypeFunction %void
+%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
+ %MySampler = OpVariable %_ptr_UniformConstant__arr_type_sampler_uint_3 UniformConstant
+ %main = OpFunction %void None %26
+ %28 = OpLabel
+ %31 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_0
+ %32 = OpLoad %type_sampler %31
+ %35 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_1
+ %36 = OpLoad %type_sampler %35
+ %40 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_2
+ %41 = OpLoad %type_sampler %40
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+TEST_F(DescriptorScalarReplacementTest, ExpandSSBO) {
+ // Tests the expansion of an SSBO. Also check that an access chain with more
+ // than 1 index is correctly handled.
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 0
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 1
+; CHECK: OpTypeStruct
+; CHECK: [[struct_type:%\w+]] = OpTypeStruct
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer Uniform [[struct_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[var2]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var1]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac1]]
+; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var2]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac2]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %buffers DescriptorSet 0
+ OpDecorate %buffers Binding 0
+ OpMemberDecorate %S 0 Offset 0
+ OpDecorate %_runtimearr_S ArrayStride 16
+ OpMemberDecorate %type_StructuredBuffer_S 0 Offset 0
+ OpMemberDecorate %type_StructuredBuffer_S 0 NonWritable
+ OpDecorate %type_StructuredBuffer_S BufferBlock
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %S = OpTypeStruct %v4float
+%_runtimearr_S = OpTypeRuntimeArray %S
+%type_StructuredBuffer_S = OpTypeStruct %_runtimearr_S
+%_arr_type_StructuredBuffer_S_uint_2 = OpTypeArray %type_StructuredBuffer_S %uint_2
+%_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 = OpTypePointer Uniform %_arr_type_StructuredBuffer_S_uint_2
+%_ptr_Uniform_type_StructuredBuffer_S = OpTypePointer Uniform %type_StructuredBuffer_S
+ %void = OpTypeVoid
+ %19 = OpTypeFunction %void
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+ %buffers = OpVariable %_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 Uniform
+ %main = OpFunction %void None %19
+ %21 = OpLabel
+ %22 = OpAccessChain %_ptr_Uniform_v4float %buffers %uint_0 %uint_0 %uint_0 %uint_0
+ %23 = OpLoad %v4float %22
+ %24 = OpAccessChain %_ptr_Uniform_type_StructuredBuffer_S %buffers %uint_1
+ %25 = OpAccessChain %_ptr_Uniform_v4float %24 %uint_0 %uint_0 %uint_0
+ %26 = OpLoad %v4float %25
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+TEST_F(DescriptorScalarReplacementTest, NameNewVariables) {
+ // Checks that if the original variable has a name, then the new variables
+ // will have a name derived from that name.
+ const std::string text = R"(
+; CHECK: OpName [[var1:%\w+]] "SSBO[0]"
+; CHECK: OpName [[var2:%\w+]] "SSBO[1]"
+; CHECK: OpDecorate [[var1]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 0
+; CHECK: OpDecorate [[var2]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 1
+; CHECK: OpTypeStruct
+; CHECK: [[struct_type:%\w+]] = OpTypeStruct
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer Uniform [[struct_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[var2]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var1]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac1]]
+; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var2]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac2]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpName %buffers "SSBO"
+ OpDecorate %buffers DescriptorSet 0
+ OpDecorate %buffers Binding 0
+ OpMemberDecorate %S 0 Offset 0
+ OpDecorate %_runtimearr_S ArrayStride 16
+ OpMemberDecorate %type_StructuredBuffer_S 0 Offset 0
+ OpMemberDecorate %type_StructuredBuffer_S 0 NonWritable
+ OpDecorate %type_StructuredBuffer_S BufferBlock
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %S = OpTypeStruct %v4float
+%_runtimearr_S = OpTypeRuntimeArray %S
+%type_StructuredBuffer_S = OpTypeStruct %_runtimearr_S
+%_arr_type_StructuredBuffer_S_uint_2 = OpTypeArray %type_StructuredBuffer_S %uint_2
+%_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 = OpTypePointer Uniform %_arr_type_StructuredBuffer_S_uint_2
+%_ptr_Uniform_type_StructuredBuffer_S = OpTypePointer Uniform %type_StructuredBuffer_S
+ %void = OpTypeVoid
+ %19 = OpTypeFunction %void
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+ %buffers = OpVariable %_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 Uniform
+ %main = OpFunction %void None %19
+ %21 = OpLabel
+ %22 = OpAccessChain %_ptr_Uniform_v4float %buffers %uint_0 %uint_0 %uint_0 %uint_0
+ %23 = OpLoad %v4float %22
+ %24 = OpAccessChain %_ptr_Uniform_type_StructuredBuffer_S %buffers %uint_1
+ %25 = OpAccessChain %_ptr_Uniform_v4float %24 %uint_0 %uint_0 %uint_0
+ %26 = OpLoad %v4float %25
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 3ea3204..b5998c7 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -11,6 +11,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "source/opt/fold.h"
+
#include <limits>
#include <memory>
#include <string>
@@ -22,7 +24,6 @@
#include "gtest/gtest.h"
#include "source/opt/build_module.h"
#include "source/opt/def_use_manager.h"
-#include "source/opt/fold.h"
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "spirv-tools/libspirv.hpp"
@@ -209,6 +210,7 @@
%float_2049 = OpConstant %float 2049
%float_n2049 = OpConstant %float -2049
%float_0p5 = OpConstant %float 0.5
+%float_0p2 = OpConstant %float 0.2
%float_pi = OpConstant %float 1.5555
%float_1e16 = OpConstant %float 1e16
%float_n1e16 = OpConstant %float -1e16
@@ -1464,24 +1466,14 @@
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::quiet_NaN()),
- // Test case 20: QuantizeToF16 inf
+ // Test case 20: FMix 1.0 4.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
- "%2 = OpFDiv %float %float_1 %float_0\n" +
- "%3 = OpQuantizeToF16 %float %3\n" +
+ "%2 = OpExtInst %float %1 FMix %float_1 %float_4 %float_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 2, std::numeric_limits<float>::infinity()),
- // Test case 21: QuantizeToF16 -inf
- InstructionFoldingCase<float>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%2 = OpFDiv %float %float_n1 %float_0\n" +
- "%3 = OpQuantizeToF16 %float %3\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, -std::numeric_limits<float>::infinity())
+ 2, 1.6f)
));
// clang-format on
@@ -2980,7 +2972,17 @@
"%4 = OpCompositeExtract %int %3 0\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 4, INT_7_ID)
+ 4, INT_7_ID),
+ // Test case 13: https://github.com/KhronosGroup/SPIRV-Tools/issues/2608
+ // Out of bounds access. Do not fold.
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1\n" +
+ "%3 = OpCompositeExtract %float %2 4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, 0)
));
INSTANTIATE_TEST_SUITE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
diff --git a/test/opt/graphics_robust_access_test.cpp b/test/opt/graphics_robust_access_test.cpp
new file mode 100644
index 0000000..137d0e8
--- /dev/null
+++ b/test/opt/graphics_robust_access_test.cpp
@@ -0,0 +1,1307 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <array>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+#include "source/opt/graphics_robust_access_pass.h"
+
+namespace {
+
+using namespace spvtools;
+
+using opt::GraphicsRobustAccessPass;
+using GraphicsRobustAccessTest = opt::PassTest<::testing::Test>;
+
+// Test incompatible module, determined at module-level.
+
+TEST_F(GraphicsRobustAccessTest, FailNotShader) {
+ const std::string text = R"(
+; CHECK: Can only process Shader modules
+OpCapability Kernel
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest, FailCantProcessVariablePointers) {
+ const std::string text = R"(
+; CHECK: Can't process modules with VariablePointers capability
+OpCapability VariablePointers
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest, FailCantProcessVariablePointersStorageBuffer) {
+ const std::string text = R"(
+; CHECK: Can't process modules with VariablePointersStorageBuffer capability
+OpCapability VariablePointersStorageBuffer
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest, FailCantProcessRuntimeDescriptorArrayEXT) {
+ const std::string text = R"(
+; CHECK: Can't process modules with RuntimeDescriptorArrayEXT capability
+OpCapability RuntimeDescriptorArrayEXT
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest, FailCantProcessPhysical32AddressingModel) {
+ const std::string text = R"(
+; CHECK: Addressing model must be Logical. Found OpMemoryModel Physical32 OpenCL
+OpCapability Shader
+OpMemoryModel Physical32 OpenCL
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest, FailCantProcessPhysical64AddressingModel) {
+ const std::string text = R"(
+; CHECK: Addressing model must be Logical. Found OpMemoryModel Physical64 OpenCL
+OpCapability Shader
+OpMemoryModel Physical64 OpenCL
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+TEST_F(GraphicsRobustAccessTest,
+ FailCantProcessPhysicalStorageBuffer64EXTAddressingModel) {
+ const std::string text = R"(
+; CHECK: Addressing model must be Logical. Found OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpCapability Shader
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+)";
+
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(text);
+}
+
+// Test access chains
+
+// Returns the names of access chain instructions handled by the pass.
+// For the purposes of this pass, regular and in-bounds access chains are the
+// same.)
+std::vector<const char*> AccessChains() {
+ return {"OpAccessChain", "OpInBoundsAccessChain"};
+}
+
+std::string ShaderPreamble() {
+ return R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint GLCompute %main "main"
+)";
+}
+
+std::string ShaderPreamble(const std::vector<std::string>& names) {
+ std::ostringstream os;
+ os << ShaderPreamble();
+ for (auto& name : names) {
+ os << " OpName %" << name << " \"" << name << "\"\n";
+ }
+ return os.str();
+}
+
+std::string ShaderPreambleAC() {
+ return ShaderPreamble({"ac", "ptr_ty", "var"});
+}
+
+std::string ShaderPreambleAC(const std::vector<std::string>& names) {
+ auto names2 = names;
+ names2.push_back("ac");
+ names2.push_back("ptr_ty");
+ names2.push_back("var");
+ return ShaderPreamble(names2);
+}
+
+std::string DecoSSBO() {
+ return R"(
+ OpDecorate %ssbo_s BufferBlock
+ OpMemberDecorate %ssbo_s 0 Offset 0
+ OpMemberDecorate %ssbo_s 1 Offset 4
+ OpMemberDecorate %ssbo_s 2 Offset 16
+ OpDecorate %var DescriptorSet 0
+ OpDecorate %var Binding 0
+)";
+}
+
+std::string TypesVoid() {
+ return R"(
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+)";
+}
+
+std::string TypesInt() {
+ return R"(
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+)";
+}
+
+std::string TypesFloat() {
+ return R"(
+ %float = OpTypeFloat 32
+)";
+}
+
+std::string TypesShort() {
+ return R"(
+ %ushort = OpTypeInt 16 0
+ %short = OpTypeInt 16 1
+)";
+}
+
+std::string TypesLong() {
+ return R"(
+ %ulong = OpTypeInt 64 0
+ %long = OpTypeInt 64 1
+)";
+}
+
+std::string MainPrefix() {
+ return R"(
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+)";
+}
+
+std::string MainSuffix() {
+ return R"(
+ OpReturn
+ OpFunctionEnd
+)";
+}
+
+std::string ACCheck(const std::string& access_chain_inst,
+ const std::string& original,
+ const std::string& transformed) {
+ return "\n ; CHECK: %ac = " + access_chain_inst + " %ptr_ty %var" +
+ (transformed.size() ? " " : "") + transformed +
+ "\n ; CHECK-NOT: " + access_chain_inst +
+ "\n ; CHECK-NEXT: OpReturn"
+ "\n %ac = " +
+ access_chain_inst + " %ptr_ty %var " + (original.size() ? " " : "") +
+ original + "\n";
+}
+
+std::string ACCheckFail(const std::string& access_chain_inst,
+ const std::string& original,
+ const std::string& transformed) {
+ return "\n ; CHECK: %ac = " + access_chain_inst + " %ptr_ty %var" +
+ (transformed.size() ? " " : "") + transformed +
+ "\n ; CHECK-NOT: " + access_chain_inst +
+ "\n ; CHECK-NOT: OpReturn"
+ "\n %ac = " +
+ access_chain_inst + " %ptr_ty %var " + (original.size() ? " " : "") +
+ original + "\n";
+}
+
+// Access chain into:
+// Vector
+// Vector sizes 2, 3, 4
+// Matrix
+// Matrix columns 2, 4
+// Component is vector 2, 4
+// Array
+// Struct
+// TODO(dneto): RuntimeArray
+
+TEST_F(GraphicsRobustAccessTest, ACVectorLeastInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt() << R"(
+ %uvec2 = OpTypeVector %uint 2
+ %var_ty = OpTypePointer Function %uvec2
+ %ptr_ty = OpTypePointer Function %uint
+ %uint_0 = OpConstant %uint 0
+ )"
+ << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%uint_0", "%uint_0")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorMostInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt() << R"(
+ %v4uint = OpTypeVector %uint 4
+ %var_ty = OpTypePointer Function %v4uint
+ %ptr_ty = OpTypePointer Function %uint
+ %uint_3 = OpConstant %uint 3
+ )"
+ << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%uint_3", "%uint_3")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorExcessConstantClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt() << R"(
+ %v4uint = OpTypeVector %uint 4
+ %var_ty = OpTypePointer Function %v4uint
+ %ptr_ty = OpTypePointer Function %uint
+ %uint_4 = OpConstant %uint 4
+ )"
+ << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%uint_4", "%uint_3")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorNegativeConstantClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt() << R"(
+ %v4uint = OpTypeVector %uint 4
+ %var_ty = OpTypePointer Function %v4uint
+ %ptr_ty = OpTypePointer Function %uint
+ %int_n1 = OpConstant %int -1
+ )"
+ << MainPrefix() << R"(
+ ; CHECK: %int_0 = OpConstant %int 0
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%int_n1", "%int_0")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+// Like the previous test, but ensures the pass knows how to modify an index
+// which does not come first in the access chain.
+TEST_F(GraphicsRobustAccessTest, ACVectorInArrayNegativeConstantClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt() << R"(
+ %v4uint = OpTypeVector %uint 4
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %arr = OpTypeArray %v4uint %uint_2
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %uint
+ %int_n1 = OpConstant %int -1
+ )"
+ << MainPrefix() << R"(
+ ; CHECK: %int_0 = OpConstant %int 0
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%uint_1 %int_n1", "%uint_1 %int_0") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorGeneralClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt() << R"(
+ %v4uint = OpTypeVector %uint 4
+ %var_ty = OpTypePointer Function %v4uint
+ %ptr_ty = OpTypePointer Function %uint
+ %i = OpUndef %int)"
+ << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_3 = OpConstant %int 3
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_3
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%i", "%[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorGeneralShortClamped) {
+ // Show that signed 16 bit integers are clamped as well.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesShort() <<
+ R"(
+ %v4short = OpTypeVector %short 4
+ %var_ty = OpTypePointer Function %v4short
+ %ptr_ty = OpTypePointer Function %short
+ %i = OpUndef %short)"
+ << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK-DAG: %short_0 = OpConstant %short 0
+ ; CHECK-DAG: %short_3 = OpConstant %short 3
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %short %[[GLSLSTD450]] UClamp %i %short_0 %short_3
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%i", "%[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorGeneralUShortClamped) {
+ // Show that unsigned 16 bit integers are clamped as well.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesShort() <<
+ R"(
+ %v4ushort = OpTypeVector %ushort 4
+ %var_ty = OpTypePointer Function %v4ushort
+ %ptr_ty = OpTypePointer Function %ushort
+ %i = OpUndef %ushort)"
+ << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK-DAG: %ushort_0 = OpConstant %ushort 0
+ ; CHECK-DAG: %ushort_3 = OpConstant %ushort 3
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %ushort %[[GLSLSTD450]] UClamp %i %ushort_0 %ushort_3
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%i", "%[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorGeneralLongClamped) {
+ // Show that signed 64 bit integers are clamped as well.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesLong() <<
+ R"(
+ %v4long = OpTypeVector %long 4
+ %var_ty = OpTypePointer Function %v4long
+ %ptr_ty = OpTypePointer Function %long
+ %i = OpUndef %long)"
+ << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK-DAG: %long_0 = OpConstant %long 0
+ ; CHECK-DAG: %long_3 = OpConstant %long 3
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %long %[[GLSLSTD450]] UClamp %i %long_0 %long_3
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%i", "%[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACVectorGeneralULongClamped) {
+ // Show that unsigned 64 bit integers are clamped as well.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesLong() <<
+ R"(
+ %v4ulong = OpTypeVector %ulong 4
+ %var_ty = OpTypePointer Function %v4ulong
+ %ptr_ty = OpTypePointer Function %ulong
+ %i = OpUndef %ulong)"
+ << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK-DAG: %ulong_0 = OpConstant %ulong 0
+ ; CHECK-DAG: %ulong_3 = OpConstant %ulong 3
+ ; CHECK-NOT: = OpTypeInt 32
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %ulong %[[GLSLSTD450]] UClamp %i %ulong_0 %ulong_3
+ %var = OpVariable %var_ty Function)" << ACCheck(ac, "%i", "%[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACMatrixLeastInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %v2float = OpTypeVector %float 2
+ %mat4v2float = OpTypeMatrix %v2float 4
+ %var_ty = OpTypePointer Function %mat4v2float
+ %ptr_ty = OpTypePointer Function %float
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%uint_0 %uint_1", "%uint_0 %uint_1")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACMatrixMostInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %v2float = OpTypeVector %float 2
+ %mat4v2float = OpTypeMatrix %v2float 4
+ %var_ty = OpTypePointer Function %mat4v2float
+ %ptr_ty = OpTypePointer Function %float
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%uint_3 %uint_1", "%uint_3 %uint_1")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACMatrixExcessConstantClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %v2float = OpTypeVector %float 2
+ %mat4v2float = OpTypeMatrix %v2float 4
+ %var_ty = OpTypePointer Function %mat4v2float
+ %ptr_ty = OpTypePointer Function %float
+ %uint_1 = OpConstant %uint 1
+ %uint_4 = OpConstant %uint 4
+ )" << MainPrefix() << R"(
+ ; CHECK: %uint_3 = OpConstant %uint 3
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%uint_4 %uint_1", "%uint_3 %uint_1")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACMatrixNegativeConstantClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %v2float = OpTypeVector %float 2
+ %mat4v2float = OpTypeMatrix %v2float 4
+ %var_ty = OpTypePointer Function %mat4v2float
+ %ptr_ty = OpTypePointer Function %float
+ %uint_1 = OpConstant %uint 1
+ %int_n1 = OpConstant %int -1
+ )" << MainPrefix() << R"(
+ ; CHECK: %int_0 = OpConstant %int 0
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%int_n1 %uint_1", "%int_0 %uint_1") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACMatrixGeneralClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %v2float = OpTypeVector %float 2
+ %mat4v2float = OpTypeMatrix %v2float 4
+ %var_ty = OpTypePointer Function %mat4v2float
+ %ptr_ty = OpTypePointer Function %float
+ %uint_1 = OpConstant %uint 1
+ %i = OpUndef %int
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_3 = OpConstant %int 3
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_3
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i %uint_1", "%[[clamp]] %uint_1") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayLeastInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %uint_200 = OpConstant %uint 200
+ %arr = OpTypeArray %float %uint_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %int_0 = OpConstant %int 0
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%int_0", "%int_0") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayMostInboundConstantUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %uint_200 = OpConstant %uint 200
+ %arr = OpTypeArray %float %uint_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %int_199 = OpConstant %int 199
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%int_199", "%int_199") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %uint_200 = OpConstant %uint 200
+ %arr = OpTypeArray %float %uint_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %int
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_199 = OpConstant %int 199
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_199
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralShortIndexUIntBoundsClamped) {
+ // Index is signed short, array bounds overflows the index type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesShort() << TypesFloat() << R"(
+ %uint_70000 = OpConstant %uint 70000 ; overflows 16bits
+ %arr = OpTypeArray %float %uint_70000
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %short
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK-DAG: %uint_69999 = OpConstant %uint 69999
+ ; CHECK: OpLabel
+ ; CHECK: %[[i_ext:\w+]] = OpSConvert %uint %i
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %[[i_ext]] %uint_0 %uint_69999
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralUShortIndexIntBoundsClamped) {
+ // Index is unsigned short, array bounds overflows the index type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesShort() << TypesFloat() << R"(
+ %int_70000 = OpConstant %int 70000 ; overflows 16bits
+ %arr = OpTypeArray %float %int_70000
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %ushort
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK-DAG: %uint_69999 = OpConstant %uint 69999
+ ; CHECK: OpLabel
+ ; CHECK: %[[i_ext:\w+]] = OpUConvert %uint %i
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %[[i_ext]] %uint_0 %uint_69999
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralUIntIndexShortBoundsClamped) {
+ // Signed int index i is wider than the array bounds type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesShort() << TypesFloat() << R"(
+ %short_200 = OpConstant %short 200
+ %arr = OpTypeArray %float %short_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %uint
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK-DAG: %uint_199 = OpConstant %uint 199
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %i %uint_0 %uint_199
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralIntIndexUShortBoundsClamped) {
+ // Unsigned int index i is wider than the array bounds type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesShort() << TypesFloat() << R"(
+ %ushort_200 = OpConstant %ushort 200
+ %arr = OpTypeArray %float %ushort_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %int
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_199 = OpConstant %int 199
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_199
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralLongIndexUIntBoundsClamped) {
+ // Signed long index i is wider than the array bounds type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesLong() << TypesFloat() << R"(
+ %uint_200 = OpConstant %uint 200
+ %arr = OpTypeArray %float %uint_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %long
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %long_0 = OpConstant %long 0
+ ; CHECK-DAG: %long_199 = OpConstant %long 199
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %long %[[GLSLSTD450]] UClamp %i %long_0 %long_199
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayGeneralULongIndexIntBoundsClamped) {
+ // Unsigned long index i is wider than the array bounds type.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64\n"
+ << ShaderPreambleAC({"i"}) << TypesVoid() << TypesInt()
+ << TypesLong() << TypesFloat() << R"(
+ %int_200 = OpConstant %int 200
+ %arr = OpTypeArray %float %int_200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %ulong
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %ulong_0 = OpConstant %ulong 0
+ ; CHECK-DAG: %ulong_199 = OpConstant %ulong 199
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %ulong %[[GLSLSTD450]] UClamp %i %ulong_0 %ulong_199
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%i", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArraySpecIdSizedAlwaysClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"spec200"}) << R"(
+ OpDecorate %spec200 SpecId 0 )" << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %spec200 = OpSpecConstant %int 200
+ %arr = OpTypeArray %float %spec200
+ %var_ty = OpTypePointer Function %arr
+ %ptr_ty = OpTypePointer Function %float
+ %uint_5 = OpConstant %uint 5
+ )" << MainPrefix() << R"(
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK-DAG: %uint_1 = OpConstant %uint 1
+ ; CHECK: OpLabel
+ ; CHECK: %[[max:\w+]] = OpISub %uint %spec200 %uint_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %uint_5 %uint_0 %[[max]]
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%uint_5", "%[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructLeastUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ %int_0 = OpConstant %int 0
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%int_0", "%int_0") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructMostUntouched) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ %int_2 = OpConstant %int 2
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function)"
+ << ACCheck(ac, "%int_2", "%int_2") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructSpecConstantFail) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"struct", "spec200"})
+ << "OpDecorate %spec200 SpecId 0\n"
+ <<
+
+ TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %spec200 = OpSpecConstant %int 200
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function
+ ; CHECK: Member index into struct is not a constant integer
+ ; CHECK-SAME: %spec200 = OpSpecConstant %int 200
+ )"
+ << ACCheckFail(ac, "%spec200", "%spec200") << MainSuffix();
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(shaders.str());
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructFloatConstantFail) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"struct"}) <<
+
+ TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %float_2 = OpConstant %float 2
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function
+ ; CHECK: Member index into struct is not a constant integer
+ ; CHECK-SAME: %float_2 = OpConstant %float 2
+ )"
+ << ACCheckFail(ac, "%float_2", "%float_2") << MainSuffix();
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(shaders.str());
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructNonConstantFail) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"struct", "i"}) <<
+
+ TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %float_2 = OpConstant %float 2
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpUndef %int
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function
+ ; CHECK: Member index into struct is not a constant integer
+ ; CHECK-SAME: %i = OpUndef %int
+ )"
+ << ACCheckFail(ac, "%i", "%i") << MainSuffix();
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(shaders.str());
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructExcessFail) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"struct", "i"}) << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpConstant %int 4
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function
+ ; CHECK: Member index 4 is out of bounds for struct type:
+ ; CHECK-SAME: %struct = OpTypeStruct %float %float %float
+ )"
+ << ACCheckFail(ac, "%i", "%i") << MainSuffix();
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(shaders.str());
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACStructNegativeFail) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"struct", "i"}) << TypesVoid() << TypesInt()
+ << TypesFloat() << R"(
+ %struct = OpTypeStruct %float %float %float
+ %var_ty = OpTypePointer Function %struct
+ %ptr_ty = OpTypePointer Function %float
+ %i = OpConstant %int -1
+ )" << MainPrefix() << R"(
+ %var = OpVariable %var_ty Function
+ ; CHECK: Member index -1 is out of bounds for struct type:
+ ; CHECK-SAME: %struct = OpTypeStruct %float %float %float
+ )"
+ << ACCheckFail(ac, "%i", "%i") << MainSuffix();
+ SinglePassRunAndFail<GraphicsRobustAccessPass>(shaders.str());
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayLeastInboundClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC() << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 "
+ << DecoSSBO() << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %uint %uint %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_0 = OpConstant %int 0
+ %int_2 = OpConstant %int 2
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK: %int_1 = OpConstant %int 1
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %int_0 %int_0 %[[max]]
+ )"
+ << MainPrefix() << ACCheck(ac, "%int_2 %int_0", "%int_2 %[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralShortIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesShort() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %short %short %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %short_2 = OpConstant %short 2
+ %i = OpUndef %short
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK: %uint = OpTypeInt 32 0
+ ; CHECK-DAG: %uint_1 = OpConstant %uint 1
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK-DAG: %[[max:\w+]] = OpISub %uint %[[arrlen]] %uint_1
+ ; CHECK-DAG: %[[i_ext:\w+]] = OpSConvert %uint %i
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %[[i_ext]] %uint_0 %[[max]]
+ )"
+ << MainPrefix() << ACCheck(ac, "%short_2 %i", "%short_2 %[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralUShortIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int16\n"
+ << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesShort() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %short %short %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %short_2 = OpConstant %short 2
+ %i = OpUndef %ushort
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK: %uint = OpTypeInt 32 0
+ ; CHECK-DAG: %uint_1 = OpConstant %uint 1
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK-DAG: %[[max:\w+]] = OpISub %uint %[[arrlen]] %uint_1
+ ; CHECK-DAG: %[[i_ext:\w+]] = OpSConvert %uint %i
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %[[i_ext]] %uint_0 %[[max]]
+ )"
+ << MainPrefix() << ACCheck(ac, "%short_2 %i", "%short_2 %[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralIntIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %int
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_1 = OpConstant %int 1
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %[[max]]
+ )" << MainPrefix()
+ << ACCheck(ac, "%int_2 %i", "%int_2 %[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralUIntIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %uint
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %uint_1 = OpConstant %uint 1
+ ; CHECK-DAG: %uint_0 = OpConstant %uint 0
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[max:\w+]] = OpISub %uint %[[arrlen]] %uint_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %uint %[[GLSLSTD450]] UClamp %i %uint_0 %[[max]]
+ )" << MainPrefix()
+ << ACCheck(ac, "%int_2 %i", "%int_2 %[[clamp]]") << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralLongIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64" << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesInt() << TypesLong() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %long
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %long_0 = OpConstant %long 0
+ ; CHECK-DAG: %long_1 = OpConstant %long 1
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[arrlen_ext:\w+]] = OpUConvert %ulong %[[arrlen]]
+ ; CHECK: %[[max:\w+]] = OpISub %long %[[arrlen_ext]] %long_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %long %[[GLSLSTD450]] UClamp %i %long_0 %[[max]]
+ )"
+ << MainPrefix() << ACCheck(ac, "%int_2 %i", "%int_2 %[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayGeneralULongIndexClamped) {
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << "OpCapability Int64" << ShaderPreambleAC({"i"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 4 " << DecoSSBO()
+ << TypesVoid() << TypesInt() << TypesLong() << TypesFloat() << R"(
+ %rtarr = OpTypeRuntimeArray %float
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %ulong
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %ulong_0 = OpConstant %ulong 0
+ ; CHECK-DAG: %ulong_1 = OpConstant %ulong 1
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[arrlen_ext:\w+]] = OpUConvert %ulong %[[arrlen]]
+ ; CHECK: %[[max:\w+]] = OpISub %ulong %[[arrlen_ext]] %ulong_1
+ ; CHECK: %[[clamp:\w+]] = OpExtInst %ulong %[[GLSLSTD450]] UClamp %i %ulong_0 %[[max]]
+ )"
+ << MainPrefix() << ACCheck(ac, "%int_2 %i", "%int_2 %[[clamp]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACRTArrayStructVectorElem) {
+ // The point of this test is that the access chain can have indices past the
+ // index into the runtime array. For good measure, the index into the final
+ // struct is out of bounds. We have to clamp that index too.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i", "j"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 32\n"
+ << DecoSSBO() << "OpMemberDecorate %rtelem 0 Offset 0\n"
+ << "OpMemberDecorate %rtelem 1 Offset 16\n"
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %v4float = OpTypeVector %float 4
+ %rtelem = OpTypeStruct %v4float %v4float
+ %rtarr = OpTypeRuntimeArray %rtelem
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %var_ty = OpTypePointer Uniform %ssbo_s
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %int
+ %j = OpUndef %int
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_3 = OpConstant %int 3
+ ; CHECK: OpLabel
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %var 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp_i:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %[[max]]
+ ; CHECK: %[[clamp_j:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %j %int_0 %int_3
+ )" << MainPrefix()
+ << ACCheck(ac, "%int_2 %i %int_1 %j",
+ "%int_2 %[[clamp_i]] %int_1 %[[clamp_j]]")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACArrayRTArrayStructVectorElem) {
+ // Now add an additional level of arrays around the Block-decorated struct.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i", "ssbo_s"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 32\n"
+ << DecoSSBO() << "OpMemberDecorate %rtelem 0 Offset 0\n"
+ << "OpMemberDecorate %rtelem 1 Offset 16\n"
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %v4float = OpTypeVector %float 4
+ %rtelem = OpTypeStruct %v4float %v4float
+ %rtarr = OpTypeRuntimeArray %rtelem
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %arr_size = OpConstant %int 10
+ %arr_ssbo = OpTypeArray %ssbo_s %arr_size
+ %var_ty = OpTypePointer Uniform %arr_ssbo
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %int_17 = OpConstant %int 17
+ %i = OpUndef %int
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %[[ssbo_p:\w+]] = OpTypePointer Uniform %ssbo_s
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_9 = OpConstant %int 9
+ ; CHECK: OpLabel
+ ; This access chain is manufatured only so we can compute the array length.
+ ; Note that the %int_9 is already clamped
+ ; CHECK: %[[ssbo_base:\w+]] = )" << ac
+ << R"( %[[ssbo_p]] %var %int_9
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %[[ssbo_base]] 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp_i:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %[[max]]
+ )" << MainPrefix()
+ << ACCheck(ac, "%int_17 %int_2 %i %int_1 %int_2",
+ "%int_9 %int_2 %[[clamp_i]] %int_1 %int_2")
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest, ACSplitACArrayRTArrayStructVectorElem) {
+ // Split the address calculation across two access chains. Force
+ // the transform to walk up the access chains to find the base variable.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i", "j", "k", "ssbo_s", "ssbo_pty",
+ "rtarr_pty", "ac_ssbo", "ac_rtarr"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 32\n"
+ << DecoSSBO() << "OpMemberDecorate %rtelem 0 Offset 0\n"
+ << "OpMemberDecorate %rtelem 1 Offset 16\n"
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %v4float = OpTypeVector %float 4
+ %rtelem = OpTypeStruct %v4float %v4float
+ %rtarr = OpTypeRuntimeArray %rtelem
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %arr_size = OpConstant %int 10
+ %arr_ssbo = OpTypeArray %ssbo_s %arr_size
+ %var_ty = OpTypePointer Uniform %arr_ssbo
+ %ssbo_pty = OpTypePointer Uniform %ssbo_s
+ %rtarr_pty = OpTypePointer Uniform %rtarr
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %int
+ %j = OpUndef %int
+ %k = OpUndef %int
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_9 = OpConstant %int 9
+ ; CHECK-DAG: %int_3 = OpConstant %int 3
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp_i:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_9
+ ; CHECK: %ac_ssbo = )" << ac
+ << R"( %ssbo_pty %var %[[clamp_i]]
+ ; CHECK: %ac_rtarr = )"
+ << ac << R"( %rtarr_pty %ac_ssbo %int_2
+
+ ; This is the interesting bit. This array length is needed for an OpAccessChain
+ ; computing %ac, but the algorithm had to track back through %ac_rtarr's
+ ; definition to find the base pointer %ac_ssbo.
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %ac_ssbo 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp_j:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %j %int_0 %[[max]]
+ ; CHECK: %[[clamp_k:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %k %int_0 %int_3
+ ; CHECK: %ac = )" << ac
+ << R"( %ptr_ty %ac_rtarr %[[clamp_j]] %int_1 %[[clamp_k]]
+ ; CHECK-NOT: AccessChain
+ )" << MainPrefix()
+ << "%ac_ssbo = " << ac << " %ssbo_pty %var %i\n"
+ << "%ac_rtarr = " << ac << " %rtarr_pty %ac_ssbo %int_2\n"
+ << "%ac = " << ac << " %ptr_ty %ac_rtarr %j %int_1 %k\n"
+
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+TEST_F(GraphicsRobustAccessTest,
+ ACSplitACArrayRTArrayStructVectorElemAcrossBasicBlocks) {
+ // Split the address calculation across two access chains. Force
+ // the transform to walk up the access chains to find the base variable.
+ // This time, put the different access chains in different basic blocks.
+ // This sanity checks that we keep the instruction-to-block mapping
+ // consistent.
+ for (auto* ac : AccessChains()) {
+ std::ostringstream shaders;
+ shaders << ShaderPreambleAC({"i", "j", "k", "bb1", "bb2", "ssbo_s",
+ "ssbo_pty", "rtarr_pty", "ac_ssbo",
+ "ac_rtarr"})
+ << "OpMemberDecorate %ssbo_s 0 ArrayStride 32\n"
+ << DecoSSBO() << "OpMemberDecorate %rtelem 0 Offset 0\n"
+ << "OpMemberDecorate %rtelem 1 Offset 16\n"
+ << TypesVoid() << TypesInt() << TypesFloat() << R"(
+ %v4float = OpTypeVector %float 4
+ %rtelem = OpTypeStruct %v4float %v4float
+ %rtarr = OpTypeRuntimeArray %rtelem
+ %ssbo_s = OpTypeStruct %int %int %rtarr
+ %arr_size = OpConstant %int 10
+ %arr_ssbo = OpTypeArray %ssbo_s %arr_size
+ %var_ty = OpTypePointer Uniform %arr_ssbo
+ %ssbo_pty = OpTypePointer Uniform %ssbo_s
+ %rtarr_pty = OpTypePointer Uniform %rtarr
+ %ptr_ty = OpTypePointer Uniform %float
+ %var = OpVariable %var_ty Uniform
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %i = OpUndef %int
+ %j = OpUndef %int
+ %k = OpUndef %int
+ ; CHECK: %[[GLSLSTD450:\w+]] = OpExtInstImport "GLSL.std.450"
+ ; CHECK-DAG: %int_0 = OpConstant %int 0
+ ; CHECK-DAG: %int_9 = OpConstant %int 9
+ ; CHECK-DAG: %int_3 = OpConstant %int 3
+ ; CHECK: OpLabel
+ ; CHECK: %[[clamp_i:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %i %int_0 %int_9
+ ; CHECK: %ac_ssbo = )" << ac
+ << R"( %ssbo_pty %var %[[clamp_i]]
+ ; CHECK: %bb1 = OpLabel
+ ; CHECK: %ac_rtarr = )"
+ << ac << R"( %rtarr_pty %ac_ssbo %int_2
+ ; CHECK: %bb2 = OpLabel
+
+ ; This is the interesting bit. This array length is needed for an OpAccessChain
+ ; computing %ac, but the algorithm had to track back through %ac_rtarr's
+ ; definition to find the base pointer %ac_ssbo.
+ ; CHECK: %[[arrlen:\w+]] = OpArrayLength %uint %ac_ssbo 2
+ ; CHECK: %[[max:\w+]] = OpISub %int %[[arrlen]] %int_1
+ ; CHECK: %[[clamp_j:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %j %int_0 %[[max]]
+ ; CHECK: %[[clamp_k:\w+]] = OpExtInst %int %[[GLSLSTD450]] UClamp %k %int_0 %int_3
+ ; CHECK: %ac = )" << ac
+ << R"( %ptr_ty %ac_rtarr %[[clamp_j]] %int_1 %[[clamp_k]]
+ ; CHECK-NOT: AccessChain
+ )" << MainPrefix()
+ << "%ac_ssbo = " << ac << " %ssbo_pty %var %i\n"
+ << "OpBranch %bb1\n%bb1 = OpLabel\n"
+ << "%ac_rtarr = " << ac << " %rtarr_pty %ac_ssbo %int_2\n"
+ << "OpBranch %bb2\n%bb2 = OpLabel\n"
+ << "%ac = " << ac << " %ptr_ty %ac_rtarr %j %int_1 %k\n"
+
+ << MainSuffix();
+ SinglePassRunAndMatch<GraphicsRobustAccessPass>(shaders.str(), true);
+ }
+}
+
+// TODO(dneto): Test access chain index wider than 64 bits?
+// TODO(dneto): Test struct access chain index wider than 64 bits?
+// TODO(dneto): OpImageTexelPointer
+// - all Dim types: 1D 2D Cube 3D Rect Buffer
+// - all Dim types that can be arrayed: 1D 2D 3D
+// - sample index: set to 0 if not multisampled
+// - Dim (2D, Cube Rect} with multisampling
+// -1 0 max excess
+// TODO(dneto): Test OpImageTexelPointer with coordinate component index other
+// than 32 bits.
+
+} // namespace
diff --git a/test/opt/inst_bindless_check_test.cpp b/test/opt/inst_bindless_check_test.cpp
index 6fe27c6..6e1adaa 100644
--- a/test/opt/inst_bindless_check_test.cpp
+++ b/test/opt/inst_bindless_check_test.cpp
@@ -5706,6 +5706,282 @@
true, 7u, 23u, false, false, 2u);
}
+TEST_F(InstBindlessTest, InstrumentTeseSimpleV2) {
+ // This test verifies that the pass will correctly instrument tessellation
+ // evaluation shader doing bindless buffer load.
+ //
+ // clang-format off
+ //
+ // #version 450
+ // #extension GL_EXT_nonuniform_qualifier : enable
+ //
+ // layout(std140, set = 0, binding = 0) uniform ufoo { uint index; } uniform_index_buffer;
+ //
+ // layout(set = 0, binding = 1) buffer bfoo { vec4 val; } adds[11];
+ //
+ // layout(triangles, equal_spacing, cw) in;
+ //
+ // void main() {
+ // gl_Position = adds[uniform_index_buffer.index].val;
+ // }
+ //
+ // clang-format on
+
+ const std::string defs_before =
+ R"(OpCapability Tessellation
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint TessellationEvaluation %main "main" %_
+OpExecutionMode %main Triangles
+OpExecutionMode %main SpacingEqual
+OpExecutionMode %main VertexOrderCw
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_nonuniform_qualifier"
+OpName %main "main"
+OpName %gl_PerVertex "gl_PerVertex"
+OpMemberName %gl_PerVertex 0 "gl_Position"
+OpMemberName %gl_PerVertex 1 "gl_PointSize"
+OpMemberName %gl_PerVertex 2 "gl_ClipDistance"
+OpMemberName %gl_PerVertex 3 "gl_CullDistance"
+OpName %_ ""
+OpName %bfoo "bfoo"
+OpMemberName %bfoo 0 "val"
+OpName %adds "adds"
+OpName %ufoo "ufoo"
+OpMemberName %ufoo 0 "index"
+OpName %uniform_index_buffer "uniform_index_buffer"
+OpMemberDecorate %gl_PerVertex 0 BuiltIn Position
+OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize
+OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance
+OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance
+OpDecorate %gl_PerVertex Block
+OpMemberDecorate %bfoo 0 Offset 0
+OpDecorate %bfoo Block
+OpDecorate %adds DescriptorSet 0
+OpDecorate %adds Binding 1
+OpMemberDecorate %ufoo 0 Offset 0
+OpDecorate %ufoo Block
+OpDecorate %uniform_index_buffer DescriptorSet 0
+OpDecorate %uniform_index_buffer Binding 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%_arr_float_uint_1 = OpTypeArray %float %uint_1
+%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1
+%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex
+%_ = OpVariable %_ptr_Output_gl_PerVertex Output
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%bfoo = OpTypeStruct %v4float
+%uint_11 = OpConstant %uint 11
+%_arr_bfoo_uint_11 = OpTypeArray %bfoo %uint_11
+%_ptr_StorageBuffer__arr_bfoo_uint_11 = OpTypePointer StorageBuffer %_arr_bfoo_uint_11
+%adds = OpVariable %_ptr_StorageBuffer__arr_bfoo_uint_11 StorageBuffer
+%ufoo = OpTypeStruct %uint
+%_ptr_Uniform_ufoo = OpTypePointer Uniform %ufoo
+%uniform_index_buffer = OpVariable %_ptr_Uniform_ufoo Uniform
+%_ptr_Uniform_uint = OpTypePointer Uniform %uint
+%_ptr_StorageBuffer_v4float = OpTypePointer StorageBuffer %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+)";
+
+ const std::string defs_after =
+ R"(OpCapability Tessellation
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint TessellationEvaluation %main "main" %_ %gl_PrimitiveID %gl_TessCoord
+OpExecutionMode %main Triangles
+OpExecutionMode %main SpacingEqual
+OpExecutionMode %main VertexOrderCw
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_nonuniform_qualifier"
+OpName %main "main"
+OpName %gl_PerVertex "gl_PerVertex"
+OpMemberName %gl_PerVertex 0 "gl_Position"
+OpMemberName %gl_PerVertex 1 "gl_PointSize"
+OpMemberName %gl_PerVertex 2 "gl_ClipDistance"
+OpMemberName %gl_PerVertex 3 "gl_CullDistance"
+OpName %_ ""
+OpName %bfoo "bfoo"
+OpMemberName %bfoo 0 "val"
+OpName %adds "adds"
+OpName %ufoo "ufoo"
+OpMemberName %ufoo 0 "index"
+OpName %uniform_index_buffer "uniform_index_buffer"
+OpMemberDecorate %gl_PerVertex 0 BuiltIn Position
+OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize
+OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance
+OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance
+OpDecorate %gl_PerVertex Block
+OpMemberDecorate %bfoo 0 Offset 0
+OpDecorate %bfoo Block
+OpDecorate %adds DescriptorSet 0
+OpDecorate %adds Binding 1
+OpMemberDecorate %ufoo 0 Offset 0
+OpDecorate %ufoo Block
+OpDecorate %uniform_index_buffer DescriptorSet 0
+OpDecorate %uniform_index_buffer Binding 0
+OpDecorate %_runtimearr_uint ArrayStride 4
+OpDecorate %_struct_47 Block
+OpMemberDecorate %_struct_47 0 Offset 0
+OpMemberDecorate %_struct_47 1 Offset 4
+OpDecorate %49 DescriptorSet 7
+OpDecorate %49 Binding 0
+OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+OpDecorate %gl_TessCoord BuiltIn TessCoord
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%_arr_float_uint_1 = OpTypeArray %float %uint_1
+%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1
+%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex
+%_ = OpVariable %_ptr_Output_gl_PerVertex Output
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%bfoo = OpTypeStruct %v4float
+%uint_11 = OpConstant %uint 11
+%_arr_bfoo_uint_11 = OpTypeArray %bfoo %uint_11
+%_ptr_StorageBuffer__arr_bfoo_uint_11 = OpTypePointer StorageBuffer %_arr_bfoo_uint_11
+%adds = OpVariable %_ptr_StorageBuffer__arr_bfoo_uint_11 StorageBuffer
+%ufoo = OpTypeStruct %uint
+%_ptr_Uniform_ufoo = OpTypePointer Uniform %ufoo
+%uniform_index_buffer = OpVariable %_ptr_Uniform_ufoo Uniform
+%_ptr_Uniform_uint = OpTypePointer Uniform %uint
+%_ptr_StorageBuffer_v4float = OpTypePointer StorageBuffer %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%uint_0 = OpConstant %uint 0
+%bool = OpTypeBool
+%40 = OpTypeFunction %void %uint %uint %uint %uint
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+%_struct_47 = OpTypeStruct %uint %_runtimearr_uint
+%_ptr_StorageBuffer__struct_47 = OpTypePointer StorageBuffer %_struct_47
+%49 = OpVariable %_ptr_StorageBuffer__struct_47 StorageBuffer
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+%uint_10 = OpConstant %uint 10
+%uint_4 = OpConstant %uint 4
+%uint_23 = OpConstant %uint 23
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%_ptr_Input_uint = OpTypePointer Input %uint
+%gl_PrimitiveID = OpVariable %_ptr_Input_uint Input
+%v3float = OpTypeVector %float 3
+%_ptr_Input_v3float = OpTypePointer Input %v3float
+%gl_TessCoord = OpVariable %_ptr_Input_v3float Input
+%v3uint = OpTypeVector %uint 3
+%uint_5 = OpConstant %uint 5
+%uint_6 = OpConstant %uint 6
+%uint_7 = OpConstant %uint 7
+%uint_8 = OpConstant %uint 8
+%uint_9 = OpConstant %uint 9
+%uint_63 = OpConstant %uint 63
+%101 = OpConstantNull %v4float
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %3
+%5 = OpLabel
+%25 = OpAccessChain %_ptr_Uniform_uint %uniform_index_buffer %int_0
+%26 = OpLoad %uint %25
+%28 = OpAccessChain %_ptr_StorageBuffer_v4float %adds %26 %int_0
+%29 = OpLoad %v4float %28
+%31 = OpAccessChain %_ptr_Output_v4float %_ %int_0
+OpStore %31 %29
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %10
+%26 = OpLabel
+%27 = OpAccessChain %_ptr_Uniform_uint %uniform_index_buffer %int_0
+%28 = OpLoad %uint %27
+%29 = OpAccessChain %_ptr_StorageBuffer_v4float %adds %28 %int_0
+%34 = OpULessThan %bool %28 %uint_11
+OpSelectionMerge %35 None
+OpBranchConditional %34 %36 %37
+%36 = OpLabel
+%38 = OpLoad %v4float %29
+OpBranch %35
+%37 = OpLabel
+%100 = OpFunctionCall %void %39 %uint_63 %uint_0 %28 %uint_11
+OpBranch %35
+%35 = OpLabel
+%102 = OpPhi %v4float %38 %36 %101 %37
+%31 = OpAccessChain %_ptr_Output_v4float %_ %int_0
+OpStore %31 %102
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string output_func =
+ R"(%39 = OpFunction %void None %40
+%41 = OpFunctionParameter %uint
+%42 = OpFunctionParameter %uint
+%43 = OpFunctionParameter %uint
+%44 = OpFunctionParameter %uint
+%45 = OpLabel
+%51 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_0
+%54 = OpAtomicIAdd %uint %51 %uint_4 %uint_0 %uint_10
+%55 = OpIAdd %uint %54 %uint_10
+%56 = OpArrayLength %uint %49 1
+%57 = OpULessThanEqual %bool %55 %56
+OpSelectionMerge %58 None
+OpBranchConditional %57 %59 %58
+%59 = OpLabel
+%60 = OpIAdd %uint %54 %uint_0
+%61 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %60
+OpStore %61 %uint_10
+%63 = OpIAdd %uint %54 %uint_1
+%64 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %63
+OpStore %64 %uint_23
+%66 = OpIAdd %uint %54 %uint_2
+%67 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %66
+OpStore %67 %41
+%69 = OpIAdd %uint %54 %uint_3
+%70 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %69
+OpStore %70 %uint_2
+%73 = OpLoad %uint %gl_PrimitiveID
+%74 = OpIAdd %uint %54 %uint_4
+%75 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %74
+OpStore %75 %73
+%79 = OpLoad %v3float %gl_TessCoord
+%81 = OpBitcast %v3uint %79
+%82 = OpCompositeExtract %uint %81 0
+%83 = OpCompositeExtract %uint %81 1
+%85 = OpIAdd %uint %54 %uint_5
+%86 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %85
+OpStore %86 %82
+%88 = OpIAdd %uint %54 %uint_6
+%89 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %88
+OpStore %89 %83
+%91 = OpIAdd %uint %54 %uint_7
+%92 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %91
+OpStore %92 %42
+%94 = OpIAdd %uint %54 %uint_8
+%95 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %94
+OpStore %95 %43
+%97 = OpIAdd %uint %54 %uint_9
+%98 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %97
+OpStore %98 %44
+OpBranch %58
+%58 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<InstBindlessCheckPass>(
+ defs_before + func_before, defs_after + func_after + output_func, true,
+ true, 7u, 23u, false, false, 2u);
+}
+
TEST_F(InstBindlessTest, MultipleDebugFunctionsV2) {
// Same source as Simple, but compiled -g and not optimized, especially not
// inlined. The OpSource has had the source extracted for the sake of brevity.
diff --git a/test/opt/inst_buff_addr_check_test.cpp b/test/opt/inst_buff_addr_check_test.cpp
new file mode 100644
index 0000000..c31266e
--- /dev/null
+++ b/test/opt/inst_buff_addr_check_test.cpp
@@ -0,0 +1,619 @@
+// Copyright (c) 2019 Valve Corporation
+// Copyright (c) 2019 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Bindless Check Instrumentation Tests.
+// Tests ending with V2 use version 2 record format.
+
+#include <string>
+#include <vector>
+
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using InstBuffAddrTest = PassTest<::testing::Test>;
+
+TEST_F(InstBuffAddrTest, InstPhysicalStorageBufferStore) {
+ // #version 450
+ // #extension GL_EXT_buffer_reference : enable
+ //
+ // layout(buffer_reference, buffer_reference_align = 16) buffer bufStruct;
+ //
+ // layout(set = 0, binding = 0) uniform ufoo {
+ // bufStruct data;
+ // uint offset;
+ // } u_info;
+ //
+ // layout(buffer_reference, std140) buffer bufStruct {
+ // layout(offset = 0) int a[2];
+ // layout(offset = 32) int b;
+ // };
+ //
+ // void main() {
+ // u_info.data.b = 0xca7;
+ // }
+
+ const std::string defs_before =
+ R"(OpCapability Shader
+OpCapability PhysicalStorageBufferAddressesEXT
+OpExtension "SPV_EXT_physical_storage_buffer"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_buffer_reference"
+OpName %main "main"
+OpName %ufoo "ufoo"
+OpMemberName %ufoo 0 "data"
+OpMemberName %ufoo 1 "offset"
+OpName %bufStruct "bufStruct"
+OpMemberName %bufStruct 0 "a"
+OpMemberName %bufStruct 1 "b"
+OpName %u_info "u_info"
+OpMemberDecorate %ufoo 0 Offset 0
+OpMemberDecorate %ufoo 1 Offset 8
+OpDecorate %ufoo Block
+OpDecorate %_arr_int_uint_2 ArrayStride 16
+OpMemberDecorate %bufStruct 0 Offset 0
+OpMemberDecorate %bufStruct 1 Offset 32
+OpDecorate %bufStruct Block
+OpDecorate %u_info DescriptorSet 0
+OpDecorate %u_info Binding 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_bufStruct PhysicalStorageBufferEXT
+%uint = OpTypeInt 32 0
+%ufoo = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_bufStruct %uint
+%int = OpTypeInt 32 1
+%uint_2 = OpConstant %uint 2
+%_arr_int_uint_2 = OpTypeArray %int %uint_2
+%bufStruct = OpTypeStruct %_arr_int_uint_2 %int
+%_ptr_PhysicalStorageBufferEXT_bufStruct = OpTypePointer PhysicalStorageBufferEXT %bufStruct
+%_ptr_Uniform_ufoo = OpTypePointer Uniform %ufoo
+%u_info = OpVariable %_ptr_Uniform_ufoo Uniform
+%int_0 = OpConstant %int 0
+%_ptr_Uniform__ptr_PhysicalStorageBufferEXT_bufStruct = OpTypePointer Uniform %_ptr_PhysicalStorageBufferEXT_bufStruct
+%int_1 = OpConstant %int 1
+%int_3239 = OpConstant %int 3239
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+)";
+
+ const std::string defs_after =
+ R"(OpCapability Shader
+OpCapability PhysicalStorageBufferAddressesEXT
+OpCapability Int64
+OpExtension "SPV_EXT_physical_storage_buffer"
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_buffer_reference"
+OpName %main "main"
+OpName %ufoo "ufoo"
+OpMemberName %ufoo 0 "data"
+OpMemberName %ufoo 1 "offset"
+OpName %bufStruct "bufStruct"
+OpMemberName %bufStruct 0 "a"
+OpMemberName %bufStruct 1 "b"
+OpName %u_info "u_info"
+OpMemberDecorate %ufoo 0 Offset 0
+OpMemberDecorate %ufoo 1 Offset 8
+OpDecorate %ufoo Block
+OpDecorate %_arr_int_uint_2 ArrayStride 16
+OpMemberDecorate %bufStruct 0 Offset 0
+OpMemberDecorate %bufStruct 1 Offset 32
+OpDecorate %bufStruct Block
+OpDecorate %u_info DescriptorSet 0
+OpDecorate %u_info Binding 0
+OpDecorate %_runtimearr_ulong ArrayStride 8
+OpDecorate %_struct_39 Block
+OpMemberDecorate %_struct_39 0 Offset 0
+OpDecorate %41 DescriptorSet 7
+OpDecorate %41 Binding 2
+OpDecorate %_runtimearr_uint ArrayStride 4
+OpDecorate %_struct_77 Block
+OpMemberDecorate %_struct_77 0 Offset 0
+OpMemberDecorate %_struct_77 1 Offset 4
+OpDecorate %79 DescriptorSet 7
+OpDecorate %79 Binding 0
+OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_bufStruct PhysicalStorageBufferEXT
+%uint = OpTypeInt 32 0
+%ufoo = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_bufStruct %uint
+%int = OpTypeInt 32 1
+%uint_2 = OpConstant %uint 2
+%_arr_int_uint_2 = OpTypeArray %int %uint_2
+%bufStruct = OpTypeStruct %_arr_int_uint_2 %int
+%_ptr_PhysicalStorageBufferEXT_bufStruct = OpTypePointer PhysicalStorageBufferEXT %bufStruct
+%_ptr_Uniform_ufoo = OpTypePointer Uniform %ufoo
+%u_info = OpVariable %_ptr_Uniform_ufoo Uniform
+%int_0 = OpConstant %int 0
+%_ptr_Uniform__ptr_PhysicalStorageBufferEXT_bufStruct = OpTypePointer Uniform %_ptr_PhysicalStorageBufferEXT_bufStruct
+%int_1 = OpConstant %int 1
+%int_3239 = OpConstant %int 3239
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+%ulong = OpTypeInt 64 0
+%uint_4 = OpConstant %uint 4
+%bool = OpTypeBool
+%28 = OpTypeFunction %bool %ulong %uint
+%uint_1 = OpConstant %uint 1
+%_runtimearr_ulong = OpTypeRuntimeArray %ulong
+%_struct_39 = OpTypeStruct %_runtimearr_ulong
+%_ptr_StorageBuffer__struct_39 = OpTypePointer StorageBuffer %_struct_39
+%41 = OpVariable %_ptr_StorageBuffer__struct_39 StorageBuffer
+%_ptr_StorageBuffer_ulong = OpTypePointer StorageBuffer %ulong
+%uint_0 = OpConstant %uint 0
+%uint_32 = OpConstant %uint 32
+%70 = OpTypeFunction %void %uint %uint %uint %uint
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+%_struct_77 = OpTypeStruct %uint %_runtimearr_uint
+%_ptr_StorageBuffer__struct_77 = OpTypePointer StorageBuffer %_struct_77
+%79 = OpVariable %_ptr_StorageBuffer__struct_77 StorageBuffer
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+%uint_10 = OpConstant %uint 10
+%uint_23 = OpConstant %uint 23
+%uint_5 = OpConstant %uint 5
+%uint_3 = OpConstant %uint 3
+%v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%uint_6 = OpConstant %uint 6
+%uint_7 = OpConstant %uint 7
+%uint_8 = OpConstant %uint 8
+%uint_9 = OpConstant %uint 9
+%uint_48 = OpConstant %uint 48
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %3
+%5 = OpLabel
+%17 = OpAccessChain %_ptr_Uniform__ptr_PhysicalStorageBufferEXT_bufStruct %u_info %int_0
+%18 = OpLoad %_ptr_PhysicalStorageBufferEXT_bufStruct %17
+%22 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %18 %int_1
+OpStore %22 %int_3239 Aligned 16
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %8
+%19 = OpLabel
+%20 = OpAccessChain %_ptr_Uniform__ptr_PhysicalStorageBufferEXT_bufStruct %u_info %int_0
+%21 = OpLoad %_ptr_PhysicalStorageBufferEXT_bufStruct %20
+%22 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %21 %int_1
+%24 = OpConvertPtrToU %ulong %22
+%61 = OpFunctionCall %bool %26 %24 %uint_4
+OpSelectionMerge %62 None
+OpBranchConditional %61 %63 %64
+%63 = OpLabel
+OpStore %22 %int_3239 Aligned 16
+OpBranch %62
+%64 = OpLabel
+%65 = OpUConvert %uint %24
+%67 = OpShiftRightLogical %ulong %24 %uint_32
+%68 = OpUConvert %uint %67
+%124 = OpFunctionCall %void %69 %uint_48 %uint_2 %65 %68
+OpBranch %62
+%62 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string new_funcs =
+ R"(%26 = OpFunction %bool None %28
+%29 = OpFunctionParameter %ulong
+%30 = OpFunctionParameter %uint
+%31 = OpLabel
+OpBranch %32
+%32 = OpLabel
+%34 = OpPhi %uint %uint_1 %31 %35 %33
+OpLoopMerge %37 %33 None
+OpBranch %33
+%33 = OpLabel
+%35 = OpIAdd %uint %34 %uint_1
+%44 = OpAccessChain %_ptr_StorageBuffer_ulong %41 %uint_0 %35
+%45 = OpLoad %ulong %44
+%46 = OpUGreaterThan %bool %45 %29
+OpBranchConditional %46 %37 %32
+%37 = OpLabel
+%47 = OpISub %uint %35 %uint_1
+%48 = OpAccessChain %_ptr_StorageBuffer_ulong %41 %uint_0 %47
+%49 = OpLoad %ulong %48
+%50 = OpISub %ulong %29 %49
+%51 = OpUConvert %ulong %30
+%52 = OpIAdd %ulong %50 %51
+%53 = OpAccessChain %_ptr_StorageBuffer_ulong %41 %uint_0 %uint_0
+%54 = OpLoad %ulong %53
+%55 = OpUConvert %uint %54
+%56 = OpISub %uint %47 %uint_1
+%57 = OpIAdd %uint %56 %55
+%58 = OpAccessChain %_ptr_StorageBuffer_ulong %41 %uint_0 %57
+%59 = OpLoad %ulong %58
+%60 = OpULessThanEqual %bool %52 %59
+OpReturnValue %60
+OpFunctionEnd
+%69 = OpFunction %void None %70
+%71 = OpFunctionParameter %uint
+%72 = OpFunctionParameter %uint
+%73 = OpFunctionParameter %uint
+%74 = OpFunctionParameter %uint
+%75 = OpLabel
+%81 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_0
+%83 = OpAtomicIAdd %uint %81 %uint_4 %uint_0 %uint_10
+%84 = OpIAdd %uint %83 %uint_10
+%85 = OpArrayLength %uint %79 1
+%86 = OpULessThanEqual %bool %84 %85
+OpSelectionMerge %87 None
+OpBranchConditional %86 %88 %87
+%88 = OpLabel
+%89 = OpIAdd %uint %83 %uint_0
+%90 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %89
+OpStore %90 %uint_10
+%92 = OpIAdd %uint %83 %uint_1
+%93 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %92
+OpStore %93 %uint_23
+%94 = OpIAdd %uint %83 %uint_2
+%95 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %94
+OpStore %95 %71
+%98 = OpIAdd %uint %83 %uint_3
+%99 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %98
+OpStore %99 %uint_5
+%103 = OpLoad %v3uint %gl_GlobalInvocationID
+%104 = OpCompositeExtract %uint %103 0
+%105 = OpCompositeExtract %uint %103 1
+%106 = OpCompositeExtract %uint %103 2
+%107 = OpIAdd %uint %83 %uint_4
+%108 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %107
+OpStore %108 %104
+%109 = OpIAdd %uint %83 %uint_5
+%110 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %109
+OpStore %110 %105
+%112 = OpIAdd %uint %83 %uint_6
+%113 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %112
+OpStore %113 %106
+%115 = OpIAdd %uint %83 %uint_7
+%116 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %115
+OpStore %116 %72
+%118 = OpIAdd %uint %83 %uint_8
+%119 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %118
+OpStore %119 %73
+%121 = OpIAdd %uint %83 %uint_9
+%122 = OpAccessChain %_ptr_StorageBuffer_uint %79 %uint_1 %121
+OpStore %122 %74
+OpBranch %87
+%87 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<InstBuffAddrCheckPass>(
+ defs_before + func_before, defs_after + func_after + new_funcs, true,
+ true, 7u, 23u, 2u);
+}
+
+TEST_F(InstBuffAddrTest, InstPhysicalStorageBufferLoadAndStore) {
+ // #version 450
+ // #extension GL_EXT_buffer_reference : enable
+
+ // // forward reference
+ // layout(buffer_reference) buffer blockType;
+
+ // layout(buffer_reference, std430, buffer_reference_align = 16) buffer
+ // blockType {
+ // int x;
+ // blockType next;
+ // };
+
+ // layout(std430) buffer rootBlock {
+ // blockType root;
+ // } r;
+
+ // void main()
+ // {
+ // blockType b = r.root;
+ // b = b.next;
+ // b.x = 531;
+ // }
+
+ const std::string defs_before =
+ R"(OpCapability Shader
+OpCapability PhysicalStorageBufferAddressesEXT
+OpExtension "SPV_EXT_physical_storage_buffer"
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_buffer_reference"
+OpName %main "main"
+OpName %blockType "blockType"
+OpMemberName %blockType 0 "x"
+OpMemberName %blockType 1 "next"
+OpName %rootBlock "rootBlock"
+OpMemberName %rootBlock 0 "root"
+OpName %r "r"
+OpMemberDecorate %blockType 0 Offset 0
+OpMemberDecorate %blockType 1 Offset 8
+OpDecorate %blockType Block
+OpMemberDecorate %rootBlock 0 Offset 0
+OpDecorate %rootBlock Block
+OpDecorate %r DescriptorSet 0
+OpDecorate %r Binding 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_blockType PhysicalStorageBufferEXT
+%int = OpTypeInt 32 1
+%blockType = OpTypeStruct %int %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %blockType
+%rootBlock = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_StorageBuffer_rootBlock = OpTypePointer StorageBuffer %rootBlock
+%r = OpVariable %_ptr_StorageBuffer_rootBlock StorageBuffer
+%int_0 = OpConstant %int 0
+%_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBufferEXT_blockType
+%int_1 = OpConstant %int 1
+%_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %_ptr_PhysicalStorageBufferEXT_blockType
+%int_531 = OpConstant %int 531
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+)";
+
+ const std::string defs_after =
+ R"(OpCapability Shader
+OpCapability PhysicalStorageBufferAddressesEXT
+OpCapability Int64
+OpExtension "SPV_EXT_physical_storage_buffer"
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel PhysicalStorageBuffer64EXT GLSL450
+OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpSourceExtension "GL_EXT_buffer_reference"
+OpName %main "main"
+OpName %blockType "blockType"
+OpMemberName %blockType 0 "x"
+OpMemberName %blockType 1 "next"
+OpName %rootBlock "rootBlock"
+OpMemberName %rootBlock 0 "root"
+OpName %r "r"
+OpMemberDecorate %blockType 0 Offset 0
+OpMemberDecorate %blockType 1 Offset 8
+OpDecorate %blockType Block
+OpMemberDecorate %rootBlock 0 Offset 0
+OpDecorate %rootBlock Block
+OpDecorate %r DescriptorSet 0
+OpDecorate %r Binding 0
+OpDecorate %_runtimearr_ulong ArrayStride 8
+OpDecorate %_struct_45 Block
+OpMemberDecorate %_struct_45 0 Offset 0
+OpDecorate %47 DescriptorSet 7
+OpDecorate %47 Binding 2
+OpDecorate %_runtimearr_uint ArrayStride 4
+OpDecorate %_struct_84 Block
+OpMemberDecorate %_struct_84 0 Offset 0
+OpMemberDecorate %_struct_84 1 Offset 4
+OpDecorate %86 DescriptorSet 7
+OpDecorate %86 Binding 0
+OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+OpTypeForwardPointer %_ptr_PhysicalStorageBufferEXT_blockType PhysicalStorageBufferEXT
+%int = OpTypeInt 32 1
+%blockType = OpTypeStruct %int %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %blockType
+%rootBlock = OpTypeStruct %_ptr_PhysicalStorageBufferEXT_blockType
+%_ptr_StorageBuffer_rootBlock = OpTypePointer StorageBuffer %rootBlock
+%r = OpVariable %_ptr_StorageBuffer_rootBlock StorageBuffer
+%int_0 = OpConstant %int 0
+%_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBufferEXT_blockType
+%int_1 = OpConstant %int 1
+%_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType = OpTypePointer PhysicalStorageBufferEXT %_ptr_PhysicalStorageBufferEXT_blockType
+%int_531 = OpConstant %int 531
+%_ptr_PhysicalStorageBufferEXT_int = OpTypePointer PhysicalStorageBufferEXT %int
+%uint = OpTypeInt 32 0
+%uint_2 = OpConstant %uint 2
+%ulong = OpTypeInt 64 0
+%uint_8 = OpConstant %uint 8
+%bool = OpTypeBool
+%34 = OpTypeFunction %bool %ulong %uint
+%uint_1 = OpConstant %uint 1
+%_runtimearr_ulong = OpTypeRuntimeArray %ulong
+%_struct_45 = OpTypeStruct %_runtimearr_ulong
+%_ptr_StorageBuffer__struct_45 = OpTypePointer StorageBuffer %_struct_45
+%47 = OpVariable %_ptr_StorageBuffer__struct_45 StorageBuffer
+%_ptr_StorageBuffer_ulong = OpTypePointer StorageBuffer %ulong
+%uint_0 = OpConstant %uint 0
+%uint_32 = OpConstant %uint 32
+%77 = OpTypeFunction %void %uint %uint %uint %uint
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+%_struct_84 = OpTypeStruct %uint %_runtimearr_uint
+%_ptr_StorageBuffer__struct_84 = OpTypePointer StorageBuffer %_struct_84
+%86 = OpVariable %_ptr_StorageBuffer__struct_84 StorageBuffer
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+%uint_10 = OpConstant %uint 10
+%uint_4 = OpConstant %uint 4
+%uint_23 = OpConstant %uint 23
+%uint_5 = OpConstant %uint 5
+%uint_3 = OpConstant %uint 3
+%v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%uint_6 = OpConstant %uint 6
+%uint_7 = OpConstant %uint 7
+%uint_9 = OpConstant %uint 9
+%uint_44 = OpConstant %uint 44
+%132 = OpConstantNull %_ptr_PhysicalStorageBufferEXT_blockType
+%uint_46 = OpConstant %uint 46
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %3
+%5 = OpLabel
+%16 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType %r %int_0
+%17 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %16
+%21 = OpAccessChain %_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType %17 %int_1
+%22 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %21 Aligned 8
+%26 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %22 %int_0
+OpStore %26 %int_531 Aligned 16
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %3
+%5 = OpLabel
+%16 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBufferEXT_blockType %r %int_0
+%17 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %16
+%21 = OpAccessChain %_ptr_PhysicalStorageBufferEXT__ptr_PhysicalStorageBufferEXT_blockType %17 %int_1
+%30 = OpConvertPtrToU %ulong %21
+%67 = OpFunctionCall %bool %32 %30 %uint_8
+OpSelectionMerge %68 None
+OpBranchConditional %67 %69 %70
+%69 = OpLabel
+%71 = OpLoad %_ptr_PhysicalStorageBufferEXT_blockType %21 Aligned 8
+OpBranch %68
+%70 = OpLabel
+%72 = OpUConvert %uint %30
+%74 = OpShiftRightLogical %ulong %30 %uint_32
+%75 = OpUConvert %uint %74
+%131 = OpFunctionCall %void %76 %uint_44 %uint_2 %72 %75
+OpBranch %68
+%68 = OpLabel
+%133 = OpPhi %_ptr_PhysicalStorageBufferEXT_blockType %71 %69 %132 %70
+%26 = OpAccessChain %_ptr_PhysicalStorageBufferEXT_int %133 %int_0
+%134 = OpConvertPtrToU %ulong %26
+%135 = OpFunctionCall %bool %32 %134 %uint_4
+OpSelectionMerge %136 None
+OpBranchConditional %135 %137 %138
+%137 = OpLabel
+OpStore %26 %int_531 Aligned 16
+OpBranch %136
+%138 = OpLabel
+%139 = OpUConvert %uint %134
+%140 = OpShiftRightLogical %ulong %134 %uint_32
+%141 = OpUConvert %uint %140
+%143 = OpFunctionCall %void %76 %uint_46 %uint_2 %139 %141
+OpBranch %136
+%136 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string new_funcs =
+ R"(%32 = OpFunction %bool None %34
+%35 = OpFunctionParameter %ulong
+%36 = OpFunctionParameter %uint
+%37 = OpLabel
+OpBranch %38
+%38 = OpLabel
+%40 = OpPhi %uint %uint_1 %37 %41 %39
+OpLoopMerge %43 %39 None
+OpBranch %39
+%39 = OpLabel
+%41 = OpIAdd %uint %40 %uint_1
+%50 = OpAccessChain %_ptr_StorageBuffer_ulong %47 %uint_0 %41
+%51 = OpLoad %ulong %50
+%52 = OpUGreaterThan %bool %51 %35
+OpBranchConditional %52 %43 %38
+%43 = OpLabel
+%53 = OpISub %uint %41 %uint_1
+%54 = OpAccessChain %_ptr_StorageBuffer_ulong %47 %uint_0 %53
+%55 = OpLoad %ulong %54
+%56 = OpISub %ulong %35 %55
+%57 = OpUConvert %ulong %36
+%58 = OpIAdd %ulong %56 %57
+%59 = OpAccessChain %_ptr_StorageBuffer_ulong %47 %uint_0 %uint_0
+%60 = OpLoad %ulong %59
+%61 = OpUConvert %uint %60
+%62 = OpISub %uint %53 %uint_1
+%63 = OpIAdd %uint %62 %61
+%64 = OpAccessChain %_ptr_StorageBuffer_ulong %47 %uint_0 %63
+%65 = OpLoad %ulong %64
+%66 = OpULessThanEqual %bool %58 %65
+OpReturnValue %66
+OpFunctionEnd
+%76 = OpFunction %void None %77
+%78 = OpFunctionParameter %uint
+%79 = OpFunctionParameter %uint
+%80 = OpFunctionParameter %uint
+%81 = OpFunctionParameter %uint
+%82 = OpLabel
+%88 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_0
+%91 = OpAtomicIAdd %uint %88 %uint_4 %uint_0 %uint_10
+%92 = OpIAdd %uint %91 %uint_10
+%93 = OpArrayLength %uint %86 1
+%94 = OpULessThanEqual %bool %92 %93
+OpSelectionMerge %95 None
+OpBranchConditional %94 %96 %95
+%96 = OpLabel
+%97 = OpIAdd %uint %91 %uint_0
+%98 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %97
+OpStore %98 %uint_10
+%100 = OpIAdd %uint %91 %uint_1
+%101 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %100
+OpStore %101 %uint_23
+%102 = OpIAdd %uint %91 %uint_2
+%103 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %102
+OpStore %103 %78
+%106 = OpIAdd %uint %91 %uint_3
+%107 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %106
+OpStore %107 %uint_5
+%111 = OpLoad %v3uint %gl_GlobalInvocationID
+%112 = OpCompositeExtract %uint %111 0
+%113 = OpCompositeExtract %uint %111 1
+%114 = OpCompositeExtract %uint %111 2
+%115 = OpIAdd %uint %91 %uint_4
+%116 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %115
+OpStore %116 %112
+%117 = OpIAdd %uint %91 %uint_5
+%118 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %117
+OpStore %118 %113
+%120 = OpIAdd %uint %91 %uint_6
+%121 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %120
+OpStore %121 %114
+%123 = OpIAdd %uint %91 %uint_7
+%124 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %123
+OpStore %124 %79
+%125 = OpIAdd %uint %91 %uint_8
+%126 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %125
+OpStore %126 %80
+%128 = OpIAdd %uint %91 %uint_9
+%129 = OpAccessChain %_ptr_StorageBuffer_uint %86 %uint_1 %128
+OpStore %129 %81
+OpBranch %95
+%95 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<InstBuffAddrCheckPass>(
+ defs_before + func_before, defs_after + func_after + new_funcs, true,
+ true, 7u, 23u, 2u);
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h
index b7a0742..53fb206 100644
--- a/test/opt/pass_fixture.h
+++ b/test/opt/pass_fixture.h
@@ -75,7 +75,9 @@
const auto status = pass->Run(context());
std::vector<uint32_t> binary;
- context()->module()->ToBinary(&binary, skip_nop);
+ if (status != Pass::Status::Failure) {
+ context()->module()->ToBinary(&binary, skip_nop);
+ }
return std::make_tuple(binary, status);
}
@@ -188,6 +190,32 @@
<< disassembly;
}
+ // Runs a single pass of class |PassT| on the binary assembled from the
+ // |original| assembly. Check for failure and expect an Effcee matcher
+ // to pass when run on the diagnostic messages. This does *not* involve
+ // pass manager. Callers are suggested to use SCOPED_TRACE() for better
+ // messages.
+ template <typename PassT, typename... Args>
+ void SinglePassRunAndFail(const std::string& original, Args&&... args) {
+ context_ =
+ std::move(BuildModule(env_, consumer_, original, assemble_options_));
+ EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n"
+ << original << std::endl;
+ std::ostringstream errs;
+ auto error_consumer = [&errs](spv_message_level_t, const char*,
+ const spv_position_t&, const char* message) {
+ errs << message << std::endl;
+ };
+ auto pass = MakeUnique<PassT>(std::forward<Args>(args)...);
+ pass->SetMessageConsumer(error_consumer);
+ const auto status = pass->Run(context());
+ EXPECT_EQ(Pass::Status::Failure, status);
+ auto match_result = effcee::Match(errs.str(), original);
+ EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+ << match_result.message() << "\nChecking messages:\n"
+ << errs.str();
+ }
+
// Adds a pass to be run.
template <typename PassT, typename... Args>
void AddPass(Args&&... args) {
@@ -215,15 +243,18 @@
context()->set_preserve_spec_constants(
OptimizerOptions()->preserve_spec_constants_);
- manager_->Run(context());
+ auto status = manager_->Run(context());
+ EXPECT_NE(status, Pass::Status::Failure);
- std::vector<uint32_t> binary;
- context()->module()->ToBinary(&binary, /* skip_nop = */ false);
+ if (status != Pass::Status::Failure) {
+ std::vector<uint32_t> binary;
+ context()->module()->ToBinary(&binary, /* skip_nop = */ false);
- std::string optimized;
- SpirvTools tools(env_);
- EXPECT_TRUE(tools.Disassemble(binary, &optimized, disassemble_options_));
- EXPECT_EQ(expected, optimized);
+ std::string optimized;
+ SpirvTools tools(env_);
+ EXPECT_TRUE(tools.Disassemble(binary, &optimized, disassemble_options_));
+ EXPECT_EQ(expected, optimized);
+ }
}
void SetAssembleOptions(uint32_t assemble_options) {
diff --git a/test/opt/private_to_local_test.cpp b/test/opt/private_to_local_test.cpp
index d154840..1230652 100644
--- a/test/opt/private_to_local_test.cpp
+++ b/test/opt/private_to_local_test.cpp
@@ -419,6 +419,39 @@
SinglePassRunAndMatch<PrivateToLocalPass>(text, true);
}
+TEST_F(PrivateToLocalTest, IdBoundOverflow1) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginLowerLeft
+ OpSource HLSL 84
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypeVector %6 4
+ %8 = OpTypeStruct %7
+ %4194302 = OpTypeStruct %8 %8
+ %9 = OpTypeStruct %8 %8
+ %11 = OpTypePointer Private %7
+ %18 = OpTypeStruct %6 %9
+ %12 = OpVariable %11 Private
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %13 = OpLoad %7 %12
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<PrivateToLocalPass>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index 2ed7b5a..3cf46ca 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -1621,7 +1621,7 @@
}
// Test that id overflow is handled gracefully.
-TEST_F(ScalarReplacementTest, IdBoundOverflow) {
+TEST_F(ScalarReplacementTest, IdBoundOverflow1) {
const std::string text = R"(
OpCapability ImageQuery
OpMemoryModel Logical GLSL450
@@ -1652,8 +1652,103 @@
{SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
{SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
SetMessageConsumer(GetTestMessageConsumer(messages));
- auto result =
- SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, false);
+ auto result = SinglePassRunToBinary<ScalarReplacementPass>(text, true, false);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+// Test that id overflow is handled gracefully.
+TEST_F(ScalarReplacementTest, IdBoundOverflow2) {
+ const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main" %17
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeFloat 32
+%7 = OpTypeVector %6 4
+%8 = OpTypeStruct %7
+%9 = OpTypePointer Function %8
+%16 = OpTypePointer Output %7
+%21 = OpTypeInt 32 1
+%22 = OpConstant %21 0
+%23 = OpTypePointer Function %7
+%17 = OpVariable %16 Output
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%4194300 = OpVariable %23 Function
+%10 = OpVariable %9 Function
+%4194301 = OpAccessChain %23 %10 %22
+%4194302 = OpLoad %7 %4194301
+OpStore %4194300 %4194302
+%15 = OpLoad %7 %4194300
+OpStore %17 %15
+OpReturn
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<ScalarReplacementPass>(text, true, false);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+// Test that id overflow is handled gracefully.
+TEST_F(ScalarReplacementTest, IdBoundOverflow3) {
+ const std::string text = R"(
+OpCapability InterpolationFunction
+OpExtension "z"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeFloat 32
+%7 = OpTypeStruct %6 %6
+%9 = OpTypePointer Function %7
+%18 = OpTypeFunction %7 %9
+%21 = OpTypeInt 32 0
+%22 = OpConstant %21 4293000676
+%4194302 = OpConstantNull %6
+%4 = OpFunction %2 Inline|Pure %3
+%786464 = OpLabel
+%4194298 = OpVariable %9 Function
+%10 = OpVariable %9 Function
+%4194299 = OpUDiv %21 %22 %22
+%4194300 = OpLoad %7 %10
+%50959 = OpLoad %7 %4194298
+OpKill
+OpFunctionEnd
+%1 = OpFunction %7 None %18
+%19 = OpFunctionParameter %9
+%147667 = OpLabel
+%2044391 = OpUDiv %21 %22 %22
+%25 = OpLoad %7 %19
+OpReturnValue %25
+OpFunctionEnd
+%4194295 = OpFunction %2 None %3
+%4194296 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<ScalarReplacementPass>(text, true, false);
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
}
@@ -1701,6 +1796,109 @@
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
+TEST_F(ScalarReplacementTest, CharIndex) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
+; CHECK: OpVariable [[ptr]] Function
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 0
+%char_1 = OpConstant %char 1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
+}
+
+TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 1
+%char_n1 = OpConstant %char -1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_n1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+ auto result =
+ SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
+ EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
+}
+
+TEST_F(ScalarReplacementTest, RelaxedPrecisionMemberDecoration) {
+ const std::string text = R"(
+; CHECK: OpDecorate {{%\w+}} RelaxedPrecision
+; CHECK: OpDecorate [[new_var:%\w+]] RelaxedPrecision
+; CHECK: [[new_var]] = OpVariable %_ptr_Function_v3float Function
+; CHECK: OpLoad %v3float [[new_var]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %1 "Draw2DTexCol_VS" %2 %3
+ OpSource HLSL 600
+ OpDecorate %2 Location 0
+ OpDecorate %3 Location 1
+ OpDecorate %3 RelaxedPrecision
+ OpMemberDecorate %_struct_4 1 RelaxedPrecision
+ %float = OpTypeFloat 32
+ %int = OpTypeInt 32 1
+ %int_1 = OpConstant %int 1
+ %v3float = OpTypeVector %float 3
+%_ptr_Input_v3float = OpTypePointer Input %v3float
+ %void = OpTypeVoid
+ %11 = OpTypeFunction %void
+ %_struct_4 = OpTypeStruct %v3float %v3float
+%_ptr_Function__struct_4 = OpTypePointer Function %_struct_4
+%_ptr_Function_v3float = OpTypePointer Function %v3float
+ %2 = OpVariable %_ptr_Input_v3float Input
+ %3 = OpVariable %_ptr_Input_v3float Input
+ %1 = OpFunction %void None %11
+ %14 = OpLabel
+ %15 = OpVariable %_ptr_Function__struct_4 Function
+ %16 = OpLoad %v3float %2
+ %17 = OpLoad %v3float %3
+ %18 = OpCompositeConstruct %_struct_4 %16 %17
+ OpStore %15 %18
+ %19 = OpAccessChain %_ptr_Function_v3float %15 %int_1
+ %20 = OpLoad %v3float %19
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/simplification_test.cpp b/test/opt/simplification_test.cpp
index 4dbcfbe..1420498 100644
--- a/test/opt/simplification_test.cpp
+++ b/test/opt/simplification_test.cpp
@@ -279,6 +279,52 @@
SinglePassRunAndCheck<SimplificationPass>(before, after, false);
}
+TEST_F(SimplificationTest, DontMoveDecorations) {
+ const std::string spirv = R"(
+; CHECK-NOT: RelaxedPrecision
+; CHECK: [[sub:%\w+]] = OpFSub
+; CHECK: OpStore {{.*}} [[sub]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %add RelaxedPrecision
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 1 Offset 4
+OpDecorate %in DescriptorSet 0
+OpDecorate %in Binding 0
+OpDecorate %out DescriptorSet 0
+OpDecorate %out Binding 1
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%void_fn = OpTypeFunction %void
+%block = OpTypeStruct %float %float
+%ptr_ssbo_block = OpTypePointer StorageBuffer %block
+%in = OpVariable %ptr_ssbo_block StorageBuffer
+%out = OpVariable %ptr_ssbo_block StorageBuffer
+%ptr_ssbo_float = OpTypePointer StorageBuffer %float
+%int = OpTypeInt 32 0
+%int_0 = OpConstant %int 0
+%int_1 = OpConstant %int 1
+%float_0 = OpConstant %float 0
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%in_gep_0 = OpAccessChain %ptr_ssbo_float %in %int_0
+%in_gep_1 = OpAccessChain %ptr_ssbo_float %in %int_1
+%load_0 = OpLoad %float %in_gep_0
+%load_1 = OpLoad %float %in_gep_1
+%sub = OpFSub %float %load_0 %load_1
+%add = OpFAdd %float %float_0 %sub
+%out_gep_0 = OpAccessChain %ptr_ssbo_float %out %int_0
+OpStore %out_gep_0 %add
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<SimplificationPass>(spirv, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index fd98806..82e4040 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -65,17 +65,18 @@
#define TestMultipleInstancesOfTheSameType(ty, ...) \
TestMultipleInstancesOfTheSameTypeQualified(ty, Simple, __VA_ARGS__)
-TestMultipleInstancesOfTheSameType(Void);
-TestMultipleInstancesOfTheSameType(Bool);
-TestMultipleInstancesOfTheSameType(Integer, 32, true);
-TestMultipleInstancesOfTheSameType(Float, 64);
-TestMultipleInstancesOfTheSameType(Vector, u32_t_.get(), 3);
-TestMultipleInstancesOfTheSameType(Matrix, v3u32_t_.get(), 4);
+// clang-format off
+TestMultipleInstancesOfTheSameType(Void)
+TestMultipleInstancesOfTheSameType(Bool)
+TestMultipleInstancesOfTheSameType(Integer, 32, true)
+TestMultipleInstancesOfTheSameType(Float, 64)
+TestMultipleInstancesOfTheSameType(Vector, u32_t_.get(), 3)
+TestMultipleInstancesOfTheSameType(Matrix, v3u32_t_.get(), 4)
TestMultipleInstancesOfTheSameType(Image, f64_t_.get(), SpvDimCube, 0, 0, 1, 1,
SpvImageFormatRgb10A2,
- SpvAccessQualifierWriteOnly);
-TestMultipleInstancesOfTheSameType(Sampler);
-TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get());
+ SpvAccessQualifierWriteOnly)
+TestMultipleInstancesOfTheSameType(Sampler)
+TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get())
// There are three classes of arrays, based on the kinds of length information
// they have.
// 1. Array length is a constant or spec constant without spec ID, with literals
@@ -85,34 +86,35 @@
{
0,
9999,
- }});
+ }})
// 2. Array length is a spec constant with a given spec id.
TestMultipleInstancesOfTheSameTypeQualified(Array, LenSpecId, u32_t_.get(),
- Array::LengthInfo{42, {1, 99}});
+ Array::LengthInfo{42, {1, 99}})
// 3. Array length is an OpSpecConstantOp expression
TestMultipleInstancesOfTheSameTypeQualified(Array, LenDefiningId, u32_t_.get(),
- Array::LengthInfo{42, {2, 42}});
+ Array::LengthInfo{42, {2, 42}})
-TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get());
+TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get())
TestMultipleInstancesOfTheSameType(Struct, std::vector<const Type*>{
- u32_t_.get(), f64_t_.get()});
-TestMultipleInstancesOfTheSameType(Opaque, "testing rocks");
-TestMultipleInstancesOfTheSameType(Pointer, u32_t_.get(), SpvStorageClassInput);
+ u32_t_.get(), f64_t_.get()})
+TestMultipleInstancesOfTheSameType(Opaque, "testing rocks")
+TestMultipleInstancesOfTheSameType(Pointer, u32_t_.get(), SpvStorageClassInput)
TestMultipleInstancesOfTheSameType(Function, u32_t_.get(),
- {f64_t_.get(), f64_t_.get()});
-TestMultipleInstancesOfTheSameType(Event);
-TestMultipleInstancesOfTheSameType(DeviceEvent);
-TestMultipleInstancesOfTheSameType(ReserveId);
-TestMultipleInstancesOfTheSameType(Queue);
-TestMultipleInstancesOfTheSameType(Pipe, SpvAccessQualifierReadWrite);
-TestMultipleInstancesOfTheSameType(ForwardPointer, 10, SpvStorageClassUniform);
-TestMultipleInstancesOfTheSameType(PipeStorage);
-TestMultipleInstancesOfTheSameType(NamedBarrier);
-TestMultipleInstancesOfTheSameType(AccelerationStructureNV);
+ {f64_t_.get(), f64_t_.get()})
+TestMultipleInstancesOfTheSameType(Event)
+TestMultipleInstancesOfTheSameType(DeviceEvent)
+TestMultipleInstancesOfTheSameType(ReserveId)
+TestMultipleInstancesOfTheSameType(Queue)
+TestMultipleInstancesOfTheSameType(Pipe, SpvAccessQualifierReadWrite)
+TestMultipleInstancesOfTheSameType(ForwardPointer, 10, SpvStorageClassUniform)
+TestMultipleInstancesOfTheSameType(PipeStorage)
+TestMultipleInstancesOfTheSameType(NamedBarrier)
+TestMultipleInstancesOfTheSameType(AccelerationStructureNV)
#undef TestMultipleInstanceOfTheSameType
#undef TestMultipleInstanceOfTheSameTypeQual
std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
+ // clang-format on
// Types in this test case are only equal to themselves, nothing else.
std::vector<std::unique_ptr<Type>> types;
diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp
new file mode 100644
index 0000000..df1b865
--- /dev/null
+++ b/test/opt/wrap_opkill_test.cpp
@@ -0,0 +1,267 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using WrapOpKillTest = PassTest<::testing::Test>;
+
+TEST_F(WrapOpKillTest, SingleOpKill) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpUnreachable
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %5
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpLoopMerge %10 %11 None
+ OpBranch %12
+ %12 = OpLabel
+ OpBranchConditional %true %13 %10
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %14 = OpFunctionCall %void %kill_
+ OpBranch %9
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %kill_ = OpFunction %void None %5
+ %15 = OpLabel
+ OpKill
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleOpKillInSameFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpSelectionMerge
+; CHECK-NEXT: OpBranchConditional
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpUnreachable
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpUnreachable
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %5
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpLoopMerge %10 %11 None
+ OpBranch %12
+ %12 = OpLabel
+ OpBranchConditional %true %13 %10
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %14 = OpFunctionCall %void %kill_
+ OpBranch %9
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %kill_ = OpFunction %void None %5
+ %15 = OpLabel
+ OpSelectionMerge %16 None
+ OpBranchConditional %true %17 %18
+ %17 = OpLabel
+ OpKill
+ %18 = OpLabel
+ OpKill
+ %16 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleOpKillInDifferentFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]]
+; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]]
+; CHECK: [[orig_kill1]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpUnreachable
+; CHECK: [[orig_kill2]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpUnreachable
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %4
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ OpLoopMerge %9 %10 None
+ OpBranch %11
+ %11 = OpLabel
+ OpBranchConditional %true %12 %9
+ %12 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ %13 = OpFunctionCall %void %14
+ %15 = OpFunctionCall %void %16
+ OpBranch %8
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %14 = OpFunction %void None %4
+ %17 = OpLabel
+ OpKill
+ OpFunctionEnd
+ %16 = OpFunction %void None %4
+ %18 = OpLabel
+ OpKill
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, IdBoundOverflow1) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194302 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+TEST_F(WrapOpKillTest, IdBoundOverflow2) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194301 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+TEST_F(WrapOpKillTest, IdBoundOverflow3) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194300 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt
index 964abdd..2d3b378 100644
--- a/test/reduce/CMakeLists.txt
+++ b/test/reduce/CMakeLists.txt
@@ -24,6 +24,7 @@
remove_block_test.cpp
remove_function_test.cpp
remove_opname_instruction_test.cpp
+ remove_relaxed_precision_decoration_test.cpp
remove_selection_test.cpp
remove_unreferenced_instruction_test.cpp
structured_loop_to_selection_test.cpp
diff --git a/test/reduce/remove_relaxed_precision_decoration_test.cpp b/test/reduce/remove_relaxed_precision_decoration_test.cpp
new file mode 100644
index 0000000..f9ff081
--- /dev/null
+++ b/test/reduce/remove_relaxed_precision_decoration_test.cpp
@@ -0,0 +1,177 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/reduce/remove_relaxed_precision_decoration_opportunity_finder.h"
+
+#include "source/opt/build_module.h"
+#include "source/reduce/reduction_opportunity.h"
+#include "source/reduce/reduction_pass.h"
+#include "test/reduce/reduce_test_util.h"
+
+namespace spvtools {
+namespace reduce {
+namespace {
+
+TEST(RemoveRelaxedPrecisionDecorationTest, NothingToRemove) {
+ const std::string source = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context =
+ BuildModule(env, consumer, source, kReduceAssembleOption);
+ const auto ops = RemoveRelaxedPrecisionDecorationOpportunityFinder()
+ .GetAvailableOpportunities(context.get());
+ ASSERT_EQ(0, ops.size());
+}
+
+TEST(RemoveRelaxedPrecisionDecorationTest, RemoveDecorations) {
+ const std::string source = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %8 "f"
+ OpName %12 "i"
+ OpName %16 "v"
+ OpName %19 "S"
+ OpMemberName %19 0 "a"
+ OpMemberName %19 1 "b"
+ OpMemberName %19 2 "c"
+ OpName %21 "s"
+ OpDecorate %8 RelaxedPrecision
+ OpDecorate %12 RelaxedPrecision
+ OpDecorate %16 RelaxedPrecision
+ OpDecorate %17 RelaxedPrecision
+ OpDecorate %18 RelaxedPrecision
+ OpMemberDecorate %19 0 RelaxedPrecision
+ OpMemberDecorate %19 1 RelaxedPrecision
+ OpMemberDecorate %19 2 RelaxedPrecision
+ OpDecorate %22 RelaxedPrecision
+ OpDecorate %23 RelaxedPrecision
+ OpDecorate %24 RelaxedPrecision
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 2
+ %10 = OpTypeInt 32 1
+ %11 = OpTypePointer Function %10
+ %13 = OpConstant %10 22
+ %14 = OpTypeVector %6 2
+ %15 = OpTypePointer Function %14
+ %19 = OpTypeStruct %10 %6 %14
+ %20 = OpTypePointer Function %19
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %8 = OpVariable %7 Function
+ %12 = OpVariable %11 Function
+ %16 = OpVariable %15 Function
+ %21 = OpVariable %20 Function
+ OpStore %8 %9
+ OpStore %12 %13
+ %17 = OpLoad %6 %8
+ %18 = OpCompositeConstruct %14 %17 %17
+ OpStore %16 %18
+ %22 = OpLoad %10 %12
+ %23 = OpLoad %6 %8
+ %24 = OpLoad %14 %16
+ %25 = OpCompositeConstruct %19 %22 %23 %24
+ OpStore %21 %25
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context =
+ BuildModule(env, consumer, source, kReduceAssembleOption);
+ const auto ops = RemoveRelaxedPrecisionDecorationOpportunityFinder()
+ .GetAvailableOpportunities(context.get());
+ ASSERT_EQ(11, ops.size());
+
+ for (auto& op : ops) {
+ ASSERT_TRUE(op->PreconditionHolds());
+ op->TryToApply();
+ }
+
+ const std::string expected = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %8 "f"
+ OpName %12 "i"
+ OpName %16 "v"
+ OpName %19 "S"
+ OpMemberName %19 0 "a"
+ OpMemberName %19 1 "b"
+ OpMemberName %19 2 "c"
+ OpName %21 "s"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 2
+ %10 = OpTypeInt 32 1
+ %11 = OpTypePointer Function %10
+ %13 = OpConstant %10 22
+ %14 = OpTypeVector %6 2
+ %15 = OpTypePointer Function %14
+ %19 = OpTypeStruct %10 %6 %14
+ %20 = OpTypePointer Function %19
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %8 = OpVariable %7 Function
+ %12 = OpVariable %11 Function
+ %16 = OpVariable %15 Function
+ %21 = OpVariable %20 Function
+ OpStore %8 %9
+ OpStore %12 %13
+ %17 = OpLoad %6 %8
+ %18 = OpCompositeConstruct %14 %17 %17
+ OpStore %16 %18
+ %22 = OpLoad %10 %12
+ %23 = OpLoad %6 %8
+ %24 = OpLoad %14 %16
+ %25 = OpCompositeConstruct %19 %22 %23 %24
+ OpStore %21 %25
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ CheckEqual(env, expected, context.get());
+}
+
+} // namespace
+} // namespace reduce
+} // namespace spvtools
diff --git a/test/text_to_binary.annotation_test.cpp b/test/text_to_binary.annotation_test.cpp
index 69a4861..61bdf64 100644
--- a/test/text_to_binary.annotation_test.cpp
+++ b/test/text_to_binary.annotation_test.cpp
@@ -21,6 +21,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -29,7 +30,7 @@
using spvtest::EnumCase;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Combine;
using ::testing::Eq;
diff --git a/test/text_to_binary.debug_test.cpp b/test/text_to_binary.debug_test.cpp
index f9a4645..39ba5c5 100644
--- a/test/text_to_binary.debug_test.cpp
+++ b/test/text_to_binary.debug_test.cpp
@@ -19,6 +19,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -26,7 +27,7 @@
namespace {
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Eq;
diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp
index 84552b5..9408e9a 100644
--- a/test/text_to_binary.extension_test.cpp
+++ b/test/text_to_binary.extension_test.cpp
@@ -22,6 +22,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -30,7 +31,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Combine;
using ::testing::Eq;
diff --git a/test/text_to_binary.mode_setting_test.cpp b/test/text_to_binary.mode_setting_test.cpp
index d1b69dd..8ddf421 100644
--- a/test/text_to_binary.mode_setting_test.cpp
+++ b/test/text_to_binary.mode_setting_test.cpp
@@ -20,6 +20,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -28,7 +29,7 @@
using spvtest::EnumCase;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using ::testing::Combine;
using ::testing::Eq;
using ::testing::TestWithParam;
diff --git a/test/tools/opt/flags.py b/test/tools/opt/flags.py
index a89477c..49e2cab 100644
--- a/test/tools/opt/flags.py
+++ b/test/tools/opt/flags.py
@@ -57,7 +57,7 @@
"""Tests that spirv-opt accepts all valid optimization flags."""
flags = [
- '--ccp', '--cfg-cleanup', '--combine-access-chains', '--compact-ids',
+ '--wrap-opkill', '--ccp', '--cfg-cleanup', '--combine-access-chains', '--compact-ids',
'--convert-local-access-chains', '--copy-propagate-arrays',
'--eliminate-dead-branches',
'--eliminate-dead-code-aggressive', '--eliminate-dead-const',
@@ -76,6 +76,7 @@
'--unify-const'
]
expected_passes = [
+ 'wrap-opkill',
'ccp',
'cfg-cleanup',
'combine-access-chains',
@@ -134,6 +135,7 @@
flags = ['-O']
expected_passes = [
+ 'wrap-opkill',
'eliminate-dead-branches',
'merge-return',
'inline-entry-points-exhaustive',
@@ -181,6 +183,7 @@
flags = ['-Os']
expected_passes = [
+ 'wrap-opkill',
'eliminate-dead-branches',
'merge-return',
'inline-entry-points-exhaustive',
@@ -221,6 +224,7 @@
flags = ['--legalize-hlsl']
expected_passes = [
+ 'wrap-opkill',
'eliminate-dead-branches',
'merge-return',
'inline-entry-points-exhaustive',
diff --git a/test/unit_spirv.cpp b/test/unit_spirv.cpp
index 84ed87a..0854439 100644
--- a/test/unit_spirv.cpp
+++ b/test/unit_spirv.cpp
@@ -15,12 +15,13 @@
#include "test/unit_spirv.h"
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
namespace spvtools {
namespace {
-using spvtest::MakeVector;
+using utils::MakeVector;
using ::testing::Eq;
using Words = std::vector<uint32_t>;
diff --git a/test/unit_spirv.h b/test/unit_spirv.h
index 2244288..3264662 100644
--- a/test/unit_spirv.h
+++ b/test/unit_spirv.h
@@ -133,29 +133,6 @@
return result;
}
-// Encodes a string as a sequence of words, using the SPIR-V encoding.
-inline std::vector<uint32_t> MakeVector(std::string input) {
- std::vector<uint32_t> result;
- uint32_t word = 0;
- size_t num_bytes = input.size();
- // SPIR-V strings are null-terminated. The byte_index == num_bytes
- // case is used to push the terminating null byte.
- for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) {
- const auto new_byte =
- (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0));
- word |= (new_byte << (8 * (byte_index % sizeof(uint32_t))));
- if (3 == (byte_index % sizeof(uint32_t))) {
- result.push_back(word);
- word = 0;
- }
- }
- // Emit a trailing partial word.
- if ((num_bytes + 1) % sizeof(uint32_t)) {
- result.push_back(word);
- }
- return result;
-}
-
// A type for easily creating spv_text_t values, with an implicit conversion to
// spv_text.
struct AutoText {
diff --git a/test/val/val_atomics_test.cpp b/test/val/val_atomics_test.cpp
index 15887eb..57a1187 100644
--- a/test/val/val_atomics_test.cpp
+++ b/test/val/val_atomics_test.cpp
@@ -377,34 +377,34 @@
TEST_F(ValidateAtomics, AtomicLoadWebGPUSuccess) {
const std::string body = R"(
%val1 = OpAtomicLoad %u32 %u32_var %queuefamily %relaxed
-%val2 = OpAtomicLoad %u32 %u32_var %workgroup %relaxed
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
}
+TEST_F(ValidateAtomics, AtomicLoadWebGPUNonQueueFamilyFailure) {
+ const std::string body = R"(
+%val3 = OpAtomicLoad %u32 %u32_var %invocation %relaxed
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Memory Scope is limited to QueueFamilyKHR for "
+ "OpAtomic* operations"));
+}
+
TEST_F(ValidateAtomics, AtomicLoadWebGPUNonRelaxedFailure) {
const std::string body = R"(
%val1 = OpAtomicLoad %u32 %u32_var %queuefamily %acquire
-%val2 = OpAtomicLoad %u32 %u32_var %workgroup %release
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec disallows, for OpAtomic*, any bit masks"));
-}
-
-TEST_F(ValidateAtomics, AtomicLoadWebGPUSequentiallyConsistentFailure) {
- const std::string body = R"(
-%val3 = OpAtomicLoad %u32 %u32_var %invocation %sequentially_consistent
-)";
-
- CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
- ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec disallows, for OpAtomic*, any bit masks"));
+ HasSubstr("no bits may be set for Memory Semantics of OpAtomic* "
+ "instructions"));
}
TEST_F(ValidateAtomics, VK_KHR_shader_atomic_int64Success) {
@@ -592,6 +592,17 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
}
+TEST_F(ValidateAtomics, AtomicStoreWebGPUNonQueueFamilyFailure) {
+ const std::string body = R"(
+OpAtomicStore %u32_var %workgroup %relaxed %u32_1
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Memory Scope is limited to QueueFamilyKHR for "
+ "OpAtomic* operations"));
+}
TEST_F(ValidateAtomics, AtomicStoreWebGPUNonRelaxedFailure) {
const std::string body = R"(
@@ -601,18 +612,8 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec disallows, for OpAtomic*, any bit masks"));
-}
-
-TEST_F(ValidateAtomics, AtomicStoreWebGPUSequentiallyConsistent) {
- const std::string body = R"(
-OpAtomicStore %u32_var %queuefamily %sequentially_consistent %u32_1
-)";
-
- CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
- ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec disallows, for OpAtomic*, any bit masks"));
+ HasSubstr("no bits may be set for Memory Semantics of OpAtomic* "
+ "instructions"));
}
TEST_F(ValidateAtomics, AtomicStoreWrongPointerType) {
@@ -1919,11 +1920,9 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to "
- "Workgroup, Invocation, and QueueFamilyKHR\n"
- " %34 = OpAtomicLoad %uint %29 %uint_0_0 %uint_0_1\n"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "QueueFamilyKHR for OpAtomic* operations"));
}
TEST_F(ValidateAtomics, WebGPUDeviceMemoryScopeBad) {
@@ -1933,20 +1932,21 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to "
- "Workgroup, Invocation, and QueueFamilyKHR\n"
- " %34 = OpAtomicLoad %uint %29 %uint_1_0 %uint_0_1\n"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "QueueFamilyKHR for OpAtomic* operations"));
}
-TEST_F(ValidateAtomics, WebGPUWorkgroupMemoryScopeGood) {
+TEST_F(ValidateAtomics, WebGPUWorkgroupMemoryScopeBad) {
const std::string body = R"(
%val1 = OpAtomicLoad %u32 %u32_var %workgroup %relaxed
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
- EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "QueueFamilyKHR for OpAtomic* operations"));
}
TEST_F(ValidateAtomics, WebGPUSubgroupMemoryScopeBad) {
@@ -1956,20 +1956,21 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to "
- "Workgroup, Invocation, and QueueFamilyKHR\n"
- " %34 = OpAtomicLoad %uint %29 %uint_3 %uint_0_1\n"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "QueueFamilyKHR for OpAtomic* operations"));
}
-TEST_F(ValidateAtomics, WebGPUInvocationMemoryScopeGood) {
+TEST_F(ValidateAtomics, WebGPUInvocationMemoryScopeBad) {
const std::string body = R"(
%val1 = OpAtomicLoad %u32 %u32_var %invocation %relaxed
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
- EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "QueueFamilyKHR for OpAtomic* operations"));
}
TEST_F(ValidateAtomics, WebGPUQueueFamilyMemoryScopeGood) {
diff --git a/test/val/val_barriers_test.cpp b/test/val/val_barriers_test.cpp
index 2214197..18f57f8 100644
--- a/test/val/val_barriers_test.cpp
+++ b/test/val/val_barriers_test.cpp
@@ -82,9 +82,12 @@
%release_uniform_workgroup = OpConstant %u32 324
%acquire_and_release_uniform = OpConstant %u32 70
%acquire_release_subgroup = OpConstant %u32 136
+%acquire_release_workgroup = OpConstant %u32 264
%uniform = OpConstant %u32 64
%uniform_workgroup = OpConstant %u32 320
-
+%workgroup_memory = OpConstant %u32 256
+%image_memory = OpConstant %u32 2048
+%uniform_image_memory = OpConstant %u32 2112
%main = OpFunction %void None %func
%main_entry = OpLabel
@@ -251,7 +254,7 @@
TEST_F(ValidateBarriers, OpControlBarrierWebGPUAcquireReleaseSuccess) {
const std::string body = R"(
-OpControlBarrier %workgroup %workgroup %acquire_release_uniform_workgroup
+OpControlBarrier %workgroup %workgroup %acquire_release_workgroup
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
@@ -260,25 +263,39 @@
TEST_F(ValidateBarriers, OpControlBarrierWebGPURelaxedFailure) {
const std::string body = R"(
-OpControlBarrier %workgroup %workgroup %uniform_workgroup
+OpControlBarrier %workgroup %workgroup %workgroup
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec requires AcquireRelease to set"));
+ HasSubstr("For WebGPU, AcquireRelease must be set for Memory "
+ "Semantics of OpControlBarrier"));
}
-TEST_F(ValidateBarriers, OpControlBarrierWebGPUAcquireFailure) {
+TEST_F(ValidateBarriers, OpControlBarrierWebGPUMissingWorkgroupFailure) {
const std::string body = R"(
-OpControlBarrier %workgroup %workgroup %acquire_uniform_workgroup
+OpControlBarrier %workgroup %workgroup %acquire_release
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, WorkgroupMemory must be set for Memory "
+ "Semantics"));
+}
+
+TEST_F(ValidateBarriers, OpControlBarrierWebGPUUniformFailure) {
+ const std::string body = R"(
+OpControlBarrier %workgroup %workgroup %acquire_release_uniform_workgroup
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("WebGPU spec disallows any bit masks in Memory Semantics"));
+ HasSubstr("For WebGPU only WorkgroupMemory and AcquireRelease may be set "
+ "for Memory Semantics of OpControlBarrier."));
}
TEST_F(ValidateBarriers, OpControlBarrierWebGPUReleaseFailure) {
@@ -288,9 +305,9 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("WebGPU spec disallows any bit masks in Memory Semantics"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, AcquireRelease must be set for Memory "
+ "Semantics of OpControlBarrier"));
}
TEST_F(ValidateBarriers, OpControlBarrierExecutionModelFragmentSpirv12) {
@@ -461,6 +478,18 @@
"cannot be CrossDevice"));
}
+TEST_F(ValidateBarriers, OpControlBarrierWebGPUMemoryScopeNonWorkgroup) {
+ const std::string body = R"(
+OpControlBarrier %workgroup %subgroup %acquire_release_workgroup
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("ControlBarrier: in WebGPU environment Memory Scope is "
+ "limited to Workgroup for OpControlBarrier"));
+}
+
TEST_F(ValidateBarriers, OpControlBarrierAcquireAndRelease) {
const std::string body = R"(
OpControlBarrier %device %device %acquire_and_release_uniform
@@ -680,13 +709,37 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
}
-TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUAcquireReleaseSuccess) {
+TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUImageMemorySuccess) {
+ const std::string body = R"(
+OpMemoryBarrier %workgroup %image_memory
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUDeviceFailure) {
+ const std::string body = R"(
+OpMemoryBarrier %subgroup %image_memory
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("in WebGPU environment Memory Scope is limited to "
+ "Workgroup for OpMemoryBarrier"));
+}
+
+TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUAcquireReleaseFailure) {
const std::string body = R"(
OpMemoryBarrier %workgroup %acquire_release_uniform_workgroup
)";
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
- ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("ImageMemory must be set for Memory Semantics of "
+ "OpMemoryBarrier"));
}
TEST_F(ValidateBarriers, OpMemoryBarrierWebGPURelaxedFailure) {
@@ -697,7 +750,8 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("WebGPU spec requires AcquireRelease to set"));
+ HasSubstr("ImageMemory must be set for Memory Semantics of "
+ "OpMemoryBarrier"));
}
TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUAcquireFailure) {
@@ -707,9 +761,9 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("WebGPU spec disallows any bit masks in Memory Semantics"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("ImageMemory must be set for Memory Semantics of "
+ "OpMemoryBarrier"));
}
TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUReleaseFailure) {
@@ -719,9 +773,21 @@
CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(
- getDiagnosticString(),
- HasSubstr("WebGPU spec disallows any bit masks in Memory Semantics"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("ImageMemory must be set for Memory Semantics of "
+ "OpMemoryBarrier"));
+}
+
+TEST_F(ValidateBarriers, OpMemoryBarrierWebGPUUniformFailure) {
+ const std::string body = R"(
+OpMemoryBarrier %workgroup %uniform_image_memory
+)";
+
+ CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("only ImageMemory may be set for Memory Semantics of "
+ "OpMemoryBarrier"));
}
TEST_F(ValidateBarriers, OpMemoryBarrierFloatMemoryScope) {
diff --git a/test/val/val_constants_test.cpp b/test/val/val_constants_test.cpp
index 72ce8df..2499f5c 100644
--- a/test/val/val_constants_test.cpp
+++ b/test/val/val_constants_test.cpp
@@ -442,6 +442,22 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
}
+TEST_F(ValidateConstant, NullMatrix) {
+ std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%float = OpTypeFloat 32
+%v2float = OpTypeVector %float 2
+%mat2x2 = OpTypeMatrix %v2float 2
+%null_vector = OpConstantNull %v2float
+%null_matrix = OpConstantComposite %mat2x2 %null_vector %null_vector
+)";
+
+ CompileSuccessfully(spirv);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/test/val/val_ext_inst_test.cpp b/test/val/val_ext_inst_test.cpp
index 73cb48f..67df43d 100644
--- a/test/val/val_ext_inst_test.cpp
+++ b/test/val/val_ext_inst_test.cpp
@@ -5315,10 +5315,10 @@
TEST_F(ValidateExtInst, OpenCLStdRemquoSuccess) {
const std::string body = R"(
-%var_f32 = OpVariable %f32_ptr_function Function
-%var_f32vec2 = OpVariable %f32vec2_ptr_function Function
-%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_f32
-%val2 = OpExtInst %f32vec2 %extinst remquo %f32vec2_01 %f32vec2_12 %var_f32vec2
+%var_u32 = OpVariable %u32_ptr_function Function
+%var_u32vec2 = OpVariable %u32vec2_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_u32
+%val2 = OpExtInst %f32vec2 %extinst remquo %f32vec2_01 %f32vec2_12 %var_u32vec2
)";
CompileSuccessfully(GenerateKernelCode(body));
@@ -5327,8 +5327,8 @@
TEST_F(ValidateExtInst, OpenCLStdRemquoIntResultType) {
const std::string body = R"(
-%var_f32 = OpVariable %f32_ptr_function Function
-%val1 = OpExtInst %u32 %extinst remquo %f32_3 %f32_2 %var_f32
+%var_u32 = OpVariable %u32_ptr_function Function
+%val1 = OpExtInst %u32 %extinst remquo %f32_3 %f32_2 %var_u32
)";
CompileSuccessfully(GenerateKernelCode(body));
@@ -5341,8 +5341,8 @@
TEST_F(ValidateExtInst, OpenCLStdRemquoXWrongType) {
const std::string body = R"(
-%var_f32 = OpVariable %f32_ptr_function Function
-%val1 = OpExtInst %f32 %extinst remquo %u32_3 %f32_2 %var_f32
+%var_u32 = OpVariable %f32_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %u32_3 %f32_2 %var_u32
)";
CompileSuccessfully(GenerateKernelCode(body));
@@ -5355,8 +5355,8 @@
TEST_F(ValidateExtInst, OpenCLStdRemquoYWrongType) {
const std::string body = R"(
-%var_f32 = OpVariable %f32_ptr_function Function
-%val1 = OpExtInst %f32 %extinst remquo %f32_3 %u32_2 %var_f32
+%var_u32 = OpVariable %f32_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %f32_3 %u32_2 %var_u32
)";
CompileSuccessfully(GenerateKernelCode(body));
@@ -5395,17 +5395,44 @@
TEST_F(ValidateExtInst, OpenCLStdRemquoPointerWrongDataType) {
const std::string body = R"(
-%var_u32 = OpVariable %u32_ptr_function Function
-%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_u32
+%var_f32 = OpVariable %f32_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_f32
+)";
+
+ CompileSuccessfully(GenerateKernelCode(body));
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpenCL.std remquo: "
+ "expected data type of the pointer to be a 32-bit int "
+ "scalar or vector type"));
+}
+
+TEST_F(ValidateExtInst, OpenCLStdRemquoPointerWrongDataTypeWidth) {
+ const std::string body = R"(
+%var_u64 = OpVariable %u64_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_u64
+)";
+ CompileSuccessfully(GenerateKernelCode(body));
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpenCL.std remquo: "
+ "expected data type of the pointer to be a 32-bit int "
+ "scalar or vector type"));
+}
+
+TEST_F(ValidateExtInst, OpenCLStdRemquoPointerWrongNumberOfComponents) {
+ const std::string body = R"(
+%var_u32vec2 = OpVariable %u32vec2_ptr_function Function
+%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_u32vec2
)";
CompileSuccessfully(GenerateKernelCode(body));
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr(
- "OpenCL.std remquo: "
- "expected data type of the pointer to be equal to Result Type"));
+ HasSubstr("OpenCL.std remquo: "
+ "expected data type of the pointer to have the same number "
+ "of components as Result Type"));
}
TEST_P(ValidateOpenCLStdFrexpLike, Success) {
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 299e38e..ec5715c 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1715,7 +1715,7 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpSpecConstantComposite Constituent <id> '7[%7]' is "
- "not a constant composite or undef."));
+ "not a constant or undef."));
}
// Invalid: Composite contains a column that is *not* a vector (it's an array)
diff --git a/tools/fuzz/fuzz.cpp b/tools/fuzz/fuzz.cpp
index 481942f..24e4ac6 100644
--- a/tools/fuzz/fuzz.cpp
+++ b/tools/fuzz/fuzz.cpp
@@ -74,12 +74,12 @@
USAGE: %s [options] <input.spv> -o <output.spv>
The SPIR-V binary is read from <input.spv>, which must have extension .spv. If
-<input.json> is also present, facts about the SPIR-V binary are read from this
+<input.facts> is also present, facts about the SPIR-V binary are read from this
file.
The transformed SPIR-V binary is written to <output.spv>. Human-readable and
-binary representations of the transformations that were applied to obtain this
-binary are written to <output.json> and <output.transformations>, respectively.
+binary representations of the transformations that were applied are written to
+<output.transformations_json> and <output.transformations>, respectively.
NOTE: The fuzzer is a work in progress.
@@ -472,7 +472,8 @@
return 1;
}
- std::ofstream transformations_json_file(output_file_prefix + ".json");
+ std::ofstream transformations_json_file(output_file_prefix +
+ ".transformations_json");
transformations_json_file << json_string;
transformations_json_file.close();
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index cce1053..b229c84 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -109,6 +109,11 @@
Options (in lexicographical order):)",
program, program);
printf(R"(
+ --amd-ext-to-khr
+ Replaces the extensions VK_AMD_shader_ballot, VK_AMD_gcn_shader,
+ and VK_AMD_shader_trinary_minmax with equivalant code using core
+ instructions and capabilities.)");
+ printf(R"(
--ccp
Apply the conditional constant propagation transform. This will
propagate constant values throughout the program, and simplify
@@ -147,6 +152,15 @@
around known issues with some Vulkan drivers for initialize
variables.)");
printf(R"(
+ --descriptor-scalar-replacement
+ Replaces every array variable |desc| that has a DescriptorSet
+ and Binding decorations with a new variable for each element of
+ the array. Suppose |desc| was bound at binding |b|. Then the
+ variable corresponding to |desc[i]| will have binding |b+i|.
+ The descriptor set will be the same. All accesses to |desc|
+ must be in OpAccessChain instructions with a literal index for
+ the first index.)");
+ printf(R"(
--eliminate-dead-branches
Convert conditional branches with constant condition to the
indicated unconditional brranch. Delete all resulting dead
@@ -206,6 +220,11 @@
Freeze the values of specialization constants to their default
values.)");
printf(R"(
+ --graphics-robust-access
+ Clamp indices used to access buffers and internal composite
+ values, providing guarantees that satisfy Vulkan's
+ robustBufferAccess rules.)");
+ printf(R"(
--generate-webgpu-initializers
Adds initial values to OpVariable instructions that are missing
them, due to their storage type requiring them for WebGPU.)");