Squashed 'third_party/SPIRV-Tools/' changes from fd773eb50d6..55af3902fc2 55af3902fc2 Fix function use (#3372) 9cb2571a184 spirv-val: allow DebugInfoNone for DebugTypeComposite.Size (#3374) 4386ef4234e Add validation support for ImageGatherBiasLodAMD (#3363) b0264b87ffb Fix validation failure on OpDecorationGroup (#3365) 4410272bdda Remove deprecated interfaces from instrument passes (#3361) 50b15578866 Preserve debug info in inline pass (#3349) 4dbe18b0c86 Reject folding comparisons with unfoldable types. (#3370) 55193b06e5e Improve build instructions for fuzzer (#3364) 3c47dac2820 Add unrolling to performance passes (#3082) 2b987c49a4e Handle OpConstantNull in ssa-rewrite (#3362) 95df4c9643c Add in a bunch of missed files to the BUILD.gn (#3360) 90930cb3115 Remove stale entries from BUILD.gn (#3358) 18ba3d9a353 allow cross compiling for Windows Store, UWP, etc. (#3330) 2f69ea849aa spirv-fuzz: Remove FuzzerPassAddUsefulConstructs (#3341) 522561619a9 Add support for StorageBuffer (#3348) b75dbf82a69 Prevent Effcee install his things when build spirv-tools with testing enabled (#3256) 85c7e7956bf Don't register edges twice in merge return (#3350) bd0a2da946c Revert "Revert "[spirv-opt] refactor inlining pass (#3328)" (#3342)" (#3345) 31182763704 spirv-reduce: Remove unused struct members (#3329) a6b0e132ecc Add adjust branch weights transformation (#3336) d4fac3451b7 Revert "[spirv-opt] refactor inlining pass (#3328)" (#3342) 233246bc9c5 [spirv-opt] refactor inlining pass (#3328) 2992386ebea spirv-reduce: Remove unused uniforms and similar (#3321) a9f2a145e65 spirv-fuzz: Fix to fact manager (#3339) 045a26e6e37 spirv-fuzz: Get rid of unnecessary template method (#3340) 63fa9114a93 Do merge return if the return is not at the end of the function. (#3337) c8590c18bd0 Preserve debug info for wrap-opkill (#3331) d2b48621949 Validate ShaderCallKHR memory scope (#3332) 2e1d208ed9d spirv-fuzz: Do not allow adding stores to read-only pointers (#3316) 54fb17b2d30 reduce: increase default step limit (#3327) 49842b88eec Generalize IsReadOnlyVariable() to apply to pointers (#3325) 49ca250b44c Delete nullptr in function bb list immedietly (#3326) d0a87194f7b Set DebugScope for termination instructions (#3323) f278b467dfd spirv-fuzz: Do not outline regions that end with a loop header (#3312) 23d68608b00 vscode: Handle '|' chains on BitEnum / ValueEnum (#3309) 42268740c95 Add debug information analysis (#3305) eed48ae479d Add spvtools::opt::Operand::AsLiteralUint64 (#3320) 94d6002dc53 spirv-fuzz: Pass on validator options during shrinking (#3317) 88faf63ad3c spirv-fuzz: Clamp statically out-of-bounds accesses in code donation (#3315) b74199a22d4 spirv-fuzz: Fix memory management in the fact manager (#3313) d158ffe5405 spirv-fuzz: Do not replace the Sample argument in OpImageTexelPointer (#3311) 5547553a0c7 Allow various validation options to be passed to spirv-opt (#3314) 30ffe62e257 typo fix: in README.md exectuable->executable (#3306) 67f4838659f spirv-fuzz: Make handling of synonym facts more efficient (#3301) 61b7de3c39f Remove unreachable code. (#3304) ed96301c6c4 spirv-fuzz: Fix to outliner (#3302) c018fc6ae66 spirv-fuzz: Do not outline regions that produce pointer outputs (#3291) f460cca9dca spirv-fuzz: Handle OpRuntimeArray when replacing ids with synonyms (#3292) 2f180468a71 spirv-fuzz: Handle image storage class in donation (#3290) f82d47003e7 spirv-fuzz: Respect rules for OpSampledImage (#3287) 7ce2db1763b spirv-fuzz: Fix comment. (#3300) 7d65bce0bbe Sampled images as read-only storage (#3295) 2a2bdbd5d72 Remove implicit fallthrough (#3298) 49566448944 Add tests for recently added command line option (#3297) ca5751590ed If SPIRV-Headers is in our tree, include it as subproject (#3299) e70d25f6fa5 Struct CFG analysus and single block loop (#3293) 000040e707a Preserve debug info in eliminate-dead-functions (#3251) c531099eb34 Update acorn version (#3294) 34be23373b9 Handle more cases in dead member elim (#3289) d0490ef080c Fix pch macro to ignore clang-cl (#3283) 538512e8e89 spirv-fuzz: Improve the handling of equation facts (#3281) 183e3242a36 spirv-fuzz: Handle more general SPIR-V in donation (#3280) 4af38c49bfe spirv-fuzz: Improve support for compute shaders in donation (#3277) e95fbfb1f50 spirv-fuzz: Transformation to add OpConstantNull (#3273) 5d491a7ed66 spirv-fuzz: Handle isomorphic types property in composite construction (#3262) bfd25ace084 spirv-fuzz: Limit adding of new variables to 'basic' types (#3257) f28cdeff16f spirv-fuzz: Only replace regular ids with synonyms (#3255) 8d4261bc440 spirv-fuzz: Introduce TransformationContext (#3272) 2fdea57d19d spirv-fuzz: Add validator options (#3254) af01d57b5e3 Update dominates to check for null nodes (#3271) f20c0d7971c Set wrapped kill basic block's parent (#3269) c37c94929bf Validate Buffer and BufferBlock apply only to struct types (#3259) git-subtree-dir: third_party/SPIRV-Tools git-subtree-split: 55af3902fc24db889b0ef8010a83efca04a6352c
diff --git a/Android.mk b/Android.mk index eec709a..5c495cd 100644 --- a/Android.mk +++ b/Android.mk
@@ -95,6 +95,7 @@ source/opt/dead_variable_elimination.cpp \ source/opt/decompose_initialized_variables_pass.cpp \ source/opt/decoration_manager.cpp \ + source/opt/debug_info_manager.cpp \ source/opt/def_use_manager.cpp \ source/opt/desc_sroa.cpp \ source/opt/dominator_analysis.cpp \
diff --git a/BUILD.gn b/BUILD.gn index d3107fd..fae7957 100644 --- a/BUILD.gn +++ b/BUILD.gn
@@ -542,6 +542,8 @@ "source/opt/decompose_initialized_variables_pass.h", "source/opt/decoration_manager.cpp", "source/opt/decoration_manager.h", + "source/opt/debug_info_manager.cpp", + "source/opt/debug_info_manager.h", "source/opt/def_use_manager.cpp", "source/opt/def_use_manager.h", "source/opt/desc_sroa.cpp", @@ -790,8 +792,12 @@ "source/reduce/remove_selection_reduction_opportunity.h", "source/reduce/remove_selection_reduction_opportunity_finder.cpp", "source/reduce/remove_selection_reduction_opportunity_finder.h", - "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp", - "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h", + "source/reduce/remove_struct_member_reduction_opportunity.cpp", + "source/reduce/remove_struct_member_reduction_opportunity.h", + "source/reduce/remove_unused_instruction_reduction_opportunity_finder.cpp", + "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h", + "source/reduce/remove_unused_struct_member_reduction_opportunity_finder.cpp", + "source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h", "source/reduce/simple_conditional_branch_to_branch_opportunity_finder.cpp", "source/reduce/simple_conditional_branch_to_branch_opportunity_finder.h", "source/reduce/simple_conditional_branch_to_branch_reduction_opportunity.cpp",
diff --git a/CMakeLists.txt b/CMakeLists.txt index ef9ad11..73248f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -40,7 +40,7 @@ set(SPIRV_TIMER_ENABLED ${SPIRV_ALLOW_TIMERS}) elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Emscripten") add_definitions(-DSPIRV_EMSCRIPTEN) -elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Windows") +elseif("${CMAKE_SYSTEM_NAME}" MATCHES "Windows") add_definitions(-DSPIRV_WINDOWS) elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "CYGWIN") add_definitions(-DSPIRV_WINDOWS) @@ -249,7 +249,7 @@ # Precompiled header macro. Parameters are source file list and filename for pch cpp file. macro(spvtools_pch SRCS PCHPREFIX) - if(MSVC AND CMAKE_GENERATOR MATCHES "^Visual Studio") + if(MSVC AND CMAKE_GENERATOR MATCHES "^Visual Studio" AND NOT "${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") set(PCH_NAME "$(IntDir)\\${PCHPREFIX}.pch") # make source files use/depend on PCH_NAME set_source_files_properties(${${SRCS}} PROPERTIES COMPILE_FLAGS "/Yu${PCHPREFIX}.h /FI${PCHPREFIX}.h /Fp${PCH_NAME} /Zm300" OBJECT_DEPENDS "${PCH_NAME}")
diff --git a/README.md b/README.md index 5714976..c82ca19 100644 --- a/README.md +++ b/README.md
@@ -324,8 +324,7 @@ contents of the repos under `external/` instead of manually maintaining them. ### Build using CMake -You can build The project using [CMake][cmake] to generate platform-specific -build configurations. +You can build the project using [CMake][cmake]: ```sh cd <spirv-dir> @@ -333,8 +332,36 @@ cmake [-G <platform-generator>] <spirv-dir> ``` -Once the build files have been generated, build using your preferred -development environment. +Once the build files have been generated, build using the appropriate build +command (e.g. `ninja`, `make`, `msbuild`, etc.; this depends on the platform +generator used above), or use your IDE, or use CMake to run the appropriate build +command for you: + +```sh +cmake --build . [--config Debug] # runs `make` or `ninja` or `msbuild` etc. +``` + +#### Note about the fuzzer + +The SPIR-V fuzzer, `spirv-fuzz`, can only be built via CMake, and is disabled by +default. To build it, clone protobuf and use the `SPIRV_BUILD_FUZZER` CMake +option, like so: + +```sh +# In <spirv-dir> (the SPIRV-Tools repo root): +git clone https://github.com/protocolbuffers/protobuf external/protobuf +pushd external/protobuf +git checkout v3.7.1 +popd + +# In your build directory: +cmake [-G <platform-generator>] <spirv-dir> -DSPIRV_BUILD_FUZZER=ON +cmake --build . --config Debug +``` + +You can also add `-DSPIRV_ENABLE_LONG_FUZZER_TESTS=ON` to build additional +fuzzer tests. + ### Build using Bazel You can also use [Bazel](https://bazel.build/) to build the project. @@ -480,7 +507,7 @@ The assembler reads the assembly language text, and emits the binary form. -The standalone assembler is the exectuable called `spirv-as`, and is located in +The standalone assembler is the executable called `spirv-as`, and is located in `<spirv-build-dir>/tools/spirv-as`. The functionality of the assembler is implemented by the `spvTextToBinary` library function.
diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 8bde13c..56dd54f 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt
@@ -25,7 +25,17 @@ endif() if (IS_DIRECTORY ${SPIRV_HEADER_DIR}) + # TODO(dneto): We should not be modifying the parent scope. set(SPIRV_HEADER_INCLUDE_DIR ${SPIRV_HEADER_DIR}/include PARENT_SCOPE) + + # Add SPIRV-Headers as a sub-project if it isn't already defined. + # Do this so enclosing projects can use SPIRV-Headers_SOURCE_DIR to find + # headers to include. + if (NOT DEFINED SPIRV-Headers_SOURCE_DIR) + set(SPIRV_HEADERS_SKIP_INSTALL ON) + set(SPIRV_HEADERS_SKIP_EXAMPLES ON) + add_subdirectory(${SPIRV_HEADER_DIR}) + endif() else() message(FATAL_ERROR "SPIRV-Headers was not found - please checkout a copy under external/.") @@ -78,7 +88,7 @@ set(RE2_BUILD_TESTING OFF CACHE STRING "Run RE2 Tests") if (NOT RE2_SOURCE_DIR) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/re2) - set(RE2_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/re2" CACHE STRING "RE2 source dir" ) + set(RE2_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/re2" CACHE STRING "RE2 source dir" ) endif() endif() endif() @@ -88,13 +98,17 @@ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/effcee) # If we're configuring RE2 (via Effcee), then turn off RE2 testing. if (NOT TARGET re2) - set(RE2_BUILD_TESTING OFF) + set(RE2_BUILD_TESTING OFF) endif() if (MSVC) - # SPIRV-Tools uses the shared CRT with MSVC. Tell Effcee to do the same. - set(EFFCEE_ENABLE_SHARED_CRT ON) + # SPIRV-Tools uses the shared CRT with MSVC. Tell Effcee to do the same. + set(EFFCEE_ENABLE_SHARED_CRT ON) endif() - add_subdirectory(effcee) + set(EFFCEE_BUILD_SAMPLES OFF CACHE BOOL "Do not build Effcee examples") + if (NOT TARGET effcee) + set(EFFCEE_BUILD_TESTING OFF CACHE BOOL "Do not build Effcee test suite") + endif() + add_subdirectory(effcee EXCLUDE_FROM_ALL) set_property(TARGET effcee PROPERTY FOLDER Effcee) # Turn off warnings for effcee and re2 set_property(TARGET effcee APPEND PROPERTY COMPILE_OPTIONS -w)
diff --git a/include/spirv-tools/instrument.hpp b/include/spirv-tools/instrument.hpp index d3180e4..ef5136a 100644 --- a/include/spirv-tools/instrument.hpp +++ b/include/spirv-tools/instrument.hpp
@@ -35,10 +35,6 @@ // generated by InstrumentPass::GenDebugStreamWrite. This method is utilized // by InstBindlessCheckPass. // -// kInst2* values support version 2 of the output record format and were used -// for the transition to this format. These values have now been transferred -// to the original kInst* values. The kInst2* values are therefore DEPRECATED. -// // The first member of the debug output buffer contains the next available word // in the data stream to be written. Shaders will atomically read and update // this value so as not to overwrite each others records. This value must be @@ -94,10 +90,6 @@ static const int kInstCompOutGlobalInvocationIdY = kInstCommonOutCnt + 1; static const int kInstCompOutGlobalInvocationIdZ = kInstCommonOutCnt + 2; -// Compute Shader Output Record Offsets - Version 1 (DEPRECATED) -static const int kInstCompOutGlobalInvocationId = kInstCommonOutCnt; -static const int kInstCompOutUnused = kInstCommonOutCnt + 1; - // Tessellation Control Shader Output Record Offsets static const int kInstTessCtlOutInvocationId = kInstCommonOutCnt; static const int kInstTessCtlOutPrimitiveId = kInstCommonOutCnt + 1; @@ -108,10 +100,6 @@ static const int kInstTessEvalOutTessCoordU = kInstCommonOutCnt + 1; static const int kInstTessEvalOutTessCoordV = kInstCommonOutCnt + 2; -// Tessellation Shader Output Record Offsets - Version 1 (DEPRECATED) -static const int kInstTessOutInvocationId = kInstCommonOutCnt; -static const int kInstTessOutUnused = kInstCommonOutCnt + 1; - // Geometry Shader Output Record Offsets static const int kInstGeomOutPrimitiveId = kInstCommonOutCnt; static const int kInstGeomOutInvocationId = kInstCommonOutCnt + 1; @@ -124,14 +112,12 @@ // Size of Common and Stage-specific Members static const int kInstStageOutCnt = kInstCommonOutCnt + 3; -static const int kInst2StageOutCnt = kInstCommonOutCnt + 3; // Validation Error Code Offset // // This identifies the validation error. It also helps to identify // how many words follow in the record and their meaning. static const int kInstValidationOutError = kInstStageOutCnt; -static const int kInst2ValidationOutError = kInst2StageOutCnt; // Validation-specific Output Record Offsets // @@ -144,37 +130,19 @@ static const int kInstBindlessBoundsOutDescBound = kInstStageOutCnt + 2; static const int kInstBindlessBoundsOutCnt = kInstStageOutCnt + 3; -static const int kInst2BindlessBoundsOutDescIndex = kInst2StageOutCnt + 1; -static const int kInst2BindlessBoundsOutDescBound = kInst2StageOutCnt + 2; -static const int kInst2BindlessBoundsOutCnt = kInst2StageOutCnt + 3; - // A bindless uninitialized error will output the index. static const int kInstBindlessUninitOutDescIndex = kInstStageOutCnt + 1; static const int kInstBindlessUninitOutUnused = kInstStageOutCnt + 2; static const int kInstBindlessUninitOutCnt = kInstStageOutCnt + 3; -static const int kInst2BindlessUninitOutDescIndex = kInst2StageOutCnt + 1; -static const int kInst2BindlessUninitOutUnused = kInst2StageOutCnt + 2; -static const int kInst2BindlessUninitOutCnt = kInst2StageOutCnt + 3; - // A buffer address unalloc error will output the 64-bit pointer in // two 32-bit pieces, lower bits first. static const int kInstBuffAddrUnallocOutDescPtrLo = kInstStageOutCnt + 1; static const int kInstBuffAddrUnallocOutDescPtrHi = kInstStageOutCnt + 2; static const int kInstBuffAddrUnallocOutCnt = kInstStageOutCnt + 3; -static const int kInst2BuffAddrUnallocOutDescPtrLo = kInst2StageOutCnt + 1; -static const int kInst2BuffAddrUnallocOutDescPtrHi = kInst2StageOutCnt + 2; -static const int kInst2BuffAddrUnallocOutCnt = kInst2StageOutCnt + 3; - -// DEPRECATED -static const int kInstBindlessOutDescIndex = kInstStageOutCnt + 1; -static const int kInstBindlessOutDescBound = kInstStageOutCnt + 2; -static const int kInstBindlessOutCnt = kInstStageOutCnt + 3; - // Maximum Output Record Member Count static const int kInstMaxOutCnt = kInstStageOutCnt + 3; -static const int kInst2MaxOutCnt = kInst2StageOutCnt + 3; // Validation Error Codes // @@ -223,9 +191,6 @@ // Data[ i + Data[ b + Data[ s + Data[ kDebugInputBindlessInitOffset ] ] ] ] static const int kDebugInputBindlessInitOffset = 0; -// DEPRECATED -static const int kDebugInputBindlessOffsetReserved = 0; - // At offset kDebugInputBindlessOffsetLengths is some number of uints which // provide the bindless length data. More specifically, the number of // descriptors at (set=s, binding=b) is:
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index b904923..d393495 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp
@@ -762,10 +762,9 @@ // |input_length_enable| controls instrumentation of runtime descriptor array // references, and |input_init_enable| controls instrumentation of descriptor // initialization checking, both of which require input buffer support. -// |version| specifies the buffer record format. Optimizer::PassToken CreateInstBindlessCheckPass( uint32_t desc_set, uint32_t shader_id, bool input_length_enable = false, - bool input_init_enable = false, uint32_t version = 2); + bool input_init_enable = false); // Create a pass to instrument physical buffer address checking // This pass instruments all physical buffer address references to check that @@ -786,10 +785,8 @@ // The instrumentation will read and write buffers in debug // descriptor set |desc_set|. It will write |shader_id| in each output record // to identify the shader module which generated the record. -// |version| specifies the output buffer record format. Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set, - uint32_t shader_id, - uint32_t version = 2); + uint32_t shader_id); // Create a pass to instrument OpDebugPrintf instructions. // This pass replaces all OpDebugPrintf instructions with instructions to write
diff --git a/kokoro/scripts/linux/build.sh b/kokoro/scripts/linux/build.sh index f7b0fe1..8fb7bdd 100644 --- a/kokoro/scripts/linux/build.sh +++ b/kokoro/scripts/linux/build.sh
@@ -46,7 +46,7 @@ ADDITIONAL_CMAKE_FLAGS="" if [ $CONFIG = "ASAN" ] then - ADDITIONAL_CMAKE_FLAGS="SPIRV_USE_SANITIZER=address" + ADDITIONAL_CMAKE_FLAGS="SPIRV_USE_SANITIZER=address,bounds,null" [ $COMPILER = "clang" ] || { echo "$CONFIG requires clang"; exit 1; } elif [ $CONFIG = "COVERAGE" ] then
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt index 3a9d604..6582927 100644 --- a/source/fuzz/CMakeLists.txt +++ b/source/fuzz/CMakeLists.txt
@@ -49,7 +49,7 @@ fuzzer_pass_add_local_variables.h fuzzer_pass_add_no_contraction_decorations.h fuzzer_pass_add_stores.h - fuzzer_pass_add_useful_constructs.h + fuzzer_pass_adjust_branch_weights.h fuzzer_pass_adjust_function_controls.h fuzzer_pass_adjust_loop_controls.h fuzzer_pass_adjust_memory_operands_masks.h @@ -79,6 +79,7 @@ transformation_access_chain.h transformation_add_constant_boolean.h transformation_add_constant_composite.h + transformation_add_constant_null.h transformation_add_constant_scalar.h transformation_add_dead_block.h transformation_add_dead_break.h @@ -97,8 +98,11 @@ transformation_add_type_pointer.h transformation_add_type_struct.h transformation_add_type_vector.h + transformation_adjust_branch_weights.h transformation_composite_construct.h transformation_composite_extract.h + transformation_compute_data_synonym_fact_closure.h + transformation_context.h transformation_copy_object.h transformation_equation_instruction.h transformation_function_call.h @@ -141,7 +145,7 @@ fuzzer_pass_add_local_variables.cpp fuzzer_pass_add_no_contraction_decorations.cpp fuzzer_pass_add_stores.cpp - fuzzer_pass_add_useful_constructs.cpp + fuzzer_pass_adjust_branch_weights.cpp fuzzer_pass_adjust_function_controls.cpp fuzzer_pass_adjust_loop_controls.cpp fuzzer_pass_adjust_memory_operands_masks.cpp @@ -170,6 +174,7 @@ transformation_access_chain.cpp transformation_add_constant_boolean.cpp transformation_add_constant_composite.cpp + transformation_add_constant_null.cpp transformation_add_constant_scalar.cpp transformation_add_dead_block.cpp transformation_add_dead_break.cpp @@ -188,8 +193,11 @@ transformation_add_type_pointer.cpp transformation_add_type_struct.cpp transformation_add_type_vector.cpp + transformation_adjust_branch_weights.cpp transformation_composite_construct.cpp transformation_composite_extract.cpp + transformation_compute_data_synonym_fact_closure.cpp + transformation_context.cpp transformation_copy_object.cpp transformation_equation_instruction.cpp transformation_function_call.cpp
diff --git a/source/fuzz/equivalence_relation.h b/source/fuzz/equivalence_relation.h index 7bb8b66..6d0b63e 100644 --- a/source/fuzz/equivalence_relation.h +++ b/source/fuzz/equivalence_relation.h
@@ -68,17 +68,14 @@ template <typename T, typename PointerHashT, typename PointerEqualsT> class EquivalenceRelation { public: - // Merges the equivalence classes associated with |value1| and |value2|. - // If any of these values was not previously in the equivalence relation, it - // is added to the pool of values known to be in the relation. + // Requires that |value1| and |value2| are already registered in the + // equivalence relation. Merges the equivalence classes associated with + // |value1| and |value2|. void MakeEquivalent(const T& value1, const T& value2) { - // Register each value if necessary. - for (auto value : {value1, value2}) { - if (!Exists(value)) { - // Register the value in the equivalence relation. - Register(value); - } - } + assert(Exists(value1) && + "Precondition: value1 must already be registered."); + assert(Exists(value2) && + "Precondition: value2 must already be registered."); // Look up canonical pointers to each of the values in the value pool. const T* value1_ptr = *value_set_.find(&value1); @@ -105,7 +102,7 @@ // Requires that |value| is not known to the equivalence relation. Registers // it in its own equivalence class and returns a pointer to the equivalence // class representative. - const T* Register(T& value) { + const T* Register(const T& value) { assert(!Exists(value)); // This relies on T having a copy constructor.
diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp index 31d3b94..0b41eeb 100644 --- a/source/fuzz/fact_manager.cpp +++ b/source/fuzz/fact_manager.cpp
@@ -159,9 +159,26 @@ uint32_t type_id) const { auto type = context->get_type_mgr()->GetType(type_id); assert(type != nullptr && "Unknown type id."); - auto constant = context->get_constant_mgr()->GetConstant( - type, GetConstantWords(constant_uniform_fact)); - return context->get_constant_mgr()->FindDeclaredConstant(constant, type_id); + const opt::analysis::Constant* known_constant; + if (type->AsInteger()) { + opt::analysis::IntConstant candidate_constant( + type->AsInteger(), GetConstantWords(constant_uniform_fact)); + known_constant = + context->get_constant_mgr()->FindConstant(&candidate_constant); + } else { + assert( + type->AsFloat() && + "Uniform constant facts are only supported for int and float types."); + opt::analysis::FloatConstant candidate_constant( + type->AsFloat(), GetConstantWords(constant_uniform_fact)); + known_constant = + context->get_constant_mgr()->FindConstant(&candidate_constant); + } + if (!known_constant) { + return 0; + } + return context->get_constant_mgr()->FindDeclaredConstant(known_constant, + type_id); } std::vector<uint32_t> FactManager::ConstantUniformFacts::GetConstantWords( @@ -416,19 +433,23 @@ // See method in FactManager which delegates to this method. std::vector<const protobufs::DataDescriptor*> GetSynonymsForDataDescriptor( - const protobufs::DataDescriptor& data_descriptor, - opt::IRContext* context) const; + const protobufs::DataDescriptor& data_descriptor) const; // See method in FactManager which delegates to this method. - std::vector<uint32_t> GetIdsForWhichSynonymsAreKnown( - opt::IRContext* context) const; + std::vector<uint32_t> GetIdsForWhichSynonymsAreKnown() const; // See method in FactManager which delegates to this method. bool IsSynonymous(const protobufs::DataDescriptor& data_descriptor1, - const protobufs::DataDescriptor& data_descriptor2, - opt::IRContext* context) const; + const protobufs::DataDescriptor& data_descriptor2) const; + + // See method in FactManager which delegates to this method. + void ComputeClosureOfFacts(opt::IRContext* context, + uint32_t maximum_equivalence_class_size); private: + using OperationSet = + std::unordered_set<Operation, OperationHash, OperationEquals>; + // Adds the synonym |dd1| = |dd2| to the set of managed facts, and recurses // into sub-components of the data descriptors, if they are composites, to // record that their components are pairwise-synonymous. @@ -436,14 +457,10 @@ const protobufs::DataDescriptor& dd2, opt::IRContext* context); - // Inspects all known facts and adds corollary facts; e.g. if we know that - // a.x == b.x and a.y == b.y, where a and b have vec2 type, we can record - // that a == b holds. - // - // This method is expensive, and is thus called on demand: rather than - // computing the closure of facts each time a data synonym fact is added, we - // compute the closure only when a data synonym fact is *queried*. - void ComputeClosureOfFacts(opt::IRContext* context) const; + // Records the fact that |dd1| and |dd2| are equivalent, and merges the sets + // of equations that are known about them. + void MakeEquivalent(const protobufs::DataDescriptor& dd1, + const protobufs::DataDescriptor& dd2); // Returns true if and only if |dd1| and |dd2| are valid data descriptors // whose associated data have the same type (modulo integer signedness). @@ -451,11 +468,14 @@ opt::IRContext* context, const protobufs::DataDescriptor& dd1, const protobufs::DataDescriptor& dd2) const; + OperationSet GetEquations(const protobufs::DataDescriptor* lhs) const; + // Requires that |lhs_dd| and every element of |rhs_dds| is present in the - // |synonymous_| equivalence relation and is its own representative. Records - // the fact that the equation "|lhs_dd| |opcode| |rhs_dds|" holds, and adds - // any corollaries, in the form of data synonym or equation facts, that - // follow from this and other known facts. + // |synonymous_| equivalence relation, but is not necessarily its own + // representative. Records the fact that the equation + // "|lhs_dd| |opcode| |rhs_dds_non_canonical|" holds, and adds any + // corollaries, in the form of data synonym or equation facts, that follow + // from this and other known facts. void AddEquationFactRecursive( const protobufs::DataDescriptor& lhs_dd, SpvOp opcode, const std::vector<const protobufs::DataDescriptor*>& rhs_dds, @@ -463,28 +483,17 @@ // The data descriptors that are known to be synonymous with one another are // captured by this equivalence relation. - // - // This member is mutable in order to allow the closure of facts captured by - // the relation to be computed lazily when a question about data synonym - // facts is asked. - mutable EquivalenceRelation<protobufs::DataDescriptor, DataDescriptorHash, - DataDescriptorEquals> + EquivalenceRelation<protobufs::DataDescriptor, DataDescriptorHash, + DataDescriptorEquals> synonymous_; // When a new synonym fact is added, it may be possible to deduce further - // synonym facts by computing a closure of all known facts. However, there is - // no point computing this closure until a question regarding synonym facts is - // actually asked: if several facts are added in succession with no questions - // asked in between, we can avoid computing fact closures multiple times. - // - // This boolean tracks whether a closure computation is required - i.e., - // whether a new fact has been added since the last time such a computation - // was performed. - // - // It is mutable to facilitate having const methods, that provide answers to - // questions about data synonym facts, triggering closure computation on - // demand. - mutable bool closure_computation_required_ = false; + // synonym facts by computing a closure of all known facts. However, this is + // an expensive operation, so it should be performed sparingly and only there + // is some chance of new facts being deduced. This boolean tracks whether a + // closure computation is required - i.e., whether a new fact has been added + // since the last time such a computation was performed. + bool closure_computation_required_ = false; // Represents a set of equations on data descriptors as a map indexed by // left-hand-side, mapping a left-hand-side to a set of operations, each of @@ -493,9 +502,7 @@ // All data descriptors occurring in equations are required to be present in // the |synonymous_| equivalence relation, and to be their own representatives // in that relation. - std::unordered_map< - const protobufs::DataDescriptor*, - std::unordered_set<Operation, OperationHash, OperationEquals>> + std::unordered_map<const protobufs::DataDescriptor*, OperationSet> id_equations_; }; @@ -510,12 +517,10 @@ const protobufs::FactIdEquation& fact, opt::IRContext* context) { protobufs::DataDescriptor lhs_dd = MakeDataDescriptor(fact.lhs_id(), {}); - // Register the LHS in the equivalence relation if needed, and get a pointer - // to its representative. + // Register the LHS in the equivalence relation if needed. if (!synonymous_.Exists(lhs_dd)) { synonymous_.Register(lhs_dd); } - const protobufs::DataDescriptor* lhs_dd_ptr = synonymous_.Find(&lhs_dd); // Get equivalence class representatives for all ids used on the RHS of the // equation. @@ -529,38 +534,45 @@ } rhs_dd_ptrs.push_back(synonymous_.Find(&rhs_dd)); } - // We now have the equation in a form where it refers exclusively to - // equivalence class representatives. Add it to our set of facts and work - // out any follow-on facts. - AddEquationFactRecursive(*lhs_dd_ptr, static_cast<SpvOp>(fact.opcode()), + + // Now add the fact. + AddEquationFactRecursive(lhs_dd, static_cast<SpvOp>(fact.opcode()), rhs_dd_ptrs, context); } +FactManager::DataSynonymAndIdEquationFacts::OperationSet +FactManager::DataSynonymAndIdEquationFacts::GetEquations( + const protobufs::DataDescriptor* lhs) const { + auto existing = id_equations_.find(lhs); + if (existing == id_equations_.end()) { + return OperationSet(); + } + return existing->second; +} + void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive( const protobufs::DataDescriptor& lhs_dd, SpvOp opcode, const std::vector<const protobufs::DataDescriptor*>& rhs_dds, opt::IRContext* context) { - // Precondition: all data descriptors referenced in this equation must be - // equivalence class representatives - i.e. the equation must be in canonical - // form. - assert(synonymous_.Exists(lhs_dd)); - assert(synonymous_.Find(&lhs_dd) == &lhs_dd); + assert(synonymous_.Exists(lhs_dd) && + "The LHS must be known to the equivalence relation."); for (auto rhs_dd : rhs_dds) { - (void)(rhs_dd); // Keep compilers happy in release mode. - assert(synonymous_.Exists(*rhs_dd)); - assert(synonymous_.Find(rhs_dd) == rhs_dd); + // Keep release compilers happy. + (void)(rhs_dd); + assert(synonymous_.Exists(*rhs_dd) && + "The RHS operands must be known to the equivalence relation."); } - if (id_equations_.count(&lhs_dd) == 0) { + auto lhs_dd_representative = synonymous_.Find(&lhs_dd); + + if (id_equations_.count(lhs_dd_representative) == 0) { // We have not seen an equation with this LHS before, so associate the LHS // with an initially empty set. - id_equations_.insert( - {&lhs_dd, - std::unordered_set<Operation, OperationHash, OperationEquals>()}); + id_equations_.insert({lhs_dd_representative, OperationSet()}); } { - auto existing_equations = id_equations_.find(&lhs_dd); + auto existing_equations = id_equations_.find(lhs_dd_representative); assert(existing_equations != id_equations_.end() && "A set of operations should be present, even if empty."); @@ -578,41 +590,29 @@ switch (opcode) { case SpvOpIAdd: { // Equation form: "a = b + c" - { - auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]); - if (existing_first_operand_equations != id_equations_.end()) { - for (auto equation : existing_first_operand_equations->second) { - if (equation.opcode == SpvOpISub) { - // Equation form: "a = (d - e) + c" - if (equation.operands[1] == rhs_dds[1]) { - // Equation form: "a = (d - c) + c" - // We can thus infer "a = d" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], - context); - } - if (equation.operands[0] == rhs_dds[1]) { - // Equation form: "a = (c - e) + c" - // We can thus infer "a = -e" - AddEquationFactRecursive(lhs_dd, SpvOpSNegate, - {equation.operands[1]}, context); - } - } + for (auto equation : GetEquations(rhs_dds[0])) { + if (equation.opcode == SpvOpISub) { + // Equation form: "a = (d - e) + c" + if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) { + // Equation form: "a = (d - c) + c" + // We can thus infer "a = d" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context); + } + if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) { + // Equation form: "a = (c - e) + c" + // We can thus infer "a = -e" + AddEquationFactRecursive(lhs_dd, SpvOpSNegate, + {equation.operands[1]}, context); } } } - { - auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]); - if (existing_second_operand_equations != id_equations_.end()) { - for (auto equation : existing_second_operand_equations->second) { - if (equation.opcode == SpvOpISub) { - // Equation form: "a = b + (d - e)" - if (equation.operands[1] == rhs_dds[0]) { - // Equation form: "a = b + (d - b)" - // We can thus infer "a = d" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], - context); - } - } + for (auto equation : GetEquations(rhs_dds[1])) { + if (equation.opcode == SpvOpISub) { + // Equation form: "a = b + (d - e)" + if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) { + // Equation form: "a = b + (d - b)" + // We can thus infer "a = d" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context); } } } @@ -620,67 +620,54 @@ } case SpvOpISub: { // Equation form: "a = b - c" - { - auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]); - if (existing_first_operand_equations != id_equations_.end()) { - for (auto equation : existing_first_operand_equations->second) { - if (equation.opcode == SpvOpIAdd) { - // Equation form: "a = (d + e) - c" - if (equation.operands[0] == rhs_dds[1]) { - // Equation form: "a = (c + e) - c" - // We can thus infer "a = e" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], - context); - } - if (equation.operands[1] == rhs_dds[1]) { - // Equation form: "a = (d + c) - c" - // We can thus infer "a = d" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], - context); - } - } + for (auto equation : GetEquations(rhs_dds[0])) { + if (equation.opcode == SpvOpIAdd) { + // Equation form: "a = (d + e) - c" + if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) { + // Equation form: "a = (c + e) - c" + // We can thus infer "a = e" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context); + } + if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) { + // Equation form: "a = (d + c) - c" + // We can thus infer "a = d" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context); + } + } - if (equation.opcode == SpvOpISub) { - // Equation form: "a = (d - e) - c" - if (equation.operands[0] == rhs_dds[1]) { - // Equation form: "a = (c - e) - c" - // We can thus infer "a = -e" - AddEquationFactRecursive(lhs_dd, SpvOpSNegate, - {equation.operands[1]}, context); - } - } + if (equation.opcode == SpvOpISub) { + // Equation form: "a = (d - e) - c" + if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) { + // Equation form: "a = (c - e) - c" + // We can thus infer "a = -e" + AddEquationFactRecursive(lhs_dd, SpvOpSNegate, + {equation.operands[1]}, context); } } } - { - auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]); - if (existing_second_operand_equations != id_equations_.end()) { - for (auto equation : existing_second_operand_equations->second) { - if (equation.opcode == SpvOpIAdd) { - // Equation form: "a = b - (d + e)" - if (equation.operands[0] == rhs_dds[0]) { - // Equation form: "a = b - (b + e)" - // We can thus infer "a = -e" - AddEquationFactRecursive(lhs_dd, SpvOpSNegate, - {equation.operands[1]}, context); - } - if (equation.operands[1] == rhs_dds[0]) { - // Equation form: "a = b - (d + b)" - // We can thus infer "a = -d" - AddEquationFactRecursive(lhs_dd, SpvOpSNegate, - {equation.operands[0]}, context); - } - } - if (equation.opcode == SpvOpISub) { - // Equation form: "a = b - (d - e)" - if (equation.operands[0] == rhs_dds[0]) { - // Equation form: "a = b - (b - e)" - // We can thus infer "a = e" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], - context); - } - } + for (auto equation : GetEquations(rhs_dds[1])) { + if (equation.opcode == SpvOpIAdd) { + // Equation form: "a = b - (d + e)" + if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) { + // Equation form: "a = b - (b + e)" + // We can thus infer "a = -e" + AddEquationFactRecursive(lhs_dd, SpvOpSNegate, + {equation.operands[1]}, context); + } + if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) { + // Equation form: "a = b - (d + b)" + // We can thus infer "a = -d" + AddEquationFactRecursive(lhs_dd, SpvOpSNegate, + {equation.operands[0]}, context); + } + } + if (equation.opcode == SpvOpISub) { + // Equation form: "a = b - (d - e)" + if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) { + // Equation form: "a = b - (b - e)" + // We can thus infer "a = e" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context); } } } @@ -689,14 +676,11 @@ case SpvOpLogicalNot: case SpvOpSNegate: { // Equation form: "a = !b" or "a = -b" - auto existing_equations = id_equations_.find(rhs_dds[0]); - if (existing_equations != id_equations_.end()) { - for (auto equation : existing_equations->second) { - if (equation.opcode == opcode) { - // Equation form: "a = !!b" or "a = -(-b)" - // We can thus infer "a = b" - AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context); - } + for (auto equation : GetEquations(rhs_dds[0])) { + if (equation.opcode == opcode) { + // Equation form: "a = !!b" or "a = -(-b)" + // We can thus infer "a = b" + AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context); } } break; @@ -712,12 +696,7 @@ assert(DataDescriptorsAreWellFormedAndComparable(context, dd1, dd2)); // Record that the data descriptors provided in the fact are equivalent. - synonymous_.MakeEquivalent(dd1, dd2); - // As we have updated the equivalence relation, we might be able to deduce - // more facts by performing a closure computation, so we record that such a - // computation is required; it will be performed next time a method answering - // a data synonym fact-related question is invoked. - closure_computation_required_ = true; + MakeEquivalent(dd1, dd2); // We now check whether this is a synonym about composite objects. If it is, // we can recursively add synonym facts about their associated sub-components. @@ -754,7 +733,17 @@ // obj_1[a_1, ..., a_m] == obj_2[b_1, ..., b_n] // then for each composite index i, we add a fact of the form: // obj_1[a_1, ..., a_m, i] == obj_2[b_1, ..., b_n, i] - for (uint32_t i = 0; i < num_composite_elements; i++) { + // + // However, to avoid adding a large number of synonym facts e.g. in the case + // of arrays, we bound the number of composite elements to which this is + // applied. Nevertheless, we always add a synonym fact for the final + // components, as this may be an interesting edge case. + + // The bound on the number of indices of the composite pair to note as being + // synonymous. + const uint32_t kCompositeElementBound = 10; + + for (uint32_t i = 0; i < num_composite_elements;) { std::vector<uint32_t> extended_indices1 = fuzzerutil::RepeatedFieldToVector(dd1.index()); extended_indices1.push_back(i); @@ -765,11 +754,21 @@ MakeDataDescriptor(dd1.object(), std::move(extended_indices1)), MakeDataDescriptor(dd2.object(), std::move(extended_indices2)), context); + + if (i < kCompositeElementBound - 1 || i == num_composite_elements - 1) { + // We have not reached the bound yet, or have already skipped ahead to the + // last element, so increment the loop counter as standard. + i++; + } else { + // We have reached the bound, so skip ahead to the last element. + assert(i == kCompositeElementBound - 1); + i = num_composite_elements - 1; + } } } void FactManager::DataSynonymAndIdEquationFacts::ComputeClosureOfFacts( - opt::IRContext* context) const { + opt::IRContext* context, uint32_t maximum_equivalence_class_size) { // Suppose that obj_1[a_1, ..., a_m] and obj_2[b_1, ..., b_n] are distinct // data descriptors that describe objects of the same composite type, and that // the composite type is comprised of k components. @@ -855,6 +854,13 @@ synonymous_.GetEquivalenceClassRepresentatives()) { auto equivalence_class = synonymous_.GetEquivalenceClass(*representative); + if (equivalence_class.size() > maximum_equivalence_class_size) { + // This equivalence class is larger than the maximum size we are willing + // to consider, so we skip it. This potentially leads to missed fact + // deductions, but avoids excessive runtime for closure computation. + continue; + } + // Consider every data descriptor in the equivalence class. for (auto dd1_it = equivalence_class.begin(); dd1_it != equivalence_class.end(); ++dd1_it) { @@ -1029,10 +1035,7 @@ // synonymous. assert(DataDescriptorsAreWellFormedAndComparable( context, dd1_prefix, dd2_prefix)); - synonymous_.MakeEquivalent(dd1_prefix, dd2_prefix); - // As we have added a new synonym fact, we might benefit from doing - // another pass over the equivalence relation. - closure_computation_required_ = true; + MakeEquivalent(dd1_prefix, dd2_prefix); // Now that we know this pair of data descriptors are synonymous, // there is no point recording how close they are to being // synonymous. @@ -1044,6 +1047,82 @@ } } +void FactManager::DataSynonymAndIdEquationFacts::MakeEquivalent( + const protobufs::DataDescriptor& dd1, + const protobufs::DataDescriptor& dd2) { + // Register the data descriptors if they are not already known to the + // equivalence relation. + for (const auto& dd : {dd1, dd2}) { + if (!synonymous_.Exists(dd)) { + synonymous_.Register(dd); + } + } + + if (synonymous_.IsEquivalent(dd1, dd2)) { + // The data descriptors are already known to be equivalent, so there is + // nothing to do. + return; + } + + // We must make the data descriptors equivalent, and also make sure any + // equation facts known about their representatives are merged. + + // Record the original equivalence class representatives of the data + // descriptors. + auto dd1_original_representative = synonymous_.Find(&dd1); + auto dd2_original_representative = synonymous_.Find(&dd2); + + // Make the data descriptors equivalent. + synonymous_.MakeEquivalent(dd1, dd2); + // As we have updated the equivalence relation, we might be able to deduce + // more facts by performing a closure computation, so we record that such a + // computation is required. + closure_computation_required_ = true; + + // At this point, exactly one of |dd1_original_representative| and + // |dd2_original_representative| will be the representative of the combined + // equivalence class. We work out which one of them is still the class + // representative and which one is no longer the class representative. + + auto still_representative = synonymous_.Find(dd1_original_representative) == + dd1_original_representative + ? dd1_original_representative + : dd2_original_representative; + auto no_longer_representative = + still_representative == dd1_original_representative + ? dd2_original_representative + : dd1_original_representative; + + assert(no_longer_representative != still_representative && + "The current and former representatives cannot be the same."); + + // We now need to add all equations about |no_longer_representative| to the + // set of equations known about |still_representative|. + + // Get the equations associated with |no_longer_representative|. + auto no_longer_representative_id_equations = + id_equations_.find(no_longer_representative); + if (no_longer_representative_id_equations != id_equations_.end()) { + // There are some equations to transfer. There might not yet be any + // equations about |still_representative|; create an empty set of equations + // if this is the case. + if (!id_equations_.count(still_representative)) { + id_equations_.insert({still_representative, OperationSet()}); + } + auto still_representative_id_equations = + id_equations_.find(still_representative); + assert(still_representative_id_equations != id_equations_.end() && + "At this point there must be a set of equations."); + // Add all the equations known about |no_longer_representative| to the set + // of equations known about |still_representative|. + still_representative_id_equations->second.insert( + no_longer_representative_id_equations->second.begin(), + no_longer_representative_id_equations->second.end()); + } + // Delete the no longer-relevant equations about |no_longer_representative|. + id_equations_.erase(no_longer_representative); +} + bool FactManager::DataSynonymAndIdEquationFacts:: DataDescriptorsAreWellFormedAndComparable( opt::IRContext* context, const protobufs::DataDescriptor& dd1, @@ -1094,9 +1173,7 @@ std::vector<const protobufs::DataDescriptor*> FactManager::DataSynonymAndIdEquationFacts::GetSynonymsForDataDescriptor( - const protobufs::DataDescriptor& data_descriptor, - opt::IRContext* context) const { - ComputeClosureOfFacts(context); + const protobufs::DataDescriptor& data_descriptor) const { if (synonymous_.Exists(data_descriptor)) { return synonymous_.GetEquivalenceClass(data_descriptor); } @@ -1104,9 +1181,8 @@ } std::vector<uint32_t> -FactManager::DataSynonymAndIdEquationFacts ::GetIdsForWhichSynonymsAreKnown( - opt::IRContext* context) const { - ComputeClosureOfFacts(context); +FactManager::DataSynonymAndIdEquationFacts::GetIdsForWhichSynonymsAreKnown() + const { std::vector<uint32_t> result; for (auto& data_descriptor : synonymous_.GetAllKnownValues()) { if (data_descriptor->index().empty()) { @@ -1118,10 +1194,7 @@ bool FactManager::DataSynonymAndIdEquationFacts::IsSynonymous( const protobufs::DataDescriptor& data_descriptor1, - const protobufs::DataDescriptor& data_descriptor2, - opt::IRContext* context) const { - const_cast<FactManager::DataSynonymAndIdEquationFacts*>(this) - ->ComputeClosureOfFacts(context); + const protobufs::DataDescriptor& data_descriptor2) const { return synonymous_.Exists(data_descriptor1) && synonymous_.Exists(data_descriptor2) && synonymous_.IsEquivalent(data_descriptor1, data_descriptor2); @@ -1303,31 +1376,27 @@ return uniform_constant_facts_->GetConstantUniformFactsAndTypes(); } -std::vector<uint32_t> FactManager::GetIdsForWhichSynonymsAreKnown( - opt::IRContext* context) const { - return data_synonym_and_id_equation_facts_->GetIdsForWhichSynonymsAreKnown( - context); +std::vector<uint32_t> FactManager::GetIdsForWhichSynonymsAreKnown() const { + return data_synonym_and_id_equation_facts_->GetIdsForWhichSynonymsAreKnown(); } std::vector<const protobufs::DataDescriptor*> FactManager::GetSynonymsForDataDescriptor( - const protobufs::DataDescriptor& data_descriptor, - opt::IRContext* context) const { + const protobufs::DataDescriptor& data_descriptor) const { return data_synonym_and_id_equation_facts_->GetSynonymsForDataDescriptor( - data_descriptor, context); + data_descriptor); } std::vector<const protobufs::DataDescriptor*> FactManager::GetSynonymsForId( - uint32_t id, opt::IRContext* context) const { - return GetSynonymsForDataDescriptor(MakeDataDescriptor(id, {}), context); + uint32_t id) const { + return GetSynonymsForDataDescriptor(MakeDataDescriptor(id, {})); } bool FactManager::IsSynonymous( const protobufs::DataDescriptor& data_descriptor1, - const protobufs::DataDescriptor& data_descriptor2, - opt::IRContext* context) const { - return data_synonym_and_id_equation_facts_->IsSynonymous( - data_descriptor1, data_descriptor2, context); + const protobufs::DataDescriptor& data_descriptor2) const { + return data_synonym_and_id_equation_facts_->IsSynonymous(data_descriptor1, + data_descriptor2); } bool FactManager::BlockIsDead(uint32_t block_id) const { @@ -1372,5 +1441,11 @@ data_synonym_and_id_equation_facts_->AddFact(fact, context); } +void FactManager::ComputeClosureOfFacts( + opt::IRContext* ir_context, uint32_t maximum_equivalence_class_size) { + data_synonym_and_id_equation_facts_->ComputeClosureOfFacts( + ir_context, maximum_equivalence_class_size); +} + } // namespace fuzz } // namespace spvtools
diff --git a/source/fuzz/fact_manager.h b/source/fuzz/fact_manager.h index f80d677..f520e42 100644 --- a/source/fuzz/fact_manager.h +++ b/source/fuzz/fact_manager.h
@@ -76,6 +76,21 @@ const std::vector<uint32_t>& rhs_id, opt::IRContext* context); + // Inspects all known facts and adds corollary facts; e.g. if we know that + // a.x == b.x and a.y == b.y, where a and b have vec2 type, we can record + // that a == b holds. + // + // This method is expensive, and should only be called (by applying a + // transformation) at the start of a fuzzer pass that depends on data + // synonym facts, rather than calling it every time a new data synonym fact + // is added. + // + // The parameter |maximum_equivalence_class_size| specifies the size beyond + // which equivalence classes should not be mined for new facts, to avoid + // excessively-long closure computations. + void ComputeClosureOfFacts(opt::IRContext* ir_context, + uint32_t maximum_equivalence_class_size); + // The fact manager is responsible for managing a few distinct categories of // facts. In principle there could be different fact managers for each kind // of fact, but in practice providing one 'go to' place for facts is @@ -125,25 +140,22 @@ // Returns every id for which a fact of the form "this id is synonymous with // this piece of data" is known. - std::vector<uint32_t> GetIdsForWhichSynonymsAreKnown( - opt::IRContext* context) const; + std::vector<uint32_t> GetIdsForWhichSynonymsAreKnown() const; // Returns the equivalence class of all known synonyms of |id|, or an empty // set if no synonyms are known. std::vector<const protobufs::DataDescriptor*> GetSynonymsForId( - uint32_t id, opt::IRContext* context) const; + uint32_t id) const; // Returns the equivalence class of all known synonyms of |data_descriptor|, // or empty if no synonyms are known. std::vector<const protobufs::DataDescriptor*> GetSynonymsForDataDescriptor( - const protobufs::DataDescriptor& data_descriptor, - opt::IRContext* context) const; + const protobufs::DataDescriptor& data_descriptor) const; // Returns true if and ony if |data_descriptor1| and |data_descriptor2| are // known to be synonymous. bool IsSynonymous(const protobufs::DataDescriptor& data_descriptor1, - const protobufs::DataDescriptor& data_descriptor2, - opt::IRContext* context) const; + const protobufs::DataDescriptor& data_descriptor2) const; // End of id synonym facts //==============================
diff --git a/source/fuzz/force_render_red.cpp b/source/fuzz/force_render_red.cpp index 46e23e8..5bf2879 100644 --- a/source/fuzz/force_render_red.cpp +++ b/source/fuzz/force_render_red.cpp
@@ -17,6 +17,7 @@ #include "source/fuzz/fact_manager.h" #include "source/fuzz/instruction_descriptor.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation_context.h" #include "source/fuzz/transformation_replace_constant_with_uniform.h" #include "source/fuzz/uniform_buffer_element_descriptor.h" #include "source/opt/build_module.h" @@ -159,7 +160,8 @@ } // namespace bool ForceRenderRed( - const spv_target_env& target_env, const std::vector<uint32_t>& binary_in, + const spv_target_env& target_env, spv_validator_options validator_options, + const std::vector<uint32_t>& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, std::vector<uint32_t>* binary_out) { auto message_consumer = spvtools::utils::CLIMessageConsumer; @@ -171,7 +173,7 @@ } // Initial binary should be valid. - if (!tools.Validate(&binary_in[0], binary_in.size())) { + if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "Initial binary is invalid; stopping."); return false; @@ -187,6 +189,8 @@ for (auto& fact : initial_facts.fact()) { fact_manager.AddFact(fact, ir_context.get()); } + TransformationContext transformation_context(&fact_manager, + validator_options); auto entry_point_function = FindFragmentShaderEntryPoint(ir_context.get(), message_consumer); @@ -355,8 +359,9 @@ for (auto& replacement : {first_greater_then_operand_replacement.get(), second_greater_then_operand_replacement.get()}) { if (replacement) { - assert(replacement->IsApplicable(ir_context.get(), fact_manager)); - replacement->Apply(ir_context.get(), &fact_manager); + assert(replacement->IsApplicable(ir_context.get(), + transformation_context)); + replacement->Apply(ir_context.get(), &transformation_context); } } }
diff --git a/source/fuzz/force_render_red.h b/source/fuzz/force_render_red.h index 2484d27..b51c72b 100644 --- a/source/fuzz/force_render_red.h +++ b/source/fuzz/force_render_red.h
@@ -38,7 +38,8 @@ // instead become: 'u > v', where 'u' and 'v' are pieces of uniform data for // which it is known that 'u < v' holds. bool ForceRenderRed( - const spv_target_env& target_env, const std::vector<uint32_t>& binary_in, + const spv_target_env& target_env, spv_validator_options validator_options, + const std::vector<uint32_t>& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, std::vector<uint32_t>* binary_out);
diff --git a/source/fuzz/fuzzer.cpp b/source/fuzz/fuzzer.cpp index 119bd3c..3343abc 100644 --- a/source/fuzz/fuzzer.cpp +++ b/source/fuzz/fuzzer.cpp
@@ -33,7 +33,7 @@ #include "source/fuzz/fuzzer_pass_add_local_variables.h" #include "source/fuzz/fuzzer_pass_add_no_contraction_decorations.h" #include "source/fuzz/fuzzer_pass_add_stores.h" -#include "source/fuzz/fuzzer_pass_add_useful_constructs.h" +#include "source/fuzz/fuzzer_pass_adjust_branch_weights.h" #include "source/fuzz/fuzzer_pass_adjust_function_controls.h" #include "source/fuzz/fuzzer_pass_adjust_loop_controls.h" #include "source/fuzz/fuzzer_pass_adjust_selection_controls.h" @@ -51,6 +51,7 @@ #include "source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/pseudo_random_generator.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/build_module.h" #include "source/spirv_fuzzer_options.h" #include "source/util/make_unique.h" @@ -66,19 +67,19 @@ const uint32_t kChanceOfApplyingAnotherPass = 85; // A convenience method to add a fuzzer pass to |passes| with probability 0.5. -// All fuzzer passes take |ir_context|, |fact_manager|, |fuzzer_context| and -// |transformation_sequence_out| as parameters. Extra arguments can be provided -// via |extra_args|. +// All fuzzer passes take |ir_context|, |transformation_context|, +// |fuzzer_context| and |transformation_sequence_out| as parameters. Extra +// arguments can be provided via |extra_args|. template <typename T, typename... Args> void MaybeAddPass( std::vector<std::unique_ptr<FuzzerPass>>* passes, - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformation_sequence_out, Args&&... extra_args) { if (fuzzer_context->ChooseEven()) { - passes->push_back(MakeUnique<T>(ir_context, fact_manager, fuzzer_context, - transformation_sequence_out, + passes->push_back(MakeUnique<T>(ir_context, transformation_context, + fuzzer_context, transformation_sequence_out, std::forward<Args>(extra_args)...)); } } @@ -86,26 +87,31 @@ } // namespace struct Fuzzer::Impl { - explicit Impl(spv_target_env env, uint32_t random_seed, - bool validate_after_each_pass) + Impl(spv_target_env env, uint32_t random_seed, bool validate_after_each_pass, + spv_validator_options options) : target_env(env), seed(random_seed), - validate_after_each_fuzzer_pass(validate_after_each_pass) {} + validate_after_each_fuzzer_pass(validate_after_each_pass), + validator_options(options) {} bool ApplyPassAndCheckValidity(FuzzerPass* pass, const opt::IRContext& ir_context, const spvtools::SpirvTools& tools) const; const spv_target_env target_env; // Target environment. + MessageConsumer consumer; // Message consumer. const uint32_t seed; // Seed for random number generator. bool validate_after_each_fuzzer_pass; // Determines whether the validator - // should be invoked after every fuzzer pass. - MessageConsumer consumer; // Message consumer. + // should be invoked after every fuzzer + // pass. + spv_validator_options validator_options; // Options to control validation. }; Fuzzer::Fuzzer(spv_target_env env, uint32_t seed, - bool validate_after_each_fuzzer_pass) - : impl_(MakeUnique<Impl>(env, seed, validate_after_each_fuzzer_pass)) {} + bool validate_after_each_fuzzer_pass, + spv_validator_options validator_options) + : impl_(MakeUnique<Impl>(env, seed, validate_after_each_fuzzer_pass, + validator_options)) {} Fuzzer::~Fuzzer() = default; @@ -120,7 +126,8 @@ if (validate_after_each_fuzzer_pass) { std::vector<uint32_t> binary_to_validate; ir_context.module()->ToBinary(&binary_to_validate, false); - if (!tools.Validate(&binary_to_validate[0], binary_to_validate.size())) { + if (!tools.Validate(&binary_to_validate[0], binary_to_validate.size(), + validator_options)) { consumer(SPV_MSG_INFO, nullptr, {}, "Binary became invalid during fuzzing (set a breakpoint to " "inspect); stopping."); @@ -149,7 +156,8 @@ } // Initial binary should be valid. - if (!tools.Validate(&binary_in[0], binary_in.size())) { + if (!tools.Validate(&binary_in[0], binary_in.size(), + impl_->validator_options)) { impl_->consumer(SPV_MSG_ERROR, nullptr, {}, "Initial binary is invalid; stopping."); return Fuzzer::FuzzerResultStatus::kInitialBinaryInvalid; @@ -175,83 +183,75 @@ FactManager fact_manager; fact_manager.AddFacts(impl_->consumer, initial_facts, ir_context.get()); - - // Add some essential ingredients to the module if they are not already - // present, such as boolean constants. - FuzzerPassAddUsefulConstructs add_useful_constructs( - ir_context.get(), &fact_manager, &fuzzer_context, - transformation_sequence_out); - if (!impl_->ApplyPassAndCheckValidity(&add_useful_constructs, *ir_context, - tools)) { - return Fuzzer::FuzzerResultStatus::kFuzzerPassLedToInvalidModule; - } + TransformationContext transformation_context(&fact_manager, + impl_->validator_options); // Apply some semantics-preserving passes. std::vector<std::unique_ptr<FuzzerPass>> passes; while (passes.empty()) { - MaybeAddPass<FuzzerPassAddAccessChains>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddCompositeTypes>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddDeadBlocks>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddDeadBreaks>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddDeadContinues>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); + MaybeAddPass<FuzzerPassAddAccessChains>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddCompositeTypes>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddDeadBlocks>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddDeadBreaks>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddDeadContinues>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); MaybeAddPass<FuzzerPassAddEquationInstructions>( - &passes, ir_context.get(), &fact_manager, &fuzzer_context, + &passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); - MaybeAddPass<FuzzerPassAddFunctionCalls>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddGlobalVariables>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddLoads>(&passes, ir_context.get(), &fact_manager, - &fuzzer_context, + MaybeAddPass<FuzzerPassAddFunctionCalls>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddGlobalVariables>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAddLoads>(&passes, ir_context.get(), + &transformation_context, &fuzzer_context, transformation_sequence_out); - MaybeAddPass<FuzzerPassAddLocalVariables>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassAddStores>(&passes, ir_context.get(), &fact_manager, - &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassApplyIdSynonyms>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassConstructComposites>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassCopyObjects>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassDonateModules>( - &passes, ir_context.get(), &fact_manager, &fuzzer_context, - transformation_sequence_out, donor_suppliers); - MaybeAddPass<FuzzerPassMergeBlocks>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassObfuscateConstants>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassOutlineFunctions>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassPermuteBlocks>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); - MaybeAddPass<FuzzerPassPermuteFunctionParameters>( - &passes, ir_context.get(), &fact_manager, &fuzzer_context, + MaybeAddPass<FuzzerPassAddLocalVariables>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); - MaybeAddPass<FuzzerPassSplitBlocks>(&passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); + MaybeAddPass<FuzzerPassAddStores>(&passes, ir_context.get(), + &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassApplyIdSynonyms>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassConstructComposites>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassCopyObjects>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassDonateModules>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out, donor_suppliers); + MaybeAddPass<FuzzerPassMergeBlocks>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassObfuscateConstants>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassOutlineFunctions>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassPermuteBlocks>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassPermuteFunctionParameters>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassSplitBlocks>( + &passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); } bool is_first = true; @@ -271,26 +271,29 @@ // Now apply some passes that it does not make sense to apply repeatedly, // as they do not unlock other passes. std::vector<std::unique_ptr<FuzzerPass>> final_passes; - MaybeAddPass<FuzzerPassAdjustFunctionControls>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + MaybeAddPass<FuzzerPassAdjustBranchWeights>( + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); - MaybeAddPass<FuzzerPassAdjustLoopControls>(&final_passes, ir_context.get(), - &fact_manager, &fuzzer_context, - transformation_sequence_out); + MaybeAddPass<FuzzerPassAdjustFunctionControls>( + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); + MaybeAddPass<FuzzerPassAdjustLoopControls>( + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, + transformation_sequence_out); MaybeAddPass<FuzzerPassAdjustMemoryOperandsMasks>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); MaybeAddPass<FuzzerPassAdjustSelectionControls>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); MaybeAddPass<FuzzerPassAddNoContractionDecorations>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); MaybeAddPass<FuzzerPassSwapCommutableOperands>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); MaybeAddPass<FuzzerPassToggleAccessChainInstruction>( - &final_passes, ir_context.get(), &fact_manager, &fuzzer_context, + &final_passes, ir_context.get(), &transformation_context, &fuzzer_context, transformation_sequence_out); for (auto& pass : final_passes) { if (!impl_->ApplyPassAndCheckValidity(pass.get(), *ir_context, tools)) {
diff --git a/source/fuzz/fuzzer.h b/source/fuzz/fuzzer.h index 3ac73a1..6c3ef71 100644 --- a/source/fuzz/fuzzer.h +++ b/source/fuzz/fuzzer.h
@@ -41,8 +41,9 @@ // seed for pseudo-random number generation. // |validate_after_each_fuzzer_pass| controls whether the validator will be // invoked after every fuzzer pass is applied. - explicit Fuzzer(spv_target_env env, uint32_t seed, - bool validate_after_each_fuzzer_pass); + Fuzzer(spv_target_env env, uint32_t seed, + bool validate_after_each_fuzzer_pass, + spv_validator_options validator_options); // Disables copy/move constructor/assignment operations. Fuzzer(const Fuzzer&) = delete;
diff --git a/source/fuzz/fuzzer_context.cpp b/source/fuzz/fuzzer_context.cpp index 2f9fc5a..1779709 100644 --- a/source/fuzz/fuzzer_context.cpp +++ b/source/fuzz/fuzzer_context.cpp
@@ -40,6 +40,7 @@ 5, 70}; const std::pair<uint32_t, uint32_t> kChanceOfAddingStore = {5, 50}; const std::pair<uint32_t, uint32_t> kChanceOfAddingVectorType = {20, 70}; +const std::pair<uint32_t, uint32_t> kChanceOfAdjustingBranchWeights = {20, 90}; const std::pair<uint32_t, uint32_t> kChanceOfAdjustingFunctionControl = {20, 70}; const std::pair<uint32_t, uint32_t> kChanceOfAdjustingLoopControl = {20, 90}; @@ -68,6 +69,7 @@ // Default limits for various quantities that are chosen during fuzzing. // Keep them in alphabetical order. +const uint32_t kDefaultMaxEquivalenceClassSizeForDataSynonymFactClosure = 1000; const uint32_t kDefaultMaxLoopControlPartialCount = 100; const uint32_t kDefaultMaxLoopControlPeelCount = 100; const uint32_t kDefaultMaxLoopLimit = 20; @@ -89,6 +91,8 @@ uint32_t min_fresh_id) : random_generator_(random_generator), next_fresh_id_(min_fresh_id), + max_equivalence_class_size_for_data_synonym_fact_closure_( + kDefaultMaxEquivalenceClassSizeForDataSynonymFactClosure), max_loop_control_partial_count_(kDefaultMaxLoopControlPartialCount), max_loop_control_peel_count_(kDefaultMaxLoopControlPeelCount), max_loop_limit_(kDefaultMaxLoopLimit), @@ -121,6 +125,8 @@ chance_of_adding_store_ = ChooseBetweenMinAndMax(kChanceOfAddingStore); chance_of_adding_vector_type_ = ChooseBetweenMinAndMax(kChanceOfAddingVectorType); + chance_of_adjusting_branch_weights_ = + ChooseBetweenMinAndMax(kChanceOfAdjustingBranchWeights); chance_of_adjusting_function_control_ = ChooseBetweenMinAndMax(kChanceOfAdjustingFunctionControl); chance_of_adjusting_loop_control_ =
diff --git a/source/fuzz/fuzzer_context.h b/source/fuzz/fuzzer_context.h index 1529705..dd19d9a 100644 --- a/source/fuzz/fuzzer_context.h +++ b/source/fuzz/fuzzer_context.h
@@ -136,6 +136,9 @@ uint32_t GetChanceOfAddingVectorType() { return chance_of_adding_vector_type_; } + uint32_t GetChanceOfAdjustingBranchWeights() { + return chance_of_adjusting_branch_weights_; + } uint32_t GetChanceOfAdjustingFunctionControl() { return chance_of_adjusting_function_control_; } @@ -183,25 +186,40 @@ uint32_t GetChanceOfTogglingAccessChainInstruction() { return chance_of_toggling_access_chain_instruction_; } - uint32_t GetRandomLoopControlPeelCount() { - return random_generator_->RandomUint32(max_loop_control_peel_count_); + + // Other functions to control transformations. Keep them in alphabetical + // order. + uint32_t GetMaximumEquivalenceClassSizeForDataSynonymFactClosure() { + return max_equivalence_class_size_for_data_synonym_fact_closure_; + } + uint32_t GetRandomIndexForAccessChain(uint32_t composite_size_bound) { + return random_generator_->RandomUint32(composite_size_bound); } uint32_t GetRandomLoopControlPartialCount() { return random_generator_->RandomUint32(max_loop_control_partial_count_); } + uint32_t GetRandomLoopControlPeelCount() { + return random_generator_->RandomUint32(max_loop_control_peel_count_); + } uint32_t GetRandomLoopLimit() { return random_generator_->RandomUint32(max_loop_limit_); } + std::pair<uint32_t, uint32_t> GetRandomBranchWeights() { + std::pair<uint32_t, uint32_t> branch_weights = {0, 0}; + + while (branch_weights.first == 0 && branch_weights.second == 0) { + // Using INT32_MAX to do not overflow UINT32_MAX when the branch weights + // are added together. + branch_weights.first = random_generator_->RandomUint32(INT32_MAX); + branch_weights.second = random_generator_->RandomUint32(INT32_MAX); + } + + return branch_weights; + } uint32_t GetRandomSizeForNewArray() { // Ensure that the array size is non-zero. return random_generator_->RandomUint32(max_new_array_size_limit_ - 1) + 1; } - - // Other functions to control transformations. Keep them in alphabetical - // order. - uint32_t GetRandomIndexForAccessChain(uint32_t composite_size_bound) { - return random_generator_->RandomUint32(composite_size_bound); - } bool GoDeeperInConstantObfuscation(uint32_t depth) { return go_deeper_in_constant_obfuscation_(depth, random_generator_); } @@ -228,6 +246,7 @@ uint32_t chance_of_adding_no_contraction_decoration_; uint32_t chance_of_adding_store_; uint32_t chance_of_adding_vector_type_; + uint32_t chance_of_adjusting_branch_weights_; uint32_t chance_of_adjusting_function_control_; uint32_t chance_of_adjusting_loop_control_; uint32_t chance_of_adjusting_memory_operands_mask_; @@ -251,6 +270,7 @@ // Limits associated with various quantities for which random values are // chosen during fuzzing. // Keep them in alphabetical order. + uint32_t max_equivalence_class_size_for_data_synonym_fact_closure_; uint32_t max_loop_control_partial_count_; uint32_t max_loop_control_peel_count_; uint32_t max_loop_limit_;
diff --git a/source/fuzz/fuzzer_pass.cpp b/source/fuzz/fuzzer_pass.cpp index a76f10d..cd94e4e 100644 --- a/source/fuzz/fuzzer_pass.cpp +++ b/source/fuzz/fuzzer_pass.cpp
@@ -14,6 +14,8 @@ #include "source/fuzz/fuzzer_pass.h" +#include <set> + #include "source/fuzz/fuzzer_util.h" #include "source/fuzz/instruction_descriptor.h" #include "source/fuzz/transformation_add_constant_boolean.h" @@ -31,11 +33,12 @@ namespace spvtools { namespace fuzz { -FuzzerPass::FuzzerPass(opt::IRContext* ir_context, FactManager* fact_manager, +FuzzerPass::FuzzerPass(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) : ir_context_(ir_context), - fact_manager_(fact_manager), + transformation_context_(transformation_context), fuzzer_context_(fuzzer_context), transformations_(transformations) {} @@ -316,6 +319,35 @@ return result; } +uint32_t FuzzerPass::FindOrCreateConstant(const std::vector<uint32_t>& words, + uint32_t type_id) { + assert(type_id && "Constant's type id can't be 0."); + + const auto* type = GetIRContext()->get_type_mgr()->GetType(type_id); + assert(type && "Type does not exist."); + + if (type->AsBool()) { + assert(words.size() == 1); + return FindOrCreateBoolConstant(words[0]); + } else if (const auto* integer = type->AsInteger()) { + assert(integer->width() == 32 && words.size() == 1 && + "Integer must have 32-bit width"); + return FindOrCreate32BitIntegerConstant(words[0], integer->IsSigned()); + } else if (const auto* floating = type->AsFloat()) { + // Assertions are not evaluated in release builds so |floating| + // variable will be unused. + (void)floating; + assert(floating->width() == 32 && words.size() == 1 && + "Floating point number must have 32-bit width"); + return FindOrCreate32BitFloatConstant(words[0]); + } + + // This assertion will fail in debug build but not in release build + // so we return 0 to make compiler happy. + assert(false && "Constant type is not supported"); + return 0; +} + uint32_t FuzzerPass::FindOrCreateGlobalUndef(uint32_t type_id) { for (auto& inst : GetIRContext()->types_values()) { if (inst.opcode() == SpvOpUndef && inst.type_id() == type_id) { @@ -328,43 +360,72 @@ } std::pair<std::vector<uint32_t>, std::map<uint32_t, std::vector<uint32_t>>> -FuzzerPass::GetAvailableBaseTypesAndPointers( +FuzzerPass::GetAvailableBasicTypesAndPointers( SpvStorageClass storage_class) const { - // Records all of the base types available in the module. - std::vector<uint32_t> base_types; + // Records all of the basic types available in the module. + std::set<uint32_t> basic_types; - // For each base type, records all the associated pointer types that target - // that base type and that have |storage_class| as their storage class. - std::map<uint32_t, std::vector<uint32_t>> base_type_to_pointers; + // For each basic type, records all the associated pointer types that target + // the basic type and that have |storage_class| as their storage class. + std::map<uint32_t, std::vector<uint32_t>> basic_type_to_pointers; for (auto& inst : GetIRContext()->types_values()) { + // For each basic type that we come across, record type, and the fact that + // we cannot yet have seen any pointers that use the basic type as its + // pointee type. + // + // For pointer types with basic pointee types, associate the pointer type + // with the basic type. switch (inst.opcode()) { - case SpvOpTypeArray: case SpvOpTypeBool: case SpvOpTypeFloat: case SpvOpTypeInt: case SpvOpTypeMatrix: - case SpvOpTypeStruct: case SpvOpTypeVector: - // These types are suitable as pointer base types. Record the type, - // and the fact that we cannot yet have seen any pointers that use this - // as its base type. - base_types.push_back(inst.result_id()); - base_type_to_pointers.insert({inst.result_id(), {}}); + // These are all basic types. + basic_types.insert(inst.result_id()); + basic_type_to_pointers.insert({inst.result_id(), {}}); break; - case SpvOpTypePointer: - if (inst.GetSingleWordInOperand(0) == storage_class) { - // The pointer has the desired storage class, so we are interested in - // it. Associate it with its base type. - base_type_to_pointers.at(inst.GetSingleWordInOperand(1)) - .push_back(inst.result_id()); + case SpvOpTypeArray: + // An array type is basic if its base type is basic. + if (basic_types.count(inst.GetSingleWordInOperand(0))) { + basic_types.insert(inst.result_id()); + basic_type_to_pointers.insert({inst.result_id(), {}}); } break; + case SpvOpTypeStruct: { + // A struct type is basic if all of its members are basic. + bool all_members_are_basic_types = true; + for (uint32_t i = 0; i < inst.NumInOperands(); i++) { + if (!basic_types.count(inst.GetSingleWordInOperand(i))) { + all_members_are_basic_types = false; + break; + } + } + if (all_members_are_basic_types) { + basic_types.insert(inst.result_id()); + basic_type_to_pointers.insert({inst.result_id(), {}}); + } + break; + } + case SpvOpTypePointer: { + // We are interested in the pointer if its pointee type is basic and it + // has the right storage class. + auto pointee_type = inst.GetSingleWordInOperand(1); + if (inst.GetSingleWordInOperand(0) == storage_class && + basic_types.count(pointee_type)) { + // The pointer has the desired storage class, and its pointee type is + // a basic type, so we are interested in it. Associate it with its + // basic type. + basic_type_to_pointers.at(pointee_type).push_back(inst.result_id()); + } + break; + } default: break; } } - return {base_types, base_type_to_pointers}; + return {{basic_types.begin(), basic_types.end()}, basic_type_to_pointers}; } uint32_t FuzzerPass::FindOrCreateZeroConstant(
diff --git a/source/fuzz/fuzzer_pass.h b/source/fuzz/fuzzer_pass.h index 46ee408..800b888 100644 --- a/source/fuzz/fuzzer_pass.h +++ b/source/fuzz/fuzzer_pass.h
@@ -18,9 +18,10 @@ #include <functional> #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/fuzzer_context.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -29,22 +30,25 @@ // Interface for applying a pass of transformations to a module. class FuzzerPass { public: - FuzzerPass(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPass(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations); virtual ~FuzzerPass(); // Applies the pass to the module |ir_context_|, assuming and updating - // facts from |fact_manager_|, and using |fuzzer_context_| to guide the - // process. Appends to |transformations_| all transformations that were - // applied during the pass. + // information from |transformation_context_|, and using |fuzzer_context_| to + // guide the process. Appends to |transformations_| all transformations that + // were applied during the pass. virtual void Apply() = 0; protected: opt::IRContext* GetIRContext() const { return ir_context_; } - FactManager* GetFactManager() const { return fact_manager_; } + TransformationContext* GetTransformationContext() const { + return transformation_context_; + } FuzzerContext* GetFuzzerContext() const { return fuzzer_context_; } @@ -91,11 +95,11 @@ // A generic helper for applying a transformation that should be applicable // by construction, and adding it to the sequence of applied transformations. - template <typename TransformationType> - void ApplyTransformation(const TransformationType& transformation) { - assert(transformation.IsApplicable(GetIRContext(), *GetFactManager()) && + void ApplyTransformation(const Transformation& transformation) { + assert(transformation.IsApplicable(GetIRContext(), + *GetTransformationContext()) && "Transformation should be applicable by construction."); - transformation.Apply(GetIRContext(), GetFactManager()); + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } @@ -160,23 +164,34 @@ // type do not exist, transformations are applied to add them. uint32_t FindOrCreateBoolConstant(bool value); + // Returns the id of an OpConstant instruction of type with |type_id| + // that consists of |words|. If that instruction doesn't exist, + // transformations are applied to add it. |type_id| must be a valid + // result id of either scalar or boolean OpType* instruction that exists + // in the module. + uint32_t FindOrCreateConstant(const std::vector<uint32_t>& words, + uint32_t type_id); + // Returns the result id of an instruction of the form: // %id = OpUndef %|type_id| // If no such instruction exists, a transformation is applied to add it. uint32_t FindOrCreateGlobalUndef(uint32_t type_id); - // Yields a pair, (base_type_ids, base_type_ids_to_pointers), such that: - // - base_type_ids captures every scalar or composite type declared in the - // module (i.e., all int, bool, float, vector, matrix, struct and array - // types - // - base_type_ids_to_pointers maps every such base type to the sequence + // Define a *basic type* to be an integer, boolean or floating-point type, + // or a matrix, vector, struct or fixed-size array built from basic types. In + // particular, a basic type cannot contain an opaque type (such as an image), + // or a runtime-sized array. + // + // Yields a pair, (basic_type_ids, basic_type_ids_to_pointers), such that: + // - basic_type_ids captures every basic type declared in the module. + // - basic_type_ids_to_pointers maps every such basic type to the sequence // of all pointer types that have storage class |storage_class| and the - // given base type as their pointee type. The sequence may be empty for - // some base types if no pointers to those types are defined for the given + // given basic type as their pointee type. The sequence may be empty for + // some basic types if no pointers to those types are defined for the given // storage class, and the sequence will have multiple elements if there are - // repeated pointer declarations for the same base type and storage class. + // repeated pointer declarations for the same basic type and storage class. std::pair<std::vector<uint32_t>, std::map<uint32_t, std::vector<uint32_t>>> - GetAvailableBaseTypesAndPointers(SpvStorageClass storage_class) const; + GetAvailableBasicTypesAndPointers(SpvStorageClass storage_class) const; // Given a type id, |scalar_or_composite_type_id|, which must correspond to // some scalar or composite type, returns the result id of an instruction @@ -230,7 +245,7 @@ const std::vector<uint32_t>& constant_ids); opt::IRContext* ir_context_; - FactManager* fact_manager_; + TransformationContext* transformation_context_; FuzzerContext* fuzzer_context_; protobufs::TransformationSequence* transformations_; };
diff --git a/source/fuzz/fuzzer_pass_add_access_chains.cpp b/source/fuzz/fuzzer_pass_add_access_chains.cpp index cfc2812..b9c1eed 100644 --- a/source/fuzz/fuzzer_pass_add_access_chains.cpp +++ b/source/fuzz/fuzzer_pass_add_access_chains.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddAccessChains::FuzzerPassAddAccessChains( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddAccessChains::~FuzzerPassAddAccessChains() = default;
diff --git a/source/fuzz/fuzzer_pass_add_access_chains.h b/source/fuzz/fuzzer_pass_add_access_chains.h index 7e8ed61..8649296 100644 --- a/source/fuzz/fuzzer_pass_add_access_chains.h +++ b/source/fuzz/fuzzer_pass_add_access_chains.h
@@ -26,7 +26,7 @@ class FuzzerPassAddAccessChains : public FuzzerPass { public: FuzzerPassAddAccessChains(opt::IRContext* ir_context, - FactManager* fact_manager, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_composite_types.cpp b/source/fuzz/fuzzer_pass_add_composite_types.cpp index 32c720e..9b0dda8 100644 --- a/source/fuzz/fuzzer_pass_add_composite_types.cpp +++ b/source/fuzz/fuzzer_pass_add_composite_types.cpp
@@ -22,10 +22,11 @@ namespace fuzz { FuzzerPassAddCompositeTypes::FuzzerPassAddCompositeTypes( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddCompositeTypes::~FuzzerPassAddCompositeTypes() = default;
diff --git a/source/fuzz/fuzzer_pass_add_composite_types.h b/source/fuzz/fuzzer_pass_add_composite_types.h index 29d4bb8..87bc0ff 100644 --- a/source/fuzz/fuzzer_pass_add_composite_types.h +++ b/source/fuzz/fuzzer_pass_add_composite_types.h
@@ -25,7 +25,7 @@ class FuzzerPassAddCompositeTypes : public FuzzerPass { public: FuzzerPassAddCompositeTypes( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_dead_blocks.cpp b/source/fuzz/fuzzer_pass_add_dead_blocks.cpp index c9bc9c4..30a4145 100644 --- a/source/fuzz/fuzzer_pass_add_dead_blocks.cpp +++ b/source/fuzz/fuzzer_pass_add_dead_blocks.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddDeadBlocks::FuzzerPassAddDeadBlocks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddDeadBlocks::~FuzzerPassAddDeadBlocks() = default; @@ -40,6 +41,12 @@ GetFuzzerContext()->GetChanceOfAddingDeadBlock())) { continue; } + + // Make sure the module contains a boolean constant equal to + // |condition_value|. + bool condition_value = GetFuzzerContext()->ChooseEven(); + FindOrCreateBoolConstant(condition_value); + // We speculatively create a transformation, and then apply it (below) if // it turns out to be applicable. This avoids duplicating the logic for // applicability checking. @@ -47,14 +54,14 @@ // It means that fresh ids for transformations that turn out not to be // applicable end up being unused. candidate_transformations.emplace_back(TransformationAddDeadBlock( - GetFuzzerContext()->GetFreshId(), block.id(), - GetFuzzerContext()->ChooseEven())); + GetFuzzerContext()->GetFreshId(), block.id(), condition_value)); } } // Apply all those transformations that are in fact applicable. for (auto& transformation : candidate_transformations) { - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { - transformation.Apply(GetIRContext(), GetFactManager()); + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } }
diff --git a/source/fuzz/fuzzer_pass_add_dead_blocks.h b/source/fuzz/fuzzer_pass_add_dead_blocks.h index 01e3843..d78f088 100644 --- a/source/fuzz/fuzzer_pass_add_dead_blocks.h +++ b/source/fuzz/fuzzer_pass_add_dead_blocks.h
@@ -24,7 +24,8 @@ // passes can then manipulate such blocks. class FuzzerPassAddDeadBlocks : public FuzzerPass { public: - FuzzerPassAddDeadBlocks(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassAddDeadBlocks(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_dead_breaks.cpp b/source/fuzz/fuzzer_pass_add_dead_breaks.cpp index aefc2fc..f3900aa 100644 --- a/source/fuzz/fuzzer_pass_add_dead_breaks.cpp +++ b/source/fuzz/fuzzer_pass_add_dead_breaks.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddDeadBreaks::FuzzerPassAddDeadBreaks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddDeadBreaks::~FuzzerPassAddDeadBreaks() = default; @@ -76,11 +77,15 @@ }); } + // Make sure the module has a required boolean constant to be used in + // OpBranchConditional instruction. + auto break_condition = GetFuzzerContext()->ChooseEven(); + FindOrCreateBoolConstant(break_condition); + auto candidate_transformation = TransformationAddDeadBreak( - block.id(), merge_block->id(), GetFuzzerContext()->ChooseEven(), - std::move(phi_ids)); - if (candidate_transformation.IsApplicable(GetIRContext(), - *GetFactManager())) { + block.id(), merge_block->id(), break_condition, std::move(phi_ids)); + if (candidate_transformation.IsApplicable( + GetIRContext(), *GetTransformationContext())) { // Only consider a transformation as a candidate if it is applicable. candidate_transformations.push_back( std::move(candidate_transformation)); @@ -109,10 +114,11 @@ candidate_transformations.erase(candidate_transformations.begin() + index); // Probabilistically decide whether to try to apply it vs. ignore it, in the // case that it is applicable. - if (transformation.IsApplicable(GetIRContext(), *GetFactManager()) && + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext()) && GetFuzzerContext()->ChoosePercentage( GetFuzzerContext()->GetChanceOfAddingDeadBreak())) { - transformation.Apply(GetIRContext(), GetFactManager()); + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } }
diff --git a/source/fuzz/fuzzer_pass_add_dead_breaks.h b/source/fuzz/fuzzer_pass_add_dead_breaks.h index 12a5095..c379eed 100644 --- a/source/fuzz/fuzzer_pass_add_dead_breaks.h +++ b/source/fuzz/fuzzer_pass_add_dead_breaks.h
@@ -23,7 +23,8 @@ // A fuzzer pass for adding dead break edges to the module. class FuzzerPassAddDeadBreaks : public FuzzerPass { public: - FuzzerPassAddDeadBreaks(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassAddDeadBreaks(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_dead_continues.cpp b/source/fuzz/fuzzer_pass_add_dead_continues.cpp index 852df3d..56a7fd1 100644 --- a/source/fuzz/fuzzer_pass_add_dead_continues.cpp +++ b/source/fuzz/fuzzer_pass_add_dead_continues.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddDeadContinues::FuzzerPassAddDeadContinues( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddDeadContinues::~FuzzerPassAddDeadContinues() = default; @@ -67,18 +68,24 @@ }); } + // Make sure the module contains a boolean constant equal to + // |condition_value|. + bool condition_value = GetFuzzerContext()->ChooseEven(); + FindOrCreateBoolConstant(condition_value); + // Make a transformation to add a dead continue from this node; if the // node turns out to be inappropriate (e.g. by not being in a loop) the // precondition for the transformation will fail and it will be ignored. auto candidate_transformation = TransformationAddDeadContinue( - block.id(), GetFuzzerContext()->ChooseEven(), std::move(phi_ids)); + block.id(), condition_value, std::move(phi_ids)); // Probabilistically decide whether to apply the transformation in the // case that it is applicable. if (candidate_transformation.IsApplicable(GetIRContext(), - *GetFactManager()) && + *GetTransformationContext()) && GetFuzzerContext()->ChoosePercentage( GetFuzzerContext()->GetChanceOfAddingDeadContinue())) { - candidate_transformation.Apply(GetIRContext(), GetFactManager()); + candidate_transformation.Apply(GetIRContext(), + GetTransformationContext()); *GetTransformations()->add_transformation() = candidate_transformation.ToMessage(); }
diff --git a/source/fuzz/fuzzer_pass_add_dead_continues.h b/source/fuzz/fuzzer_pass_add_dead_continues.h index d067f1c..b2acb93 100644 --- a/source/fuzz/fuzzer_pass_add_dead_continues.h +++ b/source/fuzz/fuzzer_pass_add_dead_continues.h
@@ -24,7 +24,7 @@ class FuzzerPassAddDeadContinues : public FuzzerPass { public: FuzzerPassAddDeadContinues( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_equation_instructions.cpp b/source/fuzz/fuzzer_pass_add_equation_instructions.cpp index 7f34344..49c4a8a 100644 --- a/source/fuzz/fuzzer_pass_add_equation_instructions.cpp +++ b/source/fuzz/fuzzer_pass_add_equation_instructions.cpp
@@ -23,10 +23,11 @@ namespace fuzz { FuzzerPassAddEquationInstructions::FuzzerPassAddEquationInstructions( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddEquationInstructions::~FuzzerPassAddEquationInstructions() = default;
diff --git a/source/fuzz/fuzzer_pass_add_equation_instructions.h b/source/fuzz/fuzzer_pass_add_equation_instructions.h index 84229c0..6e64977 100644 --- a/source/fuzz/fuzzer_pass_add_equation_instructions.h +++ b/source/fuzz/fuzzer_pass_add_equation_instructions.h
@@ -27,7 +27,7 @@ class FuzzerPassAddEquationInstructions : public FuzzerPass { public: FuzzerPassAddEquationInstructions( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_function_calls.cpp b/source/fuzz/fuzzer_pass_add_function_calls.cpp index 545aa16..569df10 100644 --- a/source/fuzz/fuzzer_pass_add_function_calls.cpp +++ b/source/fuzz/fuzzer_pass_add_function_calls.cpp
@@ -24,10 +24,11 @@ namespace fuzz { FuzzerPassAddFunctionCalls::FuzzerPassAddFunctionCalls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddFunctionCalls::~FuzzerPassAddFunctionCalls() = default; @@ -74,8 +75,9 @@ while (!candidate_functions.empty()) { opt::Function* candidate_function = GetFuzzerContext()->RemoveAtRandomIndex(&candidate_functions); - if (!GetFactManager()->BlockIsDead(block->id()) && - !GetFactManager()->FunctionIsLivesafe( + if (!GetTransformationContext()->GetFactManager()->BlockIsDead( + block->id()) && + !GetTransformationContext()->GetFactManager()->FunctionIsLivesafe( candidate_function->result_id())) { // Unless in a dead block, only livesafe functions can be invoked continue; @@ -132,9 +134,11 @@ default: return false; } - if (!GetFactManager()->BlockIsDead(block->id()) && - !GetFactManager()->PointeeValueIsIrrelevant( - inst->result_id())) { + if (!GetTransformationContext()->GetFactManager()->BlockIsDead( + block->id()) && + !GetTransformationContext() + ->GetFactManager() + ->PointeeValueIsIrrelevant(inst->result_id())) { // We can only pass a pointer as an actual parameter // if the pointee value for the pointer is irrelevant, // or if the block from which we would make the @@ -210,8 +214,9 @@ result.push_back(fresh_variable_id); // Now bring the variable into existence. - if (type_instruction->GetSingleWordInOperand(0) == - SpvStorageClassFunction) { + auto storage_class = static_cast<SpvStorageClass>( + type_instruction->GetSingleWordInOperand(0)); + if (storage_class == SpvStorageClassFunction) { // Add a new zero-initialized local variable to the current // function, noting that its pointee value is irrelevant. ApplyTransformation(TransformationAddLocalVariable( @@ -220,16 +225,19 @@ type_instruction->GetSingleWordInOperand(1)), true)); } else { - assert(type_instruction->GetSingleWordInOperand(0) == - SpvStorageClassPrivate && - "Only Function and Private storage classes are " + assert((storage_class == SpvStorageClassPrivate || + storage_class == SpvStorageClassWorkgroup) && + "Only Function, Private and Workgroup storage classes are " "supported at present."); - // Add a new zero-initialized global variable to the module, - // noting that its pointee value is irrelevant. + // Add a new global variable to the module, zero-initializing it if + // it has Private storage class, and noting that its pointee value is + // irrelevant. ApplyTransformation(TransformationAddGlobalVariable( - fresh_variable_id, arg_type_id, - FindOrCreateZeroConstant( - type_instruction->GetSingleWordInOperand(1)), + fresh_variable_id, arg_type_id, storage_class, + storage_class == SpvStorageClassPrivate + ? FindOrCreateZeroConstant( + type_instruction->GetSingleWordInOperand(1)) + : 0, true)); } } else {
diff --git a/source/fuzz/fuzzer_pass_add_function_calls.h b/source/fuzz/fuzzer_pass_add_function_calls.h index 5d184fd..8f75e8c 100644 --- a/source/fuzz/fuzzer_pass_add_function_calls.h +++ b/source/fuzz/fuzzer_pass_add_function_calls.h
@@ -25,7 +25,7 @@ class FuzzerPassAddFunctionCalls : public FuzzerPass { public: FuzzerPassAddFunctionCalls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_global_variables.cpp b/source/fuzz/fuzzer_pass_add_global_variables.cpp index 1371f46..4023b22 100644 --- a/source/fuzz/fuzzer_pass_add_global_variables.cpp +++ b/source/fuzz/fuzzer_pass_add_global_variables.cpp
@@ -21,53 +21,56 @@ namespace fuzz { FuzzerPassAddGlobalVariables::FuzzerPassAddGlobalVariables( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddGlobalVariables::~FuzzerPassAddGlobalVariables() = default; void FuzzerPassAddGlobalVariables::Apply() { - auto base_type_ids_and_pointers = - GetAvailableBaseTypesAndPointers(SpvStorageClassPrivate); + auto basic_type_ids_and_pointers = + GetAvailableBasicTypesAndPointers(SpvStorageClassPrivate); - // These are the base types that are available to this fuzzer pass. - auto& base_types = base_type_ids_and_pointers.first; + // These are the basic types that are available to this fuzzer pass. + auto& basic_types = basic_type_ids_and_pointers.first; - // These are the pointers to those base types that are *initially* available + // These are the pointers to those basic types that are *initially* available // to the fuzzer pass. The fuzzer pass might add pointer types in cases where - // none are available for a given base type. - auto& base_type_to_pointers = base_type_ids_and_pointers.second; + // none are available for a given basic type. + auto& basic_type_to_pointers = basic_type_ids_and_pointers.second; // Probabilistically keep adding global variables. while (GetFuzzerContext()->ChoosePercentage( GetFuzzerContext()->GetChanceOfAddingGlobalVariable())) { - // Choose a random base type; the new variable's type will be a pointer to - // this base type. - uint32_t base_type = - base_types[GetFuzzerContext()->RandomIndex(base_types)]; + // Choose a random basic type; the new variable's type will be a pointer to + // this basic type. + uint32_t basic_type = + basic_types[GetFuzzerContext()->RandomIndex(basic_types)]; uint32_t pointer_type_id; - std::vector<uint32_t>& available_pointers_to_base_type = - base_type_to_pointers.at(base_type); - // Determine whether there is at least one pointer to this base type. - if (available_pointers_to_base_type.empty()) { + std::vector<uint32_t>& available_pointers_to_basic_type = + basic_type_to_pointers.at(basic_type); + // Determine whether there is at least one pointer to this basic type. + if (available_pointers_to_basic_type.empty()) { // There is not. Make one, to use here, and add it to the available - // pointers for the base type so that future variables can potentially + // pointers for the basic type so that future variables can potentially // use it. pointer_type_id = GetFuzzerContext()->GetFreshId(); - available_pointers_to_base_type.push_back(pointer_type_id); + available_pointers_to_basic_type.push_back(pointer_type_id); ApplyTransformation(TransformationAddTypePointer( - pointer_type_id, SpvStorageClassPrivate, base_type)); + pointer_type_id, SpvStorageClassPrivate, basic_type)); } else { // There is - grab one. pointer_type_id = - available_pointers_to_base_type[GetFuzzerContext()->RandomIndex( - available_pointers_to_base_type)]; + available_pointers_to_basic_type[GetFuzzerContext()->RandomIndex( + available_pointers_to_basic_type)]; } + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3274): We could + // add new variables with Workgroup storage class in compute shaders. ApplyTransformation(TransformationAddGlobalVariable( GetFuzzerContext()->GetFreshId(), pointer_type_id, - FindOrCreateZeroConstant(base_type), true)); + SpvStorageClassPrivate, FindOrCreateZeroConstant(basic_type), true)); } }
diff --git a/source/fuzz/fuzzer_pass_add_global_variables.h b/source/fuzz/fuzzer_pass_add_global_variables.h index c71d147..a907d36 100644 --- a/source/fuzz/fuzzer_pass_add_global_variables.h +++ b/source/fuzz/fuzzer_pass_add_global_variables.h
@@ -25,7 +25,7 @@ class FuzzerPassAddGlobalVariables : public FuzzerPass { public: FuzzerPassAddGlobalVariables( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_loads.cpp b/source/fuzz/fuzzer_pass_add_loads.cpp index 851787f..256255b 100644 --- a/source/fuzz/fuzzer_pass_add_loads.cpp +++ b/source/fuzz/fuzzer_pass_add_loads.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddLoads::FuzzerPassAddLoads( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddLoads::~FuzzerPassAddLoads() = default; @@ -59,7 +60,7 @@ if (!instruction->result_id() || !instruction->type_id()) { return false; } - switch (instruction->result_id()) { + switch (instruction->opcode()) { case SpvOpConstantNull: case SpvOpUndef: // Do not allow loading from a null or undefined pointer;
diff --git a/source/fuzz/fuzzer_pass_add_loads.h b/source/fuzz/fuzzer_pass_add_loads.h index 125bc5d..c4d5b27 100644 --- a/source/fuzz/fuzzer_pass_add_loads.h +++ b/source/fuzz/fuzzer_pass_add_loads.h
@@ -23,7 +23,8 @@ // Fuzzer pass that adds stores, at random, from pointers in the module. class FuzzerPassAddLoads : public FuzzerPass { public: - FuzzerPassAddLoads(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassAddLoads(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_local_variables.cpp b/source/fuzz/fuzzer_pass_add_local_variables.cpp index 8d6d80d..661159e 100644 --- a/source/fuzz/fuzzer_pass_add_local_variables.cpp +++ b/source/fuzz/fuzzer_pass_add_local_variables.cpp
@@ -22,55 +22,56 @@ namespace fuzz { FuzzerPassAddLocalVariables::FuzzerPassAddLocalVariables( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddLocalVariables::~FuzzerPassAddLocalVariables() = default; void FuzzerPassAddLocalVariables::Apply() { - auto base_type_ids_and_pointers = - GetAvailableBaseTypesAndPointers(SpvStorageClassFunction); + auto basic_type_ids_and_pointers = + GetAvailableBasicTypesAndPointers(SpvStorageClassFunction); - // These are the base types that are available to this fuzzer pass. - auto& base_types = base_type_ids_and_pointers.first; + // These are the basic types that are available to this fuzzer pass. + auto& basic_types = basic_type_ids_and_pointers.first; - // These are the pointers to those base types that are *initially* available + // These are the pointers to those basic types that are *initially* available // to the fuzzer pass. The fuzzer pass might add pointer types in cases where - // none are available for a given base type. - auto& base_type_to_pointers = base_type_ids_and_pointers.second; + // none are available for a given basic type. + auto& basic_type_to_pointers = basic_type_ids_and_pointers.second; // Consider every function in the module. for (auto& function : *GetIRContext()->module()) { // Probabilistically keep adding random variables to this function. while (GetFuzzerContext()->ChoosePercentage( GetFuzzerContext()->GetChanceOfAddingLocalVariable())) { - // Choose a random base type; the new variable's type will be a pointer to - // this base type. - uint32_t base_type = - base_types[GetFuzzerContext()->RandomIndex(base_types)]; + // Choose a random basic type; the new variable's type will be a pointer + // to this basic type. + uint32_t basic_type = + basic_types[GetFuzzerContext()->RandomIndex(basic_types)]; uint32_t pointer_type; - std::vector<uint32_t>& available_pointers_to_base_type = - base_type_to_pointers.at(base_type); - // Determine whether there is at least one pointer to this base type. - if (available_pointers_to_base_type.empty()) { + std::vector<uint32_t>& available_pointers_to_basic_type = + basic_type_to_pointers.at(basic_type); + // Determine whether there is at least one pointer to this basic type. + if (available_pointers_to_basic_type.empty()) { // There is not. Make one, to use here, and add it to the available - // pointers for the base type so that future variables can potentially + // pointers for the basic type so that future variables can potentially // use it. pointer_type = GetFuzzerContext()->GetFreshId(); ApplyTransformation(TransformationAddTypePointer( - pointer_type, SpvStorageClassFunction, base_type)); - available_pointers_to_base_type.push_back(pointer_type); + pointer_type, SpvStorageClassFunction, basic_type)); + available_pointers_to_basic_type.push_back(pointer_type); } else { // There is - grab one. pointer_type = - available_pointers_to_base_type[GetFuzzerContext()->RandomIndex( - available_pointers_to_base_type)]; + available_pointers_to_basic_type[GetFuzzerContext()->RandomIndex( + available_pointers_to_basic_type)]; } ApplyTransformation(TransformationAddLocalVariable( GetFuzzerContext()->GetFreshId(), pointer_type, function.result_id(), - FindOrCreateZeroConstant(base_type), true)); + FindOrCreateZeroConstant(basic_type), true)); } } }
diff --git a/source/fuzz/fuzzer_pass_add_local_variables.h b/source/fuzz/fuzzer_pass_add_local_variables.h index eed3665..08d26d8 100644 --- a/source/fuzz/fuzzer_pass_add_local_variables.h +++ b/source/fuzz/fuzzer_pass_add_local_variables.h
@@ -25,7 +25,7 @@ class FuzzerPassAddLocalVariables : public FuzzerPass { public: FuzzerPassAddLocalVariables( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_no_contraction_decorations.cpp b/source/fuzz/fuzzer_pass_add_no_contraction_decorations.cpp index 82fb539..09627d0 100644 --- a/source/fuzz/fuzzer_pass_add_no_contraction_decorations.cpp +++ b/source/fuzz/fuzzer_pass_add_no_contraction_decorations.cpp
@@ -20,10 +20,11 @@ namespace fuzz { FuzzerPassAddNoContractionDecorations::FuzzerPassAddNoContractionDecorations( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddNoContractionDecorations:: ~FuzzerPassAddNoContractionDecorations() = default;
diff --git a/source/fuzz/fuzzer_pass_add_no_contraction_decorations.h b/source/fuzz/fuzzer_pass_add_no_contraction_decorations.h index abe5bd7..f32e5bc 100644 --- a/source/fuzz/fuzzer_pass_add_no_contraction_decorations.h +++ b/source/fuzz/fuzzer_pass_add_no_contraction_decorations.h
@@ -24,7 +24,7 @@ class FuzzerPassAddNoContractionDecorations : public FuzzerPass { public: FuzzerPassAddNoContractionDecorations( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_stores.cpp b/source/fuzz/fuzzer_pass_add_stores.cpp index 794ddc3..46efc64 100644 --- a/source/fuzz/fuzzer_pass_add_stores.cpp +++ b/source/fuzz/fuzzer_pass_add_stores.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAddStores::FuzzerPassAddStores( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAddStores::~FuzzerPassAddStores() = default; @@ -67,12 +68,11 @@ // Not a pointer. return false; } - if (type_inst->GetSingleWordInOperand(0) == - SpvStorageClassInput) { - // Read-only: cannot store to it. + if (instruction->IsReadOnlyPointer()) { + // Read only: cannot store to it. return false; } - switch (instruction->result_id()) { + switch (instruction->opcode()) { case SpvOpConstantNull: case SpvOpUndef: // Do not allow storing to a null or undefined pointer; @@ -82,9 +82,13 @@ default: break; } - return GetFactManager()->BlockIsDead(block->id()) || - GetFactManager()->PointeeValueIsIrrelevant( - instruction->result_id()); + return GetTransformationContext() + ->GetFactManager() + ->BlockIsDead(block->id()) || + GetTransformationContext() + ->GetFactManager() + ->PointeeValueIsIrrelevant( + instruction->result_id()); }); // At this point, |relevant_pointers| contains all the pointers we might
diff --git a/source/fuzz/fuzzer_pass_add_stores.h b/source/fuzz/fuzzer_pass_add_stores.h index 9daa9e0..55ec67f 100644 --- a/source/fuzz/fuzzer_pass_add_stores.h +++ b/source/fuzz/fuzzer_pass_add_stores.h
@@ -25,7 +25,8 @@ // are known not to affect the module's overall behaviour. class FuzzerPassAddStores : public FuzzerPass { public: - FuzzerPassAddStores(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassAddStores(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_add_useful_constructs.cpp b/source/fuzz/fuzzer_pass_add_useful_constructs.cpp deleted file mode 100644 index 8552dfd..0000000 --- a/source/fuzz/fuzzer_pass_add_useful_constructs.cpp +++ /dev/null
@@ -1,214 +0,0 @@ -// Copyright (c) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "source/fuzz/fuzzer_pass_add_useful_constructs.h" - -#include "source/fuzz/transformation_add_constant_boolean.h" -#include "source/fuzz/transformation_add_constant_scalar.h" -#include "source/fuzz/transformation_add_type_boolean.h" -#include "source/fuzz/transformation_add_type_float.h" -#include "source/fuzz/transformation_add_type_int.h" -#include "source/fuzz/transformation_add_type_pointer.h" - -namespace spvtools { -namespace fuzz { - -FuzzerPassAddUsefulConstructs::FuzzerPassAddUsefulConstructs( - opt::IRContext* ir_context, FactManager* fact_manager, - FuzzerContext* fuzzer_context, - protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} - -FuzzerPassAddUsefulConstructs::~FuzzerPassAddUsefulConstructs() = default; - -void FuzzerPassAddUsefulConstructs::MaybeAddIntConstant( - uint32_t width, bool is_signed, std::vector<uint32_t> data) const { - opt::analysis::Integer temp_int_type(width, is_signed); - assert(GetIRContext()->get_type_mgr()->GetId(&temp_int_type) && - "int type should already be registered."); - auto registered_int_type = GetIRContext() - ->get_type_mgr() - ->GetRegisteredType(&temp_int_type) - ->AsInteger(); - auto int_type_id = GetIRContext()->get_type_mgr()->GetId(registered_int_type); - assert(int_type_id && - "The relevant int type should have been added to the module already."); - opt::analysis::IntConstant int_constant(registered_int_type, data); - if (!GetIRContext()->get_constant_mgr()->FindConstant(&int_constant)) { - TransformationAddConstantScalar add_constant_int = - TransformationAddConstantScalar(GetFuzzerContext()->GetFreshId(), - int_type_id, data); - assert(add_constant_int.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_constant_int.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = add_constant_int.ToMessage(); - } -} - -void FuzzerPassAddUsefulConstructs::MaybeAddFloatConstant( - uint32_t width, std::vector<uint32_t> data) const { - opt::analysis::Float temp_float_type(width); - assert(GetIRContext()->get_type_mgr()->GetId(&temp_float_type) && - "float type should already be registered."); - auto registered_float_type = GetIRContext() - ->get_type_mgr() - ->GetRegisteredType(&temp_float_type) - ->AsFloat(); - auto float_type_id = - GetIRContext()->get_type_mgr()->GetId(registered_float_type); - assert( - float_type_id && - "The relevant float type should have been added to the module already."); - opt::analysis::FloatConstant float_constant(registered_float_type, data); - if (!GetIRContext()->get_constant_mgr()->FindConstant(&float_constant)) { - TransformationAddConstantScalar add_constant_float = - TransformationAddConstantScalar(GetFuzzerContext()->GetFreshId(), - float_type_id, data); - assert(add_constant_float.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_constant_float.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = - add_constant_float.ToMessage(); - } -} - -void FuzzerPassAddUsefulConstructs::Apply() { - { - // Add boolean type if not present. - opt::analysis::Bool temp_bool_type; - if (!GetIRContext()->get_type_mgr()->GetId(&temp_bool_type)) { - auto add_type_boolean = - TransformationAddTypeBoolean(GetFuzzerContext()->GetFreshId()); - assert(add_type_boolean.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_type_boolean.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = - add_type_boolean.ToMessage(); - } - } - - { - // Add signed and unsigned 32-bit integer types if not present. - for (auto is_signed : {true, false}) { - opt::analysis::Integer temp_int_type(32, is_signed); - if (!GetIRContext()->get_type_mgr()->GetId(&temp_int_type)) { - TransformationAddTypeInt add_type_int = TransformationAddTypeInt( - GetFuzzerContext()->GetFreshId(), 32, is_signed); - assert(add_type_int.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_type_int.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = add_type_int.ToMessage(); - } - } - } - - { - // Add 32-bit float type if not present. - opt::analysis::Float temp_float_type(32); - if (!GetIRContext()->get_type_mgr()->GetId(&temp_float_type)) { - TransformationAddTypeFloat add_type_float = - TransformationAddTypeFloat(GetFuzzerContext()->GetFreshId(), 32); - assert(add_type_float.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_type_float.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = add_type_float.ToMessage(); - } - } - - // Add boolean constants true and false if not present. - opt::analysis::Bool temp_bool_type; - auto bool_type = GetIRContext() - ->get_type_mgr() - ->GetRegisteredType(&temp_bool_type) - ->AsBool(); - for (auto boolean_value : {true, false}) { - // Add OpConstantTrue/False if not already there. - opt::analysis::BoolConstant bool_constant(bool_type, boolean_value); - if (!GetIRContext()->get_constant_mgr()->FindConstant(&bool_constant)) { - TransformationAddConstantBoolean add_constant_boolean( - GetFuzzerContext()->GetFreshId(), boolean_value); - assert(add_constant_boolean.IsApplicable(GetIRContext(), - *GetFactManager()) && - "Should be applicable by construction."); - add_constant_boolean.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = - add_constant_boolean.ToMessage(); - } - } - - // Add signed and unsigned 32-bit integer constants 0 and 1 if not present. - for (auto is_signed : {true, false}) { - for (auto value : {0u, 1u}) { - MaybeAddIntConstant(32, is_signed, {value}); - } - } - - // Add 32-bit float constants 0.0 and 1.0 if not present. - uint32_t uint_data[2]; - float float_data[2] = {0.0, 1.0}; - memcpy(uint_data, float_data, sizeof(float_data)); - for (unsigned int& datum : uint_data) { - MaybeAddFloatConstant(32, {datum}); - } - - // For every known-to-be-constant uniform, make sure we have instructions - // declaring: - // - a pointer type with uniform storage class, whose pointee type is the type - // of the element - // - a signed integer constant for each index required to access the element - // - a constant for the constant value itself - for (auto& fact_and_type_id : - GetFactManager()->GetConstantUniformFactsAndTypes()) { - uint32_t element_type_id = fact_and_type_id.second; - assert(element_type_id); - auto element_type = - GetIRContext()->get_type_mgr()->GetType(element_type_id); - assert(element_type && - "If the constant uniform fact is well-formed, the module must " - "already have a declaration of the type for the uniform element."); - opt::analysis::Pointer uniform_pointer(element_type, - SpvStorageClassUniform); - if (!GetIRContext()->get_type_mgr()->GetId(&uniform_pointer)) { - auto add_pointer = - TransformationAddTypePointer(GetFuzzerContext()->GetFreshId(), - SpvStorageClassUniform, element_type_id); - assert(add_pointer.IsApplicable(GetIRContext(), *GetFactManager()) && - "Should be applicable by construction."); - add_pointer.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = add_pointer.ToMessage(); - } - std::vector<uint32_t> words; - for (auto word : fact_and_type_id.first.constant_word()) { - words.push_back(word); - } - // We get the element type again as the type manager may have been - // invalidated since we last retrieved it. - element_type = GetIRContext()->get_type_mgr()->GetType(element_type_id); - if (element_type->AsInteger()) { - MaybeAddIntConstant(element_type->AsInteger()->width(), - element_type->AsInteger()->IsSigned(), words); - } else { - assert(element_type->AsFloat() && - "Known uniform values must be integer or floating-point."); - MaybeAddFloatConstant(element_type->AsFloat()->width(), words); - } - for (auto index : - fact_and_type_id.first.uniform_buffer_element_descriptor().index()) { - MaybeAddIntConstant(32, true, {index}); - } - } -} - -} // namespace fuzz -} // namespace spvtools
diff --git a/source/fuzz/fuzzer_pass_add_useful_constructs.h b/source/fuzz/fuzzer_pass_add_useful_constructs.h deleted file mode 100644 index 7dc00f1..0000000 --- a/source/fuzz/fuzzer_pass_add_useful_constructs.h +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright (c) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_ -#define SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_ - -#include "source/fuzz/fuzzer_pass.h" - -namespace spvtools { -namespace fuzz { - -// An initial pass for adding useful ingredients to the module, such as boolean -// constants, if they are not present. -class FuzzerPassAddUsefulConstructs : public FuzzerPass { - public: - FuzzerPassAddUsefulConstructs( - opt::IRContext* ir_context, FactManager* fact_manager, - FuzzerContext* fuzzer_context, - protobufs::TransformationSequence* transformations); - - ~FuzzerPassAddUsefulConstructs() override; - - void Apply() override; - - private: - void MaybeAddIntConstant(uint32_t width, bool is_signed, - std::vector<uint32_t> data) const; - - void MaybeAddFloatConstant(uint32_t width, std::vector<uint32_t> data) const; -}; - -} // namespace fuzz -} // namespace spvtools - -#endif // SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_
diff --git a/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp b/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp new file mode 100644 index 0000000..1d6d434 --- /dev/null +++ b/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp
@@ -0,0 +1,48 @@ +// Copyright (c) 2020 André Perez Maselco +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/fuzzer_pass_adjust_branch_weights.h" + +#include "source/fuzz/fuzzer_util.h" +#include "source/fuzz/instruction_descriptor.h" +#include "source/fuzz/transformation_adjust_branch_weights.h" + +namespace spvtools { +namespace fuzz { + +FuzzerPassAdjustBranchWeights::FuzzerPassAdjustBranchWeights( + opt::IRContext* ir_context, TransformationContext* transformation_context, + FuzzerContext* fuzzer_context, + protobufs::TransformationSequence* transformations) + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} + +FuzzerPassAdjustBranchWeights::~FuzzerPassAdjustBranchWeights() = default; + +void FuzzerPassAdjustBranchWeights::Apply() { + // For all OpBranchConditional instructions, + // randomly applies the transformation. + GetIRContext()->module()->ForEachInst([this](opt::Instruction* instruction) { + if (instruction->opcode() == SpvOpBranchConditional && + GetFuzzerContext()->ChoosePercentage( + GetFuzzerContext()->GetChanceOfAdjustingBranchWeights())) { + ApplyTransformation(TransformationAdjustBranchWeights( + MakeInstructionDescriptor(GetIRContext(), instruction), + GetFuzzerContext()->GetRandomBranchWeights())); + } + }); +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/fuzzer_pass_adjust_branch_weights.h b/source/fuzz/fuzzer_pass_adjust_branch_weights.h new file mode 100644 index 0000000..5b2b33f --- /dev/null +++ b/source/fuzz/fuzzer_pass_adjust_branch_weights.h
@@ -0,0 +1,41 @@ +// Copyright (c) 2020 André Perez Maselco +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_ +#define SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_ + +#include "source/fuzz/fuzzer_pass.h" + +namespace spvtools { +namespace fuzz { + +// This fuzzer pass searches for branch conditional instructions +// and randomly chooses which of these instructions will have their weights +// adjusted. +class FuzzerPassAdjustBranchWeights : public FuzzerPass { + public: + FuzzerPassAdjustBranchWeights( + opt::IRContext* ir_context, TransformationContext* transformation_context, + FuzzerContext* fuzzer_context, + protobufs::TransformationSequence* transformations); + + ~FuzzerPassAdjustBranchWeights(); + + void Apply() override; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_
diff --git a/source/fuzz/fuzzer_pass_adjust_function_controls.cpp b/source/fuzz/fuzzer_pass_adjust_function_controls.cpp index fe229bc..aa62d2f 100644 --- a/source/fuzz/fuzzer_pass_adjust_function_controls.cpp +++ b/source/fuzz/fuzzer_pass_adjust_function_controls.cpp
@@ -20,10 +20,11 @@ namespace fuzz { FuzzerPassAdjustFunctionControls::FuzzerPassAdjustFunctionControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAdjustFunctionControls::~FuzzerPassAdjustFunctionControls() = default;
diff --git a/source/fuzz/fuzzer_pass_adjust_function_controls.h b/source/fuzz/fuzzer_pass_adjust_function_controls.h index 02d3600..e20541b 100644 --- a/source/fuzz/fuzzer_pass_adjust_function_controls.h +++ b/source/fuzz/fuzzer_pass_adjust_function_controls.h
@@ -24,7 +24,7 @@ class FuzzerPassAdjustFunctionControls : public FuzzerPass { public: FuzzerPassAdjustFunctionControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_adjust_loop_controls.cpp b/source/fuzz/fuzzer_pass_adjust_loop_controls.cpp index c9843d0..f7addff 100644 --- a/source/fuzz/fuzzer_pass_adjust_loop_controls.cpp +++ b/source/fuzz/fuzzer_pass_adjust_loop_controls.cpp
@@ -20,10 +20,11 @@ namespace fuzz { FuzzerPassAdjustLoopControls::FuzzerPassAdjustLoopControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAdjustLoopControls::~FuzzerPassAdjustLoopControls() = default;
diff --git a/source/fuzz/fuzzer_pass_adjust_loop_controls.h b/source/fuzz/fuzzer_pass_adjust_loop_controls.h index e945606..ee5cd48 100644 --- a/source/fuzz/fuzzer_pass_adjust_loop_controls.h +++ b/source/fuzz/fuzzer_pass_adjust_loop_controls.h
@@ -24,7 +24,7 @@ class FuzzerPassAdjustLoopControls : public FuzzerPass { public: FuzzerPassAdjustLoopControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.cpp b/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.cpp index 2d3d676..32f5ea5 100644 --- a/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.cpp +++ b/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassAdjustMemoryOperandsMasks::FuzzerPassAdjustMemoryOperandsMasks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAdjustMemoryOperandsMasks::~FuzzerPassAdjustMemoryOperandsMasks() = default;
diff --git a/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.h b/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.h index c3d7118..699dcb5 100644 --- a/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.h +++ b/source/fuzz/fuzzer_pass_adjust_memory_operands_masks.h
@@ -25,7 +25,7 @@ class FuzzerPassAdjustMemoryOperandsMasks : public FuzzerPass { public: FuzzerPassAdjustMemoryOperandsMasks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_adjust_selection_controls.cpp b/source/fuzz/fuzzer_pass_adjust_selection_controls.cpp index 397dfed..83b1854 100644 --- a/source/fuzz/fuzzer_pass_adjust_selection_controls.cpp +++ b/source/fuzz/fuzzer_pass_adjust_selection_controls.cpp
@@ -20,10 +20,11 @@ namespace fuzz { FuzzerPassAdjustSelectionControls::FuzzerPassAdjustSelectionControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassAdjustSelectionControls::~FuzzerPassAdjustSelectionControls() = default;
diff --git a/source/fuzz/fuzzer_pass_adjust_selection_controls.h b/source/fuzz/fuzzer_pass_adjust_selection_controls.h index b5b255c..820b30d 100644 --- a/source/fuzz/fuzzer_pass_adjust_selection_controls.h +++ b/source/fuzz/fuzzer_pass_adjust_selection_controls.h
@@ -24,7 +24,7 @@ class FuzzerPassAdjustSelectionControls : public FuzzerPass { public: FuzzerPassAdjustSelectionControls( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp b/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp index 5711f35..0ec93e1 100644 --- a/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp +++ b/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp
@@ -19,32 +19,45 @@ #include "source/fuzz/id_use_descriptor.h" #include "source/fuzz/instruction_descriptor.h" #include "source/fuzz/transformation_composite_extract.h" +#include "source/fuzz/transformation_compute_data_synonym_fact_closure.h" #include "source/fuzz/transformation_replace_id_with_synonym.h" namespace spvtools { namespace fuzz { FuzzerPassApplyIdSynonyms::FuzzerPassApplyIdSynonyms( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassApplyIdSynonyms::~FuzzerPassApplyIdSynonyms() = default; void FuzzerPassApplyIdSynonyms::Apply() { - for (auto id_with_known_synonyms : - GetFactManager()->GetIdsForWhichSynonymsAreKnown(GetIRContext())) { - // Gather up all uses of |id_with_known_synonym|, and then subsequently - // iterate over these uses. We use this separation because, when - // considering a given use, we might apply a transformation that will + // Compute a closure of data synonym facts, to enrich the pool of synonyms + // that are available. + ApplyTransformation(TransformationComputeDataSynonymFactClosure( + GetFuzzerContext() + ->GetMaximumEquivalenceClassSizeForDataSynonymFactClosure())); + + for (auto id_with_known_synonyms : GetTransformationContext() + ->GetFactManager() + ->GetIdsForWhichSynonymsAreKnown()) { + // Gather up all uses of |id_with_known_synonym| as a regular id, and + // subsequently iterate over these uses. We use this separation because, + // when considering a given use, we might apply a transformation that will // invalidate the def-use manager. std::vector<std::pair<opt::Instruction*, uint32_t>> uses; GetIRContext()->get_def_use_mgr()->ForEachUse( id_with_known_synonyms, [&uses](opt::Instruction* use_inst, uint32_t use_index) -> void { - uses.emplace_back( - std::pair<opt::Instruction*, uint32_t>(use_inst, use_index)); + // We only gather up regular id uses; e.g. we do not include a use of + // the id as the scope for an atomic operation. + if (use_inst->GetOperand(use_index).type == SPV_OPERAND_TYPE_ID) { + uses.emplace_back( + std::pair<opt::Instruction*, uint32_t>(use_inst, use_index)); + } }); for (auto& use : uses) { @@ -70,8 +83,9 @@ } std::vector<const protobufs::DataDescriptor*> synonyms_to_try; - for (auto& data_descriptor : GetFactManager()->GetSynonymsForId( - id_with_known_synonyms, GetIRContext())) { + for (auto& data_descriptor : + GetTransformationContext()->GetFactManager()->GetSynonymsForId( + id_with_known_synonyms)) { protobufs::DataDescriptor descriptor_for_this_id = MakeDataDescriptor(id_with_known_synonyms, {}); if (DataDescriptorEquals()(data_descriptor, &descriptor_for_this_id)) {
diff --git a/source/fuzz/fuzzer_pass_apply_id_synonyms.h b/source/fuzz/fuzzer_pass_apply_id_synonyms.h index 1a0748e..1a9213d 100644 --- a/source/fuzz/fuzzer_pass_apply_id_synonyms.h +++ b/source/fuzz/fuzzer_pass_apply_id_synonyms.h
@@ -27,7 +27,7 @@ class FuzzerPassApplyIdSynonyms : public FuzzerPass { public: FuzzerPassApplyIdSynonyms(opt::IRContext* ir_context, - FactManager* fact_manager, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_construct_composites.cpp b/source/fuzz/fuzzer_pass_construct_composites.cpp index 330b9cf..e78f8ec 100644 --- a/source/fuzz/fuzzer_pass_construct_composites.cpp +++ b/source/fuzz/fuzzer_pass_construct_composites.cpp
@@ -25,10 +25,11 @@ namespace fuzz { FuzzerPassConstructComposites::FuzzerPassConstructComposites( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassConstructComposites::~FuzzerPassConstructComposites() = default; @@ -87,7 +88,7 @@ // for constructing a composite of that type. Otherwise these variables // will remain 0 and null respectively. uint32_t chosen_composite_type = 0; - std::unique_ptr<std::vector<uint32_t>> constructor_arguments = nullptr; + std::vector<uint32_t> constructor_arguments; // Initially, all composite type ids are available for us to try. Keep // trying until we run out of options. @@ -95,35 +96,38 @@ while (!composites_to_try_constructing.empty()) { // Remove a composite type from the composite types left for us to // try. - auto index = - GetFuzzerContext()->RandomIndex(composites_to_try_constructing); auto next_composite_to_try_constructing = - composites_to_try_constructing[index]; - composites_to_try_constructing.erase( - composites_to_try_constructing.begin() + index); + GetFuzzerContext()->RemoveAtRandomIndex( + &composites_to_try_constructing); // Now try to construct a composite of this type, using an appropriate // helper method depending on the kind of composite type. - auto composite_type = GetIRContext()->get_type_mgr()->GetType( + auto composite_type_inst = GetIRContext()->get_def_use_mgr()->GetDef( next_composite_to_try_constructing); - if (auto array_type = composite_type->AsArray()) { - constructor_arguments = TryConstructingArrayComposite( - *array_type, type_id_to_available_instructions); - } else if (auto matrix_type = composite_type->AsMatrix()) { - constructor_arguments = TryConstructingMatrixComposite( - *matrix_type, type_id_to_available_instructions); - } else if (auto struct_type = composite_type->AsStruct()) { - constructor_arguments = TryConstructingStructComposite( - *struct_type, type_id_to_available_instructions); - } else { - auto vector_type = composite_type->AsVector(); - assert(vector_type && - "The space of possible composite types should be covered by " - "the above cases."); - constructor_arguments = TryConstructingVectorComposite( - *vector_type, type_id_to_available_instructions); + switch (composite_type_inst->opcode()) { + case SpvOpTypeArray: + constructor_arguments = FindComponentsToConstructArray( + *composite_type_inst, type_id_to_available_instructions); + break; + case SpvOpTypeMatrix: + constructor_arguments = FindComponentsToConstructMatrix( + *composite_type_inst, type_id_to_available_instructions); + break; + case SpvOpTypeStruct: + constructor_arguments = FindComponentsToConstructStruct( + *composite_type_inst, type_id_to_available_instructions); + break; + case SpvOpTypeVector: + constructor_arguments = FindComponentsToConstructVector( + *composite_type_inst, type_id_to_available_instructions); + break; + default: + assert(false && + "The space of possible composite types should be covered " + "by the above cases."); + break; } - if (constructor_arguments != nullptr) { + if (!constructor_arguments.empty()) { // We succeeded! Note the composite type we finally settled on, and // exit from the loop. chosen_composite_type = next_composite_to_try_constructing; @@ -134,20 +138,15 @@ if (!chosen_composite_type) { // We did not manage to make a composite; return 0 to indicate that no // instructions were added. - assert(constructor_arguments == nullptr); + assert(constructor_arguments.empty()); return; } - assert(constructor_arguments != nullptr); + assert(!constructor_arguments.empty()); // Make and apply a transformation. - TransformationCompositeConstruct transformation( - chosen_composite_type, *constructor_arguments, - instruction_descriptor, GetFuzzerContext()->GetFreshId()); - assert(transformation.IsApplicable(GetIRContext(), *GetFactManager()) && - "This transformation should be applicable by construction."); - transformation.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = - transformation.ToMessage(); + ApplyTransformation(TransformationCompositeConstruct( + chosen_composite_type, constructor_arguments, + instruction_descriptor, GetFuzzerContext()->GetFreshId())); }); } @@ -160,20 +159,15 @@ type_id_to_available_instructions->at(inst->type_id()).push_back(inst); } -std::unique_ptr<std::vector<uint32_t>> -FuzzerPassConstructComposites::TryConstructingArrayComposite( - const opt::analysis::Array& array_type, +std::vector<uint32_t> +FuzzerPassConstructComposites::FindComponentsToConstructArray( + const opt::Instruction& array_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions) { - // At present we assume arrays have a constant size. - assert(array_type.length_info().words.size() == 2); - assert(array_type.length_info().words[0] == - opt::analysis::Array::LengthInfo::kConstant); - - auto result = MakeUnique<std::vector<uint32_t>>(); + assert(array_type_instruction.opcode() == SpvOpTypeArray && + "Precondition: instruction must be an array type."); // Get the element type for the array. - auto element_type_id = - GetIRContext()->get_type_mgr()->GetId(array_type.element_type()); + auto element_type_id = array_type_instruction.GetSingleWordInOperand(0); // Get all instructions at our disposal that compute something of this element // type. @@ -184,26 +178,34 @@ // If there are not any instructions available that compute the element type // of the array then we are not in a position to construct a composite with // this array type. - return nullptr; + return {}; } - for (uint32_t index = 0; index < array_type.length_info().words[1]; index++) { - result->push_back(available_instructions - ->second[GetFuzzerContext()->RandomIndex( - available_instructions->second)] - ->result_id()); + + uint32_t array_length = + GetIRContext() + ->get_def_use_mgr() + ->GetDef(array_type_instruction.GetSingleWordInOperand(1)) + ->GetSingleWordInOperand(0); + + std::vector<uint32_t> result; + for (uint32_t index = 0; index < array_length; index++) { + result.push_back(available_instructions + ->second[GetFuzzerContext()->RandomIndex( + available_instructions->second)] + ->result_id()); } return result; } -std::unique_ptr<std::vector<uint32_t>> -FuzzerPassConstructComposites::TryConstructingMatrixComposite( - const opt::analysis::Matrix& matrix_type, +std::vector<uint32_t> +FuzzerPassConstructComposites::FindComponentsToConstructMatrix( + const opt::Instruction& matrix_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions) { - auto result = MakeUnique<std::vector<uint32_t>>(); + assert(matrix_type_instruction.opcode() == SpvOpTypeMatrix && + "Precondition: instruction must be a matrix type."); // Get the element type for the matrix. - auto element_type_id = - GetIRContext()->get_type_mgr()->GetId(matrix_type.element_type()); + auto element_type_id = matrix_type_instruction.GetSingleWordInOperand(0); // Get all instructions at our disposal that compute something of this element // type. @@ -214,25 +216,32 @@ // If there are not any instructions available that compute the element type // of the matrix then we are not in a position to construct a composite with // this matrix type. - return nullptr; + return {}; } - for (uint32_t index = 0; index < matrix_type.element_count(); index++) { - result->push_back(available_instructions - ->second[GetFuzzerContext()->RandomIndex( - available_instructions->second)] - ->result_id()); + std::vector<uint32_t> result; + for (uint32_t index = 0; + index < matrix_type_instruction.GetSingleWordInOperand(1); index++) { + result.push_back(available_instructions + ->second[GetFuzzerContext()->RandomIndex( + available_instructions->second)] + ->result_id()); } return result; } -std::unique_ptr<std::vector<uint32_t>> -FuzzerPassConstructComposites::TryConstructingStructComposite( - const opt::analysis::Struct& struct_type, +std::vector<uint32_t> +FuzzerPassConstructComposites::FindComponentsToConstructStruct( + const opt::Instruction& struct_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions) { - auto result = MakeUnique<std::vector<uint32_t>>(); + assert(struct_type_instruction.opcode() == SpvOpTypeStruct && + "Precondition: instruction must be a struct type."); + std::vector<uint32_t> result; // Consider the type of each field of the struct. - for (auto element_type : struct_type.element_types()) { - auto element_type_id = GetIRContext()->get_type_mgr()->GetId(element_type); + for (uint32_t in_operand_index = 0; + in_operand_index < struct_type_instruction.NumInOperands(); + in_operand_index++) { + auto element_type_id = + struct_type_instruction.GetSingleWordInOperand(in_operand_index); // Find the instructions at our disposal that compute something of the field // type. auto available_instructions = @@ -240,24 +249,28 @@ if (available_instructions == type_id_to_available_instructions.cend()) { // If there are no such instructions, we cannot construct a composite of // this struct type. - return nullptr; + return {}; } - result->push_back(available_instructions - ->second[GetFuzzerContext()->RandomIndex( - available_instructions->second)] - ->result_id()); + result.push_back(available_instructions + ->second[GetFuzzerContext()->RandomIndex( + available_instructions->second)] + ->result_id()); } return result; } -std::unique_ptr<std::vector<uint32_t>> -FuzzerPassConstructComposites::TryConstructingVectorComposite( - const opt::analysis::Vector& vector_type, +std::vector<uint32_t> +FuzzerPassConstructComposites::FindComponentsToConstructVector( + const opt::Instruction& vector_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions) { + assert(vector_type_instruction.opcode() == SpvOpTypeVector && + "Precondition: instruction must be a vector type."); + // Get details of the type underlying the vector, and the width of the vector, // for convenience. - auto element_type = vector_type.element_type(); - auto element_count = vector_type.element_count(); + auto element_type_id = vector_type_instruction.GetSingleWordInOperand(0); + auto element_type = GetIRContext()->get_type_mgr()->GetType(element_type_id); + auto element_count = vector_type_instruction.GetSingleWordInOperand(1); // Collect a mapping, from type id to width, for scalar/vector types that are // smaller in width than |vector_type|, but that have the same underlying @@ -268,14 +281,12 @@ std::map<uint32_t, uint32_t> smaller_vector_type_id_to_width; // Add the underlying type. This id must exist, in order for |vector_type| to // exist. - auto scalar_type_id = GetIRContext()->get_type_mgr()->GetId(element_type); - smaller_vector_type_id_to_width[scalar_type_id] = 1; + smaller_vector_type_id_to_width[element_type_id] = 1; // Now add every vector type with width at least 2, and less than the width of // |vector_type|. for (uint32_t width = 2; width < element_count; width++) { - opt::analysis::Vector smaller_vector_type(vector_type.element_type(), - width); + opt::analysis::Vector smaller_vector_type(element_type, width); auto smaller_vector_type_id = GetIRContext()->get_type_mgr()->GetId(&smaller_vector_type); // We might find that there is no declared type of this smaller width. @@ -302,12 +313,11 @@ // order at this stage. std::vector<opt::Instruction*> instructions_to_use; - while (vector_slots_used < vector_type.element_count()) { + while (vector_slots_used < element_count) { std::vector<opt::Instruction*> instructions_to_choose_from; for (auto& entry : smaller_vector_type_id_to_width) { if (entry.second > - std::min(vector_type.element_count() - 1, - vector_type.element_count() - vector_slots_used)) { + std::min(element_count - 1, element_count - vector_slots_used)) { continue; } auto available_instructions = @@ -326,7 +336,7 @@ // another manner, so we could opt to retry a few times here, but it is // simpler to just give up on the basis that this will not happen // frequently. - return nullptr; + return {}; } auto instruction_to_use = instructions_to_choose_from[GetFuzzerContext()->RandomIndex( @@ -345,16 +355,16 @@ vector_slots_used += 1; } } - assert(vector_slots_used == vector_type.element_count()); + assert(vector_slots_used == element_count); - auto result = MakeUnique<std::vector<uint32_t>>(); + std::vector<uint32_t> result; std::vector<uint32_t> operands; while (!instructions_to_use.empty()) { auto index = GetFuzzerContext()->RandomIndex(instructions_to_use); - result->push_back(instructions_to_use[index]->result_id()); + result.push_back(instructions_to_use[index]->result_id()); instructions_to_use.erase(instructions_to_use.begin() + index); } - assert(result->size() > 1); + assert(result.size() > 1); return result; }
diff --git a/source/fuzz/fuzzer_pass_construct_composites.h b/source/fuzz/fuzzer_pass_construct_composites.h index 99ef31f..9853fad 100644 --- a/source/fuzz/fuzzer_pass_construct_composites.h +++ b/source/fuzz/fuzzer_pass_construct_composites.h
@@ -27,7 +27,7 @@ class FuzzerPassConstructComposites : public FuzzerPass { public: FuzzerPassConstructComposites( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations); @@ -49,27 +49,28 @@ opt::Instruction* inst, TypeIdToInstructions* type_id_to_available_instructions); + // Requires that |array_type_instruction| has opcode OpTypeArray. // Attempts to find suitable instruction result ids from the values of // |type_id_to_available_instructions| that would allow a composite of type - // |array_type| to be constructed. Returns said ids if they can be found. - // Returns |nullptr| otherwise. - std::unique_ptr<std::vector<uint32_t>> TryConstructingArrayComposite( - const opt::analysis::Array& array_type, + // |array_type_instruction| to be constructed. Returns said ids if they can + // be found and an empty vector otherwise. + std::vector<uint32_t> FindComponentsToConstructArray( + const opt::Instruction& array_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions); - // Similar to TryConstructingArrayComposite, but for matrices. - std::unique_ptr<std::vector<uint32_t>> TryConstructingMatrixComposite( - const opt::analysis::Matrix& matrix_type, + // Similar to FindComponentsToConstructArray, but for matrices. + std::vector<uint32_t> FindComponentsToConstructMatrix( + const opt::Instruction& matrix_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions); - // Similar to TryConstructingArrayComposite, but for structs. - std::unique_ptr<std::vector<uint32_t>> TryConstructingStructComposite( - const opt::analysis::Struct& struct_type, + // Similar to FindComponentsToConstructArray, but for structs. + std::vector<uint32_t> FindComponentsToConstructStruct( + const opt::Instruction& struct_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions); - // Similar to TryConstructingArrayComposite, but for vectors. - std::unique_ptr<std::vector<uint32_t>> TryConstructingVectorComposite( - const opt::analysis::Vector& vector_type, + // Similar to FindComponentsToConstructArray, but for vectors. + std::vector<uint32_t> FindComponentsToConstructVector( + const opt::Instruction& vector_type_instruction, const TypeIdToInstructions& type_id_to_available_instructions); };
diff --git a/source/fuzz/fuzzer_pass_copy_objects.cpp b/source/fuzz/fuzzer_pass_copy_objects.cpp index 588cfb6..f055b59 100644 --- a/source/fuzz/fuzzer_pass_copy_objects.cpp +++ b/source/fuzz/fuzzer_pass_copy_objects.cpp
@@ -21,10 +21,11 @@ namespace fuzz { FuzzerPassCopyObjects::FuzzerPassCopyObjects( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassCopyObjects::~FuzzerPassCopyObjects() = default;
diff --git a/source/fuzz/fuzzer_pass_copy_objects.h b/source/fuzz/fuzzer_pass_copy_objects.h index 5419459..8de382e 100644 --- a/source/fuzz/fuzzer_pass_copy_objects.h +++ b/source/fuzz/fuzzer_pass_copy_objects.h
@@ -23,7 +23,8 @@ // A fuzzer pass for adding adding copies of objects to the module. class FuzzerPassCopyObjects : public FuzzerPass { public: - FuzzerPassCopyObjects(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassCopyObjects(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp index 27d8a6e..b907376 100644 --- a/source/fuzz/fuzzer_pass_donate_modules.cpp +++ b/source/fuzz/fuzzer_pass_donate_modules.cpp
@@ -22,6 +22,7 @@ #include "source/fuzz/instruction_message.h" #include "source/fuzz/transformation_add_constant_boolean.h" #include "source/fuzz/transformation_add_constant_composite.h" +#include "source/fuzz/transformation_add_constant_null.h" #include "source/fuzz/transformation_add_constant_scalar.h" #include "source/fuzz/transformation_add_function.h" #include "source/fuzz/transformation_add_global_undef.h" @@ -40,11 +41,12 @@ namespace fuzz { FuzzerPassDonateModules::FuzzerPassDonateModules( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations, const std::vector<fuzzerutil::ModuleSupplier>& donor_suppliers) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations), + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations), donor_suppliers_(donor_suppliers) {} FuzzerPassDonateModules::~FuzzerPassDonateModules() = default; @@ -62,7 +64,9 @@ std::unique_ptr<opt::IRContext> donor_ir_context = donor_suppliers_.at( GetFuzzerContext()->RandomIndex(donor_suppliers_))(); assert(donor_ir_context != nullptr && "Supplying of donor failed"); - assert(fuzzerutil::IsValid(donor_ir_context.get()) && + assert(fuzzerutil::IsValid( + donor_ir_context.get(), + GetTransformationContext()->GetValidatorOptions()) && "The donor module must be valid"); // Donate the supplied module. // @@ -112,6 +116,7 @@ switch (donor_storage_class) { case SpvStorageClassFunction: case SpvStorageClassPrivate: + case SpvStorageClassWorkgroup: // We leave these alone return donor_storage_class; case SpvStorageClassInput: @@ -119,6 +124,8 @@ case SpvStorageClassUniform: case SpvStorageClassUniformConstant: case SpvStorageClassPushConstant: + case SpvStorageClassImage: + case SpvStorageClassStorageBuffer: // We change these to Private return SpvStorageClassPrivate; default: @@ -162,284 +169,385 @@ std::map<uint32_t, uint32_t>* original_id_to_donated_id) { // Consider every type/global/constant/undef in the module. for (auto& type_or_value : donor_ir_context->module()->types_values()) { - // Each such instruction generates a result id, and as part of donation we - // need to associate the donor's result id with a new result id. That new - // result id will either be the id of some existing instruction, or a fresh - // id. This variable captures it. - uint32_t new_result_id; + HandleTypeOrValue(type_or_value, original_id_to_donated_id); + } +} - // Decide how to handle each kind of instruction on a case-by-case basis. - // - // Because the donor module is required to be valid, when we encounter a - // type comprised of component types (e.g. an aggregate or pointer), we know - // that its component types will have been considered previously, and that - // |original_id_to_donated_id| will already contain an entry for them. - switch (type_or_value.opcode()) { - case SpvOpTypeVoid: { - // Void has to exist already in order for us to have an entry point. - // Get the existing id of void. - opt::analysis::Void void_type; - new_result_id = GetIRContext()->get_type_mgr()->GetId(&void_type); - assert(new_result_id && - "The module being transformed will always have 'void' type " - "declared."); - } break; - case SpvOpTypeBool: { - // Bool cannot be declared multiple times, so use its existing id if - // present, or add a declaration of Bool with a fresh id if not. - opt::analysis::Bool bool_type; - auto bool_type_id = GetIRContext()->get_type_mgr()->GetId(&bool_type); - if (bool_type_id) { - new_result_id = bool_type_id; - } else { - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypeBoolean(new_result_id)); - } - } break; - case SpvOpTypeInt: { - // Int cannot be declared multiple times with the same width and - // signedness, so check whether an existing identical Int type is - // present and use its id if so. Otherwise add a declaration of the - // Int type used by the donor, with a fresh id. - const uint32_t width = type_or_value.GetSingleWordInOperand(0); - const bool is_signed = - static_cast<bool>(type_or_value.GetSingleWordInOperand(1)); - opt::analysis::Integer int_type(width, is_signed); - auto int_type_id = GetIRContext()->get_type_mgr()->GetId(&int_type); - if (int_type_id) { - new_result_id = int_type_id; - } else { - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation( - TransformationAddTypeInt(new_result_id, width, is_signed)); - } - } break; - case SpvOpTypeFloat: { - // Similar to SpvOpTypeInt. - const uint32_t width = type_or_value.GetSingleWordInOperand(0); - opt::analysis::Float float_type(width); - auto float_type_id = GetIRContext()->get_type_mgr()->GetId(&float_type); - if (float_type_id) { - new_result_id = float_type_id; - } else { - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypeFloat(new_result_id, width)); - } - } break; - case SpvOpTypeVector: { - // It is not legal to have two Vector type declarations with identical - // element types and element counts, so check whether an existing - // identical Vector type is present and use its id if so. Otherwise add - // a declaration of the Vector type used by the donor, with a fresh id. +void FuzzerPassDonateModules::HandleTypeOrValue( + const opt::Instruction& type_or_value, + std::map<uint32_t, uint32_t>* original_id_to_donated_id) { + // The type/value instruction generates a result id, and we need to associate + // the donor's result id with a new result id. That new result id will either + // be the id of some existing instruction, or a fresh id. This variable + // captures it. + uint32_t new_result_id; - // When considering the vector's component type id, we look up the id - // use in the donor to find the id to which this has been remapped. - uint32_t component_type_id = original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(0)); - auto component_type = - GetIRContext()->get_type_mgr()->GetType(component_type_id); - assert(component_type && "The base type should be registered."); - auto component_count = type_or_value.GetSingleWordInOperand(1); - opt::analysis::Vector vector_type(component_type, component_count); - auto vector_type_id = - GetIRContext()->get_type_mgr()->GetId(&vector_type); - if (vector_type_id) { - new_result_id = vector_type_id; - } else { - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypeVector( - new_result_id, component_type_id, component_count)); - } - } break; - case SpvOpTypeMatrix: { - // Similar to SpvOpTypeVector. - uint32_t column_type_id = original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(0)); - auto column_type = - GetIRContext()->get_type_mgr()->GetType(column_type_id); - assert(column_type && column_type->AsVector() && - "The column type should be a registered vector type."); - auto column_count = type_or_value.GetSingleWordInOperand(1); - opt::analysis::Matrix matrix_type(column_type, column_count); - auto matrix_type_id = - GetIRContext()->get_type_mgr()->GetId(&matrix_type); - if (matrix_type_id) { - new_result_id = matrix_type_id; - } else { - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypeMatrix( - new_result_id, column_type_id, column_count)); - } - - } break; - case SpvOpTypeArray: { - // It is OK to have multiple structurally identical array types, so - // we go ahead and add a remapped version of the type declared by the - // donor. + // Decide how to handle each kind of instruction on a case-by-case basis. + // + // Because the donor module is required to be valid, when we encounter a + // type comprised of component types (e.g. an aggregate or pointer), we know + // that its component types will have been considered previously, and that + // |original_id_to_donated_id| will already contain an entry for them. + switch (type_or_value.opcode()) { + case SpvOpTypeImage: + case SpvOpTypeSampledImage: + case SpvOpTypeSampler: + // We do not donate types and variables that relate to images and + // samplers, so we skip these types and subsequently skip anything that + // depends on them. + return; + case SpvOpTypeVoid: { + // Void has to exist already in order for us to have an entry point. + // Get the existing id of void. + opt::analysis::Void void_type; + new_result_id = GetIRContext()->get_type_mgr()->GetId(&void_type); + assert(new_result_id && + "The module being transformed will always have 'void' type " + "declared."); + } break; + case SpvOpTypeBool: { + // Bool cannot be declared multiple times, so use its existing id if + // present, or add a declaration of Bool with a fresh id if not. + opt::analysis::Bool bool_type; + auto bool_type_id = GetIRContext()->get_type_mgr()->GetId(&bool_type); + if (bool_type_id) { + new_result_id = bool_type_id; + } else { new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypeArray( + ApplyTransformation(TransformationAddTypeBoolean(new_result_id)); + } + } break; + case SpvOpTypeInt: { + // Int cannot be declared multiple times with the same width and + // signedness, so check whether an existing identical Int type is + // present and use its id if so. Otherwise add a declaration of the + // Int type used by the donor, with a fresh id. + const uint32_t width = type_or_value.GetSingleWordInOperand(0); + const bool is_signed = + static_cast<bool>(type_or_value.GetSingleWordInOperand(1)); + opt::analysis::Integer int_type(width, is_signed); + auto int_type_id = GetIRContext()->get_type_mgr()->GetId(&int_type); + if (int_type_id) { + new_result_id = int_type_id; + } else { + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation( + TransformationAddTypeInt(new_result_id, width, is_signed)); + } + } break; + case SpvOpTypeFloat: { + // Similar to SpvOpTypeInt. + const uint32_t width = type_or_value.GetSingleWordInOperand(0); + opt::analysis::Float float_type(width); + auto float_type_id = GetIRContext()->get_type_mgr()->GetId(&float_type); + if (float_type_id) { + new_result_id = float_type_id; + } else { + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeFloat(new_result_id, width)); + } + } break; + case SpvOpTypeVector: { + // It is not legal to have two Vector type declarations with identical + // element types and element counts, so check whether an existing + // identical Vector type is present and use its id if so. Otherwise add + // a declaration of the Vector type used by the donor, with a fresh id. + + // When considering the vector's component type id, we look up the id + // use in the donor to find the id to which this has been remapped. + uint32_t component_type_id = original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(0)); + auto component_type = + GetIRContext()->get_type_mgr()->GetType(component_type_id); + assert(component_type && "The base type should be registered."); + auto component_count = type_or_value.GetSingleWordInOperand(1); + opt::analysis::Vector vector_type(component_type, component_count); + auto vector_type_id = GetIRContext()->get_type_mgr()->GetId(&vector_type); + if (vector_type_id) { + new_result_id = vector_type_id; + } else { + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeVector( + new_result_id, component_type_id, component_count)); + } + } break; + case SpvOpTypeMatrix: { + // Similar to SpvOpTypeVector. + uint32_t column_type_id = original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(0)); + auto column_type = + GetIRContext()->get_type_mgr()->GetType(column_type_id); + assert(column_type && column_type->AsVector() && + "The column type should be a registered vector type."); + auto column_count = type_or_value.GetSingleWordInOperand(1); + opt::analysis::Matrix matrix_type(column_type, column_count); + auto matrix_type_id = GetIRContext()->get_type_mgr()->GetId(&matrix_type); + if (matrix_type_id) { + new_result_id = matrix_type_id; + } else { + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeMatrix( + new_result_id, column_type_id, column_count)); + } + + } break; + case SpvOpTypeArray: { + // It is OK to have multiple structurally identical array types, so + // we go ahead and add a remapped version of the type declared by the + // donor. + uint32_t component_type_id = type_or_value.GetSingleWordInOperand(0); + if (!original_id_to_donated_id->count(component_type_id)) { + // We did not donate the component type of this array type, so we + // cannot donate the array type. + return; + } + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeArray( + new_result_id, original_id_to_donated_id->at(component_type_id), + original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(1)))); + } break; + case SpvOpTypeRuntimeArray: { + // A runtime array is allowed as the final member of an SSBO. During + // donation we turn runtime arrays into fixed-size arrays. For dead + // code donations this is OK because the array is never indexed into at + // runtime, so it does not matter what its size is. For live-safe code, + // all accesses are made in-bounds, so this is also OK. + // + // The special OpArrayLength instruction, which works on runtime arrays, + // is rewritten to yield the fixed length that is used for the array. + + uint32_t component_type_id = type_or_value.GetSingleWordInOperand(0); + if (!original_id_to_donated_id->count(component_type_id)) { + // We did not donate the component type of this runtime array type, so + // we cannot donate it as a fixed-size array. + return; + } + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeArray( + new_result_id, original_id_to_donated_id->at(component_type_id), + FindOrCreate32BitIntegerConstant( + GetFuzzerContext()->GetRandomSizeForNewArray(), false))); + } break; + case SpvOpTypeStruct: { + // Similar to SpvOpTypeArray. + std::vector<uint32_t> member_type_ids; + for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { + auto component_type_id = type_or_value.GetSingleWordInOperand(i); + if (!original_id_to_donated_id->count(component_type_id)) { + // We did not donate every member type for this struct type, so we + // cannot donate the struct type. + return; + } + member_type_ids.push_back( + original_id_to_donated_id->at(component_type_id)); + } + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation( + TransformationAddTypeStruct(new_result_id, member_type_ids)); + } break; + case SpvOpTypePointer: { + // Similar to SpvOpTypeArray. + uint32_t pointee_type_id = type_or_value.GetSingleWordInOperand(1); + if (!original_id_to_donated_id->count(pointee_type_id)) { + // We did not donate the pointee type for this pointer type, so we + // cannot donate the pointer type. + return; + } + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypePointer( + new_result_id, + AdaptStorageClass(static_cast<SpvStorageClass>( + type_or_value.GetSingleWordInOperand(0))), + original_id_to_donated_id->at(pointee_type_id))); + } break; + case SpvOpTypeFunction: { + // It is not OK to have multiple function types that use identical ids + // for their return and parameter types. We thus go through all + // existing function types to look for a match. We do not use the + // type manager here because we want to regard two function types that + // are structurally identical but that differ with respect to the + // actual ids used for pointer types as different. + // + // Example: + // + // %1 = OpTypeVoid + // %2 = OpTypeInt 32 0 + // %3 = OpTypePointer Function %2 + // %4 = OpTypePointer Function %2 + // %5 = OpTypeFunction %1 %3 + // %6 = OpTypeFunction %1 %4 + // + // We regard %5 and %6 as distinct function types here, even though + // they both have the form "uint32* -> void" + + std::vector<uint32_t> return_and_parameter_types; + for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { + uint32_t return_or_parameter_type = + type_or_value.GetSingleWordInOperand(i); + if (!original_id_to_donated_id->count(return_or_parameter_type)) { + // We did not donate every return/parameter type for this function + // type, so we cannot donate the function type. + return; + } + return_and_parameter_types.push_back( + original_id_to_donated_id->at(return_or_parameter_type)); + } + uint32_t existing_function_id = fuzzerutil::FindFunctionType( + GetIRContext(), return_and_parameter_types); + if (existing_function_id) { + new_result_id = existing_function_id; + } else { + // No match was found, so add a remapped version of the function type + // to the module, with a fresh id. + new_result_id = GetFuzzerContext()->GetFreshId(); + std::vector<uint32_t> argument_type_ids; + for (uint32_t i = 1; i < type_or_value.NumInOperands(); i++) { + argument_type_ids.push_back(original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(i))); + } + ApplyTransformation(TransformationAddTypeFunction( new_result_id, original_id_to_donated_id->at( type_or_value.GetSingleWordInOperand(0)), - original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(1)))); - } break; - case SpvOpTypeStruct: { - // Similar to SpvOpTypeArray. - new_result_id = GetFuzzerContext()->GetFreshId(); - std::vector<uint32_t> member_type_ids; - type_or_value.ForEachInId( - [&member_type_ids, - &original_id_to_donated_id](const uint32_t* component_type_id) { - member_type_ids.push_back( - original_id_to_donated_id->at(*component_type_id)); - }); - ApplyTransformation( - TransformationAddTypeStruct(new_result_id, member_type_ids)); - } break; - case SpvOpTypePointer: { - // Similar to SpvOpTypeArray. - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddTypePointer( - new_result_id, - AdaptStorageClass(static_cast<SpvStorageClass>( - type_or_value.GetSingleWordInOperand(0))), - original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(1)))); - } break; - case SpvOpTypeFunction: { - // It is not OK to have multiple function types that use identical ids - // for their return and parameter types. We thus go through all - // existing function types to look for a match. We do not use the - // type manager here because we want to regard two function types that - // are structurally identical but that differ with respect to the - // actual ids used for pointer types as different. - // - // Example: - // - // %1 = OpTypeVoid - // %2 = OpTypeInt 32 0 - // %3 = OpTypePointer Function %2 - // %4 = OpTypePointer Function %2 - // %5 = OpTypeFunction %1 %3 - // %6 = OpTypeFunction %1 %4 - // - // We regard %5 and %6 as distinct function types here, even though - // they both have the form "uint32* -> void" + argument_type_ids)); + } + } break; + case SpvOpConstantTrue: + case SpvOpConstantFalse: { + // It is OK to have duplicate definitions of True and False, so add + // these to the module, using a remapped Bool type. + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddConstantBoolean( + new_result_id, type_or_value.opcode() == SpvOpConstantTrue)); + } break; + case SpvOpConstant: { + // It is OK to have duplicate constant definitions, so add this to the + // module using a remapped result type. + new_result_id = GetFuzzerContext()->GetFreshId(); + std::vector<uint32_t> data_words; + type_or_value.ForEachInOperand([&data_words](const uint32_t* in_operand) { + data_words.push_back(*in_operand); + }); + ApplyTransformation(TransformationAddConstantScalar( + new_result_id, original_id_to_donated_id->at(type_or_value.type_id()), + data_words)); + } break; + case SpvOpConstantComposite: { + assert(original_id_to_donated_id->count(type_or_value.type_id()) && + "Composite types for which it is possible to create a constant " + "should have been donated."); - std::vector<uint32_t> return_and_parameter_types; - for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { - return_and_parameter_types.push_back(original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(i))); - } - uint32_t existing_function_id = fuzzerutil::FindFunctionType( - GetIRContext(), return_and_parameter_types); - if (existing_function_id) { - new_result_id = existing_function_id; - } else { - // No match was found, so add a remapped version of the function type - // to the module, with a fresh id. - new_result_id = GetFuzzerContext()->GetFreshId(); - std::vector<uint32_t> argument_type_ids; - for (uint32_t i = 1; i < type_or_value.NumInOperands(); i++) { - argument_type_ids.push_back(original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(i))); - } - ApplyTransformation(TransformationAddTypeFunction( - new_result_id, - original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(0)), - argument_type_ids)); - } - } break; - case SpvOpConstantTrue: - case SpvOpConstantFalse: { - // It is OK to have duplicate definitions of True and False, so add - // these to the module, using a remapped Bool type. - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddConstantBoolean( - new_result_id, type_or_value.opcode() == SpvOpConstantTrue)); - } break; - case SpvOpConstant: { - // It is OK to have duplicate constant definitions, so add this to the - // module using a remapped result type. - new_result_id = GetFuzzerContext()->GetFreshId(); - std::vector<uint32_t> data_words; - type_or_value.ForEachInOperand( - [&data_words](const uint32_t* in_operand) { - data_words.push_back(*in_operand); - }); - ApplyTransformation(TransformationAddConstantScalar( - new_result_id, - original_id_to_donated_id->at(type_or_value.type_id()), - data_words)); - } break; - case SpvOpConstantComposite: { - // It is OK to have duplicate constant composite definitions, so add - // this to the module using remapped versions of all consituent ids and - // the result type. - new_result_id = GetFuzzerContext()->GetFreshId(); - std::vector<uint32_t> constituent_ids; - type_or_value.ForEachInId( - [&constituent_ids, - &original_id_to_donated_id](const uint32_t* constituent_id) { - constituent_ids.push_back( - original_id_to_donated_id->at(*constituent_id)); - }); - ApplyTransformation(TransformationAddConstantComposite( - new_result_id, - original_id_to_donated_id->at(type_or_value.type_id()), - constituent_ids)); - } break; - case SpvOpVariable: { - // This is a global variable that could have one of various storage - // classes. However, we change all global variable pointer storage - // classes (such as Uniform, Input and Output) to private when donating - // pointer types. Thus this variable's pointer type is guaranteed to - // have storage class private. As a result, we simply add a Private - // storage class global variable, using remapped versions of the result - // type and initializer ids for the global variable in the donor. - // - // We regard the added variable as having an irrelevant value. This - // means that future passes can add stores to the variable in any - // way they wish, and pass them as pointer parameters to functions - // without worrying about whether their data might get modified. - new_result_id = GetFuzzerContext()->GetFreshId(); - uint32_t remapped_pointer_type = - original_id_to_donated_id->at(type_or_value.type_id()); - uint32_t initializer_id; - if (type_or_value.NumInOperands() == 1) { - // The variable did not have an initializer; initialize it to zero. - // This is to limit problems associated with uninitialized data. - initializer_id = FindOrCreateZeroConstant( - fuzzerutil::GetPointeeTypeIdFromPointerType( - GetIRContext(), remapped_pointer_type)); - } else { - // The variable already had an initializer; use its remapped id. - initializer_id = original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(1)); - } - ApplyTransformation(TransformationAddGlobalVariable( - new_result_id, remapped_pointer_type, initializer_id, true)); - } break; - case SpvOpUndef: { - // It is fine to have multiple Undef instructions of the same type, so - // we just add this to the recipient module. - new_result_id = GetFuzzerContext()->GetFreshId(); - ApplyTransformation(TransformationAddGlobalUndef( - new_result_id, - original_id_to_donated_id->at(type_or_value.type_id()))); - } break; - default: { - assert(0 && "Unknown type/value."); - new_result_id = 0; - } break; - } - // Update the id mapping to associate the instruction's result id with its - // corresponding id in the recipient. - original_id_to_donated_id->insert( - {type_or_value.result_id(), new_result_id}); + // It is OK to have duplicate constant composite definitions, so add + // this to the module using remapped versions of all consituent ids and + // the result type. + new_result_id = GetFuzzerContext()->GetFreshId(); + std::vector<uint32_t> constituent_ids; + type_or_value.ForEachInId([&constituent_ids, &original_id_to_donated_id]( + const uint32_t* constituent_id) { + assert(original_id_to_donated_id->count(*constituent_id) && + "The constants used to construct this composite should " + "have been donated."); + constituent_ids.push_back( + original_id_to_donated_id->at(*constituent_id)); + }); + ApplyTransformation(TransformationAddConstantComposite( + new_result_id, original_id_to_donated_id->at(type_or_value.type_id()), + constituent_ids)); + } break; + case SpvOpConstantNull: { + if (!original_id_to_donated_id->count(type_or_value.type_id())) { + // We did not donate the type associated with this null constant, so + // we cannot donate the null constant. + return; + } + + // It is fine to have multiple OpConstantNull instructions of the same + // type, so we just add this to the recipient module. + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddConstantNull( + new_result_id, + original_id_to_donated_id->at(type_or_value.type_id()))); + } break; + case SpvOpVariable: { + if (!original_id_to_donated_id->count(type_or_value.type_id())) { + // We did not donate the pointer type associated with this variable, + // so we cannot donate the variable. + return; + } + + // This is a global variable that could have one of various storage + // classes. However, we change all global variable pointer storage + // classes (such as Uniform, Input and Output) to private when donating + // pointer types, with the exception of the Workgroup storage class. + // + // Thus this variable's pointer type is guaranteed to have storage class + // Private or Workgroup. + // + // We add a global variable with either Private or Workgroup storage + // class, using remapped versions of the result type and initializer ids + // for the global variable in the donor. + // + // We regard the added variable as having an irrelevant value. This + // means that future passes can add stores to the variable in any + // way they wish, and pass them as pointer parameters to functions + // without worrying about whether their data might get modified. + new_result_id = GetFuzzerContext()->GetFreshId(); + uint32_t remapped_pointer_type = + original_id_to_donated_id->at(type_or_value.type_id()); + uint32_t initializer_id; + SpvStorageClass storage_class = + static_cast<SpvStorageClass>(type_or_value.GetSingleWordInOperand( + 0)) == SpvStorageClassWorkgroup + ? SpvStorageClassWorkgroup + : SpvStorageClassPrivate; + if (type_or_value.NumInOperands() == 1) { + // The variable did not have an initializer. Initialize it to zero + // if it has Private storage class (to limit problems associated with + // uninitialized data), and leave it uninitialized if it has Workgroup + // storage class (as Workgroup variables cannot have initializers). + + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3275): we + // could initialize Workgroup variables at the start of an entry + // point, and should do so if their uninitialized nature proves + // problematic. + initializer_id = storage_class == SpvStorageClassWorkgroup + ? 0 + : FindOrCreateZeroConstant( + fuzzerutil::GetPointeeTypeIdFromPointerType( + GetIRContext(), remapped_pointer_type)); + } else { + // The variable already had an initializer; use its remapped id. + initializer_id = original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(1)); + } + ApplyTransformation( + TransformationAddGlobalVariable(new_result_id, remapped_pointer_type, + storage_class, initializer_id, true)); + } break; + case SpvOpUndef: { + if (!original_id_to_donated_id->count(type_or_value.type_id())) { + // We did not donate the type associated with this undef, so we cannot + // donate the undef. + return; + } + + // It is fine to have multiple Undef instructions of the same type, so + // we just add this to the recipient module. + new_result_id = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddGlobalUndef( + new_result_id, + original_id_to_donated_id->at(type_or_value.type_id()))); + } break; + default: { + assert(0 && "Unknown type/value."); + new_result_id = 0; + } break; } + + // Update the id mapping to associate the instruction's result id with its + // corresponding id in the recipient. + original_id_to_donated_id->insert({type_or_value.result_id(), new_result_id}); } void FuzzerPassDonateModules::HandleFunctions( @@ -468,220 +576,49 @@ } assert(function_to_donate && "Function to be donated was not found."); + if (!original_id_to_donated_id->count( + function_to_donate->DefInst().GetSingleWordInOperand(1))) { + // We were not able to donate this function's type, so we cannot donate + // the function. + continue; + } + // We will collect up protobuf messages representing the donor function's // instructions here, and use them to create an AddFunction transformation. std::vector<protobufs::Instruction> donated_instructions; - // Scan through the function, remapping each result id that it generates to - // a fresh id. This is necessary because functions include forward - // references, e.g. to labels. - function_to_donate->ForEachInst([this, &original_id_to_donated_id]( - const opt::Instruction* instruction) { - if (instruction->result_id()) { - original_id_to_donated_id->insert( - {instruction->result_id(), GetFuzzerContext()->GetFreshId()}); - } - }); + // This set tracks the ids of those instructions for which donation was + // completely skipped: neither the instruction nor a substitute for it was + // donated. + std::set<uint32_t> skipped_instructions; // Consider every instruction of the donor function. - function_to_donate->ForEachInst([this, &donated_instructions, - &original_id_to_donated_id]( - const opt::Instruction* instruction) { - // Get the instruction's input operands into donation-ready form, - // remapping any id uses in the process. - opt::Instruction::OperandList input_operands; - - // Consider each input operand in turn. - for (uint32_t in_operand_index = 0; - in_operand_index < instruction->NumInOperands(); - in_operand_index++) { - std::vector<uint32_t> operand_data; - const opt::Operand& in_operand = - instruction->GetInOperand(in_operand_index); - switch (in_operand.type) { - case SPV_OPERAND_TYPE_ID: - case SPV_OPERAND_TYPE_TYPE_ID: - case SPV_OPERAND_TYPE_RESULT_ID: - case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: - case SPV_OPERAND_TYPE_SCOPE_ID: - // This is an id operand - it consists of a single word of data, - // which needs to be remapped so that it is replaced with the - // donated form of the id. - operand_data.push_back( - original_id_to_donated_id->at(in_operand.words[0])); - break; - default: - // For non-id operands, we just add each of the data words. - for (auto word : in_operand.words) { - operand_data.push_back(word); - } - break; - } - input_operands.push_back({in_operand.type, operand_data}); - } - - if (instruction->opcode() == SpvOpVariable && - instruction->NumInOperands() == 1) { - // This is an uninitialized local variable. Initialize it to zero. - input_operands.push_back( - {SPV_OPERAND_TYPE_ID, - {FindOrCreateZeroConstant( - fuzzerutil::GetPointeeTypeIdFromPointerType( - GetIRContext(), - original_id_to_donated_id->at(instruction->type_id())))}}); - } - - // Remap the result type and result id (if present) of the - // instruction, and turn it into a protobuf message. - donated_instructions.push_back(MakeInstructionMessage( - instruction->opcode(), - instruction->type_id() - ? original_id_to_donated_id->at(instruction->type_id()) - : 0, - instruction->result_id() - ? original_id_to_donated_id->at(instruction->result_id()) - : 0, - input_operands)); - }); + function_to_donate->ForEachInst( + [this, &donated_instructions, donor_ir_context, + &original_id_to_donated_id, + &skipped_instructions](const opt::Instruction* instruction) { + if (instruction->opcode() == SpvOpArrayLength) { + // We treat OpArrayLength specially. + HandleOpArrayLength(*instruction, original_id_to_donated_id, + &donated_instructions); + } else if (!CanDonateInstruction(donor_ir_context, *instruction, + *original_id_to_donated_id, + skipped_instructions)) { + // This is an instruction that we cannot directly donate. + HandleDifficultInstruction(*instruction, original_id_to_donated_id, + &donated_instructions, + &skipped_instructions); + } else { + PrepareInstructionForDonation(*instruction, donor_ir_context, + original_id_to_donated_id, + &donated_instructions); + } + }); if (make_livesafe) { - // Various types and constants must be in place for a function to be made - // live-safe. Add them if not already present. - FindOrCreateBoolType(); // Needed for comparisons - FindOrCreatePointerTo32BitIntegerType( - false, SpvStorageClassFunction); // Needed for adding loop limiters - FindOrCreate32BitIntegerConstant( - 0, false); // Needed for initializing loop limiters - FindOrCreate32BitIntegerConstant( - 1, false); // Needed for incrementing loop limiters - - // Get a fresh id for the variable that will be used as a loop limiter. - const uint32_t loop_limiter_variable_id = - GetFuzzerContext()->GetFreshId(); - // Choose a random loop limit, and add the required constant to the - // module if not already there. - const uint32_t loop_limit = FindOrCreate32BitIntegerConstant( - GetFuzzerContext()->GetRandomLoopLimit(), false); - - // Consider every loop header in the function to donate, and create a - // structure capturing the ids to be used for manipulating the loop - // limiter each time the loop is iterated. - std::vector<protobufs::LoopLimiterInfo> loop_limiters; - for (auto& block : *function_to_donate) { - if (block.IsLoopHeader()) { - protobufs::LoopLimiterInfo loop_limiter; - // Grab the loop header's id, mapped to its donated value. - loop_limiter.set_loop_header_id( - original_id_to_donated_id->at(block.id())); - // Get fresh ids that will be used to load the loop limiter, increment - // it, compare it with the loop limit, and an id for a new block that - // will contain the loop's original terminator. - loop_limiter.set_load_id(GetFuzzerContext()->GetFreshId()); - loop_limiter.set_increment_id(GetFuzzerContext()->GetFreshId()); - loop_limiter.set_compare_id(GetFuzzerContext()->GetFreshId()); - loop_limiter.set_logical_op_id(GetFuzzerContext()->GetFreshId()); - loop_limiters.emplace_back(loop_limiter); - } - } - - // Consider every access chain in the function to donate, and create a - // structure containing the ids necessary to clamp the access chain - // indices to be in-bounds. - std::vector<protobufs::AccessChainClampingInfo> - access_chain_clamping_info; - for (auto& block : *function_to_donate) { - for (auto& inst : block) { - switch (inst.opcode()) { - case SpvOpAccessChain: - case SpvOpInBoundsAccessChain: { - protobufs::AccessChainClampingInfo clamping_info; - clamping_info.set_access_chain_id( - original_id_to_donated_id->at(inst.result_id())); - - auto base_object = donor_ir_context->get_def_use_mgr()->GetDef( - inst.GetSingleWordInOperand(0)); - assert(base_object && "The base object must exist."); - auto pointer_type = donor_ir_context->get_def_use_mgr()->GetDef( - base_object->type_id()); - assert(pointer_type && - pointer_type->opcode() == SpvOpTypePointer && - "The base object must have pointer type."); - - auto should_be_composite_type = - donor_ir_context->get_def_use_mgr()->GetDef( - pointer_type->GetSingleWordInOperand(1)); - - // Walk the access chain, creating fresh ids to facilitate - // clamping each index. For simplicity we do this for every - // index, even though constant indices will not end up being - // clamped. - for (uint32_t index = 1; index < inst.NumInOperands(); index++) { - auto compare_and_select_ids = - clamping_info.add_compare_and_select_ids(); - compare_and_select_ids->set_first( - GetFuzzerContext()->GetFreshId()); - compare_and_select_ids->set_second( - GetFuzzerContext()->GetFreshId()); - - // Get the bound for the component being indexed into. - uint32_t bound = - TransformationAddFunction::GetBoundForCompositeIndex( - donor_ir_context, *should_be_composite_type); - const uint32_t index_id = inst.GetSingleWordInOperand(index); - auto index_inst = - donor_ir_context->get_def_use_mgr()->GetDef(index_id); - auto index_type_inst = - donor_ir_context->get_def_use_mgr()->GetDef( - index_inst->type_id()); - assert(index_type_inst->opcode() == SpvOpTypeInt); - assert(index_type_inst->GetSingleWordInOperand(0) == 32); - opt::analysis::Integer* index_int_type = - donor_ir_context->get_type_mgr() - ->GetType(index_type_inst->result_id()) - ->AsInteger(); - if (index_inst->opcode() != SpvOpConstant) { - // We will have to clamp this index, so we need a constant - // whose value is one less than the bound, to compare - // against and to use as the clamped value. - FindOrCreate32BitIntegerConstant(bound - 1, - index_int_type->IsSigned()); - } - should_be_composite_type = - TransformationAddFunction::FollowCompositeIndex( - donor_ir_context, *should_be_composite_type, index_id); - } - access_chain_clamping_info.push_back(clamping_info); - break; - } - default: - break; - } - } - } - - // If the function contains OpKill or OpUnreachable instructions, and has - // non-void return type, then we need a value %v to use in order to turn - // these into instructions of the form OpReturn %v. - uint32_t kill_unreachable_return_value_id; - auto function_return_type_inst = - donor_ir_context->get_def_use_mgr()->GetDef( - function_to_donate->type_id()); - if (function_return_type_inst->opcode() == SpvOpTypeVoid) { - // The return type is void, so we don't need a return value. - kill_unreachable_return_value_id = 0; - } else { - // We do need a return value; we use zero. - assert(function_return_type_inst->opcode() != SpvOpTypePointer && - "Function return type must not be a pointer."); - kill_unreachable_return_value_id = - FindOrCreateZeroConstant(original_id_to_donated_id->at( - function_return_type_inst->result_id())); - } - // Add the function in a livesafe manner. - ApplyTransformation(TransformationAddFunction( - donated_instructions, loop_limiter_variable_id, loop_limit, - loop_limiters, kill_unreachable_return_value_id, - access_chain_clamping_info)); + // Make the function livesafe and then add it. + AddLivesafeFunction(*function_to_donate, donor_ir_context, + *original_id_to_donated_id, donated_instructions); } else { // Add the function in a non-livesafe manner. ApplyTransformation(TransformationAddFunction(donated_instructions)); @@ -689,6 +626,133 @@ } } +bool FuzzerPassDonateModules::CanDonateInstruction( + opt::IRContext* donor_ir_context, const opt::Instruction& instruction, + const std::map<uint32_t, uint32_t>& original_id_to_donated_id, + const std::set<uint32_t>& skipped_instructions) const { + if (instruction.type_id() && + !original_id_to_donated_id.count(instruction.type_id())) { + // We could not donate the result type of this instruction, so we cannot + // donate the instruction. + return false; + } + + // Now consider instructions we specifically want to skip because we do not + // yet support them. + switch (instruction.opcode()) { + case SpvOpAtomicLoad: + case SpvOpAtomicStore: + case SpvOpAtomicExchange: + case SpvOpAtomicCompareExchange: + case SpvOpAtomicCompareExchangeWeak: + case SpvOpAtomicIIncrement: + case SpvOpAtomicIDecrement: + case SpvOpAtomicIAdd: + case SpvOpAtomicISub: + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + // We conservatively ignore all atomic instructions at present. + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3276): Consider + // being less conservative here. + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageFetch: + case SpvOpImageGather: + case SpvOpImageDrefGather: + case SpvOpImageRead: + case SpvOpImageWrite: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + case SpvOpImageSparseFetch: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + case SpvOpImageSparseRead: + case SpvOpImageSampleFootprintNV: + case SpvOpImage: + case SpvOpImageQueryFormat: + case SpvOpImageQueryLevels: + case SpvOpImageQueryLod: + case SpvOpImageQueryOrder: + case SpvOpImageQuerySamples: + case SpvOpImageQuerySize: + case SpvOpImageQuerySizeLod: + case SpvOpSampledImage: + // We ignore all instructions related to accessing images, since we do not + // donate images. + return false; + case SpvOpLoad: + switch (donor_ir_context->get_def_use_mgr() + ->GetDef(instruction.type_id()) + ->opcode()) { + case SpvOpTypeImage: + case SpvOpTypeSampledImage: + case SpvOpTypeSampler: + // Again, we ignore instructions that relate to accessing images. + return false; + default: + break; + } + default: + break; + } + + // Examine each id input operand to the instruction. If it turns out that we + // have skipped any of these operands then we cannot donate the instruction. + bool result = true; + instruction.WhileEachInId( + [donor_ir_context, &original_id_to_donated_id, &result, + &skipped_instructions](const uint32_t* in_id) -> bool { + if (!original_id_to_donated_id.count(*in_id)) { + // We do not have a mapped result id for this id operand. That either + // means that it is a forward reference (which is OK), that we skipped + // the instruction that generated it (which is not OK), or that it is + // the id of a function or global value that we did not donate (which + // is not OK). We check for the latter two cases. + if (skipped_instructions.count(*in_id) || + // A function or global value does not have an associated basic + // block. + !donor_ir_context->get_instr_block(*in_id)) { + result = false; + return false; + } + } + return true; + }); + return result; +} + +bool FuzzerPassDonateModules::IsBasicType( + const opt::Instruction& instruction) const { + switch (instruction.opcode()) { + case SpvOpTypeArray: + case SpvOpTypeFloat: + case SpvOpTypeInt: + case SpvOpTypeMatrix: + case SpvOpTypeStruct: + case SpvOpTypeVector: + return true; + default: + return false; + } +} + std::vector<uint32_t> FuzzerPassDonateModules::GetFunctionsInCallGraphTopologicalOrder( opt::IRContext* context) { @@ -735,5 +799,333 @@ return result; } +void FuzzerPassDonateModules::HandleOpArrayLength( + const opt::Instruction& instruction, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions) const { + assert(instruction.opcode() == SpvOpArrayLength && + "Precondition: instruction must be OpArrayLength."); + uint32_t donated_variable_id = + original_id_to_donated_id->at(instruction.GetSingleWordInOperand(0)); + auto donated_variable_instruction = + GetIRContext()->get_def_use_mgr()->GetDef(donated_variable_id); + auto pointer_to_struct_instruction = + GetIRContext()->get_def_use_mgr()->GetDef( + donated_variable_instruction->type_id()); + assert(pointer_to_struct_instruction->opcode() == SpvOpTypePointer && + "Type of variable must be pointer."); + auto donated_struct_type_instruction = + GetIRContext()->get_def_use_mgr()->GetDef( + pointer_to_struct_instruction->GetSingleWordInOperand(1)); + assert(donated_struct_type_instruction->opcode() == SpvOpTypeStruct && + "Pointee type of pointer used by OpArrayLength must be struct."); + assert(donated_struct_type_instruction->NumInOperands() == + instruction.GetSingleWordInOperand(1) + 1 && + "OpArrayLength must refer to the final member of the given " + "struct."); + uint32_t fixed_size_array_type_id = + donated_struct_type_instruction->GetSingleWordInOperand( + donated_struct_type_instruction->NumInOperands() - 1); + auto fixed_size_array_type_instruction = + GetIRContext()->get_def_use_mgr()->GetDef(fixed_size_array_type_id); + assert(fixed_size_array_type_instruction->opcode() == SpvOpTypeArray && + "The donated array type must be fixed-size."); + auto array_size_id = + fixed_size_array_type_instruction->GetSingleWordInOperand(1); + + if (instruction.result_id() && + !original_id_to_donated_id->count(instruction.result_id())) { + original_id_to_donated_id->insert( + {instruction.result_id(), GetFuzzerContext()->GetFreshId()}); + } + + donated_instructions->push_back(MakeInstructionMessage( + SpvOpCopyObject, original_id_to_donated_id->at(instruction.type_id()), + original_id_to_donated_id->at(instruction.result_id()), + opt::Instruction::OperandList({{SPV_OPERAND_TYPE_ID, {array_size_id}}}))); +} + +void FuzzerPassDonateModules::HandleDifficultInstruction( + const opt::Instruction& instruction, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions, + std::set<uint32_t>* skipped_instructions) { + if (!instruction.result_id()) { + // It does not generate a result id, so it can be ignored. + return; + } + if (!original_id_to_donated_id->count(instruction.type_id())) { + // We cannot handle this instruction's result type, so we need to skip it + // all together. + skipped_instructions->insert(instruction.result_id()); + return; + } + + // We now attempt to replace the instruction with an OpCopyObject. + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3278): We could do + // something more refined here - we could check which operands to the + // instruction could not be donated and replace those operands with + // references to other ids (such as constants), so that we still get an + // instruction with the opcode and easy-to-handle operands of the donor + // instruction. + auto remapped_type_id = original_id_to_donated_id->at(instruction.type_id()); + if (!IsBasicType( + *GetIRContext()->get_def_use_mgr()->GetDef(remapped_type_id))) { + // The instruction has a non-basic result type, so we cannot replace it with + // an object copy of a constant. We thus skip it completely. + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3279): We could + // instead look for an available id of the right type and generate an + // OpCopyObject of that id. + skipped_instructions->insert(instruction.result_id()); + return; + } + + // We are going to add an OpCopyObject instruction. Add a mapping for the + // result id of the original instruction if does not already exist (it may + // exist in the case that it has been forward-referenced). + if (!original_id_to_donated_id->count(instruction.result_id())) { + original_id_to_donated_id->insert( + {instruction.result_id(), GetFuzzerContext()->GetFreshId()}); + } + + // We find or add a zero constant to the receiving module for the type in + // question, and add an OpCopyObject instruction that copies this zero. + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3177): + // Using this particular constant is arbitrary, so if we have a + // mechanism for noting that an id use is arbitrary and could be + // fuzzed we should use it here. + auto zero_constant = FindOrCreateZeroConstant(remapped_type_id); + donated_instructions->push_back(MakeInstructionMessage( + SpvOpCopyObject, remapped_type_id, + original_id_to_donated_id->at(instruction.result_id()), + opt::Instruction::OperandList({{SPV_OPERAND_TYPE_ID, {zero_constant}}}))); +} + +void FuzzerPassDonateModules::PrepareInstructionForDonation( + const opt::Instruction& instruction, opt::IRContext* donor_ir_context, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions) { + // Get the instruction's input operands into donation-ready form, + // remapping any id uses in the process. + opt::Instruction::OperandList input_operands; + + // Consider each input operand in turn. + for (uint32_t in_operand_index = 0; + in_operand_index < instruction.NumInOperands(); in_operand_index++) { + std::vector<uint32_t> operand_data; + const opt::Operand& in_operand = instruction.GetInOperand(in_operand_index); + // Check whether this operand is an id. + if (spvIsIdType(in_operand.type)) { + // This is an id operand - it consists of a single word of data, + // which needs to be remapped so that it is replaced with the + // donated form of the id. + auto operand_id = in_operand.words[0]; + if (!original_id_to_donated_id->count(operand_id)) { + // This is a forward reference. We will choose a corresponding + // donor id for the referenced id and update the mapping to + // reflect it. + + // Keep release compilers happy because |donor_ir_context| is only used + // in this assertion. + (void)(donor_ir_context); + assert((donor_ir_context->get_def_use_mgr() + ->GetDef(operand_id) + ->opcode() == SpvOpLabel || + instruction.opcode() == SpvOpPhi) && + "Unsupported forward reference."); + original_id_to_donated_id->insert( + {operand_id, GetFuzzerContext()->GetFreshId()}); + } + operand_data.push_back(original_id_to_donated_id->at(operand_id)); + } else { + // For non-id operands, we just add each of the data words. + for (auto word : in_operand.words) { + operand_data.push_back(word); + } + } + input_operands.push_back({in_operand.type, operand_data}); + } + + if (instruction.opcode() == SpvOpVariable && + instruction.NumInOperands() == 1) { + // This is an uninitialized local variable. Initialize it to zero. + input_operands.push_back( + {SPV_OPERAND_TYPE_ID, + {FindOrCreateZeroConstant(fuzzerutil::GetPointeeTypeIdFromPointerType( + GetIRContext(), + original_id_to_donated_id->at(instruction.type_id())))}}); + } + + if (instruction.result_id() && + !original_id_to_donated_id->count(instruction.result_id())) { + original_id_to_donated_id->insert( + {instruction.result_id(), GetFuzzerContext()->GetFreshId()}); + } + + // Remap the result type and result id (if present) of the + // instruction, and turn it into a protobuf message. + donated_instructions->push_back(MakeInstructionMessage( + instruction.opcode(), + instruction.type_id() + ? original_id_to_donated_id->at(instruction.type_id()) + : 0, + instruction.result_id() + ? original_id_to_donated_id->at(instruction.result_id()) + : 0, + input_operands)); +} + +void FuzzerPassDonateModules::AddLivesafeFunction( + const opt::Function& function_to_donate, opt::IRContext* donor_ir_context, + const std::map<uint32_t, uint32_t>& original_id_to_donated_id, + const std::vector<protobufs::Instruction>& donated_instructions) { + // Various types and constants must be in place for a function to be made + // live-safe. Add them if not already present. + FindOrCreateBoolType(); // Needed for comparisons + FindOrCreatePointerTo32BitIntegerType( + false, SpvStorageClassFunction); // Needed for adding loop limiters + FindOrCreate32BitIntegerConstant( + 0, false); // Needed for initializing loop limiters + FindOrCreate32BitIntegerConstant( + 1, false); // Needed for incrementing loop limiters + + // Get a fresh id for the variable that will be used as a loop limiter. + const uint32_t loop_limiter_variable_id = GetFuzzerContext()->GetFreshId(); + // Choose a random loop limit, and add the required constant to the + // module if not already there. + const uint32_t loop_limit = FindOrCreate32BitIntegerConstant( + GetFuzzerContext()->GetRandomLoopLimit(), false); + + // Consider every loop header in the function to donate, and create a + // structure capturing the ids to be used for manipulating the loop + // limiter each time the loop is iterated. + std::vector<protobufs::LoopLimiterInfo> loop_limiters; + for (auto& block : function_to_donate) { + if (block.IsLoopHeader()) { + protobufs::LoopLimiterInfo loop_limiter; + // Grab the loop header's id, mapped to its donated value. + loop_limiter.set_loop_header_id(original_id_to_donated_id.at(block.id())); + // Get fresh ids that will be used to load the loop limiter, increment + // it, compare it with the loop limit, and an id for a new block that + // will contain the loop's original terminator. + loop_limiter.set_load_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_increment_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_compare_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_logical_op_id(GetFuzzerContext()->GetFreshId()); + loop_limiters.emplace_back(loop_limiter); + } + } + + // Consider every access chain in the function to donate, and create a + // structure containing the ids necessary to clamp the access chain + // indices to be in-bounds. + std::vector<protobufs::AccessChainClampingInfo> access_chain_clamping_info; + for (auto& block : function_to_donate) { + for (auto& inst : block) { + switch (inst.opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + protobufs::AccessChainClampingInfo clamping_info; + clamping_info.set_access_chain_id( + original_id_to_donated_id.at(inst.result_id())); + + auto base_object = donor_ir_context->get_def_use_mgr()->GetDef( + inst.GetSingleWordInOperand(0)); + assert(base_object && "The base object must exist."); + auto pointer_type = donor_ir_context->get_def_use_mgr()->GetDef( + base_object->type_id()); + assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer && + "The base object must have pointer type."); + + auto should_be_composite_type = + donor_ir_context->get_def_use_mgr()->GetDef( + pointer_type->GetSingleWordInOperand(1)); + + // Walk the access chain, creating fresh ids to facilitate + // clamping each index. For simplicity we do this for every + // index, even though constant indices will not end up being + // clamped. + for (uint32_t index = 1; index < inst.NumInOperands(); index++) { + auto compare_and_select_ids = + clamping_info.add_compare_and_select_ids(); + compare_and_select_ids->set_first(GetFuzzerContext()->GetFreshId()); + compare_and_select_ids->set_second( + GetFuzzerContext()->GetFreshId()); + + // Get the bound for the component being indexed into. + uint32_t bound; + if (should_be_composite_type->opcode() == SpvOpTypeRuntimeArray) { + // The donor is indexing into a runtime array. We do not + // donate runtime arrays. Instead, we donate a corresponding + // fixed-size array for every runtime array. We should thus + // find that donor composite type's result id maps to a fixed- + // size array. + auto fixed_size_array_type = + GetIRContext()->get_def_use_mgr()->GetDef( + original_id_to_donated_id.at( + should_be_composite_type->result_id())); + assert(fixed_size_array_type->opcode() == SpvOpTypeArray && + "A runtime array type in the donor should have been " + "replaced by a fixed-sized array in the recipient."); + // The size of this fixed-size array is a suitable bound. + bound = TransformationAddFunction::GetBoundForCompositeIndex( + GetIRContext(), *fixed_size_array_type); + } else { + bound = TransformationAddFunction::GetBoundForCompositeIndex( + donor_ir_context, *should_be_composite_type); + } + const uint32_t index_id = inst.GetSingleWordInOperand(index); + auto index_inst = + donor_ir_context->get_def_use_mgr()->GetDef(index_id); + auto index_type_inst = donor_ir_context->get_def_use_mgr()->GetDef( + index_inst->type_id()); + assert(index_type_inst->opcode() == SpvOpTypeInt); + assert(index_type_inst->GetSingleWordInOperand(0) == 32); + opt::analysis::Integer* index_int_type = + donor_ir_context->get_type_mgr() + ->GetType(index_type_inst->result_id()) + ->AsInteger(); + if (index_inst->opcode() != SpvOpConstant) { + // We will have to clamp this index, so we need a constant + // whose value is one less than the bound, to compare + // against and to use as the clamped value. + FindOrCreate32BitIntegerConstant(bound - 1, + index_int_type->IsSigned()); + } + should_be_composite_type = + TransformationAddFunction::FollowCompositeIndex( + donor_ir_context, *should_be_composite_type, index_id); + } + access_chain_clamping_info.push_back(clamping_info); + break; + } + default: + break; + } + } + } + + // If the function contains OpKill or OpUnreachable instructions, and has + // non-void return type, then we need a value %v to use in order to turn + // these into instructions of the form OpReturn %v. + uint32_t kill_unreachable_return_value_id; + auto function_return_type_inst = + donor_ir_context->get_def_use_mgr()->GetDef(function_to_donate.type_id()); + if (function_return_type_inst->opcode() == SpvOpTypeVoid) { + // The return type is void, so we don't need a return value. + kill_unreachable_return_value_id = 0; + } else { + // We do need a return value; we use zero. + assert(function_return_type_inst->opcode() != SpvOpTypePointer && + "Function return type must not be a pointer."); + kill_unreachable_return_value_id = FindOrCreateZeroConstant( + original_id_to_donated_id.at(function_return_type_inst->result_id())); + } + // Add the function in a livesafe manner. + ApplyTransformation(TransformationAddFunction( + donated_instructions, loop_limiter_variable_id, loop_limit, loop_limiters, + kill_unreachable_return_value_id, access_chain_clamping_info)); +} + } // namespace fuzz } // namespace spvtools
diff --git a/source/fuzz/fuzzer_pass_donate_modules.h b/source/fuzz/fuzzer_pass_donate_modules.h index ef529db..c59ad71 100644 --- a/source/fuzz/fuzzer_pass_donate_modules.h +++ b/source/fuzz/fuzzer_pass_donate_modules.h
@@ -28,7 +28,7 @@ class FuzzerPassDonateModules : public FuzzerPass { public: FuzzerPassDonateModules( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations, const std::vector<fuzzerutil::ModuleSupplier>& donor_suppliers); @@ -66,6 +66,11 @@ opt::IRContext* donor_ir_context, std::map<uint32_t, uint32_t>* original_id_to_donated_id); + // Helper method for HandleTypesAndValues, to handle a single type/value. + void HandleTypeOrValue( + const opt::Instruction& type_or_value, + std::map<uint32_t, uint32_t>* original_id_to_donated_id); + // Assumes that |donor_ir_context| does not exhibit recursion. Considers the // functions in |donor_ir_context|'s call graph in a reverse-topologically- // sorted order (leaves-to-root), adding each function to the recipient @@ -77,6 +82,68 @@ std::map<uint32_t, uint32_t>* original_id_to_donated_id, bool make_livesafe); + // During donation we will have to ignore some instructions, e.g. because they + // use opcodes that we cannot support or because they reference the ids of + // instructions that have not been donated. This function encapsulates the + // logic for deciding which whether instruction |instruction| from + // |donor_ir_context| can be donated. + bool CanDonateInstruction( + opt::IRContext* donor_ir_context, const opt::Instruction& instruction, + const std::map<uint32_t, uint32_t>& original_id_to_donated_id, + const std::set<uint32_t>& skipped_instructions) const; + + // We treat the OpArrayLength instruction specially. In the donor shader this + // instruction yields the length of a runtime array that is the final member + // of a struct. During donation, we will have converted the runtime array + // type, and the associated struct field, into a fixed-size array. + // + // Instead of donating this instruction, we turn it into an OpCopyObject + // instruction that copies the size of the fixed-size array. + void HandleOpArrayLength( + const opt::Instruction& instruction, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions) const; + + // The instruction |instruction| is required to be an instruction that cannot + // be easily donated, either because it uses an unsupported opcode, has an + // unsupported result type, or uses id operands that could not be donated. + // + // If |instruction| generates a result id, the function attempts to add a + // substitute for |instruction| to |donated_instructions| that has the correct + // result type. If this cannot be done, the instruction's result id is added + // to |skipped_instructions|. The mapping from donor ids to recipient ids is + // managed by |original_id_to_donated_id|. + void HandleDifficultInstruction( + const opt::Instruction& instruction, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions, + std::set<uint32_t>* skipped_instructions); + + // Adds an instruction based in |instruction| to |donated_instructions| in a + // form ready for donation. The original instruction comes from + // |donor_ir_context|, and |original_id_to_donated_id| maps ids from + // |donor_ir_context| to corresponding ids in the recipient module. + void PrepareInstructionForDonation( + const opt::Instruction& instruction, opt::IRContext* donor_ir_context, + std::map<uint32_t, uint32_t>* original_id_to_donated_id, + std::vector<protobufs::Instruction>* donated_instructions); + + // Requires that |donated_instructions| represents a prepared version of the + // instructions of |function_to_donate| (which comes from |donor_ir_context|) + // ready for donation, and |original_id_to_donated_id| maps ids from + // |donor_ir_context| to their corresponding ids in the recipient module. + // + // Adds a livesafe version of the function, based on |donated_instructions|, + // to the recipient module. + void AddLivesafeFunction( + const opt::Function& function_to_donate, opt::IRContext* donor_ir_context, + const std::map<uint32_t, uint32_t>& original_id_to_donated_id, + const std::vector<protobufs::Instruction>& donated_instructions); + + // Returns true if and only if |instruction| is a scalar, vector, matrix, + // array or struct; i.e. it is not an opaque type. + bool IsBasicType(const opt::Instruction& instruction) const; + // Returns the ids of all functions in |context| in a topological order in // relation to the call graph of |context|, which is assumed to be recursion- // free.
diff --git a/source/fuzz/fuzzer_pass_merge_blocks.cpp b/source/fuzz/fuzzer_pass_merge_blocks.cpp index ca1bfb3..49778ae 100644 --- a/source/fuzz/fuzzer_pass_merge_blocks.cpp +++ b/source/fuzz/fuzzer_pass_merge_blocks.cpp
@@ -22,10 +22,11 @@ namespace fuzz { FuzzerPassMergeBlocks::FuzzerPassMergeBlocks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassMergeBlocks::~FuzzerPassMergeBlocks() = default; @@ -44,7 +45,8 @@ // For other blocks, we add a transformation to merge the block into its // predecessor if that transformation would be applicable. TransformationMergeBlocks transformation(block.id()); - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { potential_transformations.push_back(transformation); } } @@ -54,8 +56,9 @@ uint32_t index = GetFuzzerContext()->RandomIndex(potential_transformations); auto transformation = potential_transformations.at(index); potential_transformations.erase(potential_transformations.begin() + index); - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { - transformation.Apply(GetIRContext(), GetFactManager()); + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } }
diff --git a/source/fuzz/fuzzer_pass_merge_blocks.h b/source/fuzz/fuzzer_pass_merge_blocks.h index 457e591..1a6c2c2 100644 --- a/source/fuzz/fuzzer_pass_merge_blocks.h +++ b/source/fuzz/fuzzer_pass_merge_blocks.h
@@ -23,7 +23,8 @@ // A fuzzer pass for merging blocks in the module. class FuzzerPassMergeBlocks : public FuzzerPass { public: - FuzzerPassMergeBlocks(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassMergeBlocks(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_obfuscate_constants.cpp b/source/fuzz/fuzzer_pass_obfuscate_constants.cpp index 2caf0c6..543c0d7 100644 --- a/source/fuzz/fuzzer_pass_obfuscate_constants.cpp +++ b/source/fuzz/fuzzer_pass_obfuscate_constants.cpp
@@ -14,21 +14,25 @@ #include "source/fuzz/fuzzer_pass_obfuscate_constants.h" +#include <algorithm> #include <cmath> +#include "source/fuzz/fuzzer_util.h" #include "source/fuzz/instruction_descriptor.h" #include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h" #include "source/fuzz/transformation_replace_constant_with_uniform.h" +#include "source/fuzz/uniform_buffer_element_descriptor.h" #include "source/opt/ir_context.h" namespace spvtools { namespace fuzz { FuzzerPassObfuscateConstants::FuzzerPassObfuscateConstants( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassObfuscateConstants::~FuzzerPassObfuscateConstants() = default; @@ -83,12 +87,13 @@ bool_constant_use, lhs_id, rhs_id, comparison_opcode, GetFuzzerContext()->GetFreshId()); // The transformation should be applicable by construction. - assert(transformation.IsApplicable(GetIRContext(), *GetFactManager())); + assert( + transformation.IsApplicable(GetIRContext(), *GetTransformationContext())); // Applying this transformation yields a pointer to the new instruction that // computes the result of the binary expression. - auto binary_operator_instruction = - transformation.ApplyWithResult(GetIRContext(), GetFactManager()); + auto binary_operator_instruction = transformation.ApplyWithResult( + GetIRContext(), GetTransformationContext()); // Add this transformation to the sequence of transformations that have been // applied. @@ -238,6 +243,29 @@ first_constant_is_larger); } +std::vector<std::vector<uint32_t>> +FuzzerPassObfuscateConstants::GetConstantWordsFromUniformsForType( + uint32_t type_id) { + assert(type_id && "Type id can't be 0"); + std::vector<std::vector<uint32_t>> result; + + for (const auto& facts_and_types : GetTransformationContext() + ->GetFactManager() + ->GetConstantUniformFactsAndTypes()) { + if (facts_and_types.second != type_id) { + continue; + } + + std::vector<uint32_t> words(facts_and_types.first.constant_word().begin(), + facts_and_types.first.constant_word().end()); + if (std::find(result.begin(), result.end(), words) == result.end()) { + result.push_back(std::move(words)); + } + } + + return result; +} + void FuzzerPassObfuscateConstants::ObfuscateBoolConstant( uint32_t depth, const protobufs::IdUseDescriptor& constant_use) { // We want to replace the boolean constant use with a binary expression over @@ -245,7 +273,9 @@ // with uniforms of the same value. auto available_types_with_uniforms = - GetFactManager()->GetTypesForWhichUniformValuesAreKnown(); + GetTransformationContext() + ->GetFactManager() + ->GetTypesForWhichUniformValuesAreKnown(); if (available_types_with_uniforms.empty()) { // Do not try to obfuscate if we do not have access to any uniform // elements with known values. @@ -254,10 +284,9 @@ auto chosen_type_id = available_types_with_uniforms[GetFuzzerContext()->RandomIndex( available_types_with_uniforms)]; - auto available_constants = - GetFactManager()->GetConstantsAvailableFromUniformsForType( - GetIRContext(), chosen_type_id); - if (available_constants.size() == 1) { + auto available_constant_words = + GetConstantWordsFromUniformsForType(chosen_type_id); + if (available_constant_words.size() == 1) { // TODO(afd): for now we only obfuscate a boolean if there are at least // two constants available from uniforms, so that we can do a // comparison between them. It would be good to be able to do the @@ -266,18 +295,25 @@ return; } + assert(!available_constant_words.empty() && + "There exists a fact but no constants - impossible"); + // We know we have at least two known-to-be-constant uniforms of the chosen // type. Pick one of them at random. - auto constant_index_1 = GetFuzzerContext()->RandomIndex(available_constants); + auto constant_index_1 = + GetFuzzerContext()->RandomIndex(available_constant_words); uint32_t constant_index_2; // Now choose another one distinct from the first one. do { - constant_index_2 = GetFuzzerContext()->RandomIndex(available_constants); + constant_index_2 = + GetFuzzerContext()->RandomIndex(available_constant_words); } while (constant_index_1 == constant_index_2); - auto constant_id_1 = available_constants[constant_index_1]; - auto constant_id_2 = available_constants[constant_index_2]; + auto constant_id_1 = FindOrCreateConstant( + available_constant_words[constant_index_1], chosen_type_id); + auto constant_id_2 = FindOrCreateConstant( + available_constant_words[constant_index_2], chosen_type_id); assert(constant_id_1 != 0 && constant_id_2 != 0 && "We should not find an available constant with an id of 0."); @@ -308,25 +344,50 @@ // Check whether we know that any uniforms are guaranteed to be equal to the // scalar constant associated with |constant_use|. - auto uniform_descriptors = GetFactManager()->GetUniformDescriptorsForConstant( - GetIRContext(), constant_use.id_of_interest()); + auto uniform_descriptors = + GetTransformationContext() + ->GetFactManager() + ->GetUniformDescriptorsForConstant(GetIRContext(), + constant_use.id_of_interest()); if (uniform_descriptors.empty()) { // No relevant uniforms, so do not obfuscate. return; } // Choose a random available uniform known to be equal to the constant. - protobufs::UniformBufferElementDescriptor uniform_descriptor = + const auto& uniform_descriptor = uniform_descriptors[GetFuzzerContext()->RandomIndex(uniform_descriptors)]; + + // Make sure the module has OpConstant instructions for each index used to + // access a uniform. + for (auto index : uniform_descriptor.index()) { + FindOrCreate32BitIntegerConstant(index, true); + } + + // Make sure the module has OpTypePointer that points to the element type of + // the uniform. + const auto* uniform_variable_instr = + FindUniformVariable(uniform_descriptor, GetIRContext(), true); + assert(uniform_variable_instr && + "Uniform variable does not exist or not unique."); + + const auto* uniform_variable_type_intr = + GetIRContext()->get_def_use_mgr()->GetDef( + uniform_variable_instr->type_id()); + assert(uniform_variable_type_intr && "Uniform variable has invalid type"); + + auto element_type_id = fuzzerutil::WalkCompositeTypeIndices( + GetIRContext(), uniform_variable_type_intr->GetSingleWordInOperand(1), + uniform_descriptor.index()); + assert(element_type_id && "Type of uniform variable is invalid"); + + FindOrCreatePointerType(element_type_id, SpvStorageClassUniform); + // Create, apply and record a transformation to replace the constant use with // the result of a load from the chosen uniform. - auto transformation = TransformationReplaceConstantWithUniform( + ApplyTransformation(TransformationReplaceConstantWithUniform( constant_use, uniform_descriptor, GetFuzzerContext()->GetFreshId(), - GetFuzzerContext()->GetFreshId()); - // Transformation should be applicable by construction. - assert(transformation.IsApplicable(GetIRContext(), *GetFactManager())); - transformation.Apply(GetIRContext(), GetFactManager()); - *GetTransformations()->add_transformation() = transformation.ToMessage(); + GetFuzzerContext()->GetFreshId())); } void FuzzerPassObfuscateConstants::ObfuscateConstant(
diff --git a/source/fuzz/fuzzer_pass_obfuscate_constants.h b/source/fuzz/fuzzer_pass_obfuscate_constants.h index f34717b..52d8efe 100644 --- a/source/fuzz/fuzzer_pass_obfuscate_constants.h +++ b/source/fuzz/fuzzer_pass_obfuscate_constants.h
@@ -28,7 +28,7 @@ class FuzzerPassObfuscateConstants : public FuzzerPass { public: FuzzerPassObfuscateConstants( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations); @@ -99,6 +99,11 @@ uint32_t base_instruction_result_id, const std::map<SpvOp, uint32_t>& skipped_opcode_count, std::vector<protobufs::IdUseDescriptor>* constant_uses); + + // Returns a vector of unique words that denote constants. Every such constant + // is used in |FactConstantUniform| and has type with id equal to |type_id|. + std::vector<std::vector<uint32_t>> GetConstantWordsFromUniformsForType( + uint32_t type_id); }; } // namespace fuzz
diff --git a/source/fuzz/fuzzer_pass_outline_functions.cpp b/source/fuzz/fuzzer_pass_outline_functions.cpp index d59c195..1665d05 100644 --- a/source/fuzz/fuzzer_pass_outline_functions.cpp +++ b/source/fuzz/fuzzer_pass_outline_functions.cpp
@@ -23,10 +23,11 @@ namespace fuzz { FuzzerPassOutlineFunctions::FuzzerPassOutlineFunctions( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassOutlineFunctions::~FuzzerPassOutlineFunctions() = default; @@ -88,8 +89,9 @@ /*new_callee_result_id*/ GetFuzzerContext()->GetFreshId(), /*input_id_to_fresh_id*/ std::move(input_id_to_fresh_id), /*output_id_to_fresh_id*/ std::move(output_id_to_fresh_id)); - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { - transformation.Apply(GetIRContext(), GetFactManager()); + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } }
diff --git a/source/fuzz/fuzzer_pass_outline_functions.h b/source/fuzz/fuzzer_pass_outline_functions.h index 5448e7d..6532ed9 100644 --- a/source/fuzz/fuzzer_pass_outline_functions.h +++ b/source/fuzz/fuzzer_pass_outline_functions.h
@@ -25,7 +25,7 @@ class FuzzerPassOutlineFunctions : public FuzzerPass { public: FuzzerPassOutlineFunctions( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_permute_blocks.cpp b/source/fuzz/fuzzer_pass_permute_blocks.cpp index af6d2a5..27a2d67 100644 --- a/source/fuzz/fuzzer_pass_permute_blocks.cpp +++ b/source/fuzz/fuzzer_pass_permute_blocks.cpp
@@ -20,10 +20,11 @@ namespace fuzz { FuzzerPassPermuteBlocks::FuzzerPassPermuteBlocks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassPermuteBlocks::~FuzzerPassPermuteBlocks() = default; @@ -66,8 +67,9 @@ // down indefinitely. while (true) { TransformationMoveBlockDown transformation(*id); - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { - transformation.Apply(GetIRContext(), GetFactManager()); + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } else {
diff --git a/source/fuzz/fuzzer_pass_permute_blocks.h b/source/fuzz/fuzzer_pass_permute_blocks.h index 6735e95..f2d3b39 100644 --- a/source/fuzz/fuzzer_pass_permute_blocks.h +++ b/source/fuzz/fuzzer_pass_permute_blocks.h
@@ -24,7 +24,8 @@ // manner. class FuzzerPassPermuteBlocks : public FuzzerPass { public: - FuzzerPassPermuteBlocks(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassPermuteBlocks(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_permute_function_parameters.cpp b/source/fuzz/fuzzer_pass_permute_function_parameters.cpp index 2c49860..57d9cab 100644 --- a/source/fuzz/fuzzer_pass_permute_function_parameters.cpp +++ b/source/fuzz/fuzzer_pass_permute_function_parameters.cpp
@@ -25,10 +25,11 @@ namespace fuzz { FuzzerPassPermuteFunctionParameters::FuzzerPassPermuteFunctionParameters( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassPermuteFunctionParameters::~FuzzerPassPermuteFunctionParameters() = default;
diff --git a/source/fuzz/fuzzer_pass_permute_function_parameters.h b/source/fuzz/fuzzer_pass_permute_function_parameters.h index bc79804..3f32864 100644 --- a/source/fuzz/fuzzer_pass_permute_function_parameters.h +++ b/source/fuzz/fuzzer_pass_permute_function_parameters.h
@@ -30,7 +30,7 @@ class FuzzerPassPermuteFunctionParameters : public FuzzerPass { public: FuzzerPassPermuteFunctionParameters( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_split_blocks.cpp b/source/fuzz/fuzzer_pass_split_blocks.cpp index 6a2ea4d..15c6790 100644 --- a/source/fuzz/fuzzer_pass_split_blocks.cpp +++ b/source/fuzz/fuzzer_pass_split_blocks.cpp
@@ -23,10 +23,11 @@ namespace fuzz { FuzzerPassSplitBlocks::FuzzerPassSplitBlocks( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassSplitBlocks::~FuzzerPassSplitBlocks() = default; @@ -95,8 +96,9 @@ // If the position we have chosen turns out to be a valid place to split // the block, we apply the split. Otherwise the block just doesn't get // split. - if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { - transformation.Apply(GetIRContext(), GetFactManager()); + if (transformation.IsApplicable(GetIRContext(), + *GetTransformationContext())) { + transformation.Apply(GetIRContext(), GetTransformationContext()); *GetTransformations()->add_transformation() = transformation.ToMessage(); } }
diff --git a/source/fuzz/fuzzer_pass_split_blocks.h b/source/fuzz/fuzzer_pass_split_blocks.h index 6e56dde..278ec6d 100644 --- a/source/fuzz/fuzzer_pass_split_blocks.h +++ b/source/fuzz/fuzzer_pass_split_blocks.h
@@ -24,7 +24,8 @@ // can be very useful for giving other passes a chance to apply. class FuzzerPassSplitBlocks : public FuzzerPass { public: - FuzzerPassSplitBlocks(opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerPassSplitBlocks(opt::IRContext* ir_context, + TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_swap_commutable_operands.cpp b/source/fuzz/fuzzer_pass_swap_commutable_operands.cpp index 4df97c9..321e8ef 100644 --- a/source/fuzz/fuzzer_pass_swap_commutable_operands.cpp +++ b/source/fuzz/fuzzer_pass_swap_commutable_operands.cpp
@@ -22,10 +22,11 @@ namespace fuzz { FuzzerPassSwapCommutableOperands::FuzzerPassSwapCommutableOperands( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassSwapCommutableOperands::~FuzzerPassSwapCommutableOperands() = default;
diff --git a/source/fuzz/fuzzer_pass_swap_commutable_operands.h b/source/fuzz/fuzzer_pass_swap_commutable_operands.h index b0206de..74d937d 100644 --- a/source/fuzz/fuzzer_pass_swap_commutable_operands.h +++ b/source/fuzz/fuzzer_pass_swap_commutable_operands.h
@@ -26,7 +26,7 @@ class FuzzerPassSwapCommutableOperands : public FuzzerPass { public: FuzzerPassSwapCommutableOperands( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.cpp b/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.cpp index 9fb175b..4f26cba 100644 --- a/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.cpp +++ b/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.cpp
@@ -22,10 +22,11 @@ namespace fuzz { FuzzerPassToggleAccessChainInstruction::FuzzerPassToggleAccessChainInstruction( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations) - : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + : FuzzerPass(ir_context, transformation_context, fuzzer_context, + transformations) {} FuzzerPassToggleAccessChainInstruction:: ~FuzzerPassToggleAccessChainInstruction() = default;
diff --git a/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h b/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h index ec8c3f7..d77c7cb 100644 --- a/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h +++ b/source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h
@@ -25,7 +25,7 @@ class FuzzerPassToggleAccessChainInstruction : public FuzzerPass { public: FuzzerPassToggleAccessChainInstruction( - opt::IRContext* ir_context, FactManager* fact_manager, + opt::IRContext* ir_context, TransformationContext* transformation_context, FuzzerContext* fuzzer_context, protobufs::TransformationSequence* transformations);
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp index 4bfa195..f09943f 100644 --- a/source/fuzz/fuzzer_util.cpp +++ b/source/fuzz/fuzzer_util.cpp
@@ -218,6 +218,12 @@ } bool CanMakeSynonymOf(opt::IRContext* ir_context, opt::Instruction* inst) { + if (inst->opcode() == SpvOpSampledImage) { + // The SPIR-V data rules say that only very specific instructions may + // may consume the result id of an OpSampledImage, and this excludes the + // instructions that are used for making synonyms. + return false; + } if (!inst->HasResultId()) { // We can only make a synonym of an instruction that generates an id. return false; @@ -329,11 +335,11 @@ return array_length_constant->GetU32(); } -bool IsValid(opt::IRContext* context) { +bool IsValid(opt::IRContext* context, spv_validator_options validator_options) { std::vector<uint32_t> binary; context->module()->ToBinary(&binary, false); SpirvTools tools(context->grammar().target_env()); - return tools.Validate(binary); + return tools.Validate(binary.data(), binary.size(), validator_options); } std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) { @@ -537,6 +543,13 @@ return 0; } +bool IsNullConstantSupported(const opt::analysis::Type& type) { + return type.AsBool() || type.AsInteger() || type.AsFloat() || + type.AsMatrix() || type.AsVector() || type.AsArray() || + type.AsStruct() || type.AsPointer() || type.AsEvent() || + type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue(); +} + } // namespace fuzzerutil } // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h index 7be0d59..bccd1d0 100644 --- a/source/fuzz/fuzzer_util.h +++ b/source/fuzz/fuzzer_util.h
@@ -18,6 +18,7 @@ #include <vector> #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/basic_block.h" #include "source/opt/instruction.h" #include "source/opt/ir_context.h" @@ -132,8 +133,9 @@ uint32_t GetArraySize(const opt::Instruction& array_type_instruction, opt::IRContext* context); -// Returns true if and only if |context| is valid, according to the validator. -bool IsValid(opt::IRContext* context); +// Returns true if and only if |context| is valid, according to the validator +// instantiated with |validator_options|. +bool IsValid(opt::IRContext* context, spv_validator_options validator_options); // Returns a clone of |context|, by writing |context| to a binary and then // parsing it again. @@ -209,6 +211,10 @@ uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id, SpvStorageClass storage_class); +// Returns true if and only if |type| is one of the types for which it is legal +// to have an OpConstantNull value. +bool IsNullConstantSupported(const opt::analysis::Type& type); + } // namespace fuzzerutil } // namespace fuzz
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index b816e3b..775b2ad 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -372,6 +372,9 @@ TransformationSwapCommutableOperands swap_commutable_operands = 41; TransformationPermuteFunctionParameters permute_function_parameters = 42; TransformationToggleAccessChainInstruction toggle_access_chain_instruction = 43; + TransformationAddConstantNull add_constant_null = 44; + TransformationComputeDataSynonymFactClosure compute_data_synonym_fact_closure = 45; + TransformationAdjustBranchWeights adjust_branch_weights = 46; // Add additional option using the next available number. } } @@ -422,6 +425,18 @@ } +message TransformationAddConstantNull { + + // Adds a null constant. + + // Id for the constant + uint32 fresh_id = 1; + + // Type of the constant + uint32 type_id = 2; + +} + message TransformationAddConstantScalar { // Adds a constant of the given scalar type. @@ -547,8 +562,9 @@ message TransformationAddGlobalVariable { - // Adds a global variable of the given type to the module, with Private - // storage class and optionally with an initializer. + // Adds a global variable of the given type to the module, with Private or + // Workgroup storage class, and optionally (for the Private case) with an + // initializer. // Fresh id for the global variable uint32 fresh_id = 1; @@ -556,13 +572,15 @@ // The type of the global variable uint32 type_id = 2; + uint32 storage_class = 3; + // Initial value of the variable - uint32 initializer_id = 3; + uint32 initializer_id = 4; // True if and only if the behaviour of the module should not depend on the // value of the variable, in which case stores to the variable can be // performed in an arbitrary fashion. - bool value_is_irrelevant = 4; + bool value_is_irrelevant = 5; } @@ -725,6 +743,19 @@ } +message TransformationAdjustBranchWeights { + + // A transformation that adjusts the branch weights + // of a branch conditional instruction. + + // A descriptor for a branch conditional instruction. + InstructionDescriptor instruction_descriptor = 1; + + // Branch weights of a branch conditional instruction. + UInt32Pair branch_weights = 2; + +} + message TransformationCompositeConstruct { // A transformation that introduces an OpCompositeConstruct instruction to @@ -765,6 +796,19 @@ } +message TransformationComputeDataSynonymFactClosure { + + // A transformation that impacts the fact manager only, forcing a computation + // of the closure of data synonym facts, so that e.g. if the components of + // vectors v and w are known to be pairwise synonymous, it is deduced that v + // and w are themselves synonymous. + + // When searching equivalence classes for implied facts, equivalence classes + // larger than this size will be skipped. + uint32 maximum_equivalence_class_size = 1; + +} + message TransformationCopyObject { // A transformation that introduces an OpCopyObject instruction to make a
diff --git a/source/fuzz/replayer.cpp b/source/fuzz/replayer.cpp index 398ce59..6312cba 100644 --- a/source/fuzz/replayer.cpp +++ b/source/fuzz/replayer.cpp
@@ -26,6 +26,7 @@ #include "source/fuzz/transformation_add_type_float.h" #include "source/fuzz/transformation_add_type_int.h" #include "source/fuzz/transformation_add_type_pointer.h" +#include "source/fuzz/transformation_context.h" #include "source/fuzz/transformation_move_block_down.h" #include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h" #include "source/fuzz/transformation_replace_constant_with_uniform.h" @@ -37,18 +38,22 @@ namespace fuzz { struct Replayer::Impl { - explicit Impl(spv_target_env env, bool validate) - : target_env(env), validate_during_replay(validate) {} + Impl(spv_target_env env, bool validate, spv_validator_options options) + : target_env(env), + validate_during_replay(validate), + validator_options(options) {} - const spv_target_env target_env; // Target environment. - MessageConsumer consumer; // Message consumer. - + const spv_target_env target_env; // Target environment. + MessageConsumer consumer; // Message consumer. const bool validate_during_replay; // Controls whether the validator should // be run after every replay step. + spv_validator_options validator_options; // Options to control + // validation }; -Replayer::Replayer(spv_target_env env, bool validate_during_replay) - : impl_(MakeUnique<Impl>(env, validate_during_replay)) {} +Replayer::Replayer(spv_target_env env, bool validate_during_replay, + spv_validator_options validator_options) + : impl_(MakeUnique<Impl>(env, validate_during_replay, validator_options)) {} Replayer::~Replayer() = default; @@ -74,7 +79,8 @@ } // Initial binary should be valid. - if (!tools.Validate(&binary_in[0], binary_in.size())) { + if (!tools.Validate(&binary_in[0], binary_in.size(), + impl_->validator_options)) { impl_->consumer(SPV_MSG_INFO, nullptr, {}, "Initial binary is invalid; stopping."); return Replayer::ReplayerResultStatus::kInitialBinaryInvalid; @@ -94,16 +100,19 @@ FactManager fact_manager; fact_manager.AddFacts(impl_->consumer, initial_facts, ir_context.get()); + TransformationContext transformation_context(&fact_manager, + impl_->validator_options); // Consider the transformation proto messages in turn. for (auto& message : transformation_sequence_in.transformation()) { auto transformation = Transformation::FromMessage(message); // Check whether the transformation can be applied. - if (transformation->IsApplicable(ir_context.get(), fact_manager)) { + if (transformation->IsApplicable(ir_context.get(), + transformation_context)) { // The transformation is applicable, so apply it, and copy it to the // sequence of transformations that were applied. - transformation->Apply(ir_context.get(), &fact_manager); + transformation->Apply(ir_context.get(), &transformation_context); *transformation_sequence_out->add_transformation() = message; if (impl_->validate_during_replay) { @@ -111,8 +120,8 @@ ir_context->module()->ToBinary(&binary_to_validate, false); // Check whether the latest transformation led to a valid binary. - if (!tools.Validate(&binary_to_validate[0], - binary_to_validate.size())) { + if (!tools.Validate(&binary_to_validate[0], binary_to_validate.size(), + impl_->validator_options)) { impl_->consumer(SPV_MSG_INFO, nullptr, {}, "Binary became invalid during replay (set a " "breakpoint to inspect); stopping.");
diff --git a/source/fuzz/replayer.h b/source/fuzz/replayer.h index 1d58bae..e77d840 100644 --- a/source/fuzz/replayer.h +++ b/source/fuzz/replayer.h
@@ -37,7 +37,8 @@ }; // Constructs a replayer from the given target environment. - explicit Replayer(spv_target_env env, bool validate_during_replay); + Replayer(spv_target_env env, bool validate_during_replay, + spv_validator_options validator_options); // Disables copy/move constructor/assignment operations. Replayer(const Replayer&) = delete;
diff --git a/source/fuzz/shrinker.cpp b/source/fuzz/shrinker.cpp index 1bb92f1..002e8a7 100644 --- a/source/fuzz/shrinker.cpp +++ b/source/fuzz/shrinker.cpp
@@ -60,20 +60,27 @@ } // namespace struct Shrinker::Impl { - explicit Impl(spv_target_env env, uint32_t limit, bool validate) - : target_env(env), step_limit(limit), validate_during_replay(validate) {} + Impl(spv_target_env env, uint32_t limit, bool validate, + spv_validator_options options) + : target_env(env), + step_limit(limit), + validate_during_replay(validate), + validator_options(options) {} - const spv_target_env target_env; // Target environment. - MessageConsumer consumer; // Message consumer. - const uint32_t step_limit; // Step limit for reductions. - const bool validate_during_replay; // Determines whether to check for - // validity during the replaying of - // transformations. + const spv_target_env target_env; // Target environment. + MessageConsumer consumer; // Message consumer. + const uint32_t step_limit; // Step limit for reductions. + const bool validate_during_replay; // Determines whether to check for + // validity during the replaying of + // transformations. + spv_validator_options validator_options; // Options to control validation. }; Shrinker::Shrinker(spv_target_env env, uint32_t step_limit, - bool validate_during_replay) - : impl_(MakeUnique<Impl>(env, step_limit, validate_during_replay)) {} + bool validate_during_replay, + spv_validator_options validator_options) + : impl_(MakeUnique<Impl>(env, step_limit, validate_during_replay, + validator_options)) {} Shrinker::~Shrinker() = default; @@ -100,7 +107,8 @@ } // Initial binary should be valid. - if (!tools.Validate(&binary_in[0], binary_in.size())) { + if (!tools.Validate(&binary_in[0], binary_in.size(), + impl_->validator_options)) { impl_->consumer(SPV_MSG_INFO, nullptr, {}, "Initial binary is invalid; stopping."); return Shrinker::ShrinkerResultStatus::kInitialBinaryInvalid; @@ -113,7 +121,8 @@ // succeeds, (b) get the binary that results from running these // transformations, and (c) get the subsequence of the initial transformations // that actually apply (in principle this could be a strict subsequence). - if (Replayer(impl_->target_env, impl_->validate_during_replay) + if (Replayer(impl_->target_env, impl_->validate_during_replay, + impl_->validator_options) .Run(binary_in, initial_facts, transformation_sequence_in, ¤t_best_binary, ¤t_best_transformations) != Replayer::ReplayerResultStatus::kComplete) { @@ -184,7 +193,8 @@ // transformations inapplicable. std::vector<uint32_t> next_binary; protobufs::TransformationSequence next_transformation_sequence; - if (Replayer(impl_->target_env, false) + if (Replayer(impl_->target_env, impl_->validate_during_replay, + impl_->validator_options) .Run(binary_in, initial_facts, transformations_with_chunk_removed, &next_binary, &next_transformation_sequence) != Replayer::ReplayerResultStatus::kComplete) {
diff --git a/source/fuzz/shrinker.h b/source/fuzz/shrinker.h index 0163a53..17b15bf 100644 --- a/source/fuzz/shrinker.h +++ b/source/fuzz/shrinker.h
@@ -50,8 +50,8 @@ const std::vector<uint32_t>& binary, uint32_t counter)>; // Constructs a shrinker from the given target environment. - Shrinker(spv_target_env env, uint32_t step_limit, - bool validate_during_replay); + Shrinker(spv_target_env env, uint32_t step_limit, bool validate_during_replay, + spv_validator_options validator_options); // Disables copy/move constructor/assignment operations. Shrinker(const Shrinker&) = delete;
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp index f18c86b..8b84169 100644 --- a/source/fuzz/transformation.cpp +++ b/source/fuzz/transformation.cpp
@@ -20,6 +20,7 @@ #include "source/fuzz/transformation_access_chain.h" #include "source/fuzz/transformation_add_constant_boolean.h" #include "source/fuzz/transformation_add_constant_composite.h" +#include "source/fuzz/transformation_add_constant_null.h" #include "source/fuzz/transformation_add_constant_scalar.h" #include "source/fuzz/transformation_add_dead_block.h" #include "source/fuzz/transformation_add_dead_break.h" @@ -38,8 +39,10 @@ #include "source/fuzz/transformation_add_type_pointer.h" #include "source/fuzz/transformation_add_type_struct.h" #include "source/fuzz/transformation_add_type_vector.h" +#include "source/fuzz/transformation_adjust_branch_weights.h" #include "source/fuzz/transformation_composite_construct.h" #include "source/fuzz/transformation_composite_extract.h" +#include "source/fuzz/transformation_compute_data_synonym_fact_closure.h" #include "source/fuzz/transformation_copy_object.h" #include "source/fuzz/transformation_equation_instruction.h" #include "source/fuzz/transformation_function_call.h" @@ -78,6 +81,9 @@ case protobufs::Transformation::TransformationCase::kAddConstantComposite: return MakeUnique<TransformationAddConstantComposite>( message.add_constant_composite()); + case protobufs::Transformation::TransformationCase::kAddConstantNull: + return MakeUnique<TransformationAddConstantNull>( + message.add_constant_null()); case protobufs::Transformation::TransformationCase::kAddConstantScalar: return MakeUnique<TransformationAddConstantScalar>( message.add_constant_scalar()); @@ -124,12 +130,19 @@ return MakeUnique<TransformationAddTypeStruct>(message.add_type_struct()); case protobufs::Transformation::TransformationCase::kAddTypeVector: return MakeUnique<TransformationAddTypeVector>(message.add_type_vector()); + case protobufs::Transformation::TransformationCase::kAdjustBranchWeights: + return MakeUnique<TransformationAdjustBranchWeights>( + message.adjust_branch_weights()); case protobufs::Transformation::TransformationCase::kCompositeConstruct: return MakeUnique<TransformationCompositeConstruct>( message.composite_construct()); case protobufs::Transformation::TransformationCase::kCompositeExtract: return MakeUnique<TransformationCompositeExtract>( message.composite_extract()); + case protobufs::Transformation::TransformationCase:: + kComputeDataSynonymFactClosure: + return MakeUnique<TransformationComputeDataSynonymFactClosure>( + message.compute_data_synonym_fact_closure()); case protobufs::Transformation::TransformationCase::kCopyObject: return MakeUnique<TransformationCopyObject>(message.copy_object()); case protobufs::Transformation::TransformationCase::kEquationInstruction: @@ -195,9 +208,9 @@ } bool Transformation::CheckIdIsFreshAndNotUsedByThisTransformation( - uint32_t id, opt::IRContext* context, + uint32_t id, opt::IRContext* ir_context, std::set<uint32_t>* ids_used_by_this_transformation) { - if (!fuzzerutil::IsFreshId(context, id)) { + if (!fuzzerutil::IsFreshId(ir_context, id)) { return false; } if (ids_used_by_this_transformation->count(id) != 0) {
diff --git a/source/fuzz/transformation.h b/source/fuzz/transformation.h index dbe803f..dbd0fe2 100644 --- a/source/fuzz/transformation.h +++ b/source/fuzz/transformation.h
@@ -17,8 +17,8 @@ #include <memory> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -60,19 +60,22 @@ public: // A precondition that determines whether the transformation can be cleanly // applied in a semantics-preserving manner to the SPIR-V module given by - // |context|, in the presence of facts captured by |fact_manager|. + // |ir_context|, in the presence of facts and other contextual information + // captured by |transformation_context|. + // // Preconditions for individual transformations must be documented in the - // associated header file using precise English. The fact manager is used to - // provide access to facts about the module that are known to be true, on + // associated header file using precise English. The transformation context + // provides access to facts about the module that are known to be true, on // which the precondition may depend. - virtual bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const = 0; + virtual bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const = 0; - // Requires that IsApplicable(context, fact_manager) holds. Applies the - // transformation, mutating |context| and possibly updating |fact_manager| - // with new facts established by the transformation. - virtual void Apply(opt::IRContext* context, - FactManager* fact_manager) const = 0; + // Requires that IsApplicable(ir_context, *transformation_context) holds. + // Applies the transformation, mutating |ir_context| and possibly updating + // |transformation_context| with new facts established by the transformation. + virtual void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const = 0; // Turns the transformation into a protobuf message for serialization. virtual protobufs::Transformation ToMessage() const = 0; @@ -90,7 +93,7 @@ // checking id freshness for a transformation that uses many ids, all of which // must be distinct. static bool CheckIdIsFreshAndNotUsedByThisTransformation( - uint32_t id, opt::IRContext* context, + uint32_t id, opt::IRContext* ir_context, std::set<uint32_t>* ids_used_by_this_transformation); };
diff --git a/source/fuzz/transformation_access_chain.cpp b/source/fuzz/transformation_access_chain.cpp index 8c31006..ff17c36 100644 --- a/source/fuzz/transformation_access_chain.cpp +++ b/source/fuzz/transformation_access_chain.cpp
@@ -40,19 +40,18 @@ } bool TransformationAccessChain::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The pointer id must exist and have a type. - auto pointer = context->get_def_use_mgr()->GetDef(message_.pointer_id()); + auto pointer = ir_context->get_def_use_mgr()->GetDef(message_.pointer_id()); if (!pointer || !pointer->type_id()) { return false; } // The type must indeed be a pointer - auto pointer_type = context->get_def_use_mgr()->GetDef(pointer->type_id()); + auto pointer_type = ir_context->get_def_use_mgr()->GetDef(pointer->type_id()); if (pointer_type->opcode() != SpvOpTypePointer) { return false; } @@ -60,7 +59,7 @@ // The described instruction to insert before must exist and be a suitable // point where an OpAccessChain instruction could be inserted. auto instruction_to_insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!instruction_to_insert_before) { return false; } @@ -86,7 +85,7 @@ // The pointer on which the access chain is to be based needs to be available // (according to dominance rules) at the insertion point. if (!fuzzerutil::IdIsAvailableBeforeInstruction( - context, instruction_to_insert_before, message_.pointer_id())) { + ir_context, instruction_to_insert_before, message_.pointer_id())) { return false; } @@ -104,7 +103,7 @@ // integer. Otherwise, the integer with which the id is associated is the // second component. std::pair<bool, uint32_t> maybe_index_value = - GetIndexValue(context, index_id); + GetIndexValue(ir_context, index_id); if (!maybe_index_value.first) { // There was no integer: this index is no good. return false; @@ -113,7 +112,7 @@ // type is not a composite or the index is out of bounds, and the id of // the next type otherwise. subobject_type_id = fuzzerutil::WalkOneCompositeTypeIndex( - context, subobject_type_id, maybe_index_value.second); + ir_context, subobject_type_id, maybe_index_value.second); if (!subobject_type_id) { // Either the type was not a composite (so that too many indices were // provided), or the index was out of bounds. @@ -128,13 +127,14 @@ // We do not use the type manager to look up this type, due to problems // associated with pointers to isomorphic structs being regarded as the same. return fuzzerutil::MaybeGetPointerType( - context, subobject_type_id, + ir_context, subobject_type_id, static_cast<SpvStorageClass>( pointer_type->GetSingleWordInOperand(0))) != 0; } void TransformationAccessChain::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // The operands to the access chain are the pointer followed by the indices. // The result type of the access chain is determined by where the indices // lead. We thus push the pointer to a sequence of operands, and then follow @@ -148,8 +148,8 @@ operands.push_back({SPV_OPERAND_TYPE_ID, {message_.pointer_id()}}); // Start walking the indices, starting with the pointer's base type. - auto pointer_type = context->get_def_use_mgr()->GetDef( - context->get_def_use_mgr()->GetDef(message_.pointer_id())->type_id()); + auto pointer_type = ir_context->get_def_use_mgr()->GetDef( + ir_context->get_def_use_mgr()->GetDef(message_.pointer_id())->type_id()); uint32_t subobject_type_id = pointer_type->GetSingleWordInOperand(1); // Go through the index ids in turn. @@ -157,33 +157,35 @@ // Add the index id to the operands. operands.push_back({SPV_OPERAND_TYPE_ID, {index_id}}); // Get the integer value associated with the index id. - uint32_t index_value = GetIndexValue(context, index_id).second; + uint32_t index_value = GetIndexValue(ir_context, index_id).second; // Walk to the next type in the composite object using this index. subobject_type_id = fuzzerutil::WalkOneCompositeTypeIndex( - context, subobject_type_id, index_value); + ir_context, subobject_type_id, index_value); } // The access chain's result type is a pointer to the composite component that // was reached after following all indices. The storage class is that of the // original pointer. uint32_t result_type = fuzzerutil::MaybeGetPointerType( - context, subobject_type_id, + ir_context, subobject_type_id, static_cast<SpvStorageClass>(pointer_type->GetSingleWordInOperand(0))); // Add the access chain instruction to the module, and update the module's id // bound. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - FindInstruction(message_.instruction_to_insert_before(), context) - ->InsertBefore( - MakeUnique<opt::Instruction>(context, SpvOpAccessChain, result_type, - message_.fresh_id(), operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + FindInstruction(message_.instruction_to_insert_before(), ir_context) + ->InsertBefore(MakeUnique<opt::Instruction>( + ir_context, SpvOpAccessChain, result_type, message_.fresh_id(), + operands)); // Conservatively invalidate all analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); // If the base pointer's pointee value was irrelevant, the same is true of the // pointee value of the result of this access chain. - if (fact_manager->PointeeValueIsIrrelevant(message_.pointer_id())) { - fact_manager->AddFactValueOfPointeeIsIrrelevant(message_.fresh_id()); + if (transformation_context->GetFactManager()->PointeeValueIsIrrelevant( + message_.pointer_id())) { + transformation_context->GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + message_.fresh_id()); } } @@ -194,8 +196,8 @@ } std::pair<bool, uint32_t> TransformationAccessChain::GetIndexValue( - opt::IRContext* context, uint32_t index_id) const { - auto index_instruction = context->get_def_use_mgr()->GetDef(index_id); + opt::IRContext* ir_context, uint32_t index_id) const { + auto index_instruction = ir_context->get_def_use_mgr()->GetDef(index_id); if (!index_instruction || !spvOpcodeIsConstant(index_instruction->opcode())) { // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3179) We could // allow non-constant indices when looking up non-structs, using clamping @@ -203,7 +205,7 @@ return {false, 0}; } auto index_type = - context->get_def_use_mgr()->GetDef(index_instruction->type_id()); + ir_context->get_def_use_mgr()->GetDef(index_instruction->type_id()); if (index_type->opcode() != SpvOpTypeInt || index_type->GetSingleWordInOperand(0) != 32) { return {false, 0};
diff --git a/source/fuzz/transformation_access_chain.h b/source/fuzz/transformation_access_chain.h index 92d9e6a..9306a59 100644 --- a/source/fuzz/transformation_access_chain.h +++ b/source/fuzz/transformation_access_chain.h
@@ -17,9 +17,9 @@ #include <utility> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -47,8 +47,9 @@ // - If type t is the final type reached by walking these indices, the module // must include an instruction "OpTypePointer SC %t" where SC is the storage // class associated with |message_.pointer_id| - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction of the form: // |message_.fresh_id| = OpAccessChain %ptr |message_.index_id| @@ -57,10 +58,12 @@ // the indices in |message_.index_id|, and with the same storage class as // |message_.pointer_id|. // - // If |fact_manager| reports that |message_.pointer_id| has an irrelevant - // pointee value, then the fact that |message_.fresh_id| (the result of the - // access chain) also has an irrelevant pointee value is also recorded. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + // If the fact manager in |transformation_context| reports that + // |message_.pointer_id| has an irrelevant pointee value, then the fact that + // |message_.fresh_id| (the result of the access chain) also has an irrelevant + // pointee value is also recorded. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -68,7 +71,7 @@ // Returns {false, 0} if |index_id| does not correspond to a 32-bit integer // constant. Otherwise, returns {true, value}, where value is the value of // the 32-bit integer constant to which |index_id| corresponds. - std::pair<bool, uint32_t> GetIndexValue(opt::IRContext* context, + std::pair<bool, uint32_t> GetIndexValue(opt::IRContext* ir_context, uint32_t index_id) const; protobufs::TransformationAccessChain message_;
diff --git a/source/fuzz/transformation_add_constant_boolean.cpp b/source/fuzz/transformation_add_constant_boolean.cpp index 21c8ed3..1930f7e 100644 --- a/source/fuzz/transformation_add_constant_boolean.cpp +++ b/source/fuzz/transformation_add_constant_boolean.cpp
@@ -31,27 +31,28 @@ } bool TransformationAddConstantBoolean::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { opt::analysis::Bool bool_type; - if (!context->get_type_mgr()->GetId(&bool_type)) { + if (!ir_context->get_type_mgr()->GetId(&bool_type)) { // No OpTypeBool is present. return false; } - return fuzzerutil::IsFreshId(context, message_.fresh_id()); + return fuzzerutil::IsFreshId(ir_context, message_.fresh_id()); } -void TransformationAddConstantBoolean::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { +void TransformationAddConstantBoolean::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::analysis::Bool bool_type; // Add the boolean constant to the module, ensuring the module's id bound is // high enough. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - context->module()->AddGlobalValue( + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + ir_context->module()->AddGlobalValue( message_.is_true() ? SpvOpConstantTrue : SpvOpConstantFalse, - message_.fresh_id(), context->get_type_mgr()->GetId(&bool_type)); + message_.fresh_id(), ir_context->get_type_mgr()->GetId(&bool_type)); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddConstantBoolean::ToMessage() const {
diff --git a/source/fuzz/transformation_add_constant_boolean.h b/source/fuzz/transformation_add_constant_boolean.h index 79df1cd..5d876cf 100644 --- a/source/fuzz/transformation_add_constant_boolean.h +++ b/source/fuzz/transformation_add_constant_boolean.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_BOOLEAN_CONSTANT_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_BOOLEAN_CONSTANT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -32,12 +32,14 @@ // - |message_.fresh_id| must not be used by the module. // - The module must already contain OpTypeBool. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - Adds OpConstantTrue (OpConstantFalse) to the module with id // |message_.fresh_id| if |message_.is_true| holds (does not hold). - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_constant_composite.cpp b/source/fuzz/transformation_add_constant_composite.cpp index 7ba1ea4..ae34b26 100644 --- a/source/fuzz/transformation_add_constant_composite.cpp +++ b/source/fuzz/transformation_add_constant_composite.cpp
@@ -37,15 +37,14 @@ } bool TransformationAddConstantComposite::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // Check that the given id is fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // Check that the composite type id is an instruction id. auto composite_type_instruction = - context->get_def_use_mgr()->GetDef(message_.type_id()); + ir_context->get_def_use_mgr()->GetDef(message_.type_id()); if (!composite_type_instruction) { return false; } @@ -56,7 +55,7 @@ case SpvOpTypeArray: for (uint32_t index = 0; index < - fuzzerutil::GetArraySize(*composite_type_instruction, context); + fuzzerutil::GetArraySize(*composite_type_instruction, ir_context); index++) { constituent_type_ids.push_back( composite_type_instruction->GetSingleWordInOperand(0)); @@ -93,7 +92,7 @@ // corresponding constituent type. for (uint32_t index = 0; index < constituent_type_ids.size(); index++) { auto constituent_instruction = - context->get_def_use_mgr()->GetDef(message_.constituent_id(index)); + ir_context->get_def_use_mgr()->GetDef(message_.constituent_id(index)); if (!constituent_instruction) { return false; } @@ -105,18 +104,19 @@ } void TransformationAddConstantComposite::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; for (auto constituent_id : message_.constituent_id()) { in_operands.push_back({SPV_OPERAND_TYPE_ID, {constituent_id}}); } - context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( - context, SpvOpConstantComposite, message_.type_id(), message_.fresh_id(), - in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( + ir_context, SpvOpConstantComposite, message_.type_id(), + message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddConstantComposite::ToMessage()
diff --git a/source/fuzz/transformation_add_constant_composite.h b/source/fuzz/transformation_add_constant_composite.h index 9a824a0..4fec561 100644 --- a/source/fuzz/transformation_add_constant_composite.h +++ b/source/fuzz/transformation_add_constant_composite.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -38,13 +38,15 @@ // - |message_.type_id| must be the id of a composite type // - |message_.constituent_id| must refer to ids that match the constituent // types of this composite type - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpConstantComposite instruction defining a constant of type // |message_.type_id|, using |message_.constituent_id| as constituents, with // result id |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_constant_null.cpp b/source/fuzz/transformation_add_constant_null.cpp new file mode 100644 index 0000000..dedbc21 --- /dev/null +++ b/source/fuzz/transformation_add_constant_null.cpp
@@ -0,0 +1,66 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_add_constant_null.h" + +#include "source/fuzz/fuzzer_util.h" + +namespace spvtools { +namespace fuzz { + +TransformationAddConstantNull::TransformationAddConstantNull( + const spvtools::fuzz::protobufs::TransformationAddConstantNull& message) + : message_(message) {} + +TransformationAddConstantNull::TransformationAddConstantNull(uint32_t fresh_id, + uint32_t type_id) { + message_.set_fresh_id(fresh_id); + message_.set_type_id(type_id); +} + +bool TransformationAddConstantNull::IsApplicable( + opt::IRContext* context, const TransformationContext& /*unused*/) const { + // A fresh id is required. + if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + return false; + } + auto type = context->get_type_mgr()->GetType(message_.type_id()); + // The type must exist. + if (!type) { + return false; + } + // The type must be one of the types for which null constants are allowed, + // according to the SPIR-V spec. + return fuzzerutil::IsNullConstantSupported(*type); +} + +void TransformationAddConstantNull::Apply( + opt::IRContext* context, TransformationContext* /*unused*/) const { + context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( + context, SpvOpConstantNull, message_.type_id(), message_.fresh_id(), + opt::Instruction::OperandList())); + fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + // We have added an instruction to the module, so need to be careful about the + // validity of existing analyses. + context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); +} + +protobufs::Transformation TransformationAddConstantNull::ToMessage() const { + protobufs::Transformation result; + *result.mutable_add_constant_null() = message_; + return result; +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/transformation_add_constant_null.h b/source/fuzz/transformation_add_constant_null.h new file mode 100644 index 0000000..590fc0d --- /dev/null +++ b/source/fuzz/transformation_add_constant_null.h
@@ -0,0 +1,54 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_ +#define SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_ + +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +class TransformationAddConstantNull : public Transformation { + public: + explicit TransformationAddConstantNull( + const protobufs::TransformationAddConstantNull& message); + + TransformationAddConstantNull(uint32_t fresh_id, uint32_t type_id); + + // - |message_.fresh_id| must be fresh + // - |message_.type_id| must be the id of a type for which it is acceptable + // to create a null constant + bool IsApplicable( + opt::IRContext* context, + const TransformationContext& transformation_context) const override; + + // Adds an OpConstantNull instruction to the module, with |message_.type_id| + // as its type. The instruction has result id |message_.fresh_id|. + void Apply(opt::IRContext* context, + TransformationContext* transformation_context) const override; + + protobufs::Transformation ToMessage() const override; + + private: + protobufs::TransformationAddConstantNull message_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_
diff --git a/source/fuzz/transformation_add_constant_scalar.cpp b/source/fuzz/transformation_add_constant_scalar.cpp index 36af5e0..e13d08f 100644 --- a/source/fuzz/transformation_add_constant_scalar.cpp +++ b/source/fuzz/transformation_add_constant_scalar.cpp
@@ -33,14 +33,13 @@ } bool TransformationAddConstantScalar::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id needs to be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The type id for the scalar must exist and be a type. - auto type = context->get_type_mgr()->GetType(message_.type_id()); + auto type = ir_context->get_type_mgr()->GetType(message_.type_id()); if (!type) { return false; } @@ -61,20 +60,21 @@ } void TransformationAddConstantScalar::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList operand_list; for (auto word : message_.word()) { operand_list.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {word}}); } - context->module()->AddGlobalValue( - MakeUnique<opt::Instruction>(context, SpvOpConstant, message_.type_id(), - message_.fresh_id(), operand_list)); + ir_context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( + ir_context, SpvOpConstant, message_.type_id(), message_.fresh_id(), + operand_list)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddConstantScalar::ToMessage() const {
diff --git a/source/fuzz/transformation_add_constant_scalar.h b/source/fuzz/transformation_add_constant_scalar.h index 914cfe6..e0ed39f 100644 --- a/source/fuzz/transformation_add_constant_scalar.h +++ b/source/fuzz/transformation_add_constant_scalar.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -37,11 +37,13 @@ // - |message_.type_id| must be the id of a floating-point or integer type // - The size of |message_.word| must be compatible with the width of this // type - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds a new OpConstant instruction with the given type and words. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_dead_block.cpp b/source/fuzz/transformation_add_dead_block.cpp index b58f75e..b246c3f 100644 --- a/source/fuzz/transformation_add_dead_block.cpp +++ b/source/fuzz/transformation_add_dead_block.cpp
@@ -32,16 +32,15 @@ } bool TransformationAddDeadBlock::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The new block's id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // First, we check that a constant with the same value as // |message_.condition_value| is present. - if (!fuzzerutil::MaybeGetBoolConstantId(context, + if (!fuzzerutil::MaybeGetBoolConstantId(ir_context, message_.condition_value())) { // The required constant is not present, so the transformation cannot be // applied. @@ -50,7 +49,7 @@ // The existing block must indeed exist. auto existing_block = - fuzzerutil::MaybeFindBlock(context, message_.existing_block()); + fuzzerutil::MaybeFindBlock(ir_context, message_.existing_block()); if (!existing_block) { return false; } @@ -68,13 +67,13 @@ // Its successor must not be a merge block nor continue target. auto successor_block_id = existing_block->terminator()->GetSingleWordInOperand(0); - if (fuzzerutil::IsMergeOrContinue(context, successor_block_id)) { + if (fuzzerutil::IsMergeOrContinue(ir_context, successor_block_id)) { return false; } // The successor must not be a loop header (i.e., |message_.existing_block| // must not be a back-edge block. - if (context->cfg()->block(successor_block_id)->IsLoopHeader()) { + if (ir_context->cfg()->block(successor_block_id)->IsLoopHeader()) { return false; } @@ -82,34 +81,36 @@ } void TransformationAddDeadBlock::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // Update the module id bound so that it is at least the id of the new block. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // Get the existing block and its successor. - auto existing_block = context->cfg()->block(message_.existing_block()); + auto existing_block = ir_context->cfg()->block(message_.existing_block()); auto successor_block_id = existing_block->terminator()->GetSingleWordInOperand(0); // Get the id of the boolean value that will be used as the branch condition. - auto bool_id = - fuzzerutil::MaybeGetBoolConstantId(context, message_.condition_value()); + auto bool_id = fuzzerutil::MaybeGetBoolConstantId(ir_context, + message_.condition_value()); // Make a new block that unconditionally branches to the original successor // block. auto enclosing_function = existing_block->GetParent(); - std::unique_ptr<opt::BasicBlock> new_block = MakeUnique<opt::BasicBlock>( - MakeUnique<opt::Instruction>(context, SpvOpLabel, 0, message_.fresh_id(), - opt::Instruction::OperandList())); + std::unique_ptr<opt::BasicBlock> new_block = + MakeUnique<opt::BasicBlock>(MakeUnique<opt::Instruction>( + ir_context, SpvOpLabel, 0, message_.fresh_id(), + opt::Instruction::OperandList())); new_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpBranch, 0, 0, + ir_context, SpvOpBranch, 0, 0, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {successor_block_id}}}))); // Turn the original block into a selection merge, with its original successor // as the merge block. existing_block->terminator()->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpSelectionMerge, 0, 0, + ir_context, SpvOpSelectionMerge, 0, 0, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {successor_block_id}}, {SPV_OPERAND_TYPE_SELECTION_CONTROL, @@ -135,7 +136,8 @@ existing_block); // Record the fact that the new block is dead. - fact_manager->AddFactBlockIsDead(message_.fresh_id()); + transformation_context->GetFactManager()->AddFactBlockIsDead( + message_.fresh_id()); // Fix up OpPhi instructions in the successor block, so that the values they // yield when control has transferred from the new block are the same as if @@ -143,7 +145,7 @@ // to be valid since |message_.existing_block| dominates the new block by // construction. Other transformations can change these phi operands to more // interesting values. - context->cfg() + ir_context->cfg() ->block(successor_block_id) ->ForEachPhiInst([this](opt::Instruction* phi_inst) { // Copy the operand that provides the phi value for the first of any @@ -156,7 +158,7 @@ // Do not rely on any existing analysis results since the control flow graph // of the module has changed. - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationAddDeadBlock::ToMessage() const {
diff --git a/source/fuzz/transformation_add_dead_block.h b/source/fuzz/transformation_add_dead_block.h index 059daca9..7d07616 100644 --- a/source/fuzz/transformation_add_dead_block.h +++ b/source/fuzz/transformation_add_dead_block.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_DEAD_BLOCK_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_DEAD_BLOCK_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -41,15 +41,17 @@ // - |message_.existing_block| must not be a back-edge block, since in this // case the newly-added block would lead to another back-edge to the // associated loop header - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Changes the OpBranch from |message_.existing_block| to its successor 's' // to an OpBranchConditional to either 's' or a new block, // |message_.fresh_id|, which itself unconditionally branches to 's'. The // conditional branch uses |message.condition_value| as its condition, and is // arranged so that control will pass to 's' at runtime. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_dead_break.cpp b/source/fuzz/transformation_add_dead_break.cpp index 43847fa..db9de7d 100644 --- a/source/fuzz/transformation_add_dead_break.cpp +++ b/source/fuzz/transformation_add_dead_break.cpp
@@ -14,8 +14,8 @@ #include "source/fuzz/transformation_add_dead_break.h" -#include "source/fuzz/fact_manager.h" #include "source/fuzz/fuzzer_util.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/basic_block.h" #include "source/opt/ir_context.h" #include "source/opt/struct_cfg_analysis.h" @@ -39,7 +39,7 @@ } bool TransformationAddDeadBreak::AddingBreakRespectsStructuredControlFlow( - opt::IRContext* context, opt::BasicBlock* bb_from) const { + opt::IRContext* ir_context, opt::BasicBlock* bb_from) const { // Look at the structured control flow associated with |from_block| and // check whether it is contained in an appropriate construct with merge id // |to_block| such that a break from |from_block| to |to_block| is legal. @@ -70,7 +70,7 @@ // structured control flow construct. auto containing_construct = - context->GetStructuredCFGAnalysis()->ContainingConstruct( + ir_context->GetStructuredCFGAnalysis()->ContainingConstruct( message_.from_block()); if (!containing_construct) { // |from_block| is not in a construct from which we can break. @@ -79,7 +79,7 @@ // Consider case (2) if (message_.to_block() == - context->cfg()->block(containing_construct)->MergeBlockId()) { + ir_context->cfg()->block(containing_construct)->MergeBlockId()) { // This looks like an instance of case (2). // However, the structured CFG analysis regards the continue construct of a // loop as part of the loop, but it is not legal to jump from a loop's @@ -90,28 +90,29 @@ // currently allow a dead break from a back edge block, but we could and // ultimately should. return !fuzzerutil::BlockIsInLoopContinueConstruct( - context, message_.from_block(), containing_construct); + ir_context, message_.from_block(), containing_construct); } // Case (3) holds if and only if |to_block| is the merge block for this // innermost loop that contains |from_block| auto containing_loop_header = - context->GetStructuredCFGAnalysis()->ContainingLoop( + ir_context->GetStructuredCFGAnalysis()->ContainingLoop( message_.from_block()); if (containing_loop_header && message_.to_block() == - context->cfg()->block(containing_loop_header)->MergeBlockId()) { + ir_context->cfg()->block(containing_loop_header)->MergeBlockId()) { return !fuzzerutil::BlockIsInLoopContinueConstruct( - context, message_.from_block(), containing_loop_header); + ir_context, message_.from_block(), containing_loop_header); } return false; } bool TransformationAddDeadBreak::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // First, we check that a constant with the same value as // |message_.break_condition_value| is present. - if (!fuzzerutil::MaybeGetBoolConstantId(context, + if (!fuzzerutil::MaybeGetBoolConstantId(ir_context, message_.break_condition_value())) { // The required constant is not present, so the transformation cannot be // applied. @@ -121,17 +122,17 @@ // Check that |message_.from_block| and |message_.to_block| really are block // ids opt::BasicBlock* bb_from = - fuzzerutil::MaybeFindBlock(context, message_.from_block()); + fuzzerutil::MaybeFindBlock(ir_context, message_.from_block()); if (bb_from == nullptr) { return false; } opt::BasicBlock* bb_to = - fuzzerutil::MaybeFindBlock(context, message_.to_block()); + fuzzerutil::MaybeFindBlock(ir_context, message_.to_block()); if (bb_to == nullptr) { return false; } - if (!fuzzerutil::BlockIsReachableInItsFunction(context, bb_to)) { + if (!fuzzerutil::BlockIsReachableInItsFunction(ir_context, bb_to)) { // If the target of the break is unreachable, we conservatively do not // allow adding a dead break, to avoid the compilations that arise due to // the lack of sensible dominance information for unreachable blocks. @@ -157,14 +158,14 @@ "The id of the block we found should match the target id for the break."); // Check whether the data passed to extend OpPhi instructions is appropriate. - if (!fuzzerutil::PhiIdsOkForNewEdge(context, bb_from, bb_to, + if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, bb_from, bb_to, message_.phi_id())) { return false; } // Check that adding the break would respect the rules of structured // control flow. - if (!AddingBreakRespectsStructuredControlFlow(context, bb_from)) { + if (!AddingBreakRespectsStructuredControlFlow(ir_context, bb_from)) { return false; } @@ -177,16 +178,18 @@ // being places on the validator. This should be revisited if we are sure // the validator is complete with respect to checking structured control flow // rules. - auto cloned_context = fuzzerutil::CloneIRContext(context); + auto cloned_context = fuzzerutil::CloneIRContext(ir_context); ApplyImpl(cloned_context.get()); - return fuzzerutil::IsValid(cloned_context.get()); + return fuzzerutil::IsValid(cloned_context.get(), + transformation_context.GetValidatorOptions()); } -void TransformationAddDeadBreak::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { - ApplyImpl(context); +void TransformationAddDeadBreak::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { + ApplyImpl(ir_context); // Invalidate all analyses - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddDeadBreak::ToMessage() const { @@ -196,10 +199,10 @@ } void TransformationAddDeadBreak::ApplyImpl( - spvtools::opt::IRContext* context) const { + spvtools::opt::IRContext* ir_context) const { fuzzerutil::AddUnreachableEdgeAndUpdateOpPhis( - context, context->cfg()->block(message_.from_block()), - context->cfg()->block(message_.to_block()), + ir_context, ir_context->cfg()->block(message_.from_block()), + ir_context->cfg()->block(message_.to_block()), message_.break_condition_value(), message_.phi_id()); }
diff --git a/source/fuzz/transformation_add_dead_break.h b/source/fuzz/transformation_add_dead_break.h index 81a2c99..0ea9210 100644 --- a/source/fuzz/transformation_add_dead_break.h +++ b/source/fuzz/transformation_add_dead_break.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -50,21 +50,23 @@ // maintain validity of the module. // In particular, the new branch must not lead to violations of the rule // that a use must be dominated by its definition. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Replaces the terminator of a with a conditional branch to b or c. // The boolean constant associated with |message_.break_condition_value| is // used as the condition, and the order of b and c is arranged such that // control is guaranteed to jump to c. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; private: // Returns true if and only if adding an edge from |bb_from| to // |message_.to_block| respects structured control flow. - bool AddingBreakRespectsStructuredControlFlow(opt::IRContext* context, + bool AddingBreakRespectsStructuredControlFlow(opt::IRContext* ir_context, opt::BasicBlock* bb_from) const; // Used by 'Apply' to actually apply the transformation to the module of @@ -73,7 +75,7 @@ // module. This is only invoked by 'IsApplicable' after certain basic // applicability checks have been made, ensuring that the invocation of this // method is legal. - void ApplyImpl(opt::IRContext* context) const; + void ApplyImpl(opt::IRContext* ir_context) const; protobufs::TransformationAddDeadBreak message_; };
diff --git a/source/fuzz/transformation_add_dead_continue.cpp b/source/fuzz/transformation_add_dead_continue.cpp index 3a4875e..1fc6d67 100644 --- a/source/fuzz/transformation_add_dead_continue.cpp +++ b/source/fuzz/transformation_add_dead_continue.cpp
@@ -34,11 +34,12 @@ } bool TransformationAddDeadContinue::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // First, we check that a constant with the same value as // |message_.continue_condition_value| is present. if (!fuzzerutil::MaybeGetBoolConstantId( - context, message_.continue_condition_value())) { + ir_context, message_.continue_condition_value())) { // The required constant is not present, so the transformation cannot be // applied. return false; @@ -46,7 +47,7 @@ // Check that |message_.from_block| really is a block id. opt::BasicBlock* bb_from = - fuzzerutil::MaybeFindBlock(context, message_.from_block()); + fuzzerutil::MaybeFindBlock(ir_context, message_.from_block()); if (bb_from == nullptr) { return false; } @@ -68,31 +69,33 @@ // Because the structured CFG analysis does not regard a loop header as part // of the loop it heads, we check first whether bb_from is a loop header // before using the structured CFG analysis. - auto loop_header = bb_from->IsLoopHeader() - ? message_.from_block() - : context->GetStructuredCFGAnalysis()->ContainingLoop( - message_.from_block()); + auto loop_header = + bb_from->IsLoopHeader() + ? message_.from_block() + : ir_context->GetStructuredCFGAnalysis()->ContainingLoop( + message_.from_block()); if (!loop_header) { return false; } - auto continue_block = context->cfg()->block(loop_header)->ContinueBlockId(); + auto continue_block = + ir_context->cfg()->block(loop_header)->ContinueBlockId(); if (!fuzzerutil::BlockIsReachableInItsFunction( - context, context->cfg()->block(continue_block))) { + ir_context, ir_context->cfg()->block(continue_block))) { // If the loop's continue block is unreachable, we conservatively do not // allow adding a dead continue, to avoid the compilations that arise due to // the lack of sensible dominance information for unreachable blocks. return false; } - if (fuzzerutil::BlockIsInLoopContinueConstruct(context, message_.from_block(), - loop_header)) { + if (fuzzerutil::BlockIsInLoopContinueConstruct( + ir_context, message_.from_block(), loop_header)) { // We cannot jump to the continue target from the continue construct. return false; } - if (context->GetStructuredCFGAnalysis()->IsMergeBlock(continue_block)) { + if (ir_context->GetStructuredCFGAnalysis()->IsMergeBlock(continue_block)) { // A branch straight to the continue target that is also a merge block might // break the property that a construct header must dominate its merge block // (if the merge block is reachable). @@ -100,8 +103,8 @@ } // Check whether the data passed to extend OpPhi instructions is appropriate. - if (!fuzzerutil::PhiIdsOkForNewEdge(context, bb_from, - context->cfg()->block(continue_block), + if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, bb_from, + ir_context->cfg()->block(continue_block), message_.phi_id())) { return false; } @@ -115,16 +118,18 @@ // being placed on the validator. This should be revisited if we are sure // the validator is complete with respect to checking structured control flow // rules. - auto cloned_context = fuzzerutil::CloneIRContext(context); + auto cloned_context = fuzzerutil::CloneIRContext(ir_context); ApplyImpl(cloned_context.get()); - return fuzzerutil::IsValid(cloned_context.get()); + return fuzzerutil::IsValid(cloned_context.get(), + transformation_context.GetValidatorOptions()); } -void TransformationAddDeadContinue::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { - ApplyImpl(context); +void TransformationAddDeadContinue::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { + ApplyImpl(ir_context); // Invalidate all analyses - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddDeadContinue::ToMessage() const { @@ -134,16 +139,16 @@ } void TransformationAddDeadContinue::ApplyImpl( - spvtools::opt::IRContext* context) const { - auto bb_from = context->cfg()->block(message_.from_block()); + spvtools::opt::IRContext* ir_context) const { + auto bb_from = ir_context->cfg()->block(message_.from_block()); auto continue_block = bb_from->IsLoopHeader() ? bb_from->ContinueBlockId() - : context->GetStructuredCFGAnalysis()->LoopContinueBlock( + : ir_context->GetStructuredCFGAnalysis()->LoopContinueBlock( message_.from_block()); assert(continue_block && "message_.from_block must be in a loop."); fuzzerutil::AddUnreachableEdgeAndUpdateOpPhis( - context, bb_from, context->cfg()->block(continue_block), + ir_context, bb_from, ir_context->cfg()->block(continue_block), message_.continue_condition_value(), message_.phi_id()); }
diff --git a/source/fuzz/transformation_add_dead_continue.h b/source/fuzz/transformation_add_dead_continue.h index 86b4c93..1053c16 100644 --- a/source/fuzz/transformation_add_dead_continue.h +++ b/source/fuzz/transformation_add_dead_continue.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -52,14 +52,16 @@ // In particular, adding an edge from somewhere in the loop to the continue // target must not prevent uses of ids in the continue target from being // dominated by the definitions of those ids. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Replaces the terminator of a with a conditional branch to b or c. // The boolean constant associated with |message_.continue_condition_value| is // used as the condition, and the order of b and c is arranged such that // control is guaranteed to jump to c. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -70,7 +72,7 @@ // module. This is only invoked by 'IsApplicable' after certain basic // applicability checks have been made, ensuring that the invocation of this // method is legal. - void ApplyImpl(opt::IRContext* context) const; + void ApplyImpl(opt::IRContext* ir_context) const; protobufs::TransformationAddDeadContinue message_; };
diff --git a/source/fuzz/transformation_add_function.cpp b/source/fuzz/transformation_add_function.cpp index 8f0d3c9..90276ed 100644 --- a/source/fuzz/transformation_add_function.cpp +++ b/source/fuzz/transformation_add_function.cpp
@@ -56,8 +56,8 @@ } bool TransformationAddFunction::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // This transformation may use a lot of ids, all of which need to be fresh // and distinct. This set tracks them. std::set<uint32_t> ids_used_by_this_transformation; @@ -66,7 +66,7 @@ for (auto& instruction : message_.instruction()) { if (instruction.result_id()) { if (!CheckIdIsFreshAndNotUsedByThisTransformation( - instruction.result_id(), context, + instruction.result_id(), ir_context, &ids_used_by_this_transformation)) { return false; } @@ -77,28 +77,28 @@ // Ensure that all ids provided for making the function livesafe are fresh // and distinct. if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.loop_limiter_variable_id(), context, + message_.loop_limiter_variable_id(), ir_context, &ids_used_by_this_transformation)) { return false; } for (auto& loop_limiter_info : message_.loop_limiter_info()) { if (!CheckIdIsFreshAndNotUsedByThisTransformation( - loop_limiter_info.load_id(), context, + loop_limiter_info.load_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - loop_limiter_info.increment_id(), context, + loop_limiter_info.increment_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - loop_limiter_info.compare_id(), context, + loop_limiter_info.compare_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - loop_limiter_info.logical_op_id(), context, + loop_limiter_info.logical_op_id(), ir_context, &ids_used_by_this_transformation)) { return false; } @@ -107,11 +107,11 @@ message_.access_chain_clamping_info()) { for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) { if (!CheckIdIsFreshAndNotUsedByThisTransformation( - pair.first(), context, &ids_used_by_this_transformation)) { + pair.first(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - pair.second(), context, &ids_used_by_this_transformation)) { + pair.second(), ir_context, &ids_used_by_this_transformation)) { return false; } } @@ -123,8 +123,8 @@ // is taken here. // We first clone the current module, so that we can try adding the new - // function without risking wrecking |context|. - auto cloned_module = fuzzerutil::CloneIRContext(context); + // function without risking wrecking |ir_context|. + auto cloned_module = fuzzerutil::CloneIRContext(ir_context); // We try to add a function to the cloned module, which may fail if // |message_.instruction| is not sufficiently well-formed. @@ -134,12 +134,14 @@ // Check whether the cloned module is still valid after adding the function. // If it is not, the transformation is not applicable. - if (!fuzzerutil::IsValid(cloned_module.get())) { + if (!fuzzerutil::IsValid(cloned_module.get(), + transformation_context.GetValidatorOptions())) { return false; } if (message_.is_livesafe()) { - if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) { + if (!TryToMakeFunctionLivesafe(cloned_module.get(), + transformation_context)) { return false; } // After making the function livesafe, we check validity of the module @@ -148,7 +150,8 @@ // has the potential to make the module invalid when it was otherwise valid. // It is simpler to rely on the validator to guard against this than to // consider all scenarios when making a function livesafe. - if (!fuzzerutil::IsValid(cloned_module.get())) { + if (!fuzzerutil::IsValid(cloned_module.get(), + transformation_context.GetValidatorOptions())) { return false; } } @@ -156,10 +159,11 @@ } void TransformationAddFunction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // Add the function to the module. As the transformation is applicable, this // should succeed. - bool success = TryToAddFunction(context); + bool success = TryToAddFunction(ir_context); assert(success && "The function should be successfully added."); (void)(success); // Keep release builds happy (otherwise they may complain // that |success| is not used). @@ -172,16 +176,16 @@ for (auto& instruction : message_.instruction()) { switch (instruction.opcode()) { case SpvOpFunctionParameter: - if (context->get_def_use_mgr() + if (ir_context->get_def_use_mgr() ->GetDef(instruction.result_type_id()) ->opcode() == SpvOpTypePointer) { - fact_manager->AddFactValueOfPointeeIsIrrelevant( - instruction.result_id()); + transformation_context->GetFactManager() + ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id()); } break; case SpvOpVariable: - fact_manager->AddFactValueOfPointeeIsIrrelevant( - instruction.result_id()); + transformation_context->GetFactManager() + ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id()); break; default: break; @@ -190,7 +194,7 @@ if (message_.is_livesafe()) { // Make the function livesafe, which also should succeed. - success = TryToMakeFunctionLivesafe(context, *fact_manager); + success = TryToMakeFunctionLivesafe(ir_context, *transformation_context); assert(success && "It should be possible to make the function livesafe."); (void)(success); // Keep release builds happy. @@ -198,17 +202,18 @@ assert(message_.instruction(0).opcode() == SpvOpFunction && "The first instruction of an 'add function' transformation must be " "OpFunction."); - fact_manager->AddFactFunctionIsLivesafe( + transformation_context->GetFactManager()->AddFactFunctionIsLivesafe( message_.instruction(0).result_id()); } else { // Inform the fact manager that all blocks in the function are dead. for (auto& inst : message_.instruction()) { if (inst.opcode() == SpvOpLabel) { - fact_manager->AddFactBlockIsDead(inst.result_id()); + transformation_context->GetFactManager()->AddFactBlockIsDead( + inst.result_id()); } } } - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationAddFunction::ToMessage() const { @@ -218,9 +223,9 @@ } bool TransformationAddFunction::TryToAddFunction( - opt::IRContext* context) const { + opt::IRContext* ir_context) const { // This function returns false if |message_.instruction| was not well-formed - // enough to actually create a function and add it to |context|. + // enough to actually create a function and add it to |ir_context|. // A function must have at least some instructions. if (message_.instruction().empty()) { @@ -235,7 +240,7 @@ // Make a function, headed by the OpFunction instruction. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>( - InstructionFromMessage(context, function_begin)); + InstructionFromMessage(ir_context, function_begin)); // Keeps track of which instruction protobuf message we are currently // considering. @@ -249,7 +254,7 @@ message_.instruction(instruction_index).opcode() == SpvOpFunctionParameter) { new_function->AddParameter(InstructionFromMessage( - context, message_.instruction(instruction_index))); + ir_context, message_.instruction(instruction_index))); instruction_index++; } @@ -270,7 +275,7 @@ // as its parent. std::unique_ptr<opt::BasicBlock> block = MakeUnique<opt::BasicBlock>(InstructionFromMessage( - context, message_.instruction(instruction_index))); + ir_context, message_.instruction(instruction_index))); block->SetParent(new_function.get()); // Consider successive instructions until we hit another label or the end @@ -281,7 +286,7 @@ SpvOpFunctionEnd && message_.instruction(instruction_index).opcode() != SpvOpLabel) { block->AddInstruction(InstructionFromMessage( - context, message_.instruction(instruction_index))); + ir_context, message_.instruction(instruction_index))); instruction_index++; } // Add the block to the new function. @@ -295,22 +300,23 @@ } // Set the function's final instruction, add the function to the module and // report success. - new_function->SetFunctionEnd( - InstructionFromMessage(context, message_.instruction(instruction_index))); - context->AddFunction(std::move(new_function)); + new_function->SetFunctionEnd(InstructionFromMessage( + ir_context, message_.instruction(instruction_index))); + ir_context->AddFunction(std::move(new_function)); - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); return true; } bool TransformationAddFunction::TryToMakeFunctionLivesafe( - opt::IRContext* context, const FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { assert(message_.is_livesafe() && "Precondition: is_livesafe must hold."); // Get a pointer to the added function. opt::Function* added_function = nullptr; - for (auto& function : *context->module()) { + for (auto& function : *ir_context->module()) { if (function.result_id() == message_.instruction(0).result_id()) { added_function = &function; break; @@ -318,7 +324,7 @@ } assert(added_function && "The added function should have been found."); - if (!TryToAddLoopLimiters(context, added_function)) { + if (!TryToAddLoopLimiters(ir_context, added_function)) { // Adding loop limiters did not work; bail out. return false; } @@ -332,20 +338,20 @@ switch (inst.opcode()) { case SpvOpKill: case SpvOpUnreachable: - if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function, + if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function, &inst)) { return false; } break; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: - if (!TryToClampAccessChainIndices(context, &inst)) { + if (!TryToClampAccessChainIndices(ir_context, &inst)) { return false; } break; case SpvOpFunctionCall: // A livesafe function my only call other livesafe functions. - if (!fact_manager.FunctionIsLivesafe( + if (!transformation_context.GetFactManager()->FunctionIsLivesafe( inst.GetSingleWordInOperand(0))) { return false; } @@ -358,7 +364,7 @@ } bool TransformationAddFunction::TryToAddLoopLimiters( - opt::IRContext* context, opt::Function* added_function) const { + opt::IRContext* ir_context, opt::Function* added_function) const { // Collect up all the loop headers so that we can subsequently add loop // limiting logic. std::vector<opt::BasicBlock*> loop_headers; @@ -377,7 +383,7 @@ // manipulating a loop limiter. auto loop_limit_constant_id_instr = - context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id()); + ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id()); if (!loop_limit_constant_id_instr || loop_limit_constant_id_instr->opcode() != SpvOpConstant) { // The loop limit constant id instruction must exist and have an @@ -385,7 +391,7 @@ return false; } - auto loop_limit_type = context->get_def_use_mgr()->GetDef( + auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef( loop_limit_constant_id_instr->type_id()); if (loop_limit_type->opcode() != SpvOpTypeInt || loop_limit_type->GetSingleWordInOperand(0) != 32) { @@ -397,36 +403,36 @@ // Find the id of the "unsigned int" type. opt::analysis::Integer unsigned_int_type(32, false); uint32_t unsigned_int_type_id = - context->get_type_mgr()->GetId(&unsigned_int_type); + ir_context->get_type_mgr()->GetId(&unsigned_int_type); if (!unsigned_int_type_id) { // Unsigned int is not available; we need this type in order to add loop // limiters. return false; } auto registered_unsigned_int_type = - context->get_type_mgr()->GetRegisteredType(&unsigned_int_type); + ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type); // Look for 0 of type unsigned int. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(), {0}); - auto registered_zero = context->get_constant_mgr()->FindConstant(&zero); + auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero); if (!registered_zero) { // We need 0 in order to be able to initialize loop limiters. return false; } - uint32_t zero_id = context->get_constant_mgr() + uint32_t zero_id = ir_context->get_constant_mgr() ->GetDefiningInstruction(registered_zero) ->result_id(); // Look for 1 of type unsigned int. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(), {1}); - auto registered_one = context->get_constant_mgr()->FindConstant(&one); + auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one); if (!registered_one) { // We need 1 in order to be able to increment loop limiters. return false; } - uint32_t one_id = context->get_constant_mgr() + uint32_t one_id = ir_context->get_constant_mgr() ->GetDefiningInstruction(registered_one) ->result_id(); @@ -434,7 +440,7 @@ opt::analysis::Pointer pointer_to_unsigned_int_type( registered_unsigned_int_type, SpvStorageClassFunction); uint32_t pointer_to_unsigned_int_type_id = - context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type); + ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type); if (!pointer_to_unsigned_int_type_id) { // We need pointer-to-unsigned int in order to declare the loop limiter // variable. @@ -443,7 +449,7 @@ // Look for bool type. opt::analysis::Bool bool_type; - uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type); + uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type); if (!bool_type_id) { // We need bool in order to compare the loop limiter's value with the loop // limit constant. @@ -454,22 +460,23 @@ // block, via an instruction of the form: // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpVariable, pointer_to_unsigned_int_type_id, + ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id, message_.loop_limiter_variable_id(), opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {zero_id}}}))); // Update the module's id bound since we have added the loop limiter // variable id. - fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, + message_.loop_limiter_variable_id()); // Consider each loop in turn. for (auto loop_header : loop_headers) { // Look for the loop's back-edge block. This is a predecessor of the loop // header that is dominated by the loop header. uint32_t back_edge_block_id = 0; - for (auto pred : context->cfg()->preds(loop_header->id())) { - if (context->GetDominatorAnalysis(added_function) + for (auto pred : ir_context->cfg()->preds(loop_header->id())) { + if (ir_context->GetDominatorAnalysis(added_function) ->Dominates(loop_header->id(), pred)) { back_edge_block_id = pred; break; @@ -481,7 +488,7 @@ // move on from this loop. continue; } - auto back_edge_block = context->cfg()->block(back_edge_block_id); + auto back_edge_block = ir_context->cfg()->block(back_edge_block_id); // Go through the sequence of loop limiter infos and find the one // corresponding to this loop. @@ -579,14 +586,15 @@ // Add a load from the loop limiter variable, of the form: // %t1 = OpLoad %uint32 %loop_limiter new_instructions.push_back(MakeUnique<opt::Instruction>( - context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(), + ir_context, SpvOpLoad, unsigned_int_type_id, + loop_limiter_info.load_id(), opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}}))); // Increment the loaded value: // %t2 = OpIAdd %uint32 %t1 %one new_instructions.push_back(MakeUnique<opt::Instruction>( - context, SpvOpIAdd, unsigned_int_type_id, + ir_context, SpvOpIAdd, unsigned_int_type_id, loop_limiter_info.increment_id(), opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}}, @@ -595,7 +603,7 @@ // Store the incremented value back to the loop limiter variable: // OpStore %loop_limiter %t2 new_instructions.push_back(MakeUnique<opt::Instruction>( - context, SpvOpStore, 0, 0, + ir_context, SpvOpStore, 0, 0, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}, {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}}))); @@ -605,7 +613,7 @@ // or // %t3 = OpULessThan %bool %t1 %loop_limit new_instructions.push_back(MakeUnique<opt::Instruction>( - context, + ir_context, compare_using_greater_than_equal ? SpvOpUGreaterThanEqual : SpvOpULessThan, bool_type_id, loop_limiter_info.compare_id(), @@ -615,7 +623,7 @@ if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) { new_instructions.push_back(MakeUnique<opt::Instruction>( - context, + ir_context, compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd, bool_type_id, loop_limiter_info.logical_op_id(), opt::Instruction::OperandList( @@ -644,8 +652,9 @@ // Check that, if the merge block starts with OpPhi instructions, suitable // ids have been provided to give these instructions a value corresponding // to the new incoming edge from the back edge block. - auto merge_block = context->cfg()->block(loop_header->MergeBlockId()); - if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block, + auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId()); + if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block, + merge_block, loop_limiter_info.phi_id())) { return false; } @@ -681,16 +690,18 @@ // Update the module's id bound with respect to the various ids that // have been used for loop limiter manipulation. - fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id()); - fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id()); - fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id()); - fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, + loop_limiter_info.increment_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, + loop_limiter_info.logical_op_id()); } return true; } bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn( - opt::IRContext* context, opt::Function* added_function, + opt::IRContext* ir_context, opt::Function* added_function, opt::Instruction* kill_or_unreachable_inst) const { assert((kill_or_unreachable_inst->opcode() == SpvOpKill || kill_or_unreachable_inst->opcode() == SpvOpUnreachable) && @@ -698,7 +709,7 @@ // Get the function's return type. auto function_return_type_inst = - context->get_def_use_mgr()->GetDef(added_function->type_id()); + ir_context->get_def_use_mgr()->GetDef(added_function->type_id()); if (function_return_type_inst->opcode() == SpvOpTypeVoid) { // The function has void return type, so change this instruction to @@ -712,7 +723,7 @@ // We first check that the id, %id, provided with the transformation // specifically to turn OpKill and OpUnreachable instructions into // OpReturnValue %id has the same type as the function's return type. - if (context->get_def_use_mgr() + if (ir_context->get_def_use_mgr() ->GetDef(message_.kill_unreachable_return_value_id()) ->type_id() != function_return_type_inst->result_id()) { return false; @@ -725,7 +736,7 @@ } bool TransformationAddFunction::TryToClampAccessChainIndices( - opt::IRContext* context, opt::Instruction* access_chain_inst) const { + opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const { assert((access_chain_inst->opcode() == SpvOpAccessChain || access_chain_inst->opcode() == SpvOpInBoundsAccessChain) && "Precondition: instruction must be OpAccessChain or " @@ -756,14 +767,14 @@ // Walk the access chain, clamping each index to be within bounds if it is // not a constant. - auto base_object = context->get_def_use_mgr()->GetDef( + auto base_object = ir_context->get_def_use_mgr()->GetDef( access_chain_inst->GetSingleWordInOperand(0)); assert(base_object && "The base object must exist."); auto pointer_type = - context->get_def_use_mgr()->GetDef(base_object->type_id()); + ir_context->get_def_use_mgr()->GetDef(base_object->type_id()); assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer && "The base object must have pointer type."); - auto should_be_composite_type = context->get_def_use_mgr()->GetDef( + auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef( pointer_type->GetSingleWordInOperand(1)); // Consider each index input operand in turn (operand 0 is the base object). @@ -784,41 +795,43 @@ // Get the bound for the composite being indexed into; e.g. the number of // columns of matrix or the size of an array. uint32_t bound = - GetBoundForCompositeIndex(context, *should_be_composite_type); + GetBoundForCompositeIndex(ir_context, *should_be_composite_type); // Get the instruction associated with the index and figure out its integer // type. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index); - auto index_inst = context->get_def_use_mgr()->GetDef(index_id); + auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id); auto index_type_inst = - context->get_def_use_mgr()->GetDef(index_inst->type_id()); + ir_context->get_def_use_mgr()->GetDef(index_inst->type_id()); assert(index_type_inst->opcode() == SpvOpTypeInt); assert(index_type_inst->GetSingleWordInOperand(0) == 32); opt::analysis::Integer* index_int_type = - context->get_type_mgr() + ir_context->get_type_mgr() ->GetType(index_type_inst->result_id()) ->AsInteger(); - if (index_inst->opcode() != SpvOpConstant) { - // The index is non-constant so we need to clamp it. + if (index_inst->opcode() != SpvOpConstant || + index_inst->GetSingleWordInOperand(0) >= bound) { + // The index is either non-constant or an out-of-bounds constant, so we + // need to clamp it. assert(should_be_composite_type->opcode() != SpvOpTypeStruct && "Access chain indices into structures are required to be " "constants."); opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1}); - if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) { + if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) { // We do not have an integer constant whose value is |bound| -1. return false; } opt::analysis::Bool bool_type; - uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type); + uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type); if (!bool_type_id) { // Bool type is not declared; we cannot do a comparison. return false; } uint32_t bound_minus_one_id = - context->get_constant_mgr() + ir_context->get_constant_mgr() ->GetDefiningInstruction(&bound_minus_one) ->result_id(); @@ -832,7 +845,7 @@ // Compare the index with the bound via an instruction of the form: // %t1 = OpULessThanEqual %bool %index %bound_minus_one new_instructions.push_back(MakeUnique<opt::Instruction>( - context, SpvOpULessThanEqual, bool_type_id, compare_id, + ir_context, SpvOpULessThanEqual, bool_type_id, compare_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}}, {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}}))); @@ -840,7 +853,7 @@ // Select the index if in-bounds, otherwise one less than the bound: // %t2 = OpSelect %int_type %t1 %index %bound_minus_one new_instructions.push_back(MakeUnique<opt::Instruction>( - context, SpvOpSelect, index_type_inst->result_id(), select_id, + ir_context, SpvOpSelect, index_type_inst->result_id(), select_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {compare_id}}, {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}}, @@ -851,41 +864,31 @@ // Replace %index with %t2. access_chain_inst->SetInOperand(index, {select_id}); - fuzzerutil::UpdateModuleIdBound(context, compare_id); - fuzzerutil::UpdateModuleIdBound(context, select_id); - } else { - // TODO(afd): At present the SPIR-V spec is not clear on whether - // statically out-of-bounds indices mean that a module is invalid (so - // that it should be rejected by the validator), or that such accesses - // yield undefined results. Via the following assertion, we assume that - // functions added to the module do not feature statically out-of-bounds - // accesses. - // Assert that the index is smaller (unsigned) than this value. - // Return false if it is not (to keep compilers happy). - if (index_inst->GetSingleWordInOperand(0) >= bound) { - assert(false && - "The function has a statically out-of-bounds access; " - "this should not occur."); - return false; - } + fuzzerutil::UpdateModuleIdBound(ir_context, compare_id); + fuzzerutil::UpdateModuleIdBound(ir_context, select_id); } should_be_composite_type = - FollowCompositeIndex(context, *should_be_composite_type, index_id); + FollowCompositeIndex(ir_context, *should_be_composite_type, index_id); } return true; } uint32_t TransformationAddFunction::GetBoundForCompositeIndex( - opt::IRContext* context, const opt::Instruction& composite_type_inst) { + opt::IRContext* ir_context, const opt::Instruction& composite_type_inst) { switch (composite_type_inst.opcode()) { case SpvOpTypeArray: - return fuzzerutil::GetArraySize(composite_type_inst, context); + return fuzzerutil::GetArraySize(composite_type_inst, ir_context); case SpvOpTypeMatrix: case SpvOpTypeVector: return composite_type_inst.GetSingleWordInOperand(1); case SpvOpTypeStruct: { return fuzzerutil::GetNumberOfStructMembers(composite_type_inst); } + case SpvOpTypeRuntimeArray: + assert(false && + "GetBoundForCompositeIndex should not be invoked with an " + "OpTypeRuntimeArray, which does not have a static bound."); + return 0; default: assert(false && "Unknown composite type."); return 0; @@ -893,11 +896,12 @@ } opt::Instruction* TransformationAddFunction::FollowCompositeIndex( - opt::IRContext* context, const opt::Instruction& composite_type_inst, + opt::IRContext* ir_context, const opt::Instruction& composite_type_inst, uint32_t index_id) { uint32_t sub_object_type_id; switch (composite_type_inst.opcode()) { case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0); break; case SpvOpTypeMatrix: @@ -905,12 +909,12 @@ sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0); break; case SpvOpTypeStruct: { - auto index_inst = context->get_def_use_mgr()->GetDef(index_id); + auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id); assert(index_inst->opcode() == SpvOpConstant); - assert( - context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() == - SpvOpTypeInt); - assert(context->get_def_use_mgr() + assert(ir_context->get_def_use_mgr() + ->GetDef(index_inst->type_id()) + ->opcode() == SpvOpTypeInt); + assert(ir_context->get_def_use_mgr() ->GetDef(index_inst->type_id()) ->GetSingleWordInOperand(0) == 32); uint32_t index_value = index_inst->GetSingleWordInOperand(0); @@ -924,7 +928,7 @@ break; } assert(sub_object_type_id && "No sub-object found."); - return context->get_def_use_mgr()->GetDef(sub_object_type_id); + return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id); } } // namespace fuzz
diff --git a/source/fuzz/transformation_add_function.h b/source/fuzz/transformation_add_function.h index 848b799..5af197b 100644 --- a/source/fuzz/transformation_add_function.h +++ b/source/fuzz/transformation_add_function.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_FUNCTION_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_FUNCTION_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -47,12 +47,14 @@ // ingredients to make the function livesafe, and the function must only // invoke other livesafe functions // - Adding the created function to the module must lead to a valid module. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds the function defined by |message_.instruction| to the module, making // it livesafe if |message_.is_livesafe| holds. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -61,26 +63,26 @@ // an array, the number of components of a vector, or the number of columns of // a matrix. static uint32_t GetBoundForCompositeIndex( - opt::IRContext* context, const opt::Instruction& composite_type_inst); + opt::IRContext* ir_context, const opt::Instruction& composite_type_inst); // Helper method that, given composite type |composite_type_inst|, returns the // type of the sub-object at index |index_id|, which is required to be in- // bounds. static opt::Instruction* FollowCompositeIndex( - opt::IRContext* context, const opt::Instruction& composite_type_inst, + opt::IRContext* ir_context, const opt::Instruction& composite_type_inst, uint32_t index_id); private: // Attempts to create a function from the series of instructions in - // |message_.instruction| and add it to |context|. + // |message_.instruction| and add it to |ir_context|. // // Returns false if adding the function is not possible due to the messages // not respecting the basic structure of a function, e.g. if there is no - // OpFunction instruction or no blocks; in this case |context| is left in an - // indeterminate state. + // OpFunction instruction or no blocks; in this case |ir_context| is left in + // an indeterminate state. // - // Otherwise returns true. Whether |context| is valid after addition of the - // function depends on the contents of |message_.instruction|. + // Otherwise returns true. Whether |ir_context| is valid after addition of + // the function depends on the contents of |message_.instruction|. // // Intended usage: // - Perform a dry run of this method on a clone of a module, and use @@ -89,30 +91,31 @@ // added, or leads to an invalid module. // - If the dry run succeeds, run the method on the real module of interest, // to add the function. - bool TryToAddFunction(opt::IRContext* context) const; + bool TryToAddFunction(opt::IRContext* ir_context) const; // Should only be called if |message_.is_livesafe| holds. Attempts to make // the function livesafe (see FactFunctionIsLivesafe for a definition). - // Returns false if this is not possible, due to |message_| or |context| not - // containing sufficient ingredients (such as types and fresh ids) to add + // Returns false if this is not possible, due to |message_| or |ir_context| + // not containing sufficient ingredients (such as types and fresh ids) to add // the instrumentation necessary to make the function livesafe. - bool TryToMakeFunctionLivesafe(opt::IRContext* context, - const FactManager& fact_manager) const; + bool TryToMakeFunctionLivesafe( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const; // A helper for TryToMakeFunctionLivesafe that tries to add loop-limiting // logic. - bool TryToAddLoopLimiters(opt::IRContext* context, + bool TryToAddLoopLimiters(opt::IRContext* ir_context, opt::Function* added_function) const; // A helper for TryToMakeFunctionLivesafe that tries to replace OpKill and // OpUnreachable instructions into return instructions. bool TryToTurnKillOrUnreachableIntoReturn( - opt::IRContext* context, opt::Function* added_function, + opt::IRContext* ir_context, opt::Function* added_function, opt::Instruction* kill_or_unreachable_inst) const; // A helper for TryToMakeFunctionLivesafe that tries to clamp access chain // indices so that they are guaranteed to be in-bounds. - bool TryToClampAccessChainIndices(opt::IRContext* context, + bool TryToClampAccessChainIndices(opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const; protobufs::TransformationAddFunction message_;
diff --git a/source/fuzz/transformation_add_global_undef.cpp b/source/fuzz/transformation_add_global_undef.cpp index f9585b3..ba45f22 100644 --- a/source/fuzz/transformation_add_global_undef.cpp +++ b/source/fuzz/transformation_add_global_undef.cpp
@@ -30,26 +30,26 @@ } bool TransformationAddGlobalUndef::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // A fresh id is required. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } - auto type = context->get_type_mgr()->GetType(message_.type_id()); + auto type = ir_context->get_type_mgr()->GetType(message_.type_id()); // The type must exist, and must not be a function type. return type && !type->AsFunction(); } void TransformationAddGlobalUndef::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { - context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( - context, SpvOpUndef, message_.type_id(), message_.fresh_id(), + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { + ir_context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( + ir_context, SpvOpUndef, message_.type_id(), message_.fresh_id(), opt::Instruction::OperandList())); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddGlobalUndef::ToMessage() const {
diff --git a/source/fuzz/transformation_add_global_undef.h b/source/fuzz/transformation_add_global_undef.h index 550d9f6..c89fe9d 100644 --- a/source/fuzz/transformation_add_global_undef.h +++ b/source/fuzz/transformation_add_global_undef.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_GLOBAL_UNDEF_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_GLOBAL_UNDEF_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -32,12 +32,14 @@ // - |message_.fresh_id| must be fresh // - |message_.type_id| must be the id of a non-function type - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpUndef instruction to the module, with |message_.type_id| as its // type. The instruction has result id |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_global_variable.cpp b/source/fuzz/transformation_add_global_variable.cpp index e4f9f7a..6464bfb 100644 --- a/source/fuzz/transformation_add_global_variable.cpp +++ b/source/fuzz/transformation_add_global_variable.cpp
@@ -24,23 +24,34 @@ : message_(message) {} TransformationAddGlobalVariable::TransformationAddGlobalVariable( - uint32_t fresh_id, uint32_t type_id, uint32_t initializer_id, - bool value_is_irrelevant) { + uint32_t fresh_id, uint32_t type_id, SpvStorageClass storage_class, + uint32_t initializer_id, bool value_is_irrelevant) { message_.set_fresh_id(fresh_id); message_.set_type_id(type_id); + message_.set_storage_class(storage_class); message_.set_initializer_id(initializer_id); message_.set_value_is_irrelevant(value_is_irrelevant); } bool TransformationAddGlobalVariable::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } + + // The storage class must be Private or Workgroup. + auto storage_class = static_cast<SpvStorageClass>(message_.storage_class()); + switch (storage_class) { + case SpvStorageClassPrivate: + case SpvStorageClassWorkgroup: + break; + default: + assert(false && "Unsupported storage class."); + return false; + } // The type id must correspond to a type. - auto type = context->get_type_mgr()->GetType(message_.type_id()); + auto type = ir_context->get_type_mgr()->GetType(message_.type_id()); if (!type) { return false; } @@ -49,42 +60,52 @@ if (!pointer_type) { return false; } - // ... with Private storage class. - if (pointer_type->storage_class() != SpvStorageClassPrivate) { + // ... with the right storage class. + if (pointer_type->storage_class() != storage_class) { return false; } - // The initializer id must be the id of a constant. Check this with the - // constant manager. - auto constant_id = context->get_constant_mgr()->GetConstantsFromIds( - {message_.initializer_id()}); - if (constant_id.empty()) { - return false; - } - assert(constant_id.size() == 1 && - "We asked for the constant associated with a single id; we should " - "get a single constant."); - // The type of the constant must match the pointee type of the pointer. - if (pointer_type->pointee_type() != constant_id[0]->type()) { - return false; + if (message_.initializer_id()) { + // An initializer is not allowed if the storage class is Workgroup. + if (storage_class == SpvStorageClassWorkgroup) { + assert(false && + "By construction this transformation should not have an " + "initializer when Workgroup storage class is used."); + return false; + } + // The initializer id must be the id of a constant. Check this with the + // constant manager. + auto constant_id = ir_context->get_constant_mgr()->GetConstantsFromIds( + {message_.initializer_id()}); + if (constant_id.empty()) { + return false; + } + assert(constant_id.size() == 1 && + "We asked for the constant associated with a single id; we should " + "get a single constant."); + // The type of the constant must match the pointee type of the pointer. + if (pointer_type->pointee_type() != constant_id[0]->type()) { + return false; + } } return true; } void TransformationAddGlobalVariable::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { opt::Instruction::OperandList input_operands; input_operands.push_back( - {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassPrivate}}); + {SPV_OPERAND_TYPE_STORAGE_CLASS, {message_.storage_class()}}); if (message_.initializer_id()) { input_operands.push_back( {SPV_OPERAND_TYPE_ID, {message_.initializer_id()}}); } - context->module()->AddGlobalValue( - MakeUnique<opt::Instruction>(context, SpvOpVariable, message_.type_id(), - message_.fresh_id(), input_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddGlobalValue(MakeUnique<opt::Instruction>( + ir_context, SpvOpVariable, message_.type_id(), message_.fresh_id(), + input_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); - if (PrivateGlobalsMustBeDeclaredInEntryPointInterfaces(context)) { + if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(ir_context)) { // Conservatively add this global to the interface of every entry point in // the module. This means that the global is available for other // transformations to use. @@ -94,18 +115,20 @@ // // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit // this if a more thorough approach to entry point interfaces is taken. - for (auto& entry_point : context->module()->entry_points()) { + for (auto& entry_point : ir_context->module()->entry_points()) { entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {message_.fresh_id()}}); } } if (message_.value_is_irrelevant()) { - fact_manager->AddFactValueOfPointeeIsIrrelevant(message_.fresh_id()); + transformation_context->GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + message_.fresh_id()); } // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddGlobalVariable::ToMessage() const { @@ -115,12 +138,12 @@ } bool TransformationAddGlobalVariable:: - PrivateGlobalsMustBeDeclaredInEntryPointInterfaces( - opt::IRContext* context) { + GlobalVariablesMustBeDeclaredInEntryPointInterfaces( + opt::IRContext* ir_context) { // TODO(afd): We capture the universal environments for which this requirement // holds. The check should be refined on demand for other target // environments. - switch (context->grammar().target_env()) { + switch (ir_context->grammar().target_env()) { case SPV_ENV_UNIVERSAL_1_0: case SPV_ENV_UNIVERSAL_1_1: case SPV_ENV_UNIVERSAL_1_2:
diff --git a/source/fuzz/transformation_add_global_variable.h b/source/fuzz/transformation_add_global_variable.h index 920ac45..289af9e 100644 --- a/source/fuzz/transformation_add_global_variable.h +++ b/source/fuzz/transformation_add_global_variable.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_GLOBAL_VARIABLE_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_GLOBAL_VARIABLE_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -29,31 +29,40 @@ const protobufs::TransformationAddGlobalVariable& message); TransformationAddGlobalVariable(uint32_t fresh_id, uint32_t type_id, + SpvStorageClass storage_class, uint32_t initializer_id, bool value_is_irrelevant); // - |message_.fresh_id| must be fresh - // - |message_.type_id| must be the id of a pointer type with Private storage - // class - // - |message_.initializer_id| must either be 0 or the id of a constant whose + // - |message_.type_id| must be the id of a pointer type with the same storage + // class as |message_.storage_class| + // - |message_.storage_class| must be Private or Workgroup + // - |message_.initializer_id| must be 0 if |message_.storage_class| is + // Workgroup, and otherwise may either be 0 or the id of a constant whose // type is the pointee type of |message_.type_id| - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; - // Adds a global variable with Private storage class to the module, with type - // |message_.type_id| and either no initializer or |message_.initializer_id| - // as an initializer, depending on whether |message_.initializer_id| is 0. - // The global variable has result id |message_.fresh_id|. + // Adds a global variable with storage class |message_.storage_class| to the + // module, with type |message_.type_id| and either no initializer or + // |message_.initializer_id| as an initializer, depending on whether + // |message_.initializer_id| is 0. The global variable has result id + // |message_.fresh_id|. // - // If |message_.value_is_irrelevant| holds, adds a corresponding fact to - // |fact_manager|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + // If |message_.value_is_irrelevant| holds, adds a corresponding fact to the + // fact manager in |transformation_context|. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; private: - static bool PrivateGlobalsMustBeDeclaredInEntryPointInterfaces( - opt::IRContext* context); + // Returns true if and only if the SPIR-V version being used requires that + // global variables accessed in the static call graph of an entry point need + // to be listed in that entry point's interface. + static bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces( + opt::IRContext* ir_context); protobufs::TransformationAddGlobalVariable message_; };
diff --git a/source/fuzz/transformation_add_local_variable.cpp b/source/fuzz/transformation_add_local_variable.cpp index 69e536d..5136249 100644 --- a/source/fuzz/transformation_add_local_variable.cpp +++ b/source/fuzz/transformation_add_local_variable.cpp
@@ -34,23 +34,22 @@ } bool TransformationAddLocalVariable::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The provided id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The pointer type id must indeed correspond to a pointer, and it must have // function storage class. auto type_instruction = - context->get_def_use_mgr()->GetDef(message_.type_id()); + ir_context->get_def_use_mgr()->GetDef(message_.type_id()); if (!type_instruction || type_instruction->opcode() != SpvOpTypePointer || type_instruction->GetSingleWordInOperand(0) != SpvStorageClassFunction) { return false; } // The initializer must... auto initializer_instruction = - context->get_def_use_mgr()->GetDef(message_.initializer_id()); + ir_context->get_def_use_mgr()->GetDef(message_.initializer_id()); // ... exist, ... if (!initializer_instruction) { return false; @@ -65,17 +64,18 @@ return false; } // The function to which the local variable is to be added must exist. - return fuzzerutil::FindFunction(context, message_.function_id()); + return fuzzerutil::FindFunction(ir_context, message_.function_id()); } void TransformationAddLocalVariable::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - fuzzerutil::FindFunction(context, message_.function_id()) + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + fuzzerutil::FindFunction(ir_context, message_.function_id()) ->begin() ->begin() ->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpVariable, message_.type_id(), message_.fresh_id(), + ir_context, SpvOpVariable, message_.type_id(), message_.fresh_id(), opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_STORAGE_CLASS, { @@ -83,9 +83,10 @@ SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {message_.initializer_id()}}}))); if (message_.value_is_irrelevant()) { - fact_manager->AddFactValueOfPointeeIsIrrelevant(message_.fresh_id()); + transformation_context->GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + message_.fresh_id()); } - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationAddLocalVariable::ToMessage() const {
diff --git a/source/fuzz/transformation_add_local_variable.h b/source/fuzz/transformation_add_local_variable.h index b8e00dd..6460904 100644 --- a/source/fuzz/transformation_add_local_variable.h +++ b/source/fuzz/transformation_add_local_variable.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_LOCAL_VARIABLE_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_LOCAL_VARIABLE_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -38,15 +38,17 @@ // - |message_.initializer_id| must be the id of a constant with the same // type as the pointer's pointee type // - |message_.function_id| must be the id of a function - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction to the start of |message_.function_id|, of the form: // |message_.fresh_id| = OpVariable |message_.type_id| Function // |message_.initializer_id| - // If |message_.value_is_irrelevant| holds, adds a corresponding fact to - // |fact_manager|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + // If |message_.value_is_irrelevant| holds, adds a corresponding fact to the + // fact manager in |transformation_context|. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_no_contraction_decoration.cpp b/source/fuzz/transformation_add_no_contraction_decoration.cpp index 7f22cc2..4668534 100644 --- a/source/fuzz/transformation_add_no_contraction_decoration.cpp +++ b/source/fuzz/transformation_add_no_contraction_decoration.cpp
@@ -31,10 +31,9 @@ } bool TransformationAddNoContractionDecoration::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // |message_.result_id| must be the id of an instruction. - auto instr = context->get_def_use_mgr()->GetDef(message_.result_id()); + auto instr = ir_context->get_def_use_mgr()->GetDef(message_.result_id()); if (!instr) { return false; } @@ -43,10 +42,10 @@ } void TransformationAddNoContractionDecoration::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Add a NoContraction decoration targeting |message_.result_id|. - context->get_decoration_mgr()->AddDecoration(message_.result_id(), - SpvDecorationNoContraction); + ir_context->get_decoration_mgr()->AddDecoration(message_.result_id(), + SpvDecorationNoContraction); } protobufs::Transformation TransformationAddNoContractionDecoration::ToMessage()
diff --git a/source/fuzz/transformation_add_no_contraction_decoration.h b/source/fuzz/transformation_add_no_contraction_decoration.h index cec1b2c..27c3a80 100644 --- a/source/fuzz/transformation_add_no_contraction_decoration.h +++ b/source/fuzz/transformation_add_no_contraction_decoration.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_NO_CONTRACTION_DECORATION_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_NO_CONTRACTION_DECORATION_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -34,13 +34,15 @@ // as defined by the SPIR-V specification. // - It does not matter whether this instruction is already annotated with the // NoContraction decoration. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds a decoration of the form: // 'OpDecoration |message_.result_id| NoContraction' // to the module. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_array.cpp b/source/fuzz/transformation_add_type_array.cpp index 2074e98..8f5af07 100644 --- a/source/fuzz/transformation_add_type_array.cpp +++ b/source/fuzz/transformation_add_type_array.cpp
@@ -32,21 +32,20 @@ } bool TransformationAddTypeArray::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // A fresh id is required. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } auto element_type = - context->get_type_mgr()->GetType(message_.element_type_id()); + ir_context->get_type_mgr()->GetType(message_.element_type_id()); if (!element_type || element_type->AsFunction()) { // The element type id either does not refer to a type, or refers to a // function type; both are illegal. return false; } auto constant = - context->get_constant_mgr()->GetConstantsFromIds({message_.size_id()}); + ir_context->get_constant_mgr()->GetConstantsFromIds({message_.size_id()}); if (constant.empty()) { // The size id does not refer to a constant. return false; @@ -66,16 +65,17 @@ } void TransformationAddTypeArray::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; in_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.element_type_id()}}); in_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.size_id()}}); - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeArray, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeArray, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeArray::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_array.h b/source/fuzz/transformation_add_type_array.h index b6e0718..5e9b8aa 100644 --- a/source/fuzz/transformation_add_type_array.h +++ b/source/fuzz/transformation_add_type_array.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_ARRAY_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_ARRAY_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -35,13 +35,15 @@ // - |message_.element_type_id| must be the id of a non-function type // - |message_.size_id| must be the id of a 32-bit integer constant that is // positive when interpreted as signed. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeArray instruction to the module, with element type given by // |message_.element_type_id| and size given by |message_.size_id|. The // result id of the instruction is |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_boolean.cpp b/source/fuzz/transformation_add_type_boolean.cpp index b55028a..77409a8 100644 --- a/source/fuzz/transformation_add_type_boolean.cpp +++ b/source/fuzz/transformation_add_type_boolean.cpp
@@ -28,27 +28,27 @@ } bool TransformationAddTypeBoolean::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // Applicable if there is no bool type already declared in the module. opt::analysis::Bool bool_type; - return context->get_type_mgr()->GetId(&bool_type) == 0; + return ir_context->get_type_mgr()->GetId(&bool_type) == 0; } void TransformationAddTypeBoolean::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList empty_operands; - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeBool, 0, message_.fresh_id(), empty_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeBool, 0, message_.fresh_id(), empty_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeBoolean::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_boolean.h b/source/fuzz/transformation_add_type_boolean.h index 98c1e63..5ce5b9a 100644 --- a/source/fuzz/transformation_add_type_boolean.h +++ b/source/fuzz/transformation_add_type_boolean.h
@@ -15,7 +15,6 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_BOOLEAN_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_BOOLEAN_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" #include "source/opt/ir_context.h" @@ -32,11 +31,13 @@ // - |message_.fresh_id| must not be used by the module. // - The module must not yet declare OpTypeBoolean - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds OpTypeBoolean with |message_.fresh_id| as result id. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_float.cpp b/source/fuzz/transformation_add_type_float.cpp index d2af5f8..80716e1 100644 --- a/source/fuzz/transformation_add_type_float.cpp +++ b/source/fuzz/transformation_add_type_float.cpp
@@ -30,29 +30,29 @@ : message_(message) {} bool TransformationAddTypeFloat::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // Applicable if there is no float type with this width already declared in // the module. opt::analysis::Float float_type(message_.width()); - return context->get_type_mgr()->GetId(&float_type) == 0; + return ir_context->get_type_mgr()->GetId(&float_type) == 0; } void TransformationAddTypeFloat::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList width = { {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.width()}}}; - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeFloat, 0, message_.fresh_id(), width)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeFloat, 0, message_.fresh_id(), width)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeFloat::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_float.h b/source/fuzz/transformation_add_type_float.h index 0fdc831..a8fa0e1 100644 --- a/source/fuzz/transformation_add_type_float.h +++ b/source/fuzz/transformation_add_type_float.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_FLOAT_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_FLOAT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,11 +33,13 @@ // - |message_.fresh_id| must not be used by the module // - The module must not contain an OpTypeFloat instruction with width // |message_.width| - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeFloat instruction to the module with the given width - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_function.cpp b/source/fuzz/transformation_add_type_function.cpp index 4b6717b..991a28b 100644 --- a/source/fuzz/transformation_add_type_function.cpp +++ b/source/fuzz/transformation_add_type_function.cpp
@@ -36,19 +36,18 @@ } bool TransformationAddTypeFunction::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The return and argument types must be type ids but not not be function // type ids. - if (!fuzzerutil::IsNonFunctionTypeId(context, message_.return_type_id())) { + if (!fuzzerutil::IsNonFunctionTypeId(ir_context, message_.return_type_id())) { return false; } for (auto argument_type_id : message_.argument_type_id()) { - if (!fuzzerutil::IsNonFunctionTypeId(context, argument_type_id)) { + if (!fuzzerutil::IsNonFunctionTypeId(ir_context, argument_type_id)) { return false; } } @@ -56,7 +55,7 @@ // exactly the same return and argument type ids. (Note that the type manager // does not allow us to check this, as it does not distinguish between // function types with different but isomorphic pointer argument types.) - for (auto& inst : context->module()->types_values()) { + for (auto& inst : ir_context->module()->types_values()) { if (inst.opcode() != SpvOpTypeFunction) { // Consider only OpTypeFunction instructions. continue; @@ -89,18 +88,19 @@ } void TransformationAddTypeFunction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; in_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.return_type_id()}}); for (auto argument_type_id : message_.argument_type_id()) { in_operands.push_back({SPV_OPERAND_TYPE_ID, {argument_type_id}}); } - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeFunction, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeFunction, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeFunction::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_function.h b/source/fuzz/transformation_add_type_function.h index 3880963..f26b250 100644 --- a/source/fuzz/transformation_add_type_function.h +++ b/source/fuzz/transformation_add_type_function.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -39,13 +39,15 @@ // - The module must not contain an OpTypeFunction instruction defining a // function type with the signature provided by the given return and // argument types - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeFunction instruction to the module, with signature given by // |message_.return_type_id| and |message_.argument_type_id|. The result id // for the instruction is |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_int.cpp b/source/fuzz/transformation_add_type_int.cpp index 6f59270..a932a5f 100644 --- a/source/fuzz/transformation_add_type_int.cpp +++ b/source/fuzz/transformation_add_type_int.cpp
@@ -32,30 +32,30 @@ } bool TransformationAddTypeInt::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // Applicable if there is no int type with this width and signedness already // declared in the module. opt::analysis::Integer int_type(message_.width(), message_.is_signed()); - return context->get_type_mgr()->GetId(&int_type) == 0; + return ir_context->get_type_mgr()->GetId(&int_type) == 0; } -void TransformationAddTypeInt::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { +void TransformationAddTypeInt::Apply(opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands = { {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.width()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.is_signed() ? 1u : 0u}}}; - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeInt, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeInt, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeInt::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_int.h b/source/fuzz/transformation_add_type_int.h index 86342d0..5c3c959 100644 --- a/source/fuzz/transformation_add_type_int.h +++ b/source/fuzz/transformation_add_type_int.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_INT_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_INT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,12 +33,14 @@ // - |message_.fresh_id| must not be used by the module // - The module must not contain an OpTypeInt instruction with width // |message_.width| and signedness |message.is_signed| - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeInt instruction to the module with the given width and // signedness. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_matrix.cpp b/source/fuzz/transformation_add_type_matrix.cpp index 07ab705..2c24eaa 100644 --- a/source/fuzz/transformation_add_type_matrix.cpp +++ b/source/fuzz/transformation_add_type_matrix.cpp
@@ -31,15 +31,14 @@ } bool TransformationAddTypeMatrix::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The column type must be a floating-point vector. auto column_type = - context->get_type_mgr()->GetType(message_.column_type_id()); + ir_context->get_type_mgr()->GetType(message_.column_type_id()); if (!column_type) { return false; } @@ -48,17 +47,18 @@ } void TransformationAddTypeMatrix::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; in_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.column_type_id()}}); in_operands.push_back( {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.column_count()}}); - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeMatrix, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeMatrix, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeMatrix::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_matrix.h b/source/fuzz/transformation_add_type_matrix.h index 69d6389..6d0724e 100644 --- a/source/fuzz/transformation_add_type_matrix.h +++ b/source/fuzz/transformation_add_type_matrix.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_MATRIX_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_MATRIX_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,13 +33,15 @@ // - |message_.fresh_id| must be a fresh id // - |message_.column_type_id| must be the id of a floating-point vector type - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeMatrix instruction to the module, with column type // |message_.column_type_id| and |message_.column_count| columns, with result // id |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_pointer.cpp b/source/fuzz/transformation_add_type_pointer.cpp index 426985a..6cc8171 100644 --- a/source/fuzz/transformation_add_type_pointer.cpp +++ b/source/fuzz/transformation_add_type_pointer.cpp
@@ -31,28 +31,29 @@ } bool TransformationAddTypePointer::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The base type must be known. - return context->get_type_mgr()->GetType(message_.base_type_id()) != nullptr; + return ir_context->get_type_mgr()->GetType(message_.base_type_id()) != + nullptr; } void TransformationAddTypePointer::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Add the pointer type. opt::Instruction::OperandList in_operands = { {SPV_OPERAND_TYPE_STORAGE_CLASS, {message_.storage_class()}}, {SPV_OPERAND_TYPE_ID, {message_.base_type_id()}}}; - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypePointer, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypePointer, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypePointer::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_pointer.h b/source/fuzz/transformation_add_type_pointer.h index 2b9ff77..3b50a29 100644 --- a/source/fuzz/transformation_add_type_pointer.h +++ b/source/fuzz/transformation_add_type_pointer.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_POINTER_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_POINTER_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -34,12 +34,14 @@ // - |message_.fresh_id| must not be used by the module // - |message_.base_type_id| must be the result id of an OpType[...] // instruction - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypePointer instruction with the given storage class and base // type to the module. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_struct.cpp b/source/fuzz/transformation_add_type_struct.cpp index 1ae8372..6ce5ea1 100644 --- a/source/fuzz/transformation_add_type_struct.cpp +++ b/source/fuzz/transformation_add_type_struct.cpp
@@ -32,14 +32,13 @@ } bool TransformationAddTypeStruct::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // A fresh id is required. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } for (auto member_type : message_.member_type_id()) { - auto type = context->get_type_mgr()->GetType(member_type); + auto type = ir_context->get_type_mgr()->GetType(member_type); if (!type || type->AsFunction()) { // The member type id either does not refer to a type, or refers to a // function type; both are illegal. @@ -50,17 +49,18 @@ } void TransformationAddTypeStruct::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; for (auto member_type : message_.member_type_id()) { in_operands.push_back({SPV_OPERAND_TYPE_ID, {member_type}}); } - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeStruct, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeStruct, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeStruct::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_struct.h b/source/fuzz/transformation_add_type_struct.h index edf3ec6..86a532d 100644 --- a/source/fuzz/transformation_add_type_struct.h +++ b/source/fuzz/transformation_add_type_struct.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -35,12 +35,14 @@ // - |message_.fresh_id| must be a fresh id // - |message_.member_type_id| must be a sequence of non-function type ids - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeStruct instruction whose field types are given by // |message_.member_type_id|, with result id |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_add_type_vector.cpp b/source/fuzz/transformation_add_type_vector.cpp index 3fdf50b..f7b2fb5 100644 --- a/source/fuzz/transformation_add_type_vector.cpp +++ b/source/fuzz/transformation_add_type_vector.cpp
@@ -31,13 +31,12 @@ } bool TransformationAddTypeVector::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } auto component_type = - context->get_type_mgr()->GetType(message_.component_type_id()); + ir_context->get_type_mgr()->GetType(message_.component_type_id()); if (!component_type) { return false; } @@ -46,17 +45,18 @@ } void TransformationAddTypeVector::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction::OperandList in_operands; in_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.component_type_id()}}); in_operands.push_back( {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.component_count()}}); - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeVector, 0, message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeVector, 0, message_.fresh_id(), in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // We have added an instruction to the module, so need to be careful about the // validity of existing analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationAddTypeVector::ToMessage() const {
diff --git a/source/fuzz/transformation_add_type_vector.h b/source/fuzz/transformation_add_type_vector.h index af840f5..240f7cc 100644 --- a/source/fuzz/transformation_add_type_vector.h +++ b/source/fuzz/transformation_add_type_vector.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_VECTOR_H_ #define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_VECTOR_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,13 +33,15 @@ // - |message_.fresh_id| must be a fresh id // - |message_.component_type_id| must be the id of a scalar type - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpTypeVector instruction to the module, with component type // |message_.component_type_id| and |message_.component_count| components, // with result id |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_adjust_branch_weights.cpp b/source/fuzz/transformation_adjust_branch_weights.cpp new file mode 100644 index 0000000..ed68134 --- /dev/null +++ b/source/fuzz/transformation_adjust_branch_weights.cpp
@@ -0,0 +1,97 @@ +// Copyright (c) 2020 André Perez Maselco +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_adjust_branch_weights.h" + +#include "source/fuzz/fuzzer_util.h" +#include "source/fuzz/instruction_descriptor.h" + +namespace spvtools { +namespace fuzz { + +namespace { + +const uint32_t kBranchWeightForTrueLabelIndex = 3; +const uint32_t kBranchWeightForFalseLabelIndex = 4; + +} // namespace + +TransformationAdjustBranchWeights::TransformationAdjustBranchWeights( + const spvtools::fuzz::protobufs::TransformationAdjustBranchWeights& message) + : message_(message) {} + +TransformationAdjustBranchWeights::TransformationAdjustBranchWeights( + const protobufs::InstructionDescriptor& instruction_descriptor, + const std::pair<uint32_t, uint32_t>& branch_weights) { + *message_.mutable_instruction_descriptor() = instruction_descriptor; + message_.mutable_branch_weights()->set_first(branch_weights.first); + message_.mutable_branch_weights()->set_second(branch_weights.second); +} + +bool TransformationAdjustBranchWeights::IsApplicable( + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + auto instruction = + FindInstruction(message_.instruction_descriptor(), ir_context); + if (instruction == nullptr) { + return false; + } + + SpvOp opcode = static_cast<SpvOp>( + message_.instruction_descriptor().target_instruction_opcode()); + + assert(instruction->opcode() == opcode && + "The located instruction must have the same opcode as in the " + "descriptor."); + + // Must be an OpBranchConditional instruction. + if (opcode != SpvOpBranchConditional) { + return false; + } + + assert((message_.branch_weights().first() != 0 || + message_.branch_weights().second() != 0) && + "At least one weight must be non-zero."); + + assert(message_.branch_weights().first() <= + UINT32_MAX - message_.branch_weights().second() && + "The sum of the two weights must not be greater than UINT32_MAX."); + + return true; +} + +void TransformationAdjustBranchWeights::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { + auto instruction = + FindInstruction(message_.instruction_descriptor(), ir_context); + if (instruction->HasBranchWeights()) { + instruction->SetOperand(kBranchWeightForTrueLabelIndex, + {message_.branch_weights().first()}); + instruction->SetOperand(kBranchWeightForFalseLabelIndex, + {message_.branch_weights().second()}); + } else { + instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER, + {message_.branch_weights().first()}}); + instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER, + {message_.branch_weights().second()}}); + } +} + +protobufs::Transformation TransformationAdjustBranchWeights::ToMessage() const { + protobufs::Transformation result; + *result.mutable_adjust_branch_weights() = message_; + return result; +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/transformation_adjust_branch_weights.h b/source/fuzz/transformation_adjust_branch_weights.h new file mode 100644 index 0000000..638b0a9 --- /dev/null +++ b/source/fuzz/transformation_adjust_branch_weights.h
@@ -0,0 +1,57 @@ +// Copyright (c) 2020 André Perez Maselco +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_ +#define SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_ + +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +class TransformationAdjustBranchWeights : public Transformation { + public: + explicit TransformationAdjustBranchWeights( + const protobufs::TransformationAdjustBranchWeights& message); + + TransformationAdjustBranchWeights( + const protobufs::InstructionDescriptor& instruction_descriptor, + const std::pair<uint32_t, uint32_t>& branch_weights); + + // - |message_.instruction_descriptor| must identify an existing + // branch conditional instruction + // - At least one of |branch_weights| must be non-zero and + // the two weights must not overflow a 32-bit unsigned integer when added + // together + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; + + // Adjust the branch weights of a branch conditional instruction. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; + + protobufs::Transformation ToMessage() const override; + + private: + protobufs::TransformationAdjustBranchWeights message_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_
diff --git a/source/fuzz/transformation_composite_construct.cpp b/source/fuzz/transformation_composite_construct.cpp index 9c63c1d..cd4f22f 100644 --- a/source/fuzz/transformation_composite_construct.cpp +++ b/source/fuzz/transformation_composite_construct.cpp
@@ -40,14 +40,14 @@ } bool TransformationCompositeConstruct::IsApplicable( - opt::IRContext* context, const FactManager& /*fact_manager*/) const { - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { // We require the id for the composite constructor to be unused. return false; } auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!insert_before) { // The instruction before which the composite should be inserted was not // found. @@ -55,7 +55,7 @@ } auto composite_type = - context->get_type_mgr()->GetType(message_.composite_type_id()); + ir_context->get_type_mgr()->GetType(message_.composite_type_id()); if (!fuzzerutil::IsCompositeType(composite_type)) { // The type must actually be a composite. @@ -64,27 +64,31 @@ // If the type is an array, matrix, struct or vector, the components need to // be suitable for constructing something of that type. - if (composite_type->AsArray() && !ComponentsForArrayConstructionAreOK( - context, *composite_type->AsArray())) { + if (composite_type->AsArray() && + !ComponentsForArrayConstructionAreOK(ir_context, + *composite_type->AsArray())) { return false; } - if (composite_type->AsMatrix() && !ComponentsForMatrixConstructionAreOK( - context, *composite_type->AsMatrix())) { + if (composite_type->AsMatrix() && + !ComponentsForMatrixConstructionAreOK(ir_context, + *composite_type->AsMatrix())) { return false; } - if (composite_type->AsStruct() && !ComponentsForStructConstructionAreOK( - context, *composite_type->AsStruct())) { + if (composite_type->AsStruct() && + !ComponentsForStructConstructionAreOK(ir_context, + *composite_type->AsStruct())) { return false; } - if (composite_type->AsVector() && !ComponentsForVectorConstructionAreOK( - context, *composite_type->AsVector())) { + if (composite_type->AsVector() && + !ComponentsForVectorConstructionAreOK(ir_context, + *composite_type->AsVector())) { return false; } // Now check whether every component being used to initialize the composite is // available at the desired program point. for (auto& component : message_.component()) { - if (!fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, component)) { return false; } @@ -93,13 +97,14 @@ return true; } -void TransformationCompositeConstruct::Apply(opt::IRContext* context, - FactManager* fact_manager) const { +void TransformationCompositeConstruct::Apply( + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // Use the base and offset information from the transformation to determine // where in the module a new instruction should be inserted. auto insert_before_inst = - FindInstruction(message_.instruction_to_insert_before(), context); - auto destination_block = context->get_instr_block(insert_before_inst); + FindInstruction(message_.instruction_to_insert_before(), ir_context); + auto destination_block = ir_context->get_instr_block(insert_before_inst); auto insert_before = fuzzerutil::GetIteratorForInstruction( destination_block, insert_before_inst); @@ -111,22 +116,22 @@ // Insert an OpCompositeConstruct instruction. insert_before.InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpCompositeConstruct, message_.composite_type_id(), + ir_context, SpvOpCompositeConstruct, message_.composite_type_id(), message_.fresh_id(), in_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); // Inform the fact manager that we now have new synonyms: every component of // the composite is synonymous with the id used to construct that component, // except in the case of a vector where a single vector id can span multiple // components. auto composite_type = - context->get_type_mgr()->GetType(message_.composite_type_id()); + ir_context->get_type_mgr()->GetType(message_.composite_type_id()); uint32_t index = 0; for (auto component : message_.component()) { - auto component_type = context->get_type_mgr()->GetType( - context->get_def_use_mgr()->GetDef(component)->type_id()); + auto component_type = ir_context->get_type_mgr()->GetType( + ir_context->get_def_use_mgr()->GetDef(component)->type_id()); if (composite_type->AsVector() && component_type->AsVector()) { // The case where the composite being constructed is a vector and the // component provided for construction is also a vector is special. It @@ -139,24 +144,24 @@ for (uint32_t subvector_index = 0; subvector_index < component_type->AsVector()->element_count(); subvector_index++) { - fact_manager->AddFactDataSynonym( + transformation_context->GetFactManager()->AddFactDataSynonym( MakeDataDescriptor(component, {subvector_index}), - MakeDataDescriptor(message_.fresh_id(), {index}), context); + MakeDataDescriptor(message_.fresh_id(), {index}), ir_context); index++; } } else { // The other cases are simple: the component is made directly synonymous // with the element of the composite being constructed. - fact_manager->AddFactDataSynonym( + transformation_context->GetFactManager()->AddFactDataSynonym( MakeDataDescriptor(component, {}), - MakeDataDescriptor(message_.fresh_id(), {index}), context); + MakeDataDescriptor(message_.fresh_id(), {index}), ir_context); index++; } } } bool TransformationCompositeConstruct::ComponentsForArrayConstructionAreOK( - opt::IRContext* context, const opt::analysis::Array& array_type) const { + opt::IRContext* ir_context, const opt::analysis::Array& array_type) const { if (array_type.length_info().words[0] != opt::analysis::Array::LengthInfo::kConstant) { // We only handle constant-sized arrays. @@ -176,13 +181,13 @@ // Check that each component is the result id of an instruction whose type is // the array's element type. for (auto component_id : message_.component()) { - auto inst = context->get_def_use_mgr()->GetDef(component_id); + auto inst = ir_context->get_def_use_mgr()->GetDef(component_id); if (inst == nullptr || !inst->type_id()) { // The component does not correspond to an instruction with a result // type. return false; } - auto component_type = context->get_type_mgr()->GetType(inst->type_id()); + auto component_type = ir_context->get_type_mgr()->GetType(inst->type_id()); assert(component_type); if (component_type != array_type.element_type()) { // The component's type does not match the array's element type. @@ -193,7 +198,8 @@ } bool TransformationCompositeConstruct::ComponentsForMatrixConstructionAreOK( - opt::IRContext* context, const opt::analysis::Matrix& matrix_type) const { + opt::IRContext* ir_context, + const opt::analysis::Matrix& matrix_type) const { if (static_cast<uint32_t>(message_.component().size()) != matrix_type.element_count()) { // The number of components must match the number of columns of the matrix. @@ -202,13 +208,13 @@ // Check that each component is the result id of an instruction whose type is // the matrix's column type. for (auto component_id : message_.component()) { - auto inst = context->get_def_use_mgr()->GetDef(component_id); + auto inst = ir_context->get_def_use_mgr()->GetDef(component_id); if (inst == nullptr || !inst->type_id()) { // The component does not correspond to an instruction with a result // type. return false; } - auto component_type = context->get_type_mgr()->GetType(inst->type_id()); + auto component_type = ir_context->get_type_mgr()->GetType(inst->type_id()); assert(component_type); if (component_type != matrix_type.element_type()) { // The component's type does not match the matrix's column type. @@ -219,7 +225,8 @@ } bool TransformationCompositeConstruct::ComponentsForStructConstructionAreOK( - opt::IRContext* context, const opt::analysis::Struct& struct_type) const { + opt::IRContext* ir_context, + const opt::analysis::Struct& struct_type) const { if (static_cast<uint32_t>(message_.component().size()) != struct_type.element_types().size()) { // The number of components must match the number of fields of the struct. @@ -229,14 +236,14 @@ // matches the associated field type. for (uint32_t field_index = 0; field_index < struct_type.element_types().size(); field_index++) { - auto inst = - context->get_def_use_mgr()->GetDef(message_.component()[field_index]); + auto inst = ir_context->get_def_use_mgr()->GetDef( + message_.component()[field_index]); if (inst == nullptr || !inst->type_id()) { // The component does not correspond to an instruction with a result // type. return false; } - auto component_type = context->get_type_mgr()->GetType(inst->type_id()); + auto component_type = ir_context->get_type_mgr()->GetType(inst->type_id()); assert(component_type); if (component_type != struct_type.element_types()[field_index]) { // The component's type does not match the corresponding field type. @@ -247,17 +254,18 @@ } bool TransformationCompositeConstruct::ComponentsForVectorConstructionAreOK( - opt::IRContext* context, const opt::analysis::Vector& vector_type) const { + opt::IRContext* ir_context, + const opt::analysis::Vector& vector_type) const { uint32_t base_element_count = 0; auto element_type = vector_type.element_type(); for (auto& component_id : message_.component()) { - auto inst = context->get_def_use_mgr()->GetDef(component_id); + auto inst = ir_context->get_def_use_mgr()->GetDef(component_id); if (inst == nullptr || !inst->type_id()) { // The component does not correspond to an instruction with a result // type. return false; } - auto component_type = context->get_type_mgr()->GetType(inst->type_id()); + auto component_type = ir_context->get_type_mgr()->GetType(inst->type_id()); assert(component_type); if (component_type == element_type) { base_element_count++;
diff --git a/source/fuzz/transformation_composite_construct.h b/source/fuzz/transformation_composite_construct.h index 5369c4c..2e55e70 100644 --- a/source/fuzz/transformation_composite_construct.h +++ b/source/fuzz/transformation_composite_construct.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_COMPOSITE_CONSTRUCT_H_ #define SOURCE_FUZZ_TRANSFORMATION_COMPOSITE_CONSTRUCT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -49,15 +49,17 @@ // before 'inst'. // - Each element of |message_.component| must be available directly before // 'inst'. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Inserts a new OpCompositeConstruct instruction, with id // |message_.fresh_id|, directly before the instruction identified by // |message_.base_instruction_id| and |message_.offset|. The instruction // creates a composite of type |message_.composite_type_id| using the ids of // |message_.component|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -65,19 +67,22 @@ // Helper to decide whether the components of the transformation are suitable // for constructing an array of the given type. bool ComponentsForArrayConstructionAreOK( - opt::IRContext* context, const opt::analysis::Array& array_type) const; + opt::IRContext* ir_context, const opt::analysis::Array& array_type) const; // Similar, but for matrices. bool ComponentsForMatrixConstructionAreOK( - opt::IRContext* context, const opt::analysis::Matrix& matrix_type) const; + opt::IRContext* ir_context, + const opt::analysis::Matrix& matrix_type) const; // Similar, but for structs. bool ComponentsForStructConstructionAreOK( - opt::IRContext* context, const opt::analysis::Struct& struct_type) const; + opt::IRContext* ir_context, + const opt::analysis::Struct& struct_type) const; // Similar, but for vectors. bool ComponentsForVectorConstructionAreOK( - opt::IRContext* context, const opt::analysis::Vector& vector_type) const; + opt::IRContext* ir_context, + const opt::analysis::Vector& vector_type) const; protobufs::TransformationCompositeConstruct message_; };
diff --git a/source/fuzz/transformation_composite_extract.cpp b/source/fuzz/transformation_composite_extract.cpp index 5d3a386..3dc3953 100644 --- a/source/fuzz/transformation_composite_extract.cpp +++ b/source/fuzz/transformation_composite_extract.cpp
@@ -40,24 +40,23 @@ } bool TransformationCompositeExtract::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } auto instruction_to_insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!instruction_to_insert_before) { return false; } auto composite_instruction = - context->get_def_use_mgr()->GetDef(message_.composite_id()); + ir_context->get_def_use_mgr()->GetDef(message_.composite_id()); if (!composite_instruction) { return false; } - if (auto block = context->get_instr_block(composite_instruction)) { + if (auto block = ir_context->get_instr_block(composite_instruction)) { if (composite_instruction == instruction_to_insert_before || - !context->GetDominatorAnalysis(block->GetParent()) + !ir_context->GetDominatorAnalysis(block->GetParent()) ->Dominates(composite_instruction, instruction_to_insert_before)) { return false; } @@ -66,7 +65,7 @@ "An instruction in a block cannot have a result id but no type id."); auto composite_type = - context->get_type_mgr()->GetType(composite_instruction->type_id()); + ir_context->get_type_mgr()->GetType(composite_instruction->type_id()); if (!composite_type) { return false; } @@ -76,30 +75,33 @@ return false; } - return fuzzerutil::WalkCompositeTypeIndices( - context, composite_instruction->type_id(), message_.index()) != 0; + return fuzzerutil::WalkCompositeTypeIndices(ir_context, + composite_instruction->type_id(), + message_.index()) != 0; } void TransformationCompositeExtract::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { opt::Instruction::OperandList extract_operands; extract_operands.push_back({SPV_OPERAND_TYPE_ID, {message_.composite_id()}}); for (auto an_index : message_.index()) { extract_operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {an_index}}); } auto composite_instruction = - context->get_def_use_mgr()->GetDef(message_.composite_id()); + ir_context->get_def_use_mgr()->GetDef(message_.composite_id()); auto extracted_type = fuzzerutil::WalkCompositeTypeIndices( - context, composite_instruction->type_id(), message_.index()); + ir_context, composite_instruction->type_id(), message_.index()); - FindInstruction(message_.instruction_to_insert_before(), context) + FindInstruction(message_.instruction_to_insert_before(), ir_context) ->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpCompositeExtract, extracted_type, message_.fresh_id(), - extract_operands)); + ir_context, SpvOpCompositeExtract, extracted_type, + message_.fresh_id(), extract_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); // Add the fact that the id storing the extracted element is synonymous with // the index into the structure. @@ -111,8 +113,9 @@ MakeDataDescriptor(message_.composite_id(), std::move(indices)); protobufs::DataDescriptor data_descriptor_for_result_id = MakeDataDescriptor(message_.fresh_id(), {}); - fact_manager->AddFactDataSynonym(data_descriptor_for_extracted_element, - data_descriptor_for_result_id, context); + transformation_context->GetFactManager()->AddFactDataSynonym( + data_descriptor_for_extracted_element, data_descriptor_for_result_id, + ir_context); } protobufs::Transformation TransformationCompositeExtract::ToMessage() const {
diff --git a/source/fuzz/transformation_composite_extract.h b/source/fuzz/transformation_composite_extract.h index c4c9278..8f52d22 100644 --- a/source/fuzz/transformation_composite_extract.h +++ b/source/fuzz/transformation_composite_extract.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_COMPOSITE_EXTRACT_H_ #define SOURCE_FUZZ_TRANSFORMATION_COMPOSITE_EXTRACT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -41,15 +41,17 @@ // - |message_.index| must be a suitable set of indices for // |message_.composite_id|, i.e. it must be possible to follow this chain // of indices to reach a sub-object of |message_.composite_id| - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an OpCompositeConstruct instruction before the instruction identified // by |message_.instruction_to_insert_before|, that extracts from // |message_.composite_id| via indices |message_.index| into // |message_.fresh_id|. Generates a data synonym fact relating // |message_.fresh_id| to the extracted element. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_compute_data_synonym_fact_closure.cpp b/source/fuzz/transformation_compute_data_synonym_fact_closure.cpp new file mode 100644 index 0000000..ff3ba3c --- /dev/null +++ b/source/fuzz/transformation_compute_data_synonym_fact_closure.cpp
@@ -0,0 +1,52 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_compute_data_synonym_fact_closure.h" + +namespace spvtools { +namespace fuzz { + +TransformationComputeDataSynonymFactClosure:: + TransformationComputeDataSynonymFactClosure( + const spvtools::fuzz::protobufs:: + TransformationComputeDataSynonymFactClosure& message) + : message_(message) {} + +TransformationComputeDataSynonymFactClosure:: + TransformationComputeDataSynonymFactClosure( + uint32_t maximum_equivalence_class_size) { + message_.set_maximum_equivalence_class_size(maximum_equivalence_class_size); +} + +bool TransformationComputeDataSynonymFactClosure::IsApplicable( + opt::IRContext* /*unused*/, const TransformationContext& /*unused*/) const { + return true; +} + +void TransformationComputeDataSynonymFactClosure::Apply( + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { + transformation_context->GetFactManager()->ComputeClosureOfFacts( + ir_context, message_.maximum_equivalence_class_size()); +} + +protobufs::Transformation +TransformationComputeDataSynonymFactClosure::ToMessage() const { + protobufs::Transformation result; + *result.mutable_compute_data_synonym_fact_closure() = message_; + return result; +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/transformation_compute_data_synonym_fact_closure.h b/source/fuzz/transformation_compute_data_synonym_fact_closure.h new file mode 100644 index 0000000..eab43ff --- /dev/null +++ b/source/fuzz/transformation_compute_data_synonym_fact_closure.h
@@ -0,0 +1,53 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_COMPUTE_DATA_SYNONYM_FACT_CLOSURE_H_ +#define SOURCE_FUZZ_TRANSFORMATION_COMPUTE_DATA_SYNONYM_FACT_CLOSURE_H_ + +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +class TransformationComputeDataSynonymFactClosure : public Transformation { + public: + explicit TransformationComputeDataSynonymFactClosure( + const protobufs::TransformationComputeDataSynonymFactClosure& message); + + explicit TransformationComputeDataSynonymFactClosure( + uint32_t maximum_equivalence_class_size); + + // This transformation is trivially applicable. + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; + + // Forces the fact manager to compute a closure of data synonym facts, so that + // facts implied by existing facts are deduced. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; + + protobufs::Transformation ToMessage() const override; + + private: + protobufs::TransformationComputeDataSynonymFactClosure message_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_COMPUTE_DATA_SYNONYM_FACT_CLOSURE_H_
diff --git a/source/fuzz/transformation_context.cpp b/source/fuzz/transformation_context.cpp new file mode 100644 index 0000000..9c8a90f --- /dev/null +++ b/source/fuzz/transformation_context.cpp
@@ -0,0 +1,29 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_context.h" + +namespace spvtools { +namespace fuzz { + +TransformationContext::TransformationContext( + FactManager* transformation_context, + spv_validator_options validator_options) + : fact_manager_(transformation_context), + validator_options_(validator_options) {} + +TransformationContext::~TransformationContext() = default; + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/transformation_context.h b/source/fuzz/transformation_context.h new file mode 100644 index 0000000..37e15a2 --- /dev/null +++ b/source/fuzz/transformation_context.h
@@ -0,0 +1,56 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_CONTEXT_H_ +#define SOURCE_FUZZ_TRANSFORMATION_CONTEXT_H_ + +#include "source/fuzz/fact_manager.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace fuzz { + +// Encapsulates all information that is required to inform how to apply a +// transformation to a module. +class TransformationContext { + public: + // Constructs a transformation context with a given fact manager and validator + // options. + TransformationContext(FactManager* fact_manager, + spv_validator_options validator_options); + + ~TransformationContext(); + + FactManager* GetFactManager() { return fact_manager_; } + + const FactManager* GetFactManager() const { return fact_manager_; } + + spv_validator_options GetValidatorOptions() const { + return validator_options_; + } + + private: + // Manages facts that inform whether transformations can be applied, and that + // are produced by applying transformations. + FactManager* fact_manager_; + + // Options to control validation when deciding whether transformations can be + // applied. + spv_validator_options validator_options_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_CONTEXT_H_
diff --git a/source/fuzz/transformation_copy_object.cpp b/source/fuzz/transformation_copy_object.cpp index bfdced3..7b5b5c9 100644 --- a/source/fuzz/transformation_copy_object.cpp +++ b/source/fuzz/transformation_copy_object.cpp
@@ -38,22 +38,22 @@ } bool TransformationCopyObject::IsApplicable( - opt::IRContext* context, const FactManager& /*fact_manager*/) const { - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { // We require the id for the object copy to be unused. return false; } // The id of the object to be copied must exist - auto object_inst = context->get_def_use_mgr()->GetDef(message_.object()); + auto object_inst = ir_context->get_def_use_mgr()->GetDef(message_.object()); if (!object_inst) { return false; } - if (!fuzzerutil::CanMakeSynonymOf(context, object_inst)) { + if (!fuzzerutil::CanMakeSynonymOf(ir_context, object_inst)) { return false; } auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!insert_before) { // The instruction before which the copy should be inserted was not found. return false; @@ -66,17 +66,18 @@ // |message_object| must be available directly before the point where we want // to add the copy. - return fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + return fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, message_.object()); } -void TransformationCopyObject::Apply(opt::IRContext* context, - FactManager* fact_manager) const { - auto object_inst = context->get_def_use_mgr()->GetDef(message_.object()); +void TransformationCopyObject::Apply( + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { + auto object_inst = ir_context->get_def_use_mgr()->GetDef(message_.object()); assert(object_inst && "The object to be copied must exist."); auto insert_before_inst = - FindInstruction(message_.instruction_to_insert_before(), context); - auto destination_block = context->get_instr_block(insert_before_inst); + FindInstruction(message_.instruction_to_insert_before(), ir_context); + auto destination_block = ir_context->get_instr_block(insert_before_inst); assert(destination_block && "The base instruction must be in a block."); auto insert_before = fuzzerutil::GetIteratorForInstruction( destination_block, insert_before_inst); @@ -86,18 +87,21 @@ opt::Instruction::OperandList operands = { {SPV_OPERAND_TYPE_ID, {message_.object()}}}; insert_before->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOp::SpvOpCopyObject, object_inst->type_id(), + ir_context, SpvOp::SpvOpCopyObject, object_inst->type_id(), message_.fresh_id(), operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); - fact_manager->AddFactDataSynonym(MakeDataDescriptor(message_.object(), {}), - MakeDataDescriptor(message_.fresh_id(), {}), - context); + transformation_context->GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(message_.object(), {}), + MakeDataDescriptor(message_.fresh_id(), {}), ir_context); - if (fact_manager->PointeeValueIsIrrelevant(message_.object())) { - fact_manager->AddFactValueOfPointeeIsIrrelevant(message_.fresh_id()); + if (transformation_context->GetFactManager()->PointeeValueIsIrrelevant( + message_.object())) { + transformation_context->GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + message_.fresh_id()); } }
diff --git a/source/fuzz/transformation_copy_object.h b/source/fuzz/transformation_copy_object.h index 9e9c26a..80d57ae 100644 --- a/source/fuzz/transformation_copy_object.h +++ b/source/fuzz/transformation_copy_object.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_COPY_OBJECT_H_ #define SOURCE_FUZZ_TRANSFORMATION_COPY_OBJECT_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -49,19 +49,21 @@ // - |message_.object| must be available directly before 'inst'. // - |message_.object| must not be a null pointer or undefined pointer (so as // to make it legal to load from copied pointers). - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - A new instruction, // %|message_.fresh_id| = OpCopyObject %ty %|message_.object| // is added directly before the instruction at |message_.insert_after_id| + // |message_|.offset, where %ty is the type of |message_.object|. // - The fact that |message_.fresh_id| and |message_.object| are synonyms - // is added to |fact_manager|. + // is added to the fact manager in |transformation_context|. // - If |message_.object| is a pointer whose pointee value is known to be - // irrelevant, the analogous fact is added to |fact_manager| about - // |message_.fresh_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + // irrelevant, the analogous fact is added to the fact manager in + // |transformation_context| about |message_.fresh_id|. + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_equation_instruction.cpp b/source/fuzz/transformation_equation_instruction.cpp index 21b67f6..5c31417 100644 --- a/source/fuzz/transformation_equation_instruction.cpp +++ b/source/fuzz/transformation_equation_instruction.cpp
@@ -37,40 +37,40 @@ } bool TransformationEquationInstruction::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The instruction to insert before must exist. auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!insert_before) { return false; } // The input ids must all exist, not be OpUndef, and be available before this // instruction. for (auto id : message_.in_operand_id()) { - auto inst = context->get_def_use_mgr()->GetDef(id); + auto inst = ir_context->get_def_use_mgr()->GetDef(id); if (!inst) { return false; } if (inst->opcode() == SpvOpUndef) { return false; } - if (!fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, id)) { return false; } } - return MaybeGetResultType(context) != 0; + return MaybeGetResultType(ir_context) != 0; } void TransformationEquationInstruction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); opt::Instruction::OperandList in_operands; std::vector<uint32_t> rhs_id; @@ -79,16 +79,16 @@ rhs_id.push_back(id); } - FindInstruction(message_.instruction_to_insert_before(), context) + FindInstruction(message_.instruction_to_insert_before(), ir_context) ->InsertBefore(MakeUnique<opt::Instruction>( - context, static_cast<SpvOp>(message_.opcode()), - MaybeGetResultType(context), message_.fresh_id(), in_operands)); + ir_context, static_cast<SpvOp>(message_.opcode()), + MaybeGetResultType(ir_context), message_.fresh_id(), in_operands)); - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); - fact_manager->AddFactIdEquation(message_.fresh_id(), - static_cast<SpvOp>(message_.opcode()), rhs_id, - context); + transformation_context->GetFactManager()->AddFactIdEquation( + message_.fresh_id(), static_cast<SpvOp>(message_.opcode()), rhs_id, + ir_context); } protobufs::Transformation TransformationEquationInstruction::ToMessage() const { @@ -98,7 +98,7 @@ } uint32_t TransformationEquationInstruction::MaybeGetResultType( - opt::IRContext* context) const { + opt::IRContext* ir_context) const { switch (static_cast<SpvOp>(message_.opcode())) { case SpvOpIAdd: case SpvOpISub: { @@ -108,13 +108,13 @@ uint32_t first_operand_width = 0; uint32_t first_operand_type_id = 0; for (uint32_t index = 0; index < 2; index++) { - auto operand_inst = - context->get_def_use_mgr()->GetDef(message_.in_operand_id(index)); + auto operand_inst = ir_context->get_def_use_mgr()->GetDef( + message_.in_operand_id(index)); if (!operand_inst || !operand_inst->type_id()) { return 0; } auto operand_type = - context->get_type_mgr()->GetType(operand_inst->type_id()); + ir_context->get_type_mgr()->GetType(operand_inst->type_id()); if (!(operand_type->AsInteger() || (operand_type->AsVector() && operand_type->AsVector()->element_type()->AsInteger()))) { @@ -144,12 +144,12 @@ return 0; } auto operand_inst = - context->get_def_use_mgr()->GetDef(message_.in_operand_id(0)); + ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0)); if (!operand_inst || !operand_inst->type_id()) { return 0; } auto operand_type = - context->get_type_mgr()->GetType(operand_inst->type_id()); + ir_context->get_type_mgr()->GetType(operand_inst->type_id()); if (!(operand_type->AsBool() || (operand_type->AsVector() && operand_type->AsVector()->element_type()->AsBool()))) { @@ -162,12 +162,12 @@ return 0; } auto operand_inst = - context->get_def_use_mgr()->GetDef(message_.in_operand_id(0)); + ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0)); if (!operand_inst || !operand_inst->type_id()) { return 0; } auto operand_type = - context->get_type_mgr()->GetType(operand_inst->type_id()); + ir_context->get_type_mgr()->GetType(operand_inst->type_id()); if (!(operand_type->AsInteger() || (operand_type->AsVector() && operand_type->AsVector()->element_type()->AsInteger()))) {
diff --git a/source/fuzz/transformation_equation_instruction.h b/source/fuzz/transformation_equation_instruction.h index 2456ba5..7eec9c6 100644 --- a/source/fuzz/transformation_equation_instruction.h +++ b/source/fuzz/transformation_equation_instruction.h
@@ -17,9 +17,9 @@ #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -44,8 +44,9 @@ // equations, the types of the ids in |message_.in_operand_id| must be // suitable for use with this opcode, and the module must contain an // appropriate result type id. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction to the module, right before // |message_.instruction_to_insert_before|, of the form: @@ -56,7 +57,8 @@ // compatible with the opcode and input operands. // // The fact manager is also updated to inform it of this equation fact. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -65,7 +67,7 @@ // in |message_.in_operand_id| are compatible, and that the module contains // an appropriate result type id. If all is well, the result type id is // returned. Otherwise, 0 is returned. - uint32_t MaybeGetResultType(opt::IRContext* context) const; + uint32_t MaybeGetResultType(opt::IRContext* ir_context) const; protobufs::TransformationEquationInstruction message_; };
diff --git a/source/fuzz/transformation_function_call.cpp b/source/fuzz/transformation_function_call.cpp index cea8537..432634d 100644 --- a/source/fuzz/transformation_function_call.cpp +++ b/source/fuzz/transformation_function_call.cpp
@@ -39,25 +39,26 @@ } bool TransformationFunctionCall::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // The result id must be fresh - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The function must exist - auto callee_inst = context->get_def_use_mgr()->GetDef(message_.callee_id()); + auto callee_inst = + ir_context->get_def_use_mgr()->GetDef(message_.callee_id()); if (!callee_inst || callee_inst->opcode() != SpvOpFunction) { return false; } // The function must not be an entry point - if (fuzzerutil::FunctionIsEntryPoint(context, message_.callee_id())) { + if (fuzzerutil::FunctionIsEntryPoint(ir_context, message_.callee_id())) { return false; } - auto callee_type_inst = context->get_def_use_mgr()->GetDef( + auto callee_type_inst = ir_context->get_def_use_mgr()->GetDef( callee_inst->GetSingleWordInOperand(1)); assert(callee_type_inst->opcode() == SpvOpTypeFunction && "Bad function type."); @@ -73,7 +74,7 @@ // The instruction descriptor must refer to a position where it is valid to // insert the call auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!insert_before) { return false; } @@ -82,13 +83,15 @@ return false; } - auto block = context->get_instr_block(insert_before); + auto block = ir_context->get_instr_block(insert_before); auto enclosing_function = block->GetParent(); // If the block is not dead, the function must be livesafe - bool block_is_dead = fact_manager.BlockIsDead(block->id()); + bool block_is_dead = + transformation_context.GetFactManager()->BlockIsDead(block->id()); if (!block_is_dead && - !fact_manager.FunctionIsLivesafe(message_.callee_id())) { + !transformation_context.GetFactManager()->FunctionIsLivesafe( + message_.callee_id())) { return false; } @@ -98,7 +101,7 @@ arg_index < static_cast<uint32_t>(message_.argument_id().size()); arg_index++) { opt::Instruction* arg_inst = - context->get_def_use_mgr()->GetDef(message_.argument_id(arg_index)); + ir_context->get_def_use_mgr()->GetDef(message_.argument_id(arg_index)); if (!arg_inst) { // The given argument does not correspond to an instruction. return false; @@ -112,7 +115,7 @@ return false; } opt::Instruction* arg_type_inst = - context->get_def_use_mgr()->GetDef(arg_inst->type_id()); + ir_context->get_def_use_mgr()->GetDef(arg_inst->type_id()); if (arg_type_inst->opcode() == SpvOpTypePointer) { switch (arg_inst->opcode()) { case SpvOpFunctionParameter: @@ -124,7 +127,8 @@ return false; } if (!block_is_dead && - !fact_manager.PointeeValueIsIrrelevant(arg_inst->result_id())) { + !transformation_context.GetFactManager()->PointeeValueIsIrrelevant( + arg_inst->result_id())) { // This is not a dead block, so pointer parameters passed to the called // function might really have their contents modified. We thus require // such pointers to be to arbitrary-valued variables, which this is not. @@ -134,7 +138,7 @@ // The argument id needs to be available (according to dominance rules) at // the point where the call will occur. - if (!fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, arg_inst->result_id())) { return false; } @@ -146,19 +150,19 @@ return false; } // Ensure the call would not lead to indirect recursion. - return !CallGraph(context) + return !CallGraph(ir_context) .GetIndirectCallees(message_.callee_id()) .count(block->GetParent()->result_id()); } void TransformationFunctionCall::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Update the module's bound to reflect the fresh id for the result of the // function call. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // Get the return type of the function being called. uint32_t return_type = - context->get_def_use_mgr()->GetDef(message_.callee_id())->type_id(); + ir_context->get_def_use_mgr()->GetDef(message_.callee_id())->type_id(); // Populate the operands to the call instruction, with the function id and the // arguments. opt::Instruction::OperandList operands; @@ -167,12 +171,12 @@ operands.push_back({SPV_OPERAND_TYPE_ID, {arg}}); } // Insert the function call before the instruction specified in the message. - FindInstruction(message_.instruction_to_insert_before(), context) - ->InsertBefore( - MakeUnique<opt::Instruction>(context, SpvOpFunctionCall, return_type, - message_.fresh_id(), operands)); + FindInstruction(message_.instruction_to_insert_before(), ir_context) + ->InsertBefore(MakeUnique<opt::Instruction>( + ir_context, SpvOpFunctionCall, return_type, message_.fresh_id(), + operands)); // Invalidate all analyses since we have changed the module. - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationFunctionCall::ToMessage() const {
diff --git a/source/fuzz/transformation_function_call.h b/source/fuzz/transformation_function_call.h index a9ae5be..4ad7db1 100644 --- a/source/fuzz/transformation_function_call.h +++ b/source/fuzz/transformation_function_call.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_FUNCTION_CALL_H_ #define SOURCE_FUZZ_TRANSFORMATION_FUNCTION_CALL_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -44,14 +44,16 @@ // - If the insertion point is not in a dead block then |message_function_id| // must refer to a livesafe function, and every pointer argument in // |message_.arg_id| must refer to an arbitrary-valued variable - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction of the form: // |fresh_id| = OpFunctionCall %type |callee_id| |arg_id...| // before |instruction_to_insert_before|, where %type is the return type of // |callee_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_load.cpp b/source/fuzz/transformation_load.cpp index 4cba37d..a260c33 100644 --- a/source/fuzz/transformation_load.cpp +++ b/source/fuzz/transformation_load.cpp
@@ -34,20 +34,19 @@ } bool TransformationLoad::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The result id must be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The pointer must exist and have a type. - auto pointer = context->get_def_use_mgr()->GetDef(message_.pointer_id()); + auto pointer = ir_context->get_def_use_mgr()->GetDef(message_.pointer_id()); if (!pointer || !pointer->type_id()) { return false; } // The type must indeed be a pointer type. - auto pointer_type = context->get_def_use_mgr()->GetDef(pointer->type_id()); + auto pointer_type = ir_context->get_def_use_mgr()->GetDef(pointer->type_id()); assert(pointer_type && "Type id must be defined."); if (pointer_type->opcode() != SpvOpTypePointer) { return false; @@ -65,7 +64,7 @@ // Determine which instruction we should be inserting before. auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); // It must exist, ... if (!insert_before) { return false; @@ -76,21 +75,21 @@ } // The pointer needs to be available at the insertion point. - return fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + return fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, message_.pointer_id()); } -void TransformationLoad::Apply(opt::IRContext* context, - spvtools::fuzz::FactManager* /*unused*/) const { +void TransformationLoad::Apply(opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { uint32_t result_type = fuzzerutil::GetPointeeTypeIdFromPointerType( - context, fuzzerutil::GetTypeId(context, message_.pointer_id())); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - FindInstruction(message_.instruction_to_insert_before(), context) + ir_context, fuzzerutil::GetTypeId(ir_context, message_.pointer_id())); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + FindInstruction(message_.instruction_to_insert_before(), ir_context) ->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpLoad, result_type, message_.fresh_id(), + ir_context, SpvOpLoad, result_type, message_.fresh_id(), opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.pointer_id()}}}))); - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationLoad::ToMessage() const {
diff --git a/source/fuzz/transformation_load.h b/source/fuzz/transformation_load.h index ff99016..4c7c00b 100644 --- a/source/fuzz/transformation_load.h +++ b/source/fuzz/transformation_load.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_LOAD_H_ #define SOURCE_FUZZ_TRANSFORMATION_LOAD_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -37,15 +37,17 @@ // - |message_.instruction_to_insert_before| must identify an instruction // before which it is valid to insert an OpLoad, and where // |message_.pointer_id| is available (according to dominance rules) - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction of the form: // |message_.fresh_id| = OpLoad %type |message_.pointer_id| // before the instruction identified by // |message_.instruction_to_insert_before|, where %type is the pointer's // pointee type. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_merge_blocks.cpp b/source/fuzz/transformation_merge_blocks.cpp index 316e80d..68ac092 100644 --- a/source/fuzz/transformation_merge_blocks.cpp +++ b/source/fuzz/transformation_merge_blocks.cpp
@@ -29,40 +29,41 @@ } bool TransformationMergeBlocks::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { - auto second_block = fuzzerutil::MaybeFindBlock(context, message_.block_id()); + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + auto second_block = + fuzzerutil::MaybeFindBlock(ir_context, message_.block_id()); // The given block must exist. if (!second_block) { return false; } // The block must have just one predecessor. - auto predecessors = context->cfg()->preds(second_block->id()); + auto predecessors = ir_context->cfg()->preds(second_block->id()); if (predecessors.size() != 1) { return false; } - auto first_block = context->cfg()->block(predecessors.at(0)); + auto first_block = ir_context->cfg()->block(predecessors.at(0)); - return opt::blockmergeutil::CanMergeWithSuccessor(context, first_block); + return opt::blockmergeutil::CanMergeWithSuccessor(ir_context, first_block); } -void TransformationMergeBlocks::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { - auto second_block = fuzzerutil::MaybeFindBlock(context, message_.block_id()); - auto first_block = - context->cfg()->block(context->cfg()->preds(second_block->id()).at(0)); +void TransformationMergeBlocks::Apply(opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { + auto second_block = + fuzzerutil::MaybeFindBlock(ir_context, message_.block_id()); + auto first_block = ir_context->cfg()->block( + ir_context->cfg()->preds(second_block->id()).at(0)); auto function = first_block->GetParent(); // We need an iterator pointing to the predecessor, hence the loop. for (auto bi = function->begin(); bi != function->end(); ++bi) { if (bi->id() == first_block->id()) { - assert(opt::blockmergeutil::CanMergeWithSuccessor(context, &*bi) && + assert(opt::blockmergeutil::CanMergeWithSuccessor(ir_context, &*bi) && "Because 'Apply' should only be invoked if 'IsApplicable' holds, " "it must be possible to merge |bi| with its successor."); - opt::blockmergeutil::MergeWithSuccessor(context, function, bi); + opt::blockmergeutil::MergeWithSuccessor(ir_context, function, bi); // Invalidate all analyses, since we have changed the module // significantly. - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); return; } }
diff --git a/source/fuzz/transformation_merge_blocks.h b/source/fuzz/transformation_merge_blocks.h index 86216db..1dc16d2 100644 --- a/source/fuzz/transformation_merge_blocks.h +++ b/source/fuzz/transformation_merge_blocks.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_MERGE_BLOCKS_H_ #define SOURCE_FUZZ_TRANSFORMATION_MERGE_BLOCKS_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -35,12 +35,14 @@ // - b must be the sole successor of a // - Replacing a with the merge of a and b (and removing b) must lead to a // valid module - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // The contents of b are merged into a, and a's terminator is replaced with // the terminator of b. Block b is removed from the module. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_move_block_down.cpp b/source/fuzz/transformation_move_block_down.cpp index f181855..6c71ab7 100644 --- a/source/fuzz/transformation_move_block_down.cpp +++ b/source/fuzz/transformation_move_block_down.cpp
@@ -28,10 +28,10 @@ } bool TransformationMoveBlockDown::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // Go through every block in every function, looking for a block whose id // matches that of the block we want to consider moving down. - for (auto& function : *context->module()) { + for (auto& function : *ir_context->module()) { for (auto block_it = function.begin(); block_it != function.end(); ++block_it) { if (block_it->id() == message_.block_id()) { @@ -43,7 +43,7 @@ } // Record the block we would like to consider moving down. opt::BasicBlock* block_matching_id = &*block_it; - if (!context->GetDominatorAnalysis(&function)->IsReachable( + if (!ir_context->GetDominatorAnalysis(&function)->IsReachable( block_matching_id)) { // The block is not reachable. We are not allowed to move it down. return false; @@ -60,7 +60,7 @@ opt::BasicBlock* next_block_in_program_order = &*block_it; // We can move the block of interest down if and only if it does not // dominate the block that comes next. - return !context->GetDominatorAnalysis(&function)->Dominates( + return !ir_context->GetDominatorAnalysis(&function)->Dominates( block_matching_id, next_block_in_program_order); } } @@ -71,11 +71,11 @@ return false; } -void TransformationMoveBlockDown::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { +void TransformationMoveBlockDown::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Go through every block in every function, looking for a block whose id // matches that of the block we want to move down. - for (auto& function : *context->module()) { + for (auto& function : *ir_context->module()) { for (auto block_it = function.begin(); block_it != function.end(); ++block_it) { if (block_it->id() == message_.block_id()) { @@ -87,7 +87,7 @@ // For performance, it is vital to keep the dominator analysis valid // (which due to https://github.com/KhronosGroup/SPIRV-Tools/issues/2889 // requires keeping the CFG analysis valid). - context->InvalidateAnalysesExceptFor( + ir_context->InvalidateAnalysesExceptFor( opt::IRContext::Analysis::kAnalysisDefUse | opt::IRContext::Analysis::kAnalysisCFG | opt::IRContext::Analysis::kAnalysisDominatorAnalysis);
diff --git a/source/fuzz/transformation_move_block_down.h b/source/fuzz/transformation_move_block_down.h index fd1584a..7551c38 100644 --- a/source/fuzz/transformation_move_block_down.h +++ b/source/fuzz/transformation_move_block_down.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_MOVE_BLOCK_DOWN_H_ #define SOURCE_FUZZ_TRANSFORMATION_MOVE_BLOCK_DOWN_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -35,12 +35,14 @@ // in a function. // - b must not dominate the block that follows it in program order. // - b must be reachable. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // The block with id |message_.block_id| is moved down; i.e. the program order // between it and the block that follows it is swapped. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp index 01d1c45..05fd923 100644 --- a/source/fuzz/transformation_outline_function.cpp +++ b/source/fuzz/transformation_outline_function.cpp
@@ -70,72 +70,71 @@ } bool TransformationOutlineFunction::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { std::set<uint32_t> ids_used_by_this_transformation; // The various new ids used by the transformation must be fresh and distinct. if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_function_struct_return_type_id(), context, + message_.new_function_struct_return_type_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_function_type_id(), context, + message_.new_function_type_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_function_id(), context, + message_.new_function_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_function_region_entry_block(), context, + message_.new_function_region_entry_block(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_caller_result_id(), context, + message_.new_caller_result_id(), ir_context, &ids_used_by_this_transformation)) { return false; } if (!CheckIdIsFreshAndNotUsedByThisTransformation( - message_.new_callee_result_id(), context, + message_.new_callee_result_id(), ir_context, &ids_used_by_this_transformation)) { return false; } for (auto& pair : message_.input_id_to_fresh_id()) { if (!CheckIdIsFreshAndNotUsedByThisTransformation( - pair.second(), context, &ids_used_by_this_transformation)) { + pair.second(), ir_context, &ids_used_by_this_transformation)) { return false; } } for (auto& pair : message_.output_id_to_fresh_id()) { if (!CheckIdIsFreshAndNotUsedByThisTransformation( - pair.second(), context, &ids_used_by_this_transformation)) { + pair.second(), ir_context, &ids_used_by_this_transformation)) { return false; } } // The entry and exit block ids must indeed refer to blocks. for (auto block_id : {message_.entry_block(), message_.exit_block()}) { - auto block_label = context->get_def_use_mgr()->GetDef(block_id); + auto block_label = ir_context->get_def_use_mgr()->GetDef(block_id); if (!block_label || block_label->opcode() != SpvOpLabel) { return false; } } - auto entry_block = context->cfg()->block(message_.entry_block()); - auto exit_block = context->cfg()->block(message_.exit_block()); + auto entry_block = ir_context->cfg()->block(message_.entry_block()); + auto exit_block = ir_context->cfg()->block(message_.exit_block()); // The entry block cannot start with OpVariable - this would mean that // outlining would remove a variable from the function containing the region @@ -151,7 +150,7 @@ // For simplicity, we do not allow the exit block to be a merge block or // continue target. - if (fuzzerutil::IsMergeOrContinue(context, exit_block->id())) { + if (fuzzerutil::IsMergeOrContinue(ir_context, exit_block->id())) { return false; } @@ -169,14 +168,14 @@ // The entry block must dominate the exit block. auto dominator_analysis = - context->GetDominatorAnalysis(entry_block->GetParent()); + ir_context->GetDominatorAnalysis(entry_block->GetParent()); if (!dominator_analysis->Dominates(entry_block, exit_block)) { return false; } // The exit block must post-dominate the entry block. auto postdominator_analysis = - context->GetPostDominatorAnalysis(entry_block->GetParent()); + ir_context->GetPostDominatorAnalysis(entry_block->GetParent()); if (!postdominator_analysis->Dominates(exit_block, entry_block)) { return false; } @@ -184,8 +183,9 @@ // Find all the blocks dominated by |message_.entry_block| and post-dominated // by |message_.exit_block|. auto region_set = GetRegionBlocks( - context, entry_block = context->cfg()->block(message_.entry_block()), - exit_block = context->cfg()->block(message_.exit_block())); + ir_context, + entry_block = ir_context->cfg()->block(message_.entry_block()), + exit_block = ir_context->cfg()->block(message_.exit_block())); // Check whether |region_set| really is a single-entry single-exit region, and // also check whether structured control flow constructs and their merge @@ -198,10 +198,15 @@ for (auto& block : *entry_block->GetParent()) { if (&block == exit_block) { // It is OK (and typically expected) for the exit block of the region to - // have successors outside the region. It is also OK for the exit block - // to head a structured control flow construct - the block containing the - // call to the outlined function will end up heading this construct if - // outlining takes place. + // have successors outside the region. + // + // It is also OK for the exit block to head a selection construct: the + // block containing the call to the outlined function will end up heading + // this construct if outlining takes place. However, it is not OK for + // the exit block to head a loop construct. + if (block.GetLoopMergeInst()) { + return false; + } continue; } @@ -210,9 +215,9 @@ // see whether all of the block's successors are in the region. If they // are not, the region is not single-entry single-exit. bool all_successors_in_region = true; - block.WhileEachSuccessorLabel([&all_successors_in_region, context, + block.WhileEachSuccessorLabel([&all_successors_in_region, ir_context, ®ion_set](uint32_t successor) -> bool { - if (region_set.count(context->cfg()->block(successor)) == 0) { + if (region_set.count(ir_context->cfg()->block(successor)) == 0) { all_successors_in_region = false; return false; } @@ -227,7 +232,8 @@ // The block is a loop or selection header -- the header and its // associated merge block had better both be in the region or both be // outside the region. - auto merge_block = context->cfg()->block(merge->GetSingleWordOperand(0)); + auto merge_block = + ir_context->cfg()->block(merge->GetSingleWordOperand(0)); if (region_set.count(&block) != region_set.count(merge_block)) { return false; } @@ -236,7 +242,7 @@ if (auto loop_merge = block.GetLoopMergeInst()) { // Similar to the above, but for the continue target of a loop. auto continue_target = - context->cfg()->block(loop_merge->GetSingleWordOperand(1)); + ir_context->cfg()->block(loop_merge->GetSingleWordOperand(1)); if (continue_target != exit_block && region_set.count(&block) != region_set.count(continue_target)) { return false; @@ -248,7 +254,7 @@ // used inside the region, ... std::map<uint32_t, uint32_t> input_id_to_fresh_id_map = PairSequenceToMap(message_.input_id_to_fresh_id()); - for (auto id : GetRegionInputIds(context, region_set, exit_block)) { + for (auto id : GetRegionInputIds(ir_context, region_set, exit_block)) { // There needs to be a corresponding fresh id to be used as a function // parameter. if (input_id_to_fresh_id_map.count(id) == 0) { @@ -256,8 +262,8 @@ } // Furthermore, if the input id has pointer type it must be an OpVariable // or OpFunctionParameter. - auto input_id_inst = context->get_def_use_mgr()->GetDef(id); - if (context->get_def_use_mgr() + auto input_id_inst = ir_context->get_def_use_mgr()->GetDef(id); + if (ir_context->get_def_use_mgr() ->GetDef(input_id_inst->type_id()) ->opcode() == SpvOpTypePointer) { switch (input_id_inst->opcode()) { @@ -273,12 +279,20 @@ } // For each region output id -- i.e. every id defined inside the region but - // used outside the region -- there needs to be a corresponding fresh id that - // can hold the value for this id computed in the outlined function. + // used outside the region, ... std::map<uint32_t, uint32_t> output_id_to_fresh_id_map = PairSequenceToMap(message_.output_id_to_fresh_id()); - for (auto id : GetRegionOutputIds(context, region_set, exit_block)) { - if (output_id_to_fresh_id_map.count(id) == 0) { + for (auto id : GetRegionOutputIds(ir_context, region_set, exit_block)) { + if ( + // ... there needs to be a corresponding fresh id that can hold the + // value for this id computed in the outlined function, and ... + output_id_to_fresh_id_map.count(id) == 0 + // ... the output id must not have pointer type (to avoid creating a + // struct with pointer members to pass data out of the outlined + // function) + || ir_context->get_def_use_mgr() + ->GetDef(fuzzerutil::GetTypeId(ir_context, id)) + ->opcode() == SpvOpTypePointer) { return false; } } @@ -287,25 +301,26 @@ } void TransformationOutlineFunction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // The entry block for the region before outlining. auto original_region_entry_block = - context->cfg()->block(message_.entry_block()); + ir_context->cfg()->block(message_.entry_block()); // The exit block for the region before outlining. auto original_region_exit_block = - context->cfg()->block(message_.exit_block()); + ir_context->cfg()->block(message_.exit_block()); // The single-entry single-exit region defined by |message_.entry_block| and // |message_.exit_block|. std::set<opt::BasicBlock*> region_blocks = GetRegionBlocks( - context, original_region_entry_block, original_region_exit_block); + ir_context, original_region_entry_block, original_region_exit_block); // Input and output ids for the region being outlined. std::vector<uint32_t> region_input_ids = - GetRegionInputIds(context, region_blocks, original_region_exit_block); + GetRegionInputIds(ir_context, region_blocks, original_region_exit_block); std::vector<uint32_t> region_output_ids = - GetRegionOutputIds(context, region_blocks, original_region_exit_block); + GetRegionOutputIds(ir_context, region_blocks, original_region_exit_block); // Maps from input and output ids to fresh ids. std::map<uint32_t, uint32_t> input_id_to_fresh_id_map = @@ -313,14 +328,14 @@ std::map<uint32_t, uint32_t> output_id_to_fresh_id_map = PairSequenceToMap(message_.output_id_to_fresh_id()); - UpdateModuleIdBoundForFreshIds(context, input_id_to_fresh_id_map, + UpdateModuleIdBoundForFreshIds(ir_context, input_id_to_fresh_id_map, output_id_to_fresh_id_map); // Construct a map that associates each output id with its type id. std::map<uint32_t, uint32_t> output_id_to_type_id; for (uint32_t output_id : region_output_ids) { output_id_to_type_id[output_id] = - context->get_def_use_mgr()->GetDef(output_id)->type_id(); + ir_context->get_def_use_mgr()->GetDef(output_id)->type_id(); } // The region will be collapsed to a single block that calls a function @@ -331,53 +346,55 @@ // collapsed block later. std::unique_ptr<opt::Instruction> cloned_exit_block_terminator = std::unique_ptr<opt::Instruction>( - original_region_exit_block->terminator()->Clone(context)); + original_region_exit_block->terminator()->Clone(ir_context)); std::unique_ptr<opt::Instruction> cloned_exit_block_merge = original_region_exit_block->GetMergeInst() ? std::unique_ptr<opt::Instruction>( - original_region_exit_block->GetMergeInst()->Clone(context)) + original_region_exit_block->GetMergeInst()->Clone(ir_context)) : nullptr; // Make a function prototype for the outlined function, which involves // figuring out its required type. - std::unique_ptr<opt::Function> outlined_function = - PrepareFunctionPrototype(region_input_ids, region_output_ids, - input_id_to_fresh_id_map, context, fact_manager); + std::unique_ptr<opt::Function> outlined_function = PrepareFunctionPrototype( + region_input_ids, region_output_ids, input_id_to_fresh_id_map, ir_context, + transformation_context); // If the original function was livesafe, the new function should also be // livesafe. - if (fact_manager->FunctionIsLivesafe( + if (transformation_context->GetFactManager()->FunctionIsLivesafe( original_region_entry_block->GetParent()->result_id())) { - fact_manager->AddFactFunctionIsLivesafe(message_.new_function_id()); + transformation_context->GetFactManager()->AddFactFunctionIsLivesafe( + message_.new_function_id()); } // Adapt the region to be outlined so that its input ids are replaced with the // ids of the outlined function's input parameters, and so that output ids // are similarly remapped. RemapInputAndOutputIdsInRegion( - context, *original_region_exit_block, region_blocks, region_input_ids, + ir_context, *original_region_exit_block, region_blocks, region_input_ids, region_output_ids, input_id_to_fresh_id_map, output_id_to_fresh_id_map); // Fill out the body of the outlined function according to the region that is // being outlined. - PopulateOutlinedFunction(*original_region_entry_block, - *original_region_exit_block, region_blocks, - region_output_ids, output_id_to_fresh_id_map, - context, outlined_function.get(), fact_manager); + PopulateOutlinedFunction( + *original_region_entry_block, *original_region_exit_block, region_blocks, + region_output_ids, output_id_to_fresh_id_map, ir_context, + outlined_function.get(), transformation_context); // Collapse the region that has been outlined into a function down to a single // block that calls said function. ShrinkOriginalRegion( - context, region_blocks, region_input_ids, region_output_ids, + ir_context, region_blocks, region_input_ids, region_output_ids, output_id_to_type_id, outlined_function->type_id(), std::move(cloned_exit_block_merge), std::move(cloned_exit_block_terminator), original_region_entry_block); // Add the outlined function to the module. - context->module()->AddFunction(std::move(outlined_function)); + ir_context->module()->AddFunction(std::move(outlined_function)); // Major surgery has been conducted on the module, so invalidate all analyses. - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationOutlineFunction::ToMessage() const { @@ -387,30 +404,31 @@ } std::vector<uint32_t> TransformationOutlineFunction::GetRegionInputIds( - opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set, + opt::IRContext* ir_context, const std::set<opt::BasicBlock*>& region_set, opt::BasicBlock* region_exit_block) { std::vector<uint32_t> result; auto enclosing_function = region_exit_block->GetParent(); // Consider each parameter of the function containing the region. - enclosing_function->ForEachParam([context, ®ion_set, &result]( - opt::Instruction* function_parameter) { - // Consider every use of the parameter. - context->get_def_use_mgr()->WhileEachUse( - function_parameter, [context, function_parameter, ®ion_set, &result]( - opt::Instruction* use, uint32_t /*unused*/) { - // Get the block, if any, in which the parameter is used. - auto use_block = context->get_instr_block(use); - // If the use is in a block that lies within the region, the - // parameter is an input id for the region. - if (use_block && region_set.count(use_block) != 0) { - result.push_back(function_parameter->result_id()); - return false; - } - return true; - }); - }); + enclosing_function->ForEachParam( + [ir_context, ®ion_set, &result](opt::Instruction* function_parameter) { + // Consider every use of the parameter. + ir_context->get_def_use_mgr()->WhileEachUse( + function_parameter, + [ir_context, function_parameter, ®ion_set, &result]( + opt::Instruction* use, uint32_t /*unused*/) { + // Get the block, if any, in which the parameter is used. + auto use_block = ir_context->get_instr_block(use); + // If the use is in a block that lies within the region, the + // parameter is an input id for the region. + if (use_block && region_set.count(use_block) != 0) { + result.push_back(function_parameter->result_id()); + return false; + } + return true; + }); + }); // Consider all definitions in the function that might turn out to be input // ids. @@ -430,15 +448,15 @@ // Consider each candidate input id to check whether it is used in the // region. for (auto& inst : candidate_input_ids_for_block) { - context->get_def_use_mgr()->WhileEachUse( + ir_context->get_def_use_mgr()->WhileEachUse( inst, - [context, &inst, region_exit_block, ®ion_set, &result]( + [ir_context, &inst, region_exit_block, ®ion_set, &result]( opt::Instruction* use, uint32_t /*unused*/) -> bool { // Find the block in which this id use occurs, recording the id as // an input id if the block is outside the region, with some // exceptions detailed below. - auto use_block = context->get_instr_block(use); + auto use_block = ir_context->get_instr_block(use); if (!use_block) { // There might be no containing block, e.g. if the use is in a @@ -467,7 +485,7 @@ } std::vector<uint32_t> TransformationOutlineFunction::GetRegionOutputIds( - opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set, + opt::IRContext* ir_context, const std::set<opt::BasicBlock*>& region_set, opt::BasicBlock* region_exit_block) { std::vector<uint32_t> result; @@ -479,15 +497,15 @@ } // Consider each use of each instruction defined in the block. for (auto& inst : block) { - context->get_def_use_mgr()->WhileEachUse( + ir_context->get_def_use_mgr()->WhileEachUse( &inst, - [®ion_set, context, &inst, region_exit_block, &result]( + [®ion_set, ir_context, &inst, region_exit_block, &result]( opt::Instruction* use, uint32_t /*unused*/) -> bool { // Find the block in which this id use occurs, recording the id as // an output id if the block is outside the region, with some // exceptions detailed below. - auto use_block = context->get_instr_block(use); + auto use_block = ir_context->get_instr_block(use); if (!use_block) { // There might be no containing block, e.g. if the use is in a @@ -513,12 +531,13 @@ } std::set<opt::BasicBlock*> TransformationOutlineFunction::GetRegionBlocks( - opt::IRContext* context, opt::BasicBlock* entry_block, + opt::IRContext* ir_context, opt::BasicBlock* entry_block, opt::BasicBlock* exit_block) { auto enclosing_function = entry_block->GetParent(); - auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function); + auto dominator_analysis = + ir_context->GetDominatorAnalysis(enclosing_function); auto postdominator_analysis = - context->GetPostDominatorAnalysis(enclosing_function); + ir_context->GetPostDominatorAnalysis(enclosing_function); std::set<opt::BasicBlock*> result; for (auto& block : *enclosing_function) { @@ -535,7 +554,8 @@ const std::vector<uint32_t>& region_input_ids, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& input_id_to_fresh_id_map, - opt::IRContext* context, FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { uint32_t return_type_id = 0; uint32_t function_type_id = 0; @@ -547,14 +567,14 @@ if (region_output_ids.empty()) { std::vector<uint32_t> return_and_parameter_types; opt::analysis::Void void_type; - return_type_id = context->get_type_mgr()->GetId(&void_type); + return_type_id = ir_context->get_type_mgr()->GetId(&void_type); return_and_parameter_types.push_back(return_type_id); for (auto id : region_input_ids) { return_and_parameter_types.push_back( - context->get_def_use_mgr()->GetDef(id)->type_id()); + ir_context->get_def_use_mgr()->GetDef(id)->type_id()); } function_type_id = - fuzzerutil::FindFunctionType(context, return_and_parameter_types); + fuzzerutil::FindFunctionType(ir_context, return_and_parameter_types); } // If no existing function type was found, we need to create one. @@ -568,12 +588,12 @@ opt::Instruction::OperandList struct_member_types; for (uint32_t output_id : region_output_ids) { auto output_id_type = - context->get_def_use_mgr()->GetDef(output_id)->type_id(); + ir_context->get_def_use_mgr()->GetDef(output_id)->type_id(); struct_member_types.push_back({SPV_OPERAND_TYPE_ID, {output_id_type}}); } // Add a new struct type to the module. - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeStruct, 0, + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeStruct, 0, message_.new_function_struct_return_type_id(), std::move(struct_member_types))); // The return type for the function is the newly-created struct. @@ -589,12 +609,12 @@ for (auto id : region_input_ids) { function_type_operands.push_back( {SPV_OPERAND_TYPE_ID, - {context->get_def_use_mgr()->GetDef(id)->type_id()}}); + {ir_context->get_def_use_mgr()->GetDef(id)->type_id()}}); } // Add a new function type to the module, and record that this is the type // id for the new function. - context->module()->AddType(MakeUnique<opt::Instruction>( - context, SpvOpTypeFunction, 0, message_.new_function_type_id(), + ir_context->module()->AddType(MakeUnique<opt::Instruction>( + ir_context, SpvOpTypeFunction, 0, message_.new_function_type_id(), function_type_operands)); function_type_id = message_.new_function_type_id(); } @@ -603,7 +623,7 @@ // and the return type and function type prepared above. std::unique_ptr<opt::Function> outlined_function = MakeUnique<opt::Function>(MakeUnique<opt::Instruction>( - context, SpvOpFunction, return_type_id, message_.new_function_id(), + ir_context, SpvOpFunction, return_type_id, message_.new_function_id(), opt::Instruction::OperandList( {{spv_operand_type_t ::SPV_OPERAND_TYPE_LITERAL_INTEGER, {SpvFunctionControlMaskNone}}, @@ -614,14 +634,15 @@ // provided in |input_id_to_fresh_id_map|. for (auto id : region_input_ids) { outlined_function->AddParameter(MakeUnique<opt::Instruction>( - context, SpvOpFunctionParameter, - context->get_def_use_mgr()->GetDef(id)->type_id(), + ir_context, SpvOpFunctionParameter, + ir_context->get_def_use_mgr()->GetDef(id)->type_id(), input_id_to_fresh_id_map.at(id), opt::Instruction::OperandList())); // If the input id is an irrelevant-valued variable, the same should be true // of the corresponding parameter. - if (fact_manager->PointeeValueIsIrrelevant(id)) { - fact_manager->AddFactValueOfPointeeIsIrrelevant( - input_id_to_fresh_id_map.at(id)); + if (transformation_context->GetFactManager()->PointeeValueIsIrrelevant( + id)) { + transformation_context->GetFactManager() + ->AddFactValueOfPointeeIsIrrelevant(input_id_to_fresh_id_map.at(id)); } } @@ -629,31 +650,32 @@ } void TransformationOutlineFunction::UpdateModuleIdBoundForFreshIds( - opt::IRContext* context, + opt::IRContext* ir_context, const std::map<uint32_t, uint32_t>& input_id_to_fresh_id_map, const std::map<uint32_t, uint32_t>& output_id_to_fresh_id_map) const { // Enlarge the module's id bound as needed to accommodate the various fresh // ids associated with the transformation. fuzzerutil::UpdateModuleIdBound( - context, message_.new_function_struct_return_type_id()); - fuzzerutil::UpdateModuleIdBound(context, message_.new_function_type_id()); - fuzzerutil::UpdateModuleIdBound(context, message_.new_function_id()); - fuzzerutil::UpdateModuleIdBound(context, + ir_context, message_.new_function_struct_return_type_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.new_function_type_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.new_function_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.new_function_region_entry_block()); - fuzzerutil::UpdateModuleIdBound(context, message_.new_caller_result_id()); - fuzzerutil::UpdateModuleIdBound(context, message_.new_callee_result_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.new_caller_result_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.new_callee_result_id()); for (auto& entry : input_id_to_fresh_id_map) { - fuzzerutil::UpdateModuleIdBound(context, entry.second); + fuzzerutil::UpdateModuleIdBound(ir_context, entry.second); } for (auto& entry : output_id_to_fresh_id_map) { - fuzzerutil::UpdateModuleIdBound(context, entry.second); + fuzzerutil::UpdateModuleIdBound(ir_context, entry.second); } } void TransformationOutlineFunction::RemapInputAndOutputIdsInRegion( - opt::IRContext* context, const opt::BasicBlock& original_region_exit_block, + opt::IRContext* ir_context, + const opt::BasicBlock& original_region_exit_block, const std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_input_ids, const std::vector<uint32_t>& region_output_ids, @@ -664,11 +686,11 @@ // This is done by considering each region input id in turn. for (uint32_t id : region_input_ids) { // We then consider each use of the input id. - context->get_def_use_mgr()->ForEachUse( - id, [context, id, &input_id_to_fresh_id_map, region_blocks]( + ir_context->get_def_use_mgr()->ForEachUse( + id, [ir_context, id, &input_id_to_fresh_id_map, region_blocks]( opt::Instruction* use, uint32_t operand_index) { // Find the block in which this use of the input id occurs. - opt::BasicBlock* use_block = context->get_instr_block(use); + opt::BasicBlock* use_block = ir_context->get_instr_block(use); // We want to rewrite the use id if its block occurs in the outlined // region. if (region_blocks.count(use_block) != 0) { @@ -684,12 +706,12 @@ // This is done by considering each region output id in turn. for (uint32_t id : region_output_ids) { // First consider each use of the output id and update the relevant uses. - context->get_def_use_mgr()->ForEachUse( - id, - [context, &original_region_exit_block, id, &output_id_to_fresh_id_map, - region_blocks](opt::Instruction* use, uint32_t operand_index) { + ir_context->get_def_use_mgr()->ForEachUse( + id, [ir_context, &original_region_exit_block, id, + &output_id_to_fresh_id_map, + region_blocks](opt::Instruction* use, uint32_t operand_index) { // Find the block in which this use of the output id occurs. - auto use_block = context->get_instr_block(use); + auto use_block = ir_context->get_instr_block(use); // We want to rewrite the use id if its block occurs in the outlined // region, with one exception: the terminator of the exit block of // the region is going to remain in the original function, so if the @@ -710,7 +732,7 @@ // defines the corresponding fresh id. We do this after changing all the // uses so that the definition of the original id is still registered when // we analyse its uses. - context->get_def_use_mgr()->GetDef(id)->SetResultId( + ir_context->get_def_use_mgr()->GetDef(id)->SetResultId( output_id_to_fresh_id_map.at(id)); } } @@ -721,8 +743,8 @@ const std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& output_id_to_fresh_id_map, - opt::IRContext* context, opt::Function* outlined_function, - FactManager* fact_manager) const { + opt::IRContext* ir_context, opt::Function* outlined_function, + TransformationContext* transformation_context) const { // When we create the exit block for the outlined region, we use this pointer // to track of it so that we can manipulate it later. opt::BasicBlock* outlined_region_exit_block = nullptr; @@ -732,14 +754,16 @@ // |message_.new_function_region_entry_block| as its id. std::unique_ptr<opt::BasicBlock> outlined_region_entry_block = MakeUnique<opt::BasicBlock>(MakeUnique<opt::Instruction>( - context, SpvOpLabel, 0, message_.new_function_region_entry_block(), + ir_context, SpvOpLabel, 0, message_.new_function_region_entry_block(), opt::Instruction::OperandList())); outlined_region_entry_block->SetParent(outlined_function); // If the original region's entry block was dead, the outlined region's entry // block is also dead. - if (fact_manager->BlockIsDead(original_region_entry_block.id())) { - fact_manager->AddFactBlockIsDead(outlined_region_entry_block->id()); + if (transformation_context->GetFactManager()->BlockIsDead( + original_region_entry_block.id())) { + transformation_context->GetFactManager()->AddFactBlockIsDead( + outlined_region_entry_block->id()); } if (&original_region_entry_block == &original_region_exit_block) { @@ -748,7 +772,7 @@ for (auto& inst : original_region_entry_block) { outlined_region_entry_block->AddInstruction( - std::unique_ptr<opt::Instruction>(inst.Clone(context))); + std::unique_ptr<opt::Instruction>(inst.Clone(ir_context))); } outlined_function->AddBasicBlock(std::move(outlined_region_entry_block)); @@ -767,7 +791,7 @@ } // Clone the block so that it can be added to the new function. auto cloned_block = - std::unique_ptr<opt::BasicBlock>(block_it->Clone(context)); + std::unique_ptr<opt::BasicBlock>(block_it->Clone(ir_context)); // If this is the region's exit block, then the cloned block is the outlined // region's exit block. @@ -823,7 +847,7 @@ // The case where there are no region output ids is simple: we just add // OpReturn. outlined_region_exit_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpReturn, 0, 0, opt::Instruction::OperandList())); + ir_context, SpvOpReturn, 0, 0, opt::Instruction::OperandList())); } else { // In the case where there are output ids, we add an OpCompositeConstruct // instruction to pack all the output values into a struct, and then an @@ -834,21 +858,21 @@ {SPV_OPERAND_TYPE_ID, {output_id_to_fresh_id_map.at(id)}}); } outlined_region_exit_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpCompositeConstruct, + ir_context, SpvOpCompositeConstruct, message_.new_function_struct_return_type_id(), message_.new_callee_result_id(), struct_member_operands)); outlined_region_exit_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpReturnValue, 0, 0, + ir_context, SpvOpReturnValue, 0, 0, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.new_callee_result_id()}}}))); } outlined_function->SetFunctionEnd(MakeUnique<opt::Instruction>( - context, SpvOpFunctionEnd, 0, 0, opt::Instruction::OperandList())); + ir_context, SpvOpFunctionEnd, 0, 0, opt::Instruction::OperandList())); } void TransformationOutlineFunction::ShrinkOriginalRegion( - opt::IRContext* context, std::set<opt::BasicBlock*>& region_blocks, + opt::IRContext* ir_context, std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_input_ids, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& output_id_to_type_id, @@ -912,7 +936,7 @@ } original_region_entry_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpFunctionCall, return_type_id, + ir_context, SpvOpFunctionCall, return_type_id, message_.new_caller_result_id(), function_call_operands)); // If there are output ids, the function call will return a struct. For each @@ -921,7 +945,7 @@ for (uint32_t index = 0; index < region_output_ids.size(); ++index) { uint32_t output_id = region_output_ids[index]; original_region_entry_block->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpCompositeExtract, output_id_to_type_id.at(output_id), + ir_context, SpvOpCompositeExtract, output_id_to_type_id.at(output_id), output_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.new_caller_result_id()}},
diff --git a/source/fuzz/transformation_outline_function.h b/source/fuzz/transformation_outline_function.h index 5711790..ba439c8 100644 --- a/source/fuzz/transformation_outline_function.h +++ b/source/fuzz/transformation_outline_function.h
@@ -19,9 +19,9 @@ #include <set> #include <vector> -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -58,8 +58,9 @@ // defined outside the region but used in the region // - |message_.output_id_to_fresh_id| must contain an entry for every id // defined in the region but used outside the region - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - A new function with id |message_.new_function_id| is added to the module. // - If the region generates output ids, the return type of this function is @@ -95,14 +96,15 @@ // |message_.new_function_struct_return_type| comprised of all the fresh // output ids (unless the return type is void, in which case no value is // returned. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; // Returns the set of blocks dominated by |entry_block| and post-dominated // by |exit_block|. static std::set<opt::BasicBlock*> GetRegionBlocks( - opt::IRContext* context, opt::BasicBlock* entry_block, + opt::IRContext* ir_context, opt::BasicBlock* entry_block, opt::BasicBlock* exit_block); // Yields ids that are used in |region_set| and that are either parameters @@ -114,7 +116,7 @@ // - id uses in OpPhi instructions in |region_entry_block| are ignored // - id uses in the terminator instruction of |region_exit_block| are ignored static std::vector<uint32_t> GetRegionInputIds( - opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set, + opt::IRContext* ir_context, const std::set<opt::BasicBlock*>& region_set, opt::BasicBlock* region_exit_block); // Yields all ids that are defined in |region_set| and used outside @@ -124,14 +126,14 @@ // - ids defined in the region and used in the terminator of // |region_exit_block| count as output ids static std::vector<uint32_t> GetRegionOutputIds( - opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set, + opt::IRContext* ir_context, const std::set<opt::BasicBlock*>& region_set, opt::BasicBlock* region_exit_block); private: // Ensures that the module's id bound is at least the maximum of any fresh id // associated with the transformation. void UpdateModuleIdBoundForFreshIds( - opt::IRContext* context, + opt::IRContext* ir_context, const std::map<uint32_t, uint32_t>& input_id_to_fresh_id_map, const std::map<uint32_t, uint32_t>& output_id_to_fresh_id_map) const; @@ -142,7 +144,7 @@ // modified, and |original_region_exit_block| allows for some special cases // where ids should not be remapped. void RemapInputAndOutputIdsInRegion( - opt::IRContext* context, + opt::IRContext* ir_context, const opt::BasicBlock& original_region_exit_block, const std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_input_ids, @@ -160,12 +162,14 @@ // are already present). // // Facts about the function containing the outlined region that are relevant - // to the new function are propagated via |fact_manager|. + // to the new function are propagated via the vact manager in + // |transformation_context|. std::unique_ptr<opt::Function> PrepareFunctionPrototype( const std::vector<uint32_t>& region_input_ids, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& input_id_to_fresh_id_map, - opt::IRContext* context, FactManager* fact_manager) const; + opt::IRContext* ir_context, + TransformationContext* transformation_context) const; // Creates the body of the outlined function by cloning blocks from the // original region, given by |region_blocks|, adapting the cloned version @@ -174,17 +178,17 @@ // clone. Parameters |region_output_ids| and |output_id_to_fresh_id_map| are // used to determine what the function should return. // - // The |fact_manager| argument allow facts about blocks being outlined, e.g. - // whether they are dead blocks, to be asserted about blocks that get created - // during outlining. + // The |transformation_context| argument allow facts about blocks being + // outlined, e.g. whether they are dead blocks, to be asserted about blocks + // that get created during outlining. void PopulateOutlinedFunction( const opt::BasicBlock& original_region_entry_block, const opt::BasicBlock& original_region_exit_block, const std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& output_id_to_fresh_id_map, - opt::IRContext* context, opt::Function* outlined_function, - FactManager* fact_manager) const; + opt::IRContext* ir_context, opt::Function* outlined_function, + TransformationContext* transformation_context) const; // Shrinks the outlined region, given by |region_blocks|, down to the single // block |original_region_entry_block|. This block is itself shrunk to just @@ -203,7 +207,7 @@ // function is called, this information cannot be gotten from the def-use // manager. void ShrinkOriginalRegion( - opt::IRContext* context, std::set<opt::BasicBlock*>& region_blocks, + opt::IRContext* ir_context, std::set<opt::BasicBlock*>& region_blocks, const std::vector<uint32_t>& region_input_ids, const std::vector<uint32_t>& region_output_ids, const std::map<uint32_t, uint32_t>& output_id_to_type_id,
diff --git a/source/fuzz/transformation_permute_function_parameters.cpp b/source/fuzz/transformation_permute_function_parameters.cpp index 2141533..0f1220e 100644 --- a/source/fuzz/transformation_permute_function_parameters.cpp +++ b/source/fuzz/transformation_permute_function_parameters.cpp
@@ -40,17 +40,17 @@ } bool TransformationPermuteFunctionParameters::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // Check that function exists const auto* function = - fuzzerutil::FindFunction(context, message_.function_id()); + fuzzerutil::FindFunction(ir_context, message_.function_id()); if (!function || function->DefInst().opcode() != SpvOpFunction || - fuzzerutil::FunctionIsEntryPoint(context, function->result_id())) { + fuzzerutil::FunctionIsEntryPoint(ir_context, function->result_id())) { return false; } // Check that permutation has valid indices - const auto* function_type = fuzzerutil::GetFunctionType(context, function); + const auto* function_type = fuzzerutil::GetFunctionType(ir_context, function); assert(function_type && "Function type is null"); const auto& permutation = message_.permutation(); @@ -83,7 +83,7 @@ // - Has the same result type as the old one // - Order of arguments is permuted auto new_type_id = message_.new_type_id(); - const auto* new_type = context->get_def_use_mgr()->GetDef(new_type_id); + const auto* new_type = ir_context->get_def_use_mgr()->GetDef(new_type_id); if (!new_type || new_type->opcode() != SpvOpTypeFunction || new_type->NumInOperands() != function_type->NumInOperands()) { @@ -109,14 +109,14 @@ } void TransformationPermuteFunctionParameters::Apply( - opt::IRContext* context, FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Retrieve all data from the message uint32_t function_id = message_.function_id(); uint32_t new_type_id = message_.new_type_id(); const auto& permutation = message_.permutation(); // Find the function that will be transformed - auto* function = fuzzerutil::FindFunction(context, function_id); + auto* function = fuzzerutil::FindFunction(ir_context, function_id); assert(function && "Can't find the function"); // Change function's type @@ -149,7 +149,7 @@ }); // Fix all OpFunctionCall instructions - context->get_def_use_mgr()->ForEachUser( + ir_context->get_def_use_mgr()->ForEachUser( &function->DefInst(), [function_id, &permutation](opt::Instruction* call) { if (call->opcode() != SpvOpFunctionCall || @@ -170,7 +170,8 @@ }); // Make sure our changes are analyzed - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationPermuteFunctionParameters::ToMessage()
diff --git a/source/fuzz/transformation_permute_function_parameters.h b/source/fuzz/transformation_permute_function_parameters.h index c67a735..994e4c2 100644 --- a/source/fuzz/transformation_permute_function_parameters.h +++ b/source/fuzz/transformation_permute_function_parameters.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_PERMUTE_FUNCTION_PARAMETERS_H_ #define SOURCE_FUZZ_TRANSFORMATION_PERMUTE_FUNCTION_PARAMETERS_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -40,14 +40,16 @@ // - function's arguments are permuted according to |permutation| vector // - |permutation| is a set of [0..(n - 1)], where n is a number of arguments // to the function - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - OpFunction instruction with |result_id == function_id| is changed. // Its arguments are permuted according to the |permutation| vector // - Changed function gets a new type specified by |type_id| // - Calls to the function are adjusted accordingly - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.cpp b/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.cpp index 72d9b22..d6f17fc 100644 --- a/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.cpp +++ b/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.cpp
@@ -128,15 +128,15 @@ } bool TransformationReplaceBooleanConstantWithConstantBinary::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The id for the binary result must be fresh - if (!fuzzerutil::IsFreshId(context, + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id_for_binary_operation())) { return false; } // The used id must be for a boolean constant - auto boolean_constant = context->get_def_use_mgr()->GetDef( + auto boolean_constant = ir_context->get_def_use_mgr()->GetDef( message_.id_use_descriptor().id_of_interest()); if (!boolean_constant) { return false; @@ -148,7 +148,7 @@ // The left-hand-side id must correspond to a constant instruction. auto lhs_constant_inst = - context->get_def_use_mgr()->GetDef(message_.lhs_id()); + ir_context->get_def_use_mgr()->GetDef(message_.lhs_id()); if (!lhs_constant_inst) { return false; } @@ -158,7 +158,7 @@ // The right-hand-side id must correspond to a constant instruction. auto rhs_constant_inst = - context->get_def_use_mgr()->GetDef(message_.rhs_id()); + ir_context->get_def_use_mgr()->GetDef(message_.rhs_id()); if (!rhs_constant_inst) { return false; } @@ -173,9 +173,9 @@ // The expression 'LHS opcode RHS' must evaluate to the boolean constant. auto lhs_constant = - context->get_constant_mgr()->FindDeclaredConstant(message_.lhs_id()); + ir_context->get_constant_mgr()->FindDeclaredConstant(message_.lhs_id()); auto rhs_constant = - context->get_constant_mgr()->FindDeclaredConstant(message_.rhs_id()); + ir_context->get_constant_mgr()->FindDeclaredConstant(message_.rhs_id()); bool expected_result = (boolean_constant->opcode() == SpvOpConstantTrue); const auto binary_opcode = static_cast<SpvOp>(message_.opcode()); @@ -238,7 +238,7 @@ // The id use descriptor must identify some instruction auto instruction = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); if (instruction == nullptr) { return false; } @@ -262,24 +262,25 @@ } void TransformationReplaceBooleanConstantWithConstantBinary::Apply( - opt::IRContext* context, FactManager* fact_manager) const { - ApplyWithResult(context, fact_manager); + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { + ApplyWithResult(ir_context, transformation_context); } opt::Instruction* TransformationReplaceBooleanConstantWithConstantBinary::ApplyWithResult( - opt::IRContext* context, FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::analysis::Bool bool_type; opt::Instruction::OperandList operands = { {SPV_OPERAND_TYPE_ID, {message_.lhs_id()}}, {SPV_OPERAND_TYPE_ID, {message_.rhs_id()}}}; auto binary_instruction = MakeUnique<opt::Instruction>( - context, static_cast<SpvOp>(message_.opcode()), - context->get_type_mgr()->GetId(&bool_type), + ir_context, static_cast<SpvOp>(message_.opcode()), + ir_context->get_type_mgr()->GetId(&bool_type), message_.fresh_id_for_binary_operation(), operands); opt::Instruction* result = binary_instruction.get(); auto instruction_containing_constant_use = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); // We want to insert the new instruction before the instruction that contains // the use of the boolean, but we need to go backwards one more instruction if @@ -298,9 +299,10 @@ instruction_containing_constant_use->SetInOperand( message_.id_use_descriptor().in_operand_index(), {message_.fresh_id_for_binary_operation()}); - fuzzerutil::UpdateModuleIdBound(context, + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id_for_binary_operation()); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); return result; }
diff --git a/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h b/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h index f74cd8d..3abb485 100644 --- a/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h +++ b/source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_BOOLEAN_CONSTANT_WITH_CONSTANT_BINARY_H_ #define SOURCE_FUZZ_TRANSFORMATION_REPLACE_BOOLEAN_CONSTANT_WITH_CONSTANT_BINARY_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -49,20 +49,23 @@ // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2902): consider // replacing a boolean in an OpPhi by adding a binary operator instruction // to the parent block for the OpPhi. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // A new instruction is added before the boolean constant usage that computes // the result of applying |message_.opcode| to |message_.lhs_id| and // |message_.rhs_id| is added, with result id // |message_.fresh_id_for_binary_operation|. The boolean constant usage is // replaced with this result id. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; // The same as Apply, except that the newly-added binary instruction is // returned. - opt::Instruction* ApplyWithResult(opt::IRContext* context, - FactManager* fact_manager) const; + opt::Instruction* ApplyWithResult( + opt::IRContext* ir_context, + TransformationContext* transformation_context) const; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.cpp b/source/fuzz/transformation_replace_constant_with_uniform.cpp index 8e0e4e5..a8f9495 100644 --- a/source/fuzz/transformation_replace_constant_with_uniform.cpp +++ b/source/fuzz/transformation_replace_constant_with_uniform.cpp
@@ -39,12 +39,12 @@ std::unique_ptr<opt::Instruction> TransformationReplaceConstantWithUniform::MakeAccessChainInstruction( - spvtools::opt::IRContext* context, uint32_t constant_type_id) const { + spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const { // The input operands for the access chain. opt::Instruction::OperandList operands_for_access_chain; opt::Instruction* uniform_variable = - FindUniformVariable(message_.uniform_descriptor(), context, false); + FindUniformVariable(message_.uniform_descriptor(), ir_context, false); // The first input operand is the id of the uniform variable. operands_for_access_chain.push_back( @@ -56,42 +56,43 @@ // instruction ids as operands. opt::analysis::Integer int_type(32, true); auto registered_int_type = - context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger(); - auto int_type_id = context->get_type_mgr()->GetId(&int_type); + ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger(); + auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type); for (auto index : message_.uniform_descriptor().index()) { opt::analysis::IntConstant int_constant(registered_int_type, {index}); - auto constant_id = context->get_constant_mgr()->FindDeclaredConstant( + auto constant_id = ir_context->get_constant_mgr()->FindDeclaredConstant( &int_constant, int_type_id); operands_for_access_chain.push_back({SPV_OPERAND_TYPE_ID, {constant_id}}); } // The type id for the access chain is a uniform pointer with base type // matching the given constant id type. - auto type_and_pointer_type = context->get_type_mgr()->GetTypeAndPointerType( - constant_type_id, SpvStorageClassUniform); + auto type_and_pointer_type = + ir_context->get_type_mgr()->GetTypeAndPointerType(constant_type_id, + SpvStorageClassUniform); assert(type_and_pointer_type.first != nullptr); assert(type_and_pointer_type.second != nullptr); auto pointer_to_uniform_constant_type_id = - context->get_type_mgr()->GetId(type_and_pointer_type.second.get()); + ir_context->get_type_mgr()->GetId(type_and_pointer_type.second.get()); return MakeUnique<opt::Instruction>( - context, SpvOpAccessChain, pointer_to_uniform_constant_type_id, + ir_context, SpvOpAccessChain, pointer_to_uniform_constant_type_id, message_.fresh_id_for_access_chain(), operands_for_access_chain); } std::unique_ptr<opt::Instruction> TransformationReplaceConstantWithUniform::MakeLoadInstruction( - spvtools::opt::IRContext* context, uint32_t constant_type_id) const { + spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const { opt::Instruction::OperandList operands_for_load = { {SPV_OPERAND_TYPE_ID, {message_.fresh_id_for_access_chain()}}}; - return MakeUnique<opt::Instruction>(context, SpvOpLoad, constant_type_id, + return MakeUnique<opt::Instruction>(ir_context, SpvOpLoad, constant_type_id, message_.fresh_id_for_load(), operands_for_load); } bool TransformationReplaceConstantWithUniform::IsApplicable( - spvtools::opt::IRContext* context, - const spvtools::fuzz::FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // The following is really an invariant of the transformation rather than // merely a requirement of the precondition. We check it here since we cannot // check it in the message_ constructor. @@ -99,16 +100,17 @@ "Fresh ids for access chain and load result cannot be the same."); // The ids for the access chain and load instructions must both be fresh. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id_for_access_chain())) { + if (!fuzzerutil::IsFreshId(ir_context, + message_.fresh_id_for_access_chain())) { return false; } - if (!fuzzerutil::IsFreshId(context, message_.fresh_id_for_load())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id_for_load())) { return false; } // The id specified in the id use descriptor must be that of a declared scalar // constant. - auto declared_constant = context->get_constant_mgr()->FindDeclaredConstant( + auto declared_constant = ir_context->get_constant_mgr()->FindDeclaredConstant( message_.id_use_descriptor().id_of_interest()); if (!declared_constant) { return false; @@ -120,13 +122,13 @@ // The fact manager needs to believe that the uniform data element described // by the uniform buffer element descriptor will hold a scalar value. auto constant_id_associated_with_uniform = - fact_manager.GetConstantFromUniformDescriptor( - context, message_.uniform_descriptor()); + transformation_context.GetFactManager()->GetConstantFromUniformDescriptor( + ir_context, message_.uniform_descriptor()); if (!constant_id_associated_with_uniform) { return false; } auto constant_associated_with_uniform = - context->get_constant_mgr()->FindDeclaredConstant( + ir_context->get_constant_mgr()->FindDeclaredConstant( constant_id_associated_with_uniform); assert(constant_associated_with_uniform && "The constant should be present in the module."); @@ -149,7 +151,7 @@ // The id use descriptor must identify some instruction with respect to the // module. auto instruction_using_constant = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); if (!instruction_using_constant) { return false; } @@ -165,23 +167,23 @@ // replace with a uniform. opt::analysis::Pointer pointer_to_type_of_constant(declared_constant->type(), SpvStorageClassUniform); - if (!context->get_type_mgr()->GetId(&pointer_to_type_of_constant)) { + if (!ir_context->get_type_mgr()->GetId(&pointer_to_type_of_constant)) { return false; } // In order to index into the uniform, the module has got to contain the int32 // type, plus an OpConstant for each of the indices of interest. opt::analysis::Integer int_type(32, true); - if (!context->get_type_mgr()->GetId(&int_type)) { + if (!ir_context->get_type_mgr()->GetId(&int_type)) { return false; } auto registered_int_type = - context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger(); - auto int_type_id = context->get_type_mgr()->GetId(&int_type); + ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger(); + auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type); for (auto index : message_.uniform_descriptor().index()) { opt::analysis::IntConstant int_constant(registered_int_type, {index}); - if (!context->get_constant_mgr()->FindDeclaredConstant(&int_constant, - int_type_id)) { + if (!ir_context->get_constant_mgr()->FindDeclaredConstant(&int_constant, + int_type_id)) { return false; } } @@ -190,11 +192,11 @@ } void TransformationReplaceConstantWithUniform::Apply( - spvtools::opt::IRContext* context, - spvtools::fuzz::FactManager* /*unused*/) const { + spvtools::opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { // Get the instruction that contains the id use we wish to replace. auto instruction_containing_constant_use = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); assert(instruction_containing_constant_use && "Precondition requires that the id use can be found."); assert(instruction_containing_constant_use->GetSingleWordInOperand( @@ -204,17 +206,17 @@ // The id of the type for the constant whose use we wish to replace. auto constant_type_id = - context->get_def_use_mgr() + ir_context->get_def_use_mgr() ->GetDef(message_.id_use_descriptor().id_of_interest()) ->type_id(); // Add an access chain instruction to target the uniform element. instruction_containing_constant_use->InsertBefore( - MakeAccessChainInstruction(context, constant_type_id)); + MakeAccessChainInstruction(ir_context, constant_type_id)); // Add a load from this access chain. instruction_containing_constant_use->InsertBefore( - MakeLoadInstruction(context, constant_type_id)); + MakeLoadInstruction(ir_context, constant_type_id)); // Adjust the instruction containing the usage of the constant so that this // usage refers instead to the result of the load. @@ -223,11 +225,12 @@ {message_.fresh_id_for_load()}); // Update the module id bound to reflect the new instructions. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id_for_load()); - fuzzerutil::UpdateModuleIdBound(context, + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id_for_load()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id_for_access_chain()); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationReplaceConstantWithUniform::ToMessage()
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.h b/source/fuzz/transformation_replace_constant_with_uniform.h index ed354b1..b72407c 100644 --- a/source/fuzz/transformation_replace_constant_with_uniform.h +++ b/source/fuzz/transformation_replace_constant_with_uniform.h
@@ -58,8 +58,9 @@ // - According to the fact manager, the uniform data element specified by // |message_.uniform_descriptor| holds a value with the same type and // value as %C - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - Introduces two new instructions: // - An access chain targeting the uniform data element specified by @@ -68,7 +69,8 @@ // - A load from this access chain, with id |message_.fresh_id_for_load| // - Replaces the id use specified by |message_.id_use_descriptor| with // |message_.fresh_id_for_load| - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -76,11 +78,11 @@ // Helper method to create an access chain for the uniform element associated // with the transformation. std::unique_ptr<opt::Instruction> MakeAccessChainInstruction( - spvtools::opt::IRContext* context, uint32_t constant_type_id) const; + spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const; // Helper to create a load instruction. std::unique_ptr<opt::Instruction> MakeLoadInstruction( - spvtools::opt::IRContext* context, uint32_t constant_type_id) const; + spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const; protobufs::TransformationReplaceConstantWithUniform message_; };
diff --git a/source/fuzz/transformation_replace_id_with_synonym.cpp b/source/fuzz/transformation_replace_id_with_synonym.cpp index 88c977a..e427f3c 100644 --- a/source/fuzz/transformation_replace_id_with_synonym.cpp +++ b/source/fuzz/transformation_replace_id_with_synonym.cpp
@@ -37,28 +37,29 @@ } bool TransformationReplaceIdWithSynonym::IsApplicable( - spvtools::opt::IRContext* context, - const spvtools::fuzz::FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { auto id_of_interest = message_.id_use_descriptor().id_of_interest(); // Does the fact manager know about the synonym? auto data_descriptor_for_synonymous_id = MakeDataDescriptor(message_.synonymous_id(), {}); - if (!fact_manager.IsSynonymous(MakeDataDescriptor(id_of_interest, {}), - data_descriptor_for_synonymous_id, context)) { + if (!transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(id_of_interest, {}), + data_descriptor_for_synonymous_id)) { return false; } // Does the id use descriptor in the transformation identify an instruction? auto use_instruction = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); if (!use_instruction) { return false; } // Is the use suitable for being replaced in principle? if (!UseCanBeReplacedWithSynonym( - context, use_instruction, + ir_context, use_instruction, message_.id_use_descriptor().in_operand_index())) { return false; } @@ -66,19 +67,21 @@ // The transformation is applicable if the synonymous id is available at the // use point. return fuzzerutil::IdIsAvailableAtUse( - context, use_instruction, message_.id_use_descriptor().in_operand_index(), + ir_context, use_instruction, + message_.id_use_descriptor().in_operand_index(), message_.synonymous_id()); } void TransformationReplaceIdWithSynonym::Apply( - spvtools::opt::IRContext* context, - spvtools::fuzz::FactManager* /*unused*/) const { + spvtools::opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { auto instruction_to_change = - FindInstructionContainingUse(message_.id_use_descriptor(), context); + FindInstructionContainingUse(message_.id_use_descriptor(), ir_context); instruction_to_change->SetInOperand( message_.id_use_descriptor().in_operand_index(), {message_.synonymous_id()}); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationReplaceIdWithSynonym::ToMessage() @@ -89,7 +92,7 @@ } bool TransformationReplaceIdWithSynonym::UseCanBeReplacedWithSynonym( - opt::IRContext* context, opt::Instruction* use_instruction, + opt::IRContext* ir_context, opt::Instruction* use_instruction, uint32_t use_in_operand_index) { if (use_instruction->opcode() == SpvOpAccessChain && use_in_operand_index > 0) { @@ -98,10 +101,10 @@ // synonym, as the use needs to be an OpConstant. // Get the top-level composite type that is being accessed. - auto object_being_accessed = context->get_def_use_mgr()->GetDef( + auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef( use_instruction->GetSingleWordInOperand(0)); auto pointer_type = - context->get_type_mgr()->GetType(object_being_accessed->type_id()); + ir_context->get_type_mgr()->GetType(object_being_accessed->type_id()); assert(pointer_type->AsPointer()); auto composite_type_being_accessed = pointer_type->AsPointer()->pointee_type(); @@ -122,9 +125,12 @@ } else if (composite_type_being_accessed->AsArray()) { composite_type_being_accessed = composite_type_being_accessed->AsArray()->element_type(); + } else if (composite_type_being_accessed->AsRuntimeArray()) { + composite_type_being_accessed = + composite_type_being_accessed->AsRuntimeArray()->element_type(); } else { assert(composite_type_being_accessed->AsStruct()); - auto constant_index_instruction = context->get_def_use_mgr()->GetDef( + auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef( use_instruction->GetSingleWordInOperand(index_in_operand)); assert(constant_index_instruction->opcode() == SpvOpConstant); uint32_t member_index = @@ -149,21 +155,30 @@ // type. // Get the definition of the function being called. - auto function = context->get_def_use_mgr()->GetDef( + auto function = ir_context->get_def_use_mgr()->GetDef( use_instruction->GetSingleWordInOperand(0)); // From the function definition, get the function type. - auto function_type = - context->get_def_use_mgr()->GetDef(function->GetSingleWordInOperand(1)); + auto function_type = ir_context->get_def_use_mgr()->GetDef( + function->GetSingleWordInOperand(1)); // OpTypeFunction's 0-th input operand is the function return type, and the // function argument types follow. Because the arguments to OpFunctionCall // start from input operand 1, we can use |use_in_operand_index| to get the // type associated with this function argument. - auto parameter_type = context->get_type_mgr()->GetType( + auto parameter_type = ir_context->get_type_mgr()->GetType( function_type->GetSingleWordInOperand(use_in_operand_index)); if (parameter_type->AsPointer()) { return false; } } + + if (use_instruction->opcode() == SpvOpImageTexelPointer && + use_in_operand_index == 2) { + // The OpImageTexelPointer instruction has a Sample parameter that in some + // situations must be an id for the value 0. To guard against disrupting + // that requirement, we do not replace this argument to that instruction. + return false; + } + return true; }
diff --git a/source/fuzz/transformation_replace_id_with_synonym.h b/source/fuzz/transformation_replace_id_with_synonym.h index 48132c1..a5a9dfd 100644 --- a/source/fuzz/transformation_replace_id_with_synonym.h +++ b/source/fuzz/transformation_replace_id_with_synonym.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_ID_WITH_SYNONYM_H_ #define SOURCE_FUZZ_TRANSFORMATION_REPLACE_ID_WITH_SYNONYM_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -42,12 +42,14 @@ // - The id must not be a pointer argument to a function call (because the // synonym might not be a memory object declaration). // - |fresh_id_for_temporary| must be 0. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Replaces the use identified by |message_.id_use_descriptor| with the // synonymous id identified by |message_.synonymous_id|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -58,7 +60,7 @@ // indices must be constants, so it is dangerous to replace them. // - the id use is not a pointer function call argument, on which there are // restrictions that make replacement problematic. - static bool UseCanBeReplacedWithSynonym(opt::IRContext* context, + static bool UseCanBeReplacedWithSynonym(opt::IRContext* ir_context, opt::Instruction* use_instruction, uint32_t use_in_operand_index);
diff --git a/source/fuzz/transformation_set_function_control.cpp b/source/fuzz/transformation_set_function_control.cpp index d2b61f1..d01e743 100644 --- a/source/fuzz/transformation_set_function_control.cpp +++ b/source/fuzz/transformation_set_function_control.cpp
@@ -28,9 +28,9 @@ } bool TransformationSetFunctionControl::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { opt::Instruction* function_def_instruction = - FindFunctionDefInstruction(context); + FindFunctionDefInstruction(ir_context); if (!function_def_instruction) { // The given function id does not correspond to any function. return false; @@ -69,10 +69,10 @@ return true; } -void TransformationSetFunctionControl::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { +void TransformationSetFunctionControl::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { opt::Instruction* function_def_instruction = - FindFunctionDefInstruction(context); + FindFunctionDefInstruction(ir_context); function_def_instruction->SetInOperand(0, {message_.function_control()}); } @@ -83,11 +83,11 @@ } opt::Instruction* TransformationSetFunctionControl ::FindFunctionDefInstruction( - opt::IRContext* context) const { + opt::IRContext* ir_context) const { // Look through all functions for a function whose defining instruction's // result id matches |message_.function_id|, returning the defining // instruction if found. - for (auto& function : *context->module()) { + for (auto& function : *ir_context->module()) { if (function.DefInst().result_id() == message_.function_id()) { return &function.DefInst(); }
diff --git a/source/fuzz/transformation_set_function_control.h b/source/fuzz/transformation_set_function_control.h index 0526bb9..5109f74 100644 --- a/source/fuzz/transformation_set_function_control.h +++ b/source/fuzz/transformation_set_function_control.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SET_FUNCTION_CONTROL_H_ #define SOURCE_FUZZ_TRANSFORMATION_SET_FUNCTION_CONTROL_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -37,17 +37,20 @@ // at most one of 'Inline' or 'DontInline', and that may not contain 'Pure' // (respectively 'Const') unless the existing function control mask contains // 'Pure' (respectively 'Const'). - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // The function control operand of instruction |message_.function_id| is // over-written with |message_.function_control|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; private: - opt::Instruction* FindFunctionDefInstruction(opt::IRContext* context) const; + opt::Instruction* FindFunctionDefInstruction( + opt::IRContext* ir_context) const; protobufs::TransformationSetFunctionControl message_; };
diff --git a/source/fuzz/transformation_set_loop_control.cpp b/source/fuzz/transformation_set_loop_control.cpp index 9062f17..845ac69 100644 --- a/source/fuzz/transformation_set_loop_control.cpp +++ b/source/fuzz/transformation_set_loop_control.cpp
@@ -31,9 +31,9 @@ } bool TransformationSetLoopControl::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // |message_.block_id| must identify a block that ends with OpLoopMerge. - auto block = context->get_instr_block(message_.block_id()); + auto block = ir_context->get_instr_block(message_.block_id()); if (!block) { return false; } @@ -79,7 +79,8 @@ if ((message_.loop_control() & (SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask)) && - !(PeelCountIsSupported(context) && PartialCountIsSupported(context))) { + !(PeelCountIsSupported(ir_context) && + PartialCountIsSupported(ir_context))) { // At least one of PeelCount or PartialCount is used, but the SPIR-V version // in question does not support these loop controls. return false; @@ -104,11 +105,11 @@ (SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask))); } -void TransformationSetLoopControl::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { +void TransformationSetLoopControl::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { // Grab the loop merge instruction and its associated loop control mask. auto merge_inst = - context->get_instr_block(message_.block_id())->GetMergeInst(); + ir_context->get_instr_block(message_.block_id())->GetMergeInst(); auto existing_loop_control_mask = merge_inst->GetSingleWordInOperand(kLoopControlMaskInOperandIndex); @@ -181,11 +182,11 @@ } bool TransformationSetLoopControl::PartialCountIsSupported( - opt::IRContext* context) { + opt::IRContext* ir_context) { // TODO(afd): We capture the universal environments for which this loop // control is definitely not supported. The check should be refined on // demand for other target environments. - switch (context->grammar().target_env()) { + switch (ir_context->grammar().target_env()) { case SPV_ENV_UNIVERSAL_1_0: case SPV_ENV_UNIVERSAL_1_1: case SPV_ENV_UNIVERSAL_1_2: @@ -197,11 +198,11 @@ } bool TransformationSetLoopControl::PeelCountIsSupported( - opt::IRContext* context) { + opt::IRContext* ir_context) { // TODO(afd): We capture the universal environments for which this loop // control is definitely not supported. The check should be refined on // demand for other target environments. - switch (context->grammar().target_env()) { + switch (ir_context->grammar().target_env()) { case SPV_ENV_UNIVERSAL_1_0: case SPV_ENV_UNIVERSAL_1_1: case SPV_ENV_UNIVERSAL_1_2:
diff --git a/source/fuzz/transformation_set_loop_control.h b/source/fuzz/transformation_set_loop_control.h index 28b148c..f0c364f 100644 --- a/source/fuzz/transformation_set_loop_control.h +++ b/source/fuzz/transformation_set_loop_control.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SET_LOOP_CONTROL_H_ #define SOURCE_FUZZ_TRANSFORMATION_SET_LOOP_CONTROL_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -38,13 +38,14 @@ // instruction. // - |message_.loop_control| must be a legal loop control mask that // only uses controls available in the SPIR-V version associated with - // |context|, and must not add loop controls that are only valid in the + // |ir_context|, and must not add loop controls that are only valid in the // presence of guarantees about what the loop does (e.g. MinIterations). // - |message_.peel_count| (respectively |message_.partial_count|) must be // zero PeelCount (respectively PartialCount) is set in // |message_.loop_control|. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - The loop control operand of the OpLoopMergeInstruction in // |message_.block_id| is overwritten with |message_.loop_control|. @@ -52,16 +53,17 @@ // controls with associated literals that have been removed (e.g. // MinIterations), and any that have been added (PeelCount and/or // PartialCount). - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; // Does the version of SPIR-V being used support the PartialCount loop // control? - static bool PartialCountIsSupported(opt::IRContext* context); + static bool PartialCountIsSupported(opt::IRContext* ir_context); // Does the version of SPIR-V being used support the PeelCount loop control? - static bool PeelCountIsSupported(opt::IRContext* context); + static bool PeelCountIsSupported(opt::IRContext* ir_context); private: // Returns true if and only if |loop_single_bit_mask| is *not* set in
diff --git a/source/fuzz/transformation_set_memory_operands_mask.cpp b/source/fuzz/transformation_set_memory_operands_mask.cpp index a14e1a6..131a499 100644 --- a/source/fuzz/transformation_set_memory_operands_mask.cpp +++ b/source/fuzz/transformation_set_memory_operands_mask.cpp
@@ -42,8 +42,7 @@ } bool TransformationSetMemoryOperandsMask::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { if (message_.memory_operands_mask_index() != 0) { // The following conditions should never be violated, even if // transformations end up being replayed in a different way to the manner in @@ -54,11 +53,11 @@ SpvOpCopyMemory || message_.memory_access_instruction().target_instruction_opcode() == SpvOpCopyMemorySized); - assert(MultipleMemoryOperandMasksAreSupported(context)); + assert(MultipleMemoryOperandMasksAreSupported(ir_context)); } auto instruction = - FindInstruction(message_.memory_access_instruction(), context); + FindInstruction(message_.memory_access_instruction(), ir_context); if (!instruction) { return false; } @@ -94,9 +93,9 @@ } void TransformationSetMemoryOperandsMask::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { auto instruction = - FindInstruction(message_.memory_access_instruction(), context); + FindInstruction(message_.memory_access_instruction(), ir_context); auto original_mask_in_operand_index = GetInOperandIndexForMask( *instruction, message_.memory_operands_mask_index()); // Either add a new operand, if no mask operand was already present, or @@ -182,11 +181,11 @@ } bool TransformationSetMemoryOperandsMask:: - MultipleMemoryOperandMasksAreSupported(opt::IRContext* context) { + MultipleMemoryOperandMasksAreSupported(opt::IRContext* ir_context) { // TODO(afd): We capture the universal environments for which this loop // control is definitely not supported. The check should be refined on // demand for other target environments. - switch (context->grammar().target_env()) { + switch (ir_context->grammar().target_env()) { case SPV_ENV_UNIVERSAL_1_0: case SPV_ENV_UNIVERSAL_1_1: case SPV_ENV_UNIVERSAL_1_2:
diff --git a/source/fuzz/transformation_set_memory_operands_mask.h b/source/fuzz/transformation_set_memory_operands_mask.h index 20ae145..9f5081b 100644 --- a/source/fuzz/transformation_set_memory_operands_mask.h +++ b/source/fuzz/transformation_set_memory_operands_mask.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SET_MEMORY_OPERANDS_MASK_H_ #define SOURCE_FUZZ_TRANSFORMATION_SET_MEMORY_OPERANDS_MASK_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -40,14 +40,16 @@ // - |message_.memory_operands_mask| must be identical to the original memory // operands mask, except that Volatile may be added, and Nontemporal may be // toggled. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Replaces the operands mask identified by // |message_.memory_operands_mask_index| in the instruction described by // |message_.memory_access_instruction| with |message_.memory_operands_mask|, // creating an input operand for the mask if no such operand was present. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; @@ -57,7 +59,8 @@ // Does the version of SPIR-V being used support multiple memory operand // masks on relevant memory access instructions? - static bool MultipleMemoryOperandMasksAreSupported(opt::IRContext* context); + static bool MultipleMemoryOperandMasksAreSupported( + opt::IRContext* ir_context); // Helper function to get the input operand index associated with mask number // |mask_index|. This is a bit tricky if there are multiple masks, because the
diff --git a/source/fuzz/transformation_set_selection_control.cpp b/source/fuzz/transformation_set_selection_control.cpp index ebabdef..bee1e35 100644 --- a/source/fuzz/transformation_set_selection_control.cpp +++ b/source/fuzz/transformation_set_selection_control.cpp
@@ -28,13 +28,13 @@ } bool TransformationSetSelectionControl::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { assert((message_.selection_control() == SpvSelectionControlMaskNone || message_.selection_control() == SpvSelectionControlFlattenMask || message_.selection_control() == SpvSelectionControlDontFlattenMask) && "Selection control should never be set to something other than " "'None', 'Flatten' or 'DontFlatten'"); - if (auto block = context->get_instr_block(message_.block_id())) { + if (auto block = ir_context->get_instr_block(message_.block_id())) { if (auto merge_inst = block->GetMergeInst()) { return merge_inst->opcode() == SpvOpSelectionMerge; } @@ -43,9 +43,9 @@ return false; } -void TransformationSetSelectionControl::Apply(opt::IRContext* context, - FactManager* /*unused*/) const { - context->get_instr_block(message_.block_id()) +void TransformationSetSelectionControl::Apply( + opt::IRContext* ir_context, TransformationContext* /*unused*/) const { + ir_context->get_instr_block(message_.block_id()) ->GetMergeInst() ->SetInOperand(1, {message_.selection_control()}); }
diff --git a/source/fuzz/transformation_set_selection_control.h b/source/fuzz/transformation_set_selection_control.h index 19e0c3c..21fbdda 100644 --- a/source/fuzz/transformation_set_selection_control.h +++ b/source/fuzz/transformation_set_selection_control.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SET_SELECTION_CONTROL_H_ #define SOURCE_FUZZ_TRANSFORMATION_SET_SELECTION_CONTROL_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -35,12 +35,14 @@ // instruction. // - |message_.selection_control| must be one of None, Flatten or // DontFlatten. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - The selection control operand of the OpSelectionMergeInstruction in // |message_.block_id| is overwritten with |message_.selection_control|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_split_block.cpp b/source/fuzz/transformation_split_block.cpp index fc5229e..b020d98 100644 --- a/source/fuzz/transformation_split_block.cpp +++ b/source/fuzz/transformation_split_block.cpp
@@ -35,18 +35,19 @@ } bool TransformationSplitBlock::IsApplicable( - opt::IRContext* context, const FactManager& /*unused*/) const { - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { // We require the id for the new block to be unused. return false; } auto instruction_to_split_before = - FindInstruction(message_.instruction_to_split_before(), context); + FindInstruction(message_.instruction_to_split_before(), ir_context); if (!instruction_to_split_before) { // The instruction describing the block we should split does not exist. return false; } - auto block_to_split = context->get_instr_block(instruction_to_split_before); + auto block_to_split = + ir_context->get_instr_block(instruction_to_split_before); assert(block_to_split && "We should not have managed to find the " "instruction if it was not contained in a block."); @@ -75,16 +76,43 @@ } // We cannot split before an OpPhi unless the OpPhi has exactly one // associated incoming edge. - return !(split_before->opcode() == SpvOpPhi && - split_before->NumInOperands() != 2); + if (split_before->opcode() == SpvOpPhi && + split_before->NumInOperands() != 2) { + return false; + } + + // Splitting the block must not separate the definition of an OpSampledImage + // from its use: the SPIR-V data rules require them to be in the same block. + std::set<uint32_t> sampled_image_result_ids; + bool before_split = true; + for (auto& instruction : *block_to_split) { + if (&instruction == &*split_before) { + before_split = false; + } + if (before_split) { + if (instruction.opcode() == SpvOpSampledImage) { + sampled_image_result_ids.insert(instruction.result_id()); + } + } else { + if (!instruction.WhileEachInId( + [&sampled_image_result_ids](uint32_t* id) -> bool { + return !sampled_image_result_ids.count(*id); + })) { + return false; + } + } + } + + return true; } -void TransformationSplitBlock::Apply(opt::IRContext* context, - FactManager* fact_manager) const { +void TransformationSplitBlock::Apply( + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { opt::Instruction* instruction_to_split_before = - FindInstruction(message_.instruction_to_split_before(), context); + FindInstruction(message_.instruction_to_split_before(), ir_context); opt::BasicBlock* block_to_split = - context->get_instr_block(instruction_to_split_before); + ir_context->get_instr_block(instruction_to_split_before); auto split_before = fuzzerutil::GetIteratorForInstruction( block_to_split, instruction_to_split_before); assert(split_before != block_to_split->end() && @@ -93,14 +121,14 @@ // We need to make sure the module's id bound is large enough to add the // fresh id. - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); // Split the block. - auto new_bb = block_to_split->SplitBasicBlock(context, message_.fresh_id(), + auto new_bb = block_to_split->SplitBasicBlock(ir_context, message_.fresh_id(), split_before); // The split does not automatically add a branch between the two parts of // the original block, so we add one. block_to_split->AddInstruction(MakeUnique<opt::Instruction>( - context, SpvOpBranch, 0, 0, + ir_context, SpvOpBranch, 0, 0, std::initializer_list<opt::Operand>{opt::Operand( spv_operand_type_t::SPV_OPERAND_TYPE_ID, {message_.fresh_id()})})); // If we split before OpPhi instructions, we need to update their @@ -117,12 +145,15 @@ // If the block being split was dead, the new block arising from the split is // also dead. - if (fact_manager->BlockIsDead(block_to_split->id())) { - fact_manager->AddFactBlockIsDead(message_.fresh_id()); + if (transformation_context->GetFactManager()->BlockIsDead( + block_to_split->id())) { + transformation_context->GetFactManager()->AddFactBlockIsDead( + message_.fresh_id()); } // Invalidate all analyses - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); } protobufs::Transformation TransformationSplitBlock::ToMessage() const {
diff --git a/source/fuzz/transformation_split_block.h b/source/fuzz/transformation_split_block.h index a193fc7..3bf6dfd 100644 --- a/source/fuzz/transformation_split_block.h +++ b/source/fuzz/transformation_split_block.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SPLIT_BLOCK_H_ #define SOURCE_FUZZ_TRANSFORMATION_SPLIT_BLOCK_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -40,8 +40,9 @@ // - Splitting 'blk' at 'inst', so that all instructions from 'inst' onwards // appear in a new block that 'blk' directly jumps to must be valid. // - |message_.fresh_id| must not be used by the module. - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // - A new block with label |message_.fresh_id| is inserted right after 'blk' // in program order. @@ -49,7 +50,8 @@ // block. // - 'blk' is made to jump unconditionally to the new block. // - If 'blk' was dead, the new block is also dead. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_store.cpp b/source/fuzz/transformation_store.cpp index 7cb7611..f77afe3 100644 --- a/source/fuzz/transformation_store.cpp +++ b/source/fuzz/transformation_store.cpp
@@ -34,23 +34,23 @@ } bool TransformationStore::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& fact_manager) const { + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const { // The pointer must exist and have a type. - auto pointer = context->get_def_use_mgr()->GetDef(message_.pointer_id()); + auto pointer = ir_context->get_def_use_mgr()->GetDef(message_.pointer_id()); if (!pointer || !pointer->type_id()) { return false; } // The pointer type must indeed be a pointer. - auto pointer_type = context->get_def_use_mgr()->GetDef(pointer->type_id()); + auto pointer_type = ir_context->get_def_use_mgr()->GetDef(pointer->type_id()); assert(pointer_type && "Type id must be defined."); if (pointer_type->opcode() != SpvOpTypePointer) { return false; } // The pointer must not be read only. - if (pointer_type->GetSingleWordInOperand(0) == SpvStorageClassInput) { + if (pointer->IsReadOnlyPointer()) { return false; } @@ -65,7 +65,7 @@ // Determine which instruction we should be inserting before. auto insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); // It must exist, ... if (!insert_before) { return false; @@ -79,14 +79,15 @@ // The block we are inserting into needs to be dead, or else the pointee type // of the pointer we are storing to needs to be irrelevant (otherwise the // store could impact on the observable behaviour of the module). - if (!fact_manager.BlockIsDead( - context->get_instr_block(insert_before)->id()) && - !fact_manager.PointeeValueIsIrrelevant(message_.pointer_id())) { + if (!transformation_context.GetFactManager()->BlockIsDead( + ir_context->get_instr_block(insert_before)->id()) && + !transformation_context.GetFactManager()->PointeeValueIsIrrelevant( + message_.pointer_id())) { return false; } // The value being stored needs to exist and have a type. - auto value = context->get_def_use_mgr()->GetDef(message_.value_id()); + auto value = ir_context->get_def_use_mgr()->GetDef(message_.value_id()); if (!value || !value->type_id()) { return false; } @@ -97,25 +98,25 @@ } // The pointer needs to be available at the insertion point. - if (!fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, message_.pointer_id())) { return false; } // The value needs to be available at the insertion point. - return fuzzerutil::IdIsAvailableBeforeInstruction(context, insert_before, + return fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before, message_.value_id()); } -void TransformationStore::Apply(opt::IRContext* context, - spvtools::fuzz::FactManager* /*unused*/) const { - FindInstruction(message_.instruction_to_insert_before(), context) +void TransformationStore::Apply(opt::IRContext* ir_context, + TransformationContext* /*unused*/) const { + FindInstruction(message_.instruction_to_insert_before(), ir_context) ->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpStore, 0, 0, + ir_context, SpvOpStore, 0, 0, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {message_.pointer_id()}}, {SPV_OPERAND_TYPE_ID, {message_.value_id()}}}))); - context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationStore::ToMessage() const {
diff --git a/source/fuzz/transformation_store.h b/source/fuzz/transformation_store.h index 699afdd..6746aab 100644 --- a/source/fuzz/transformation_store.h +++ b/source/fuzz/transformation_store.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_STORE_H_ #define SOURCE_FUZZ_TRANSFORMATION_STORE_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -42,14 +42,16 @@ // to dominance rules) // - Either the insertion point must be in a dead block, or it must be known // that the pointee value of |message_.pointer_id| is irrelevant - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Adds an instruction of the form: // OpStore |pointer_id| |value_id| // before the instruction identified by // |message_.instruction_to_insert_before|. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_swap_commutable_operands.cpp b/source/fuzz/transformation_swap_commutable_operands.cpp index 49d9de8..b7622a2 100644 --- a/source/fuzz/transformation_swap_commutable_operands.cpp +++ b/source/fuzz/transformation_swap_commutable_operands.cpp
@@ -31,10 +31,10 @@ } bool TransformationSwapCommutableOperands::IsApplicable( - opt::IRContext* context, const spvtools::fuzz::FactManager& /*unused*/ + opt::IRContext* ir_context, const TransformationContext& /*unused*/ ) const { auto instruction = - FindInstruction(message_.instruction_descriptor(), context); + FindInstruction(message_.instruction_descriptor(), ir_context); if (instruction == nullptr) return false; SpvOp opcode = static_cast<SpvOp>( @@ -46,10 +46,10 @@ } void TransformationSwapCommutableOperands::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/ + opt::IRContext* ir_context, TransformationContext* /*unused*/ ) const { auto instruction = - FindInstruction(message_.instruction_descriptor(), context); + FindInstruction(message_.instruction_descriptor(), ir_context); // By design, the instructions defined to be commutative have exactly two // input parameters. std::swap(instruction->GetInOperand(0), instruction->GetInOperand(1));
diff --git a/source/fuzz/transformation_swap_commutable_operands.h b/source/fuzz/transformation_swap_commutable_operands.h index 061e92d..7fe5b70 100644 --- a/source/fuzz/transformation_swap_commutable_operands.h +++ b/source/fuzz/transformation_swap_commutable_operands.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_SWAP_COMMUTABLE_OPERANDS_H_ #define SOURCE_FUZZ_TRANSFORMATION_SWAP_COMMUTABLE_OPERANDS_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,11 +33,13 @@ // - |message_.instruction_descriptor| must identify an existing // commutative instruction - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Swaps the commutable operands. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_toggle_access_chain_instruction.cpp b/source/fuzz/transformation_toggle_access_chain_instruction.cpp index ace331a..ca24a18 100644 --- a/source/fuzz/transformation_toggle_access_chain_instruction.cpp +++ b/source/fuzz/transformation_toggle_access_chain_instruction.cpp
@@ -33,10 +33,10 @@ } bool TransformationToggleAccessChainInstruction::IsApplicable( - opt::IRContext* context, const spvtools::fuzz::FactManager& /*unused*/ + opt::IRContext* ir_context, const TransformationContext& /*unused*/ ) const { auto instruction = - FindInstruction(message_.instruction_descriptor(), context); + FindInstruction(message_.instruction_descriptor(), ir_context); if (instruction == nullptr) { return false; } @@ -56,10 +56,10 @@ } void TransformationToggleAccessChainInstruction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/ + opt::IRContext* ir_context, TransformationContext* /*unused*/ ) const { auto instruction = - FindInstruction(message_.instruction_descriptor(), context); + FindInstruction(message_.instruction_descriptor(), ir_context); SpvOp opcode = instruction->opcode(); if (opcode == SpvOpAccessChain) {
diff --git a/source/fuzz/transformation_toggle_access_chain_instruction.h b/source/fuzz/transformation_toggle_access_chain_instruction.h index 125e1ab..9cd8fd6 100644 --- a/source/fuzz/transformation_toggle_access_chain_instruction.h +++ b/source/fuzz/transformation_toggle_access_chain_instruction.h
@@ -15,9 +15,9 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_TOGGLE_ACCESS_CHAIN_INSTRUCTION_H_ #define SOURCE_FUZZ_TRANSFORMATION_TOGGLE_ACCESS_CHAIN_INSTRUCTION_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" namespace spvtools { @@ -33,11 +33,13 @@ // - |message_.instruction_descriptor| must identify an existing // access chain instruction - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Toggles the access chain instruction. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override;
diff --git a/source/fuzz/transformation_vector_shuffle.cpp b/source/fuzz/transformation_vector_shuffle.cpp index e2d889d..ee64292 100644 --- a/source/fuzz/transformation_vector_shuffle.cpp +++ b/source/fuzz/transformation_vector_shuffle.cpp
@@ -39,38 +39,37 @@ } bool TransformationVectorShuffle::IsApplicable( - opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { // The fresh id must not already be in use. - if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) { + if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) { return false; } // The instruction before which the shuffle will be inserted must exist. auto instruction_to_insert_before = - FindInstruction(message_.instruction_to_insert_before(), context); + FindInstruction(message_.instruction_to_insert_before(), ir_context); if (!instruction_to_insert_before) { return false; } // The first vector must be an instruction with a type id auto vector1_instruction = - context->get_def_use_mgr()->GetDef(message_.vector1()); + ir_context->get_def_use_mgr()->GetDef(message_.vector1()); if (!vector1_instruction || !vector1_instruction->type_id()) { return false; } // The second vector must be an instruction with a type id auto vector2_instruction = - context->get_def_use_mgr()->GetDef(message_.vector2()); + ir_context->get_def_use_mgr()->GetDef(message_.vector2()); if (!vector2_instruction || !vector2_instruction->type_id()) { return false; } auto vector1_type = - context->get_type_mgr()->GetType(vector1_instruction->type_id()); + ir_context->get_type_mgr()->GetType(vector1_instruction->type_id()); // The first vector instruction's type must actually be a vector type. if (!vector1_type->AsVector()) { return false; } auto vector2_type = - context->get_type_mgr()->GetType(vector2_instruction->type_id()); + ir_context->get_type_mgr()->GetType(vector2_instruction->type_id()); // The second vector instruction's type must actually be a vector type. if (!vector2_type->AsVector()) { return false; @@ -92,14 +91,14 @@ } // The module must already declare an appropriate type in which to store the // result of the shuffle. - if (!GetResultTypeId(context, *vector1_type->AsVector()->element_type())) { + if (!GetResultTypeId(ir_context, *vector1_type->AsVector()->element_type())) { return false; } // Each of the vectors used in the shuffle must be available at the insertion // point. for (auto used_instruction : {vector1_instruction, vector2_instruction}) { - if (auto block = context->get_instr_block(used_instruction)) { - if (!context->GetDominatorAnalysis(block->GetParent()) + if (auto block = ir_context->get_instr_block(used_instruction)) { + if (!ir_context->GetDominatorAnalysis(block->GetParent()) ->Dominates(used_instruction, instruction_to_insert_before)) { return false; } @@ -113,7 +112,8 @@ } void TransformationVectorShuffle::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + opt::IRContext* ir_context, + TransformationContext* transformation_context) const { // Make input operands for a shuffle instruction - these comprise the two // vectors being shuffled, followed by the integer literal components. opt::Instruction::OperandList shuffle_operands = { @@ -125,16 +125,18 @@ } uint32_t result_type_id = GetResultTypeId( - context, *GetVectorType(context, message_.vector1())->element_type()); + ir_context, + *GetVectorType(ir_context, message_.vector1())->element_type()); // Add a shuffle instruction right before the instruction identified by // |message_.instruction_to_insert_before|. - FindInstruction(message_.instruction_to_insert_before(), context) + FindInstruction(message_.instruction_to_insert_before(), ir_context) ->InsertBefore(MakeUnique<opt::Instruction>( - context, SpvOpVectorShuffle, result_type_id, message_.fresh_id(), + ir_context, SpvOpVectorShuffle, result_type_id, message_.fresh_id(), shuffle_operands)); - fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id()); - context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); + fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id()); + ir_context->InvalidateAnalysesExceptFor( + opt::IRContext::Analysis::kAnalysisNone); // Add synonym facts relating the defined elements of the shuffle result to // the vector components that they come from. @@ -158,24 +160,26 @@ // Get a data descriptor for the component of the input vector to which // |component| refers. if (component < - GetVectorType(context, message_.vector1())->element_count()) { + GetVectorType(ir_context, message_.vector1())->element_count()) { descriptor_for_source_component = MakeDataDescriptor(message_.vector1(), {component}); } else { auto index_into_vector_2 = component - - GetVectorType(context, message_.vector1())->element_count(); - assert(index_into_vector_2 < - GetVectorType(context, message_.vector2())->element_count() && - "Vector shuffle index is out of bounds."); + GetVectorType(ir_context, message_.vector1())->element_count(); + assert( + index_into_vector_2 < + GetVectorType(ir_context, message_.vector2())->element_count() && + "Vector shuffle index is out of bounds."); descriptor_for_source_component = MakeDataDescriptor(message_.vector2(), {index_into_vector_2}); } // Add a fact relating this input vector component with the associated // result component. - fact_manager->AddFactDataSynonym(descriptor_for_result_component, - descriptor_for_source_component, context); + transformation_context->GetFactManager()->AddFactDataSynonym( + descriptor_for_result_component, descriptor_for_source_component, + ir_context); } } @@ -186,16 +190,16 @@ } uint32_t TransformationVectorShuffle::GetResultTypeId( - opt::IRContext* context, const opt::analysis::Type& element_type) const { + opt::IRContext* ir_context, const opt::analysis::Type& element_type) const { opt::analysis::Vector result_type( &element_type, static_cast<uint32_t>(message_.component_size())); - return context->get_type_mgr()->GetId(&result_type); + return ir_context->get_type_mgr()->GetId(&result_type); } opt::analysis::Vector* TransformationVectorShuffle::GetVectorType( - opt::IRContext* context, uint32_t id_of_vector) { - return context->get_type_mgr() - ->GetType(context->get_def_use_mgr()->GetDef(id_of_vector)->type_id()) + opt::IRContext* ir_context, uint32_t id_of_vector) { + return ir_context->get_type_mgr() + ->GetType(ir_context->get_def_use_mgr()->GetDef(id_of_vector)->type_id()) ->AsVector(); }
diff --git a/source/fuzz/transformation_vector_shuffle.h b/source/fuzz/transformation_vector_shuffle.h index 81ed227..f73fc31 100644 --- a/source/fuzz/transformation_vector_shuffle.h +++ b/source/fuzz/transformation_vector_shuffle.h
@@ -15,10 +15,11 @@ #ifndef SOURCE_FUZZ_TRANSFORMATION_VECTOR_SHUFFLE_H_ #define SOURCE_FUZZ_TRANSFORMATION_VECTOR_SHUFFLE_H_ -#include "source/fuzz/fact_manager.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation.h" +#include "source/fuzz/transformation_context.h" #include "source/opt/ir_context.h" + #include "source/opt/types.h" namespace spvtools { @@ -45,8 +46,9 @@ // - The module must already contain a vector type with the same element type // as |message_.vector1| and |message_.vector2|, and with the size of // |message_component| as its element count - bool IsApplicable(opt::IRContext* context, - const FactManager& fact_manager) const override; + bool IsApplicable( + opt::IRContext* ir_context, + const TransformationContext& transformation_context) const override; // Inserts an OpVectorShuffle instruction before // |message_.instruction_to_insert_before|, shuffles vectors @@ -58,19 +60,20 @@ // result vector is a contiguous sub-range of one of the input vectors, a // fact is added to record that |message_.fresh_id| is synonymous with this // sub-range. - void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + void Apply(opt::IRContext* ir_context, + TransformationContext* transformation_context) const override; protobufs::Transformation ToMessage() const override; private: - // Returns a type id that already exists in |context| suitable for + // Returns a type id that already exists in |ir_context| suitable for // representing the result of the shuffle, where |element_type| is known to // be the common element type of the vectors to which the shuffle is being // applied. Returns 0 if no such id exists. - uint32_t GetResultTypeId(opt::IRContext* context, + uint32_t GetResultTypeId(opt::IRContext* ir_context, const opt::analysis::Type& element_type) const; - static opt::analysis::Vector* GetVectorType(opt::IRContext* context, + static opt::analysis::Vector* GetVectorType(opt::IRContext* ir_context, uint32_t id_of_vector); protobufs::TransformationVectorShuffle message_;
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 1428c74..0047c34 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt
@@ -34,6 +34,7 @@ dead_variable_elimination.h decompose_initialized_variables_pass.h decoration_manager.h + debug_info_manager.h def_use_manager.h desc_sroa.h dominator_analysis.h @@ -141,6 +142,7 @@ dead_variable_elimination.cpp decompose_initialized_variables_pass.cpp decoration_manager.cpp + debug_info_manager.cpp def_use_manager.cpp desc_sroa.cpp dominator_analysis.cpp
diff --git a/source/opt/code_sink.cpp b/source/opt/code_sink.cpp index 9d54ee5..4c88cd4 100644 --- a/source/opt/code_sink.cpp +++ b/source/opt/code_sink.cpp
@@ -177,7 +177,7 @@ return true; } - if (base_ptr->IsReadOnlyVariable()) { + if (base_ptr->IsReadOnlyPointer()) { return false; }
diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp new file mode 100644 index 0000000..9d98584 --- /dev/null +++ b/source/opt/debug_info_manager.cpp
@@ -0,0 +1,297 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/debug_info_manager.h" + +#include <cassert> + +#include "source/opt/ir_context.h" + +// Constants for OpenCL.DebugInfo.100 extension instructions. + +static const uint32_t kOpLineOperandLineIndex = 1; +static const uint32_t kLineOperandIndexDebugFunction = 7; +static const uint32_t kLineOperandIndexDebugLexicalBlock = 5; +static const uint32_t kDebugFunctionOperandFunctionIndex = 13; +static const uint32_t kDebugInlinedAtOperandInlinedIndex = 6; + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +void SetInlinedOperand(Instruction* dbg_inlined_at, uint32_t inlined_operand) { + assert(dbg_inlined_at); + assert(dbg_inlined_at->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugInlinedAt); + if (dbg_inlined_at->NumOperands() <= kDebugInlinedAtOperandInlinedIndex) { + dbg_inlined_at->AddOperand({SPV_OPERAND_TYPE_RESULT_ID, {inlined_operand}}); + } else { + dbg_inlined_at->SetOperand(kDebugInlinedAtOperandInlinedIndex, + {inlined_operand}); + } +} + +uint32_t GetInlinedOperand(Instruction* dbg_inlined_at) { + assert(dbg_inlined_at); + assert(dbg_inlined_at->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugInlinedAt); + if (dbg_inlined_at->NumOperands() <= kDebugInlinedAtOperandInlinedIndex) + return kNoInlinedAt; + return dbg_inlined_at->GetSingleWordOperand( + kDebugInlinedAtOperandInlinedIndex); +} + +} // namespace + +DebugInfoManager::DebugInfoManager(IRContext* c) : context_(c) { + AnalyzeDebugInsts(*c->module()); +} + +Instruction* DebugInfoManager::GetDbgInst(uint32_t id) { + auto dbg_inst_it = id_to_dbg_inst_.find(id); + return dbg_inst_it == id_to_dbg_inst_.end() ? nullptr : dbg_inst_it->second; +} + +void DebugInfoManager::RegisterDbgInst(Instruction* inst) { + assert( + inst->NumInOperands() != 0 && + context()->get_feature_mgr()->GetExtInstImportId_OpenCL100DebugInfo() == + inst->GetInOperand(0).words[0] && + "Given instruction is not a debug instruction"); + id_to_dbg_inst_[inst->result_id()] = inst; +} + +void DebugInfoManager::RegisterDbgFunction(Instruction* inst) { + assert(inst->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugFunction && + "inst is not a DebugFunction"); + auto fn_id = inst->GetSingleWordOperand(kDebugFunctionOperandFunctionIndex); + assert( + fn_id_to_dbg_fn_.find(fn_id) == fn_id_to_dbg_fn_.end() && + "Register DebugFunction for a function that already has DebugFunction"); + fn_id_to_dbg_fn_[fn_id] = inst; +} + +uint32_t DebugInfoManager::CreateDebugInlinedAt(const Instruction* line, + const DebugScope& scope) { + if (context()->get_feature_mgr()->GetExtInstImportId_OpenCL100DebugInfo() == + 0) + return kNoInlinedAt; + + uint32_t line_number = 0; + if (line == nullptr) { + auto* lexical_scope_inst = GetDbgInst(scope.GetLexicalScope()); + if (lexical_scope_inst == nullptr) return kNoInlinedAt; + OpenCLDebugInfo100Instructions debug_opcode = + lexical_scope_inst->GetOpenCL100DebugOpcode(); + switch (debug_opcode) { + case OpenCLDebugInfo100DebugFunction: + line_number = lexical_scope_inst->GetSingleWordOperand( + kLineOperandIndexDebugFunction); + break; + case OpenCLDebugInfo100DebugLexicalBlock: + line_number = lexical_scope_inst->GetSingleWordOperand( + kLineOperandIndexDebugLexicalBlock); + break; + case OpenCLDebugInfo100DebugTypeComposite: + case OpenCLDebugInfo100DebugCompilationUnit: + assert(false && + "DebugTypeComposite and DebugCompilationUnit are lexical " + "scopes, but we inline functions into a function or a block " + "of a function, not into a struct/class or a global scope."); + break; + default: + assert(false && + "Unreachable. a debug extension instruction for a " + "lexical scope must be DebugFunction, DebugTypeComposite, " + "DebugLexicalBlock, or DebugCompilationUnit."); + break; + } + } else { + line_number = line->GetSingleWordOperand(kOpLineOperandLineIndex); + } + + uint32_t result_id = context()->TakeNextId(); + std::unique_ptr<Instruction> inlined_at(new Instruction( + context(), SpvOpExtInst, context()->get_type_mgr()->GetVoidTypeId(), + result_id, + { + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {context() + ->get_feature_mgr() + ->GetExtInstImportId_OpenCL100DebugInfo()}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {static_cast<uint32_t>(OpenCLDebugInfo100DebugInlinedAt)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {line_number}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {scope.GetLexicalScope()}}, + })); + // |scope| already has DebugInlinedAt. We put the existing DebugInlinedAt + // into the Inlined operand of this new DebugInlinedAt. + if (scope.GetInlinedAt() != kNoInlinedAt) { + inlined_at->AddOperand({spv_operand_type_t::SPV_OPERAND_TYPE_RESULT_ID, + {scope.GetInlinedAt()}}); + } + RegisterDbgInst(inlined_at.get()); + context()->module()->AddExtInstDebugInfo(std::move(inlined_at)); + return result_id; +} + +DebugScope DebugInfoManager::BuildDebugScope( + const DebugScope& callee_instr_scope, + DebugInlinedAtContext* inlined_at_ctx) { + return DebugScope(callee_instr_scope.GetLexicalScope(), + BuildDebugInlinedAtChain(callee_instr_scope.GetInlinedAt(), + inlined_at_ctx)); +} + +uint32_t DebugInfoManager::BuildDebugInlinedAtChain( + uint32_t callee_inlined_at, DebugInlinedAtContext* inlined_at_ctx) { + if (inlined_at_ctx->GetScopeOfCallInstruction().GetLexicalScope() == + kNoDebugScope) + return kNoInlinedAt; + + // Reuse the already generated DebugInlinedAt chain if exists. + uint32_t already_generated_chain_head_id = + inlined_at_ctx->GetDebugInlinedAtChain(callee_inlined_at); + if (already_generated_chain_head_id != kNoInlinedAt) { + return already_generated_chain_head_id; + } + + const uint32_t new_dbg_inlined_at_id = + CreateDebugInlinedAt(inlined_at_ctx->GetLineOfCallInstruction(), + inlined_at_ctx->GetScopeOfCallInstruction()); + if (new_dbg_inlined_at_id == kNoInlinedAt) return kNoInlinedAt; + + if (callee_inlined_at == kNoInlinedAt) { + inlined_at_ctx->SetDebugInlinedAtChain(kNoInlinedAt, new_dbg_inlined_at_id); + return new_dbg_inlined_at_id; + } + + uint32_t chain_head_id = kNoInlinedAt; + uint32_t chain_iter_id = callee_inlined_at; + Instruction* last_inlined_at_in_chain = nullptr; + do { + Instruction* new_inlined_at_in_chain = CloneDebugInlinedAt( + chain_iter_id, /* insert_before */ last_inlined_at_in_chain); + assert(new_inlined_at_in_chain != nullptr); + + // Set DebugInlinedAt of the new scope as the head of the chain. + if (chain_head_id == kNoInlinedAt) + chain_head_id = new_inlined_at_in_chain->result_id(); + + // Previous DebugInlinedAt of the chain must point to the new + // DebugInlinedAt as its Inlined operand to build a recursive + // chain. + if (last_inlined_at_in_chain != nullptr) { + SetInlinedOperand(last_inlined_at_in_chain, + new_inlined_at_in_chain->result_id()); + } + last_inlined_at_in_chain = new_inlined_at_in_chain; + + chain_iter_id = GetInlinedOperand(new_inlined_at_in_chain); + } while (chain_iter_id != kNoInlinedAt); + + // Put |new_dbg_inlined_at_id| into the end of the chain. + SetInlinedOperand(last_inlined_at_in_chain, new_dbg_inlined_at_id); + + // Keep the new chain information that will be reused it. + inlined_at_ctx->SetDebugInlinedAtChain(callee_inlined_at, chain_head_id); + return chain_head_id; +} + +Instruction* DebugInfoManager::GetDebugInfoNone() { + if (debug_info_none_inst_ != nullptr) return debug_info_none_inst_; + + uint32_t result_id = context()->TakeNextId(); + std::unique_ptr<Instruction> dbg_info_none_inst(new Instruction( + context(), SpvOpExtInst, context()->get_type_mgr()->GetVoidTypeId(), + result_id, + { + {SPV_OPERAND_TYPE_RESULT_ID, + {context() + ->get_feature_mgr() + ->GetExtInstImportId_OpenCL100DebugInfo()}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {static_cast<uint32_t>(OpenCLDebugInfo100DebugInfoNone)}}, + })); + + // Add to the front of |ext_inst_debuginfo_|. + debug_info_none_inst_ = + context()->module()->ext_inst_debuginfo_begin()->InsertBefore( + std::move(dbg_info_none_inst)); + + RegisterDbgInst(debug_info_none_inst_); + return debug_info_none_inst_; +} + +Instruction* DebugInfoManager::GetDebugInlinedAt(uint32_t dbg_inlined_at_id) { + auto* inlined_at = GetDbgInst(dbg_inlined_at_id); + if (inlined_at == nullptr) return nullptr; + if (inlined_at->GetOpenCL100DebugOpcode() != + OpenCLDebugInfo100DebugInlinedAt) { + return nullptr; + } + return inlined_at; +} + +Instruction* DebugInfoManager::CloneDebugInlinedAt(uint32_t clone_inlined_at_id, + Instruction* insert_before) { + auto* inlined_at = GetDebugInlinedAt(clone_inlined_at_id); + if (inlined_at == nullptr) return nullptr; + std::unique_ptr<Instruction> new_inlined_at(inlined_at->Clone(context())); + new_inlined_at->SetResultId(context()->TakeNextId()); + RegisterDbgInst(new_inlined_at.get()); + if (insert_before != nullptr) + return insert_before->InsertBefore(std::move(new_inlined_at)); + return context()->module()->ext_inst_debuginfo_end()->InsertBefore( + std::move(new_inlined_at)); +} + +void DebugInfoManager::AnalyzeDebugInst(Instruction* dbg_inst) { + if (dbg_inst->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100InstructionsMax) + return; + + RegisterDbgInst(dbg_inst); + + if (dbg_inst->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugFunction) { + assert(GetDebugFunction(dbg_inst->GetSingleWordOperand( + kDebugFunctionOperandFunctionIndex)) == nullptr && + "Two DebugFunction instruction exists for a single OpFunction."); + RegisterDbgFunction(dbg_inst); + } + + if (debug_info_none_inst_ == nullptr && + dbg_inst->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugInfoNone) { + debug_info_none_inst_ = dbg_inst; + } +} + +void DebugInfoManager::AnalyzeDebugInsts(Module& module) { + debug_info_none_inst_ = nullptr; + module.ForEachInst([this](Instruction* cpi) { AnalyzeDebugInst(cpi); }); + + // Move |debug_info_none_inst_| to the beginning of the debug instruction + // list. + if (debug_info_none_inst_ != nullptr && + debug_info_none_inst_->PreviousNode() != nullptr && + debug_info_none_inst_->PreviousNode()->GetOpenCL100DebugOpcode() != + OpenCLDebugInfo100InstructionsMax) { + debug_info_none_inst_->InsertBefore( + &*context()->module()->ext_inst_debuginfo_begin()); + } +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools
diff --git a/source/opt/debug_info_manager.h b/source/opt/debug_info_manager.h new file mode 100644 index 0000000..0c7186e --- /dev/null +++ b/source/opt/debug_info_manager.h
@@ -0,0 +1,169 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DEBUG_INFO_MANAGER_H_ +#define SOURCE_OPT_DEBUG_INFO_MANAGER_H_ + +#include <unordered_map> + +#include "source/opt/instruction.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +// When an instruction of a callee function is inlined to its caller function, +// we need the line and the scope information of the function call instruction +// to generate DebugInlinedAt. This class keeps the data. For multiple inlining +// of a single instruction, we have to create multiple DebugInlinedAt +// instructions as a chain. This class keeps the information of the generated +// DebugInlinedAt chains to reduce the number of chains. +class DebugInlinedAtContext { + public: + explicit DebugInlinedAtContext(Instruction* call_inst) + : call_inst_line_(call_inst->dbg_line_inst()), + call_inst_scope_(call_inst->GetDebugScope()) {} + + const Instruction* GetLineOfCallInstruction() { return call_inst_line_; } + const DebugScope& GetScopeOfCallInstruction() { return call_inst_scope_; } + // Puts the DebugInlinedAt chain that is generated for the callee instruction + // whose DebugInlinedAt of DebugScope is |callee_instr_inlined_at| into + // |callee_inlined_at2chain_|. + void SetDebugInlinedAtChain(uint32_t callee_instr_inlined_at, + uint32_t chain_head_id) { + callee_inlined_at2chain_[callee_instr_inlined_at] = chain_head_id; + } + // Gets the DebugInlinedAt chain from |callee_inlined_at2chain_|. + uint32_t GetDebugInlinedAtChain(uint32_t callee_instr_inlined_at) { + auto chain_itr = callee_inlined_at2chain_.find(callee_instr_inlined_at); + if (chain_itr != callee_inlined_at2chain_.end()) return chain_itr->second; + return kNoInlinedAt; + } + + private: + // The line information of the function call instruction that will be + // replaced by the callee function. + const Instruction* call_inst_line_; + + // The scope information of the function call instruction that will be + // replaced by the callee function. + const DebugScope call_inst_scope_; + + // Map from DebugInlinedAt ids of callee to head ids of new generated + // DebugInlinedAt chain. + std::unordered_map<uint32_t, uint32_t> callee_inlined_at2chain_; +}; + +// A class for analyzing, managing, and creating OpenCL.DebugInfo.100 extension +// instructions. +class DebugInfoManager { + public: + // Constructs a debug information manager from the given |context|. + DebugInfoManager(IRContext* context); + + DebugInfoManager(const DebugInfoManager&) = delete; + DebugInfoManager(DebugInfoManager&&) = delete; + DebugInfoManager& operator=(const DebugInfoManager&) = delete; + DebugInfoManager& operator=(DebugInfoManager&&) = delete; + + friend bool operator==(const DebugInfoManager&, const DebugInfoManager&); + friend bool operator!=(const DebugInfoManager& lhs, + const DebugInfoManager& rhs) { + return !(lhs == rhs); + } + + // Analyzes OpenCL.DebugInfo.100 instruction |dbg_inst|. + void AnalyzeDebugInst(Instruction* dbg_inst); + + // Creates new DebugInlinedAt and returns its id. Its line operand is the + // line number of |line| if |line| is not nullptr. Otherwise, its line operand + // is the line number of lexical scope of |scope|. Its Scope and Inlined + // operands are Scope and Inlined of |scope|. + uint32_t CreateDebugInlinedAt(const Instruction* line, + const DebugScope& scope); + + // Returns a DebugInfoNone instruction. + Instruction* GetDebugInfoNone(); + + // Returns DebugInlinedAt whose id is |dbg_inlined_at_id|. If it does not + // exist or it is not a DebugInlinedAt instruction, return nullptr. + Instruction* GetDebugInlinedAt(uint32_t dbg_inlined_at_id); + + // Returns DebugFunction whose Function operand is |fn_id|. If it does not + // exist, return nullptr. + Instruction* GetDebugFunction(uint32_t fn_id) { + auto dbg_fn_it = fn_id_to_dbg_fn_.find(fn_id); + return dbg_fn_it == fn_id_to_dbg_fn_.end() ? nullptr : dbg_fn_it->second; + } + + // Clones DebugInlinedAt whose id is |clone_inlined_at_id|. If + // |clone_inlined_at_id| is not an id of DebugInlinedAt, returns nullptr. + // If |insert_before| is given, inserts the new DebugInlinedAt before it. + // Otherwise, inserts the new DebugInlinedAt into the debug instruction + // section of the module. + Instruction* CloneDebugInlinedAt(uint32_t clone_inlined_at_id, + Instruction* insert_before = nullptr); + + // Returns the debug scope corresponding to an inlining instruction in the + // scope |callee_instr_scope| into |inlined_at_ctx|. Generates all new + // debug instructions needed to represent the scope. + DebugScope BuildDebugScope(const DebugScope& callee_instr_scope, + DebugInlinedAtContext* inlined_at_ctx); + + // Returns DebugInlinedAt corresponding to inlining an instruction, which + // was inlined at |callee_inlined_at|, into |inlined_at_ctx|. Generates all + // new debug instructions needed to represent the DebugInlinedAt. + uint32_t BuildDebugInlinedAtChain(uint32_t callee_inlined_at, + DebugInlinedAtContext* inlined_at_ctx); + + private: + IRContext* context() { return context_; } + + // Analyzes OpenCL.DebugInfo.100 instructions in the given |module| and + // populates data structures in this class. + void AnalyzeDebugInsts(Module& module); + + // Returns the debug instruction whose id is |id|. Returns |nullptr| if one + // does not exists. + Instruction* GetDbgInst(uint32_t id); + + // Registers the debug instruction |inst| into |id_to_dbg_inst_| using id of + // |inst| as a key. + void RegisterDbgInst(Instruction* inst); + + // Register the DebugFunction instruction |inst|. The function referenced + // in |inst| must not already be registered. + void RegisterDbgFunction(Instruction* inst); + + IRContext* context_; + + // Mapping from ids of OpenCL.DebugInfo.100 extension instructions + // to their Instruction instances. + std::unordered_map<uint32_t, Instruction*> id_to_dbg_inst_; + + // Mapping from function's ids to DebugFunction instructions whose + // operand is the function. + std::unordered_map<uint32_t, Instruction*> fn_id_to_dbg_fn_; + + // DebugInfoNone instruction. We need only a single DebugInfoNone. + // To reuse the existing one, we keep it using this member variable. + Instruction* debug_info_none_inst_; +}; + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DEBUG_INFO_MANAGER_H_
diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp index c9346e1..da5073a 100644 --- a/source/opt/dominator_tree.cpp +++ b/source/opt/dominator_tree.cpp
@@ -241,6 +241,7 @@ bool DominatorTree::Dominates(const DominatorTreeNode* a, const DominatorTreeNode* b) const { + if (!a || !b) return false; // Node A dominates node B if they are the same. if (a == b) return true;
diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp index 0b73b2d..5b8f4ec 100644 --- a/source/opt/eliminate_dead_members_pass.cpp +++ b/source/opt/eliminate_dead_members_pass.cpp
@@ -19,6 +19,7 @@ namespace { const uint32_t kRemovedMember = 0xFFFFFFFF; +const uint32_t kSpecConstOpOpcodeIdx = 0; } namespace spvtools { @@ -40,7 +41,22 @@ // we have to mark them as fully used just to be safe. for (auto& inst : get_module()->types_values()) { if (inst.opcode() == SpvOpSpecConstantOp) { - MarkTypeAsFullyUsed(inst.type_id()); + switch (inst.GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) { + case SpvOpCompositeExtract: + MarkMembersAsLiveForExtract(&inst); + break; + case SpvOpCompositeInsert: + // Nothing specific to do. + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + assert(false && "Not implemented yet."); + break; + default: + break; + } } else if (inst.opcode() == SpvOpVariable) { switch (inst.GetSingleWordInOperand(0)) { case SpvStorageClassInput: @@ -153,13 +169,17 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract( const Instruction* inst) { - assert(inst->opcode() == SpvOpCompositeExtract); + assert(inst->opcode() == SpvOpCompositeExtract || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeExtract)); - uint32_t composite_id = inst->GetSingleWordInOperand(0); + uint32_t first_operand = (inst->opcode() == SpvOpSpecConstantOp ? 1 : 0); + uint32_t composite_id = inst->GetSingleWordInOperand(first_operand); Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id); uint32_t type_id = composite_inst->type_id(); - for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) { Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); uint32_t member_idx = inst->GetSingleWordInOperand(i); switch (type_inst->opcode()) { @@ -295,10 +315,22 @@ modified |= UpdateOpArrayLength(inst); break; case SpvOpSpecConstantOp: - assert(false && "Not yet implemented."); - // with OpCompositeExtract, OpCompositeInsert - // For kernels: OpAccessChain, OpInBoundsAccessChain, OpPtrAccessChain, - // OpInBoundsPtrAccessChain + switch (inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) { + case SpvOpCompositeExtract: + modified |= UpdateCompsiteExtract(inst); + break; + case SpvOpCompositeInsert: + modified |= UpdateCompositeInsert(inst); + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + assert(false && "Not implemented yet."); + break; + default: + break; + } break; default: break; @@ -393,7 +425,8 @@ } bool EliminateDeadMembersPass::UpdateConstantComposite(Instruction* inst) { - assert(inst->opcode() == SpvOpConstantComposite || + assert(inst->opcode() == SpvOpSpecConstantComposite || + inst->opcode() == SpvOpConstantComposite || inst->opcode() == SpvOpCompositeConstruct); uint32_t type_id = inst->type_id(); @@ -506,14 +539,25 @@ } bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { - uint32_t object_id = inst->GetSingleWordInOperand(0); + assert(inst->opcode() == SpvOpCompositeExtract || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeExtract)); + + uint32_t first_operand = 0; + if (inst->opcode() == SpvOpSpecConstantOp) { + first_operand = 1; + } + uint32_t object_id = inst->GetSingleWordInOperand(first_operand); Instruction* object_inst = get_def_use_mgr()->GetDef(object_id); uint32_t type_id = object_inst->type_id(); Instruction::OperandList new_operands; bool modified = false; - new_operands.emplace_back(inst->GetInOperand(0)); - for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + for (uint32_t i = 0; i < first_operand + 1; i++) { + new_operands.emplace_back(inst->GetInOperand(i)); + } + for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) { uint32_t member_idx = inst->GetSingleWordInOperand(i); uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx); assert(new_member_idx != kRemovedMember); @@ -526,8 +570,6 @@ Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); switch (type_inst->opcode()) { case SpvOpTypeStruct: - assert(i != 1 || (inst->opcode() != SpvOpPtrAccessChain && - inst->opcode() != SpvOpInBoundsPtrAccessChain)); // The type will have already been rewriten, so use the new member // index. type_id = type_inst->GetSingleWordInOperand(new_member_idx); @@ -552,15 +594,27 @@ } bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) { - uint32_t composite_id = inst->GetSingleWordInOperand(1); + assert(inst->opcode() == SpvOpCompositeInsert || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeInsert)); + + uint32_t first_operand = 0; + if (inst->opcode() == SpvOpSpecConstantOp) { + first_operand = 1; + } + + uint32_t composite_id = inst->GetSingleWordInOperand(first_operand + 1); Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id); uint32_t type_id = composite_inst->type_id(); Instruction::OperandList new_operands; bool modified = false; - new_operands.emplace_back(inst->GetInOperand(0)); - new_operands.emplace_back(inst->GetInOperand(1)); - for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + + for (uint32_t i = 0; i < first_operand + 2; ++i) { + new_operands.emplace_back(inst->GetInOperand(i)); + } + for (uint32_t i = first_operand + 2; i < inst->NumInOperands(); ++i) { uint32_t member_idx = inst->GetSingleWordInOperand(i); uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx); if (new_member_idx == kRemovedMember) {
diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp index b4d6f1b..ad70c1e 100644 --- a/source/opt/feature_manager.cpp +++ b/source/opt/feature_manager.cpp
@@ -78,6 +78,8 @@ void FeatureManager::AddExtInstImportIds(Module* module) { extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450"); + extinst_importid_OpenCL100DebugInfo_ = + module->GetExtInstImportId("OpenCL.DebugInfo.100"); } bool operator==(const FeatureManager& a, const FeatureManager& b) { @@ -100,6 +102,11 @@ return false; } + if (a.extinst_importid_OpenCL100DebugInfo_ != + b.extinst_importid_OpenCL100DebugInfo_) { + return false; + } + return true; } } // namespace opt
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h index 881d5e6..66d1cba 100644 --- a/source/opt/feature_manager.h +++ b/source/opt/feature_manager.h
@@ -51,6 +51,10 @@ return extinst_importid_GLSLstd450_; } + uint32_t GetExtInstImportId_OpenCL100DebugInfo() const { + return extinst_importid_OpenCL100DebugInfo_; + } + friend bool operator==(const FeatureManager& a, const FeatureManager& b); friend bool operator!=(const FeatureManager& a, const FeatureManager& b) { return !(a == b); @@ -84,6 +88,10 @@ // Common external instruction import ids, cached for performance. uint32_t extinst_importid_GLSLstd450_ = 0; + + // Common OpenCL100DebugInfo external instruction import ids, cached + // for performance. + uint32_t extinst_importid_OpenCL100DebugInfo_ = 0; }; } // namespace opt
diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 5d50f37..320f8ca 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp
@@ -84,9 +84,12 @@ } } - for (auto& di : debug_insts_in_header_) { - if (!di.WhileEachInst(f, run_on_debug_line_insts)) { - return false; + if (!debug_insts_in_header_.empty()) { + Instruction* di = &debug_insts_in_header_.front(); + while (di != nullptr) { + Instruction* next_instruction = di->NextNode(); + if (!di->WhileEachInst(f, run_on_debug_line_insts)) return false; + di = next_instruction; } } @@ -118,9 +121,9 @@ } for (const auto& di : debug_insts_in_header_) { - if (!di.WhileEachInst(f, run_on_debug_line_insts)) { + if (!static_cast<const Instruction*>(&di)->WhileEachInst( + f, run_on_debug_line_insts)) return false; - } } for (const auto& bb : blocks_) { @@ -151,6 +154,18 @@ ->ForEachInst(f, run_on_debug_line_insts); } +void Function::ForEachDebugInstructionsInHeader( + const std::function<void(Instruction*)>& f) { + if (debug_insts_in_header_.empty()) return; + + Instruction* di = &debug_insts_in_header_.front(); + while (di != nullptr) { + Instruction* next_instruction = di->NextNode(); + di->ForEachInst(f); + di = next_instruction; + } +} + BasicBlock* Function::InsertBasicBlockAfter( std::unique_ptr<BasicBlock>&& new_block, BasicBlock* position) { for (auto bb_iter = begin(); bb_iter != end(); ++bb_iter) {
diff --git a/source/opt/function.h b/source/opt/function.h index f208d8e..d569bf9 100644 --- a/source/opt/function.h +++ b/source/opt/function.h
@@ -88,6 +88,10 @@ // Returns the entry basic block for this function. const std::unique_ptr<BasicBlock>& entry() const { return blocks_.front(); } + // Returns the last basic block in this function. + BasicBlock* tail() { return blocks_.back().get(); } + const BasicBlock* tail() const { return blocks_.back().get(); } + iterator begin() { return iterator(&blocks_, blocks_.begin()); } iterator end() { return iterator(&blocks_, blocks_.end()); } const_iterator begin() const { return cbegin(); } @@ -129,6 +133,11 @@ void ForEachParam(const std::function<void(Instruction*)>& f, bool run_on_debug_line_insts = false); + // Runs the given function |f| on each debug instruction in this function's + // header in order. + void ForEachDebugInstructionsInHeader( + const std::function<void(Instruction*)>& f); + BasicBlock* InsertBasicBlockAfter(std::unique_ptr<BasicBlock>&& new_block, BasicBlock* position); @@ -192,13 +201,13 @@ } inline void Function::MoveBasicBlockToAfter(uint32_t id, BasicBlock* ip) { - auto block_to_move = std::move(*FindBlock(id).Get()); + std::unique_ptr<BasicBlock> block_to_move = std::move(*FindBlock(id).Get()); + blocks_.erase(std::find(std::begin(blocks_), std::end(blocks_), nullptr)); assert(block_to_move->GetParent() == ip->GetParent() && "Both blocks have to be in the same function."); InsertBasicBlockAfter(std::move(block_to_move), ip); - blocks_.erase(std::find(std::begin(blocks_), std::end(blocks_), nullptr)); } inline void Function::RemoveEmptyBlocks() {
diff --git a/source/opt/graphics_robust_access_pass.cpp b/source/opt/graphics_robust_access_pass.cpp index 22c979c..db14020 100644 --- a/source/opt/graphics_robust_access_pass.cpp +++ b/source/opt/graphics_robust_access_pass.cpp
@@ -802,8 +802,11 @@ opt::Instruction* image_texel_pointer) { // TODO(dneto): Write tests for this code. // TODO(dneto): Use signed-clamp + (void)(image_texel_pointer); return SPV_SUCCESS; + // Do not compile this code until it is ready to be used. +#if 0 // Example: // %texel_ptr = OpImageTexelPointer %texel_ptr_type %image_ptr %coord // %sample @@ -1035,6 +1038,7 @@ def_use_mgr->AnalyzeInstUse(image_texel_pointer); return SPV_SUCCESS; +#endif } opt::Instruction* GraphicsRobustAccessPass::InsertInst(
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index 3c874a7..cb5a126 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp
@@ -20,6 +20,7 @@ #include <utility> #include "source/cfa.h" +#include "source/opt/reflect.h" #include "source/util/make_unique.h" // Indices of operands in SPIR-V instructions @@ -83,19 +84,31 @@ } void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, - std::unique_ptr<BasicBlock>* block_ptr) { + std::unique_ptr<BasicBlock>* block_ptr, + const Instruction* line_inst, + const DebugScope& dbg_scope) { std::unique_ptr<Instruction> newStore( new Instruction(context(), SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); + if (line_inst != nullptr) { + newStore->dbg_line_insts().push_back(*line_inst); + } + newStore->SetDebugScope(dbg_scope); (*block_ptr)->AddInstruction(std::move(newStore)); } void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, - std::unique_ptr<BasicBlock>* block_ptr) { + std::unique_ptr<BasicBlock>* block_ptr, + const Instruction* line_inst, + const DebugScope& dbg_scope) { std::unique_ptr<Instruction> newLoad( new Instruction(context(), SpvOpLoad, type_id, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); + if (line_inst != nullptr) { + newLoad->dbg_line_insts().push_back(*line_inst); + } + newLoad->SetDebugScope(dbg_scope); (*block_ptr)->AddInstruction(std::move(newLoad)); } @@ -140,10 +153,18 @@ bool InlinePass::CloneAndMapLocals( Function* calleeFn, std::vector<std::unique_ptr<Instruction>>* new_vars, - std::unordered_map<uint32_t, uint32_t>* callee2caller) { + std::unordered_map<uint32_t, uint32_t>* callee2caller, + analysis::DebugInlinedAtContext* inlined_at_ctx) { auto callee_block_itr = calleeFn->begin(); auto callee_var_itr = callee_block_itr->begin(); - while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { + while (callee_var_itr->opcode() == SpvOp::SpvOpVariable || + callee_var_itr->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugDeclare) { + if (callee_var_itr->opcode() != SpvOp::SpvOpVariable) { + ++callee_var_itr; + continue; + } + std::unique_ptr<Instruction> var_inst(callee_var_itr->Clone(context())); uint32_t newId = context()->TakeNextId(); if (newId == 0) { @@ -151,6 +172,9 @@ } get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId); var_inst->SetResultId(newId); + var_inst->UpdateDebugInlinedAt( + context()->get_debug_info_mgr()->BuildDebugInlinedAtChain( + callee_var_itr->GetDebugInlinedAt(), inlined_at_ctx)); (*callee2caller)[callee_var_itr->result_id()] = newId; new_vars->push_back(std::move(var_inst)); ++callee_var_itr; @@ -232,6 +256,248 @@ }); } +void InlinePass::MoveInstsBeforeEntryBlock( + std::unordered_map<uint32_t, Instruction*>* preCallSB, + BasicBlock* new_blk_ptr, BasicBlock::iterator call_inst_itr, + UptrVectorIterator<BasicBlock> call_block_itr) { + for (auto cii = call_block_itr->begin(); cii != call_inst_itr; + cii = call_block_itr->begin()) { + Instruction* inst = &*cii; + inst->RemoveFromList(); + std::unique_ptr<Instruction> cp_inst(inst); + // Remember same-block ops for possible regeneration. + if (IsSameBlockOp(&*cp_inst)) { + auto* sb_inst_ptr = cp_inst.get(); + (*preCallSB)[cp_inst->result_id()] = sb_inst_ptr; + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } +} + +std::unique_ptr<BasicBlock> InlinePass::AddGuardBlock( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + std::unordered_map<uint32_t, uint32_t>* callee2caller, + std::unique_ptr<BasicBlock> new_blk_ptr, uint32_t entry_blk_label_id) { + const auto guard_block_id = context()->TakeNextId(); + if (guard_block_id == 0) { + return nullptr; + } + AddBranch(guard_block_id, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + // Start the next block. + new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(guard_block_id)); + // Reset the mapping of the callee's entry block to point to + // the guard block. Do this so we can fix up phis later on to + // satisfy dominance. + (*callee2caller)[entry_blk_label_id] = guard_block_id; + return new_blk_ptr; +} + +InstructionList::iterator InlinePass::AddStoresForVariableInitializers( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + analysis::DebugInlinedAtContext* inlined_at_ctx, + std::unique_ptr<BasicBlock>* new_blk_ptr, + UptrVectorIterator<BasicBlock> callee_first_block_itr) { + auto callee_itr = callee_first_block_itr->begin(); + while (callee_itr->opcode() == SpvOp::SpvOpVariable || + callee_itr->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugDeclare) { + if (callee_itr->opcode() == SpvOp::SpvOpVariable && + callee_itr->NumInOperands() == 2) { + assert(callee2caller.count(callee_itr->result_id()) && + "Expected the variable to have already been mapped."); + uint32_t new_var_id = callee2caller.at(callee_itr->result_id()); + + // The initializer must be a constant or global value. No mapped + // should be used. + uint32_t val_id = callee_itr->GetSingleWordInOperand(1); + AddStore(new_var_id, val_id, new_blk_ptr, callee_itr->dbg_line_inst(), + context()->get_debug_info_mgr()->BuildDebugScope( + callee_itr->GetDebugScope(), inlined_at_ctx)); + } + if (callee_itr->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugDeclare) { + InlineSingleInstruction( + callee2caller, new_blk_ptr->get(), &*callee_itr, + context()->get_debug_info_mgr()->BuildDebugInlinedAtChain( + callee_itr->GetDebugScope().GetInlinedAt(), inlined_at_ctx)); + } + ++callee_itr; + } + return callee_itr; +} + +bool InlinePass::InlineSingleInstruction( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + BasicBlock* new_blk_ptr, const Instruction* inst, uint32_t dbg_inlined_at) { + // If we have return, it must be at the end of the callee. We will handle + // it at the end. + if (inst->opcode() == SpvOpReturnValue || inst->opcode() == SpvOpReturn) + return true; + + // Copy callee instruction and remap all input Ids. + std::unique_ptr<Instruction> cp_inst(inst->Clone(context())); + cp_inst->ForEachInId([&callee2caller](uint32_t* iid) { + const auto mapItr = callee2caller.find(*iid); + if (mapItr != callee2caller.end()) { + *iid = mapItr->second; + } + }); + + // If result id is non-zero, remap it. + const uint32_t rid = cp_inst->result_id(); + if (rid != 0) { + const auto mapItr = callee2caller.find(rid); + if (mapItr == callee2caller.end()) { + return false; + } + uint32_t nid = mapItr->second; + cp_inst->SetResultId(nid); + get_decoration_mgr()->CloneDecorations(rid, nid); + } + + cp_inst->UpdateDebugInlinedAt(dbg_inlined_at); + new_blk_ptr->AddInstruction(std::move(cp_inst)); + return true; +} + +std::unique_ptr<BasicBlock> InlinePass::InlineReturn( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + std::unique_ptr<BasicBlock> new_blk_ptr, + analysis::DebugInlinedAtContext* inlined_at_ctx, Function* calleeFn, + const Instruction* inst, uint32_t returnVarId) { + // Store return value to return variable. + if (inst->opcode() == SpvOpReturnValue) { + assert(returnVarId != 0); + uint32_t valId = inst->GetInOperand(kSpvReturnValueId).words[0]; + const auto mapItr = callee2caller.find(valId); + if (mapItr != callee2caller.end()) { + valId = mapItr->second; + } + AddStore(returnVarId, valId, &new_blk_ptr, inst->dbg_line_inst(), + context()->get_debug_info_mgr()->BuildDebugScope( + inst->GetDebugScope(), inlined_at_ctx)); + } + + uint32_t returnLabelId = 0; + for (auto callee_block_itr = calleeFn->begin(); + callee_block_itr != calleeFn->end(); ++callee_block_itr) { + if (callee_block_itr->tail()->opcode() == SpvOpUnreachable || + callee_block_itr->tail()->opcode() == SpvOpKill) { + returnLabelId = context()->TakeNextId(); + break; + } + } + if (returnLabelId == 0) return new_blk_ptr; + + if (inst->opcode() == SpvOpReturn || inst->opcode() == SpvOpReturnValue) + AddBranch(returnLabelId, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + return MakeUnique<BasicBlock>(NewLabel(returnLabelId)); +} + +bool InlinePass::InlineEntryBlock( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::unique_ptr<BasicBlock>* new_blk_ptr, + UptrVectorIterator<BasicBlock> callee_first_block, + analysis::DebugInlinedAtContext* inlined_at_ctx) { + auto callee_inst_itr = AddStoresForVariableInitializers( + callee2caller, inlined_at_ctx, new_blk_ptr, callee_first_block); + + while (callee_inst_itr != callee_first_block->end()) { + if (!InlineSingleInstruction( + callee2caller, new_blk_ptr->get(), &*callee_inst_itr, + context()->get_debug_info_mgr()->BuildDebugInlinedAtChain( + callee_inst_itr->GetDebugScope().GetInlinedAt(), + inlined_at_ctx))) { + return false; + } + ++callee_inst_itr; + } + return true; +} + +std::unique_ptr<BasicBlock> InlinePass::InlineBasicBlocks( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::unique_ptr<BasicBlock> new_blk_ptr, + analysis::DebugInlinedAtContext* inlined_at_ctx, Function* calleeFn) { + auto callee_block_itr = calleeFn->begin(); + ++callee_block_itr; + + while (callee_block_itr != calleeFn->end()) { + new_blocks->push_back(std::move(new_blk_ptr)); + const auto mapItr = + callee2caller.find(callee_block_itr->GetLabelInst()->result_id()); + if (mapItr == callee2caller.end()) return nullptr; + new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(mapItr->second)); + + auto tail_inst_itr = callee_block_itr->end(); + for (auto inst_itr = callee_block_itr->begin(); inst_itr != tail_inst_itr; + ++inst_itr) { + if (!InlineSingleInstruction( + callee2caller, new_blk_ptr.get(), &*inst_itr, + context()->get_debug_info_mgr()->BuildDebugInlinedAtChain( + inst_itr->GetDebugScope().GetInlinedAt(), inlined_at_ctx))) { + return nullptr; + } + } + + ++callee_block_itr; + } + return new_blk_ptr; +} + +bool InlinePass::MoveCallerInstsAfterFunctionCall( + std::unordered_map<uint32_t, Instruction*>* preCallSB, + std::unordered_map<uint32_t, uint32_t>* postCallSB, + std::unique_ptr<BasicBlock>* new_blk_ptr, + BasicBlock::iterator call_inst_itr, bool multiBlocks) { + // Copy remaining instructions from caller block. + for (Instruction* inst = call_inst_itr->NextNode(); inst; + inst = call_inst_itr->NextNode()) { + inst->RemoveFromList(); + std::unique_ptr<Instruction> cp_inst(inst); + // If multiple blocks generated, regenerate any same-block + // instruction that has not been seen in this last block. + if (multiBlocks) { + if (!CloneSameBlockOps(&cp_inst, postCallSB, preCallSB, new_blk_ptr)) { + return false; + } + + // Remember same-block ops in this block. + if (IsSameBlockOp(&*cp_inst)) { + const uint32_t rid = cp_inst->result_id(); + (*postCallSB)[rid] = rid; + } + } + new_blk_ptr->get()->AddInstruction(std::move(cp_inst)); + } + + return true; +} + +void InlinePass::MoveLoopMergeInstToFirstBlock( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { + // Move the OpLoopMerge from the last block back to the first, where + // it belongs. + auto& first = new_blocks->front(); + auto& last = new_blocks->back(); + assert(first != last); + + // Insert a modified copy of the loop merge into the first block. + auto loop_merge_itr = last->tail(); + --loop_merge_itr; + assert(loop_merge_itr->opcode() == SpvOpLoopMerge); + std::unique_ptr<Instruction> cp_inst(loop_merge_itr->Clone(context())); + first->tail().InsertBefore(std::move(cp_inst)); + + // Remove the loop merge from the last block. + loop_merge_itr->RemoveFromList(); + delete &*loop_merge_itr; +} + bool InlinePass::GenInlineCode( std::vector<std::unique_ptr<BasicBlock>>* new_blocks, std::vector<std::unique_ptr<Instruction>>* new_vars, @@ -245,27 +511,60 @@ // Post-call same-block op ids std::unordered_map<uint32_t, uint32_t> postCallSB; + analysis::DebugInlinedAtContext inlined_at_ctx(&*call_inst_itr); + // Invalidate the def-use chains. They are not kept up to date while // inlining. However, certain calls try to keep them up-to-date if they are // valid. These operations can fail. context()->InvalidateAnalyses(IRContext::kAnalysisDefUse); + // If the caller is a loop header and the callee has multiple blocks, then the + // normal inlining logic will place the OpLoopMerge in the last of several + // blocks in the loop. Instead, it should be placed at the end of the first + // block. We'll wait to move the OpLoopMerge until the end of the regular + // inlining logic, and only if necessary. + bool caller_is_loop_header = call_block_itr->GetLoopMergeInst() != nullptr; + + // Single-trip loop continue block + std::unique_ptr<BasicBlock> single_trip_loop_cont_blk; + Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand( kSpvFunctionCallFunctionId)]; - // Check for multiple returns in the callee. - auto fi = early_return_funcs_.find(calleeFn->result_id()); - const bool earlyReturn = fi != early_return_funcs_.end(); - // Map parameters to actual arguments. MapParams(calleeFn, call_inst_itr, &callee2caller); // Define caller local variables for all callee variables and create map to // them. - if (!CloneAndMapLocals(calleeFn, new_vars, &callee2caller)) { + if (!CloneAndMapLocals(calleeFn, new_vars, &callee2caller, &inlined_at_ctx)) { return false; } + // First block needs to use label of original block + // but map callee label in case of phi reference. + uint32_t entry_blk_label_id = calleeFn->begin()->GetLabelInst()->result_id(); + callee2caller[entry_blk_label_id] = call_block_itr->id(); + std::unique_ptr<BasicBlock> new_blk_ptr = + MakeUnique<BasicBlock>(NewLabel(call_block_itr->id())); + + // Move instructions of original caller block up to call instruction. + MoveInstsBeforeEntryBlock(&preCallSB, new_blk_ptr.get(), call_inst_itr, + call_block_itr); + + if (caller_is_loop_header && + (*(calleeFn->begin())).GetMergeInst() != nullptr) { + // We can't place both the caller's merge instruction and + // another merge instruction in the same block. So split the + // calling block. Insert an unconditional branch to a new guard + // block. Later, once we know the ID of the last block, we + // will move the caller's OpLoopMerge from the last generated + // block into the first block. We also wait to avoid + // invalidating various iterators. + new_blk_ptr = AddGuardBlock(new_blocks, &callee2caller, + std::move(new_blk_ptr), entry_blk_label_id); + if (new_blk_ptr == nullptr) return false; + } + // Create return var if needed. const uint32_t calleeTypeId = calleeFn->type_id(); uint32_t returnVarId = 0; @@ -277,341 +576,62 @@ } } - // Create set of callee result ids. Used to detect forward references - std::unordered_set<uint32_t> callee_result_ids; - calleeFn->ForEachInst([&callee_result_ids](const Instruction* cpi) { + calleeFn->WhileEachInst([&callee2caller, this](const Instruction* cpi) { + // Create set of callee result ids. Used to detect forward references const uint32_t rid = cpi->result_id(); - if (rid != 0) callee_result_ids.insert(rid); + if (rid != 0 && callee2caller.find(rid) == callee2caller.end()) { + const uint32_t nid = context()->TakeNextId(); + if (nid == 0) return false; + callee2caller[rid] = nid; + } + return true; }); - // If the caller is a loop header and the callee has multiple blocks, then the - // normal inlining logic will place the OpLoopMerge in the last of several - // blocks in the loop. Instead, it should be placed at the end of the first - // block. We'll wait to move the OpLoopMerge until the end of the regular - // inlining logic, and only if necessary. - bool caller_is_loop_header = false; - if (call_block_itr->GetLoopMergeInst()) { - caller_is_loop_header = true; - } - - bool callee_begins_with_structured_header = - (*(calleeFn->begin())).GetMergeInst() != nullptr; - - // Clone and map callee code. Copy caller block code to beginning of - // first block and end of last block. - bool prevInstWasReturn = false; - uint32_t singleTripLoopHeaderId = 0; - uint32_t singleTripLoopContinueId = 0; - uint32_t returnLabelId = 0; - bool multiBlocks = false; - // new_blk_ptr is a new basic block in the caller. New instructions are - // written to it. It is created when we encounter the OpLabel - // of the first callee block. It is appended to new_blocks only when - // it is complete. - std::unique_ptr<BasicBlock> new_blk_ptr; - bool successful = calleeFn->WhileEachInst( - [&new_blocks, &callee2caller, &call_block_itr, &call_inst_itr, - &new_blk_ptr, &prevInstWasReturn, &returnLabelId, &returnVarId, - caller_is_loop_header, callee_begins_with_structured_header, - &calleeTypeId, &multiBlocks, &postCallSB, &preCallSB, earlyReturn, - &singleTripLoopHeaderId, &singleTripLoopContinueId, &callee_result_ids, - this](const Instruction* cpi) { - switch (cpi->opcode()) { - case SpvOpFunction: - case SpvOpFunctionParameter: - // Already processed - break; - case SpvOpVariable: - if (cpi->NumInOperands() == 2) { - assert(callee2caller.count(cpi->result_id()) && - "Expected the variable to have already been mapped."); - uint32_t new_var_id = callee2caller.at(cpi->result_id()); - - // The initializer must be a constant or global value. No mapped - // should be used. - uint32_t val_id = cpi->GetSingleWordInOperand(1); - AddStore(new_var_id, val_id, &new_blk_ptr); - } - break; - case SpvOpUnreachable: - case SpvOpKill: { - // Generate a return label so that we split the block with the - // function call. Copy the terminator into the new block. - if (returnLabelId == 0) { - returnLabelId = context()->TakeNextId(); - if (returnLabelId == 0) { - return false; - } - } - std::unique_ptr<Instruction> terminator( - new Instruction(context(), cpi->opcode(), 0, 0, {})); - new_blk_ptr->AddInstruction(std::move(terminator)); - break; - } - case SpvOpLabel: { - // If previous instruction was early return, insert branch - // instruction to return block. - if (prevInstWasReturn) { - if (returnLabelId == 0) { - returnLabelId = context()->TakeNextId(); - if (returnLabelId == 0) { - return false; - } - } - AddBranch(returnLabelId, &new_blk_ptr); - prevInstWasReturn = false; - } - // Finish current block (if it exists) and get label for next block. - uint32_t labelId; - bool firstBlock = false; - if (new_blk_ptr != nullptr) { - new_blocks->push_back(std::move(new_blk_ptr)); - // If result id is already mapped, use it, otherwise get a new - // one. - const uint32_t rid = cpi->result_id(); - const auto mapItr = callee2caller.find(rid); - labelId = (mapItr != callee2caller.end()) - ? mapItr->second - : context()->TakeNextId(); - if (labelId == 0) { - return false; - } - } else { - // First block needs to use label of original block - // but map callee label in case of phi reference. - labelId = call_block_itr->id(); - callee2caller[cpi->result_id()] = labelId; - firstBlock = true; - } - // Create first/next block. - new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(labelId)); - if (firstBlock) { - // Copy contents of original caller block up to call instruction. - for (auto cii = call_block_itr->begin(); cii != call_inst_itr; - cii = call_block_itr->begin()) { - Instruction* inst = &*cii; - inst->RemoveFromList(); - std::unique_ptr<Instruction> cp_inst(inst); - // Remember same-block ops for possible regeneration. - if (IsSameBlockOp(&*cp_inst)) { - auto* sb_inst_ptr = cp_inst.get(); - preCallSB[cp_inst->result_id()] = sb_inst_ptr; - } - new_blk_ptr->AddInstruction(std::move(cp_inst)); - } - if (caller_is_loop_header && - callee_begins_with_structured_header) { - // We can't place both the caller's merge instruction and - // another merge instruction in the same block. So split the - // calling block. Insert an unconditional branch to a new guard - // block. Later, once we know the ID of the last block, we - // will move the caller's OpLoopMerge from the last generated - // block into the first block. We also wait to avoid - // invalidating various iterators. - const auto guard_block_id = context()->TakeNextId(); - if (guard_block_id == 0) { - return false; - } - AddBranch(guard_block_id, &new_blk_ptr); - new_blocks->push_back(std::move(new_blk_ptr)); - // Start the next block. - new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(guard_block_id)); - // Reset the mapping of the callee's entry block to point to - // the guard block. Do this so we can fix up phis later on to - // satisfy dominance. - callee2caller[cpi->result_id()] = guard_block_id; - } - // If callee has early return, insert a header block for - // single-trip loop that will encompass callee code. Start - // postheader block. - // - // Note: Consider the following combination: - // - the caller is a single block loop - // - the callee does not begin with a structure header - // - the callee has multiple returns. - // We still need to split the caller block and insert a guard - // block. But we only need to do it once. We haven't done it yet, - // but the single-trip loop header will serve the same purpose. - if (earlyReturn) { - singleTripLoopHeaderId = context()->TakeNextId(); - if (singleTripLoopHeaderId == 0) { - return false; - } - AddBranch(singleTripLoopHeaderId, &new_blk_ptr); - new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr = - MakeUnique<BasicBlock>(NewLabel(singleTripLoopHeaderId)); - returnLabelId = context()->TakeNextId(); - singleTripLoopContinueId = context()->TakeNextId(); - if (returnLabelId == 0 || singleTripLoopContinueId == 0) { - return false; - } - AddLoopMerge(returnLabelId, singleTripLoopContinueId, - &new_blk_ptr); - uint32_t postHeaderId = context()->TakeNextId(); - if (postHeaderId == 0) { - return false; - } - AddBranch(postHeaderId, &new_blk_ptr); - new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(postHeaderId)); - multiBlocks = true; - // Reset the mapping of the callee's entry block to point to - // the post-header block. Do this so we can fix up phis later - // on to satisfy dominance. - callee2caller[cpi->result_id()] = postHeaderId; - } - } else { - multiBlocks = true; - } - } break; - case SpvOpReturnValue: { - // Store return value to return variable. - assert(returnVarId != 0); - uint32_t valId = cpi->GetInOperand(kSpvReturnValueId).words[0]; - const auto mapItr = callee2caller.find(valId); - if (mapItr != callee2caller.end()) { - valId = mapItr->second; - } - AddStore(returnVarId, valId, &new_blk_ptr); - - // Remember we saw a return; if followed by a label, will need to - // insert branch. - prevInstWasReturn = true; - } break; - case SpvOpReturn: { - // Remember we saw a return; if followed by a label, will need to - // insert branch. - prevInstWasReturn = true; - } break; - case SpvOpFunctionEnd: { - // If there was an early return, we generated a return label id - // for it. Now we have to generate the return block with that Id. - if (returnLabelId != 0) { - // If previous instruction was return, insert branch instruction - // to return block. - if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr); - if (earlyReturn) { - // If we generated a loop header for the single-trip loop - // to accommodate early returns, insert the continue - // target block now, with a false branch back to the loop - // header. - new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr = - MakeUnique<BasicBlock>(NewLabel(singleTripLoopContinueId)); - uint32_t false_id = GetFalseId(); - if (false_id == 0) { - return false; - } - AddBranchCond(false_id, singleTripLoopHeaderId, returnLabelId, - &new_blk_ptr); - } - // Generate the return block. - new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(returnLabelId)); - multiBlocks = true; - } - // Load return value into result id of call, if it exists. - if (returnVarId != 0) { - const uint32_t resId = call_inst_itr->result_id(); - assert(resId != 0); - AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr); - } - // Copy remaining instructions from caller block. - for (Instruction* inst = call_inst_itr->NextNode(); inst; - inst = call_inst_itr->NextNode()) { - inst->RemoveFromList(); - std::unique_ptr<Instruction> cp_inst(inst); - // If multiple blocks generated, regenerate any same-block - // instruction that has not been seen in this last block. - if (multiBlocks) { - if (!CloneSameBlockOps(&cp_inst, &postCallSB, &preCallSB, - &new_blk_ptr)) { - return false; - } - - // Remember same-block ops in this block. - if (IsSameBlockOp(&*cp_inst)) { - const uint32_t rid = cp_inst->result_id(); - postCallSB[rid] = rid; - } - } - new_blk_ptr->AddInstruction(std::move(cp_inst)); - } - // Finalize inline code. - new_blocks->push_back(std::move(new_blk_ptr)); - } break; - default: { - // Copy callee instruction and remap all input Ids. - std::unique_ptr<Instruction> cp_inst(cpi->Clone(context())); - bool succeeded = cp_inst->WhileEachInId( - [&callee2caller, &callee_result_ids, this](uint32_t* iid) { - const auto mapItr = callee2caller.find(*iid); - if (mapItr != callee2caller.end()) { - *iid = mapItr->second; - } else if (callee_result_ids.find(*iid) != - callee_result_ids.end()) { - // Forward reference. Allocate a new id, map it, - // use it and check for it when remapping result ids - const uint32_t nid = context()->TakeNextId(); - if (nid == 0) { - return false; - } - callee2caller[*iid] = nid; - *iid = nid; - } - return true; - }); - if (!succeeded) { - return false; - } - // If result id is non-zero, remap it. If already mapped, use mapped - // value, else use next id. - const uint32_t rid = cp_inst->result_id(); - if (rid != 0) { - const auto mapItr = callee2caller.find(rid); - uint32_t nid; - if (mapItr != callee2caller.end()) { - nid = mapItr->second; - } else { - nid = context()->TakeNextId(); - if (nid == 0) { - return false; - } - callee2caller[rid] = nid; - } - cp_inst->SetResultId(nid); - get_decoration_mgr()->CloneDecorations(rid, nid); - } - new_blk_ptr->AddInstruction(std::move(cp_inst)); - } break; - } - return true; + // Inline DebugClare instructions in the callee's header. + calleeFn->ForEachDebugInstructionsInHeader( + [&new_blk_ptr, &callee2caller, &inlined_at_ctx, this](Instruction* inst) { + InlineSingleInstruction( + callee2caller, new_blk_ptr.get(), inst, + context()->get_debug_info_mgr()->BuildDebugInlinedAtChain( + inst->GetDebugScope().GetInlinedAt(), &inlined_at_ctx)); }); - if (!successful) { + // Inline the entry block of the callee function. + if (!InlineEntryBlock(callee2caller, &new_blk_ptr, calleeFn->begin(), + &inlined_at_ctx)) { return false; } - if (caller_is_loop_header && (new_blocks->size() > 1)) { - // Move the OpLoopMerge from the last block back to the first, where - // it belongs. - auto& first = new_blocks->front(); - auto& last = new_blocks->back(); - assert(first != last); + // Inline blocks of the callee function other than the entry block. + new_blk_ptr = + InlineBasicBlocks(new_blocks, callee2caller, std::move(new_blk_ptr), + &inlined_at_ctx, calleeFn); + if (new_blk_ptr == nullptr) return false; - // Insert a modified copy of the loop merge into the first block. - auto loop_merge_itr = last->tail(); - --loop_merge_itr; - assert(loop_merge_itr->opcode() == SpvOpLoopMerge); - std::unique_ptr<Instruction> cp_inst(loop_merge_itr->Clone(context())); - first->tail().InsertBefore(std::move(cp_inst)); + new_blk_ptr = InlineReturn(callee2caller, new_blocks, std::move(new_blk_ptr), + &inlined_at_ctx, calleeFn, + &*(calleeFn->tail()->tail()), returnVarId); - // Remove the loop merge from the last block. - loop_merge_itr->RemoveFromList(); - delete &*loop_merge_itr; + // Load return value into result id of call, if it exists. + if (returnVarId != 0) { + const uint32_t resId = call_inst_itr->result_id(); + assert(resId != 0); + AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr, + call_inst_itr->dbg_line_inst(), call_inst_itr->GetDebugScope()); } + // Move instructions of original caller block after call instruction. + if (!MoveCallerInstsAfterFunctionCall(&preCallSB, &postCallSB, &new_blk_ptr, + call_inst_itr, + calleeFn->begin() != calleeFn->end())) + return false; + + // Finalize inline code. + new_blocks->push_back(std::move(new_blk_ptr)); + + if (caller_is_loop_header && (new_blocks->size() > 1)) + MoveLoopMergeInstToFirstBlock(new_blocks); + // Update block map given replacement blocks. for (auto& blk : *new_blocks) { id2block_[blk->id()] = &*blk; @@ -624,7 +644,21 @@ const uint32_t calleeFnId = inst->GetSingleWordOperand(kSpvFunctionCallFunctionId); const auto ci = inlinable_.find(calleeFnId); - return ci != inlinable_.cend(); + if (ci == inlinable_.cend()) return false; + + if (early_return_funcs_.find(calleeFnId) != early_return_funcs_.end()) { + // We rely on the merge-return pass to handle the early return case + // in advance. + std::string message = + "The function '" + id2function_[calleeFnId]->DefInst().PrettyPrint() + + "' could not be inlined because the return instruction " + "is not at the end of the function. This could be fixed by " + "running merge-return before inlining."; + consumer()(SPV_MSG_WARNING, "", {0, 0, 0}, message.c_str()); + return false; + } + + return true; } void InlinePass::UpdateSucceedingPhis( @@ -645,26 +679,6 @@ }); } -bool InlinePass::HasNoReturnInStructuredConstruct(Function* func) { - // If control not structured, do not do loop/return analysis - // TODO: Analyze returns in non-structured control flow - if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) - return false; - const auto structured_analysis = context()->GetStructuredCFGAnalysis(); - // Search for returns in structured construct. - bool return_in_construct = false; - for (auto& blk : *func) { - auto terminal_ii = blk.cend(); - --terminal_ii; - if (spvOpcodeIsReturn(terminal_ii->opcode()) && - structured_analysis->ContainingConstruct(blk.id()) != 0) { - return_in_construct = true; - break; - } - } - return !return_in_construct; -} - bool InlinePass::HasNoReturnInLoop(Function* func) { // If control not structured, do not do loop/return analysis // TODO: Analyze returns in non-structured control flow @@ -686,10 +700,18 @@ } void InlinePass::AnalyzeReturns(Function* func) { + // Analyze functions without a return in loop. if (HasNoReturnInLoop(func)) { no_return_in_loop_.insert(func->result_id()); - if (!HasNoReturnInStructuredConstruct(func)) + } + // Analyze functions with a return before its tail basic block. + for (auto& blk : *func) { + auto terminal_ii = blk.cend(); + --terminal_ii; + if (spvOpcodeIsReturn(terminal_ii->opcode()) && &blk != func->tail()) { early_return_funcs_.insert(func->result_id()); + break; + } } }
diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h index bc5f781..202bc97 100644 --- a/source/opt/inline_pass.h +++ b/source/opt/inline_pass.h
@@ -24,6 +24,7 @@ #include <unordered_map> #include <vector> +#include "source/opt/debug_info_manager.h" #include "source/opt/decoration_manager.h" #include "source/opt/module.h" #include "source/opt/pass.h" @@ -58,11 +59,13 @@ // Add store of valId to ptrId to end of block block_ptr. void AddStore(uint32_t ptrId, uint32_t valId, - std::unique_ptr<BasicBlock>* block_ptr); + std::unique_ptr<BasicBlock>* block_ptr, + const Instruction* line_inst, const DebugScope& dbg_scope); // Add load of ptrId into resultId to end of block block_ptr. void AddLoad(uint32_t typeId, uint32_t resultId, uint32_t ptrId, - std::unique_ptr<BasicBlock>* block_ptr); + std::unique_ptr<BasicBlock>* block_ptr, + const Instruction* line_inst, const DebugScope& dbg_scope); // Return new label. std::unique_ptr<Instruction> NewLabel(uint32_t label_id); @@ -79,7 +82,8 @@ // Clone and map callee locals. Return true if successful. bool CloneAndMapLocals(Function* calleeFn, std::vector<std::unique_ptr<Instruction>>* new_vars, - std::unordered_map<uint32_t, uint32_t>* callee2caller); + std::unordered_map<uint32_t, uint32_t>* callee2caller, + analysis::DebugInlinedAtContext* inlined_at_ctx); // Create return variable for callee clone code. The return type of // |calleeFn| must not be void. Returns the id of the return variable if @@ -124,10 +128,6 @@ // Return true if |inst| is a function call that can be inlined. bool IsInlinableFunctionCall(const Instruction* inst); - // Return true if |func| does not have a return that is - // nested in a structured if, switch or loop. - bool HasNoReturnInStructuredConstruct(Function* func); - // Return true if |func| has no return in a loop. The current analysis // requires structured control flow, so return false if control flow not // structured ie. module is not a shader. @@ -171,6 +171,69 @@ // Set of functions that are originally called directly or indirectly from a // continue construct. std::unordered_set<uint32_t> funcs_called_from_continue_; + + private: + // Moves instructions of the caller function up to the call instruction + // to |new_blk_ptr|. + void MoveInstsBeforeEntryBlock( + std::unordered_map<uint32_t, Instruction*>* preCallSB, + BasicBlock* new_blk_ptr, BasicBlock::iterator call_inst_itr, + UptrVectorIterator<BasicBlock> call_block_itr); + + // Returns a new guard block after adding a branch to the end of + // |new_blocks|. + std::unique_ptr<BasicBlock> AddGuardBlock( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + std::unordered_map<uint32_t, uint32_t>* callee2caller, + std::unique_ptr<BasicBlock> new_blk_ptr, uint32_t entry_blk_label_id); + + // Add store instructions for initializers of variables. + InstructionList::iterator AddStoresForVariableInitializers( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + analysis::DebugInlinedAtContext* inlined_at_ctx, + std::unique_ptr<BasicBlock>* new_blk_ptr, + UptrVectorIterator<BasicBlock> callee_block_itr); + + // Inlines a single instruction of the callee function. + bool InlineSingleInstruction( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + BasicBlock* new_blk_ptr, const Instruction* inst, + uint32_t dbg_inlined_at); + + // Inlines the return instruction of the callee function. + std::unique_ptr<BasicBlock> InlineReturn( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + std::unique_ptr<BasicBlock> new_blk_ptr, + analysis::DebugInlinedAtContext* inlined_at_ctx, Function* calleeFn, + const Instruction* inst, uint32_t returnVarId); + + // Inlines the entry block of the callee function. + bool InlineEntryBlock( + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::unique_ptr<BasicBlock>* new_blk_ptr, + UptrVectorIterator<BasicBlock> callee_first_block, + analysis::DebugInlinedAtContext* inlined_at_ctx); + + // Inlines basic blocks of the callee function other than the entry basic + // block. + std::unique_ptr<BasicBlock> InlineBasicBlocks( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks, + const std::unordered_map<uint32_t, uint32_t>& callee2caller, + std::unique_ptr<BasicBlock> new_blk_ptr, + analysis::DebugInlinedAtContext* inlined_at_ctx, Function* calleeFn); + + // Moves instructions of the caller function after the call instruction + // to |new_blk_ptr|. + bool MoveCallerInstsAfterFunctionCall( + std::unordered_map<uint32_t, Instruction*>* preCallSB, + std::unordered_map<uint32_t, uint32_t>* postCallSB, + std::unique_ptr<BasicBlock>* new_blk_ptr, + BasicBlock::iterator call_inst_itr, bool multiBlocks); + + // Move the OpLoopMerge from the last block back to the first. + void MoveLoopMergeInstToFirstBlock( + std::vector<std::unique_ptr<BasicBlock>>* new_blocks); }; } // namespace opt
diff --git a/source/opt/inst_bindless_check_pass.h b/source/opt/inst_bindless_check_pass.h index 447871b..9335fa5 100644 --- a/source/opt/inst_bindless_check_pass.h +++ b/source/opt/inst_bindless_check_pass.h
@@ -28,13 +28,6 @@ // external design may change as the layer evolves. class InstBindlessCheckPass : public InstrumentPass { public: - // Deprecated interface - InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id, - bool input_length_enable, bool input_init_enable, - uint32_t version) - : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless, version), - input_length_enabled_(input_length_enable), - input_init_enabled_(input_init_enable) {} // Preferred Interface InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id, bool input_length_enable, bool input_init_enable)
diff --git a/source/opt/inst_buff_addr_check_pass.h b/source/opt/inst_buff_addr_check_pass.h index 67ffcc3..ec7bb68 100644 --- a/source/opt/inst_buff_addr_check_pass.h +++ b/source/opt/inst_buff_addr_check_pass.h
@@ -28,10 +28,6 @@ // external design of this class may change as the layer evolves. class InstBuffAddrCheckPass : public InstrumentPass { public: - // Deprecated interface - InstBuffAddrCheckPass(uint32_t desc_set, uint32_t shader_id, uint32_t version) - : InstrumentPass(desc_set, shader_id, kInstValidationIdBuffAddr, - version) {} // Preferred interface InstBuffAddrCheckPass(uint32_t desc_set, uint32_t shader_id) : InstrumentPass(desc_set, shader_id, kInstValidationIdBuffAddr) {}
diff --git a/source/opt/inst_debug_printf_pass.h b/source/opt/inst_debug_printf_pass.h index 2968a20..70b0a72 100644 --- a/source/opt/inst_debug_printf_pass.h +++ b/source/opt/inst_debug_printf_pass.h
@@ -28,11 +28,10 @@ class InstDebugPrintfPass : public InstrumentPass { public: // For test harness only - InstDebugPrintfPass() - : InstrumentPass(7, 23, kInstValidationIdDebugPrintf, 2) {} + InstDebugPrintfPass() : InstrumentPass(7, 23, kInstValidationIdDebugPrintf) {} // For all other interfaces InstDebugPrintfPass(uint32_t desc_set, uint32_t shader_id) - : InstrumentPass(desc_set, shader_id, kInstValidationIdDebugPrintf, 2) {} + : InstrumentPass(desc_set, shader_id, kInstValidationIdDebugPrintf) {} ~InstDebugPrintfPass() override = default;
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp index 3ce38a9..126848e 100644 --- a/source/opt/instruction.cpp +++ b/source/opt/instruction.cpp
@@ -29,13 +29,19 @@ // Indices used to get particular operands out of instructions using InOperand. const uint32_t kTypeImageDimIndex = 1; const uint32_t kLoadBaseIndex = 0; -const uint32_t kVariableStorageClassIndex = 0; +const uint32_t kPointerTypeStorageClassIndex = 0; const uint32_t kTypeImageSampledIndex = 5; // Constants for OpenCL.DebugInfo.100 extension instructions. +const uint32_t kExtInstSetIdInIdx = 0; +const uint32_t kExtInstInstructionInIdx = 1; const uint32_t kDebugScopeNumWords = 7; const uint32_t kDebugScopeNumWordsWithoutInlinedAt = 6; const uint32_t kDebugNoScopeNumWords = 5; + +// Number of operands of an OpBranchConditional instruction +// with weights. +const uint32_t kOpBranchConditionalWithWeightsNumOperands = 5; } // namespace Instruction::Instruction(IRContext* c) @@ -164,6 +170,15 @@ return size; } +bool Instruction::HasBranchWeights() const { + if (opcode_ == SpvOpBranchConditional && + NumOperands() == kOpBranchConditionalWithWeightsNumOperands) { + return true; + } + + return false; +} + void Instruction::ToBinaryWithoutAttachedDebugInsts( std::vector<uint32_t>* binary) const { const uint32_t num_words = 1 + NumOperandWords(); @@ -180,10 +195,27 @@ bool Instruction::IsReadOnlyLoad() const { if (IsLoad()) { Instruction* address_def = GetBaseAddress(); - if (!address_def || address_def->opcode() != SpvOpVariable) { + if (!address_def) { return false; } - return address_def->IsReadOnlyVariable(); + + if (address_def->opcode() == SpvOpVariable) { + if (address_def->IsReadOnlyPointer()) { + return true; + } + } + + if (address_def->opcode() == SpvOpLoad) { + const analysis::Type* address_type = + context()->get_type_mgr()->GetType(address_def->type_id()); + if (address_type->AsSampledImage() != nullptr) { + const auto* image_type = + address_type->AsSampledImage()->image_type()->AsImage(); + if (image_type->sampled() == 1) { + return true; + } + } + } } return false; } @@ -213,11 +245,11 @@ return base_inst; } -bool Instruction::IsReadOnlyVariable() const { +bool Instruction::IsReadOnlyPointer() const { if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) - return IsReadOnlyVariableShaders(); + return IsReadOnlyPointerShaders(); else - return IsReadOnlyVariableKernel(); + return IsReadOnlyPointerKernel(); } bool Instruction::IsVulkanStorageImage() const { @@ -225,7 +257,8 @@ return false; } - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + uint32_t storage_class = + GetSingleWordInOperand(kPointerTypeStorageClassIndex); if (storage_class != SpvStorageClassUniformConstant) { return false; } @@ -259,7 +292,8 @@ return false; } - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + uint32_t storage_class = + GetSingleWordInOperand(kPointerTypeStorageClassIndex); if (storage_class != SpvStorageClassUniformConstant) { return false; } @@ -293,7 +327,8 @@ return false; } - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + uint32_t storage_class = + GetSingleWordInOperand(kPointerTypeStorageClassIndex); if (storage_class != SpvStorageClassUniformConstant) { return false; } @@ -342,7 +377,8 @@ return false; } - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + uint32_t storage_class = + GetSingleWordInOperand(kPointerTypeStorageClassIndex); if (storage_class == SpvStorageClassUniform) { bool is_buffer_block = false; context()->get_decoration_mgr()->ForEachDecoration( @@ -364,7 +400,8 @@ return false; } - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + uint32_t storage_class = + GetSingleWordInOperand(kPointerTypeStorageClassIndex); if (storage_class != SpvStorageClassUniform) { return false; } @@ -390,9 +427,18 @@ return is_block; } -bool Instruction::IsReadOnlyVariableShaders() const { - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); +bool Instruction::IsReadOnlyPointerShaders() const { + if (type_id() == 0) { + return false; + } + Instruction* type_def = context()->get_def_use_mgr()->GetDef(type_id()); + if (type_def->opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = + type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex); switch (storage_class) { case SpvStorageClassUniformConstant: @@ -420,8 +466,19 @@ return is_nonwritable; } -bool Instruction::IsReadOnlyVariableKernel() const { - uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); +bool Instruction::IsReadOnlyPointerKernel() const { + if (type_id() == 0) { + return false; + } + + Instruction* type_def = context()->get_def_use_mgr()->GetDef(type_id()); + if (type_def->opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = + type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex); + return storage_class == SpvStorageClassUniformConstant; } @@ -510,6 +567,21 @@ return false; } +OpenCLDebugInfo100Instructions Instruction::GetOpenCL100DebugOpcode() const { + if (opcode() != SpvOpExtInst) return OpenCLDebugInfo100InstructionsMax; + + if (!context()->get_feature_mgr()->GetExtInstImportId_OpenCL100DebugInfo()) + return OpenCLDebugInfo100InstructionsMax; + + if (GetSingleWordInOperand(kExtInstSetIdInIdx) != + context()->get_feature_mgr()->GetExtInstImportId_OpenCL100DebugInfo()) { + return OpenCLDebugInfo100InstructionsMax; + } + + return OpenCLDebugInfo100Instructions( + GetSingleWordInOperand(kExtInstInstructionInIdx)); +} + bool Instruction::IsValidBaseImage() const { uint32_t tid = type_id(); if (tid == 0) { @@ -551,7 +623,19 @@ return false; } Instruction* type = context()->get_def_use_mgr()->GetDef(type_id()); - return folder.IsFoldableType(type); + if (!folder.IsFoldableType(type)) { + return false; + } + + // Even if the type of the instruction is foldable, its operands may not be + // foldable (e.g., comparisons of 64bit types). Check that all operand types + // are foldable before accepting the instruction. + return WhileEachInOperand([&folder, this](const uint32_t* op_id) { + Instruction* def_inst = context()->get_def_use_mgr()->GetDef(*op_id); + Instruction* def_inst_type = + context()->get_def_use_mgr()->GetDef(def_inst->type_id()); + return folder.IsFoldableType(def_inst_type); + }); } bool Instruction::IsFloatingPointFoldingAllowed() const { @@ -714,9 +798,6 @@ return true; } - const uint32_t kExtInstSetIdInIdx = 0; - const uint32_t kExtInstInstructionInIdx = 1; - if (opcode() == SpvOpExtInst) { uint32_t instSetId = context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
diff --git a/source/opt/instruction.h b/source/opt/instruction.h index a3342c6..7d8fed8 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h
@@ -22,14 +22,14 @@ #include <utility> #include <vector> -#include "source/opcode.h" -#include "source/operand.h" -#include "source/util/ilist_node.h" -#include "source/util/small_vector.h" - +#include "OpenCLDebugInfo100.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/latest_version_spirv_header.h" +#include "source/opcode.h" +#include "source/operand.h" #include "source/opt/reflect.h" +#include "source/util/ilist_node.h" +#include "source/util/small_vector.h" #include "spirv-tools/libspirv.h" const uint32_t kNoDebugScope = 0; @@ -92,6 +92,19 @@ // Returns a string operand as a std::string. std::string AsString() const { return AsCString(); } + // Returns a literal integer operand as a uint64_t + uint64_t AsLiteralUint64() const { + assert(type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER); + assert(1 <= words.size()); + assert(words.size() <= 2); + // Load the low word. + uint64_t result = uint64_t(words[0]); + if (words.size() > 1) { + result = result | (uint64_t(words[1]) << 32); + } + return result; + } + friend bool operator==(const Operand& o1, const Operand& o2) { return o1.type == o2.type && o1.words == o2.words; } @@ -226,6 +239,10 @@ return dbg_line_insts_; } + const Instruction* dbg_line_inst() const { + return dbg_line_insts_.empty() ? nullptr : &dbg_line_insts_[0]; + } + // Clear line-related debug instructions attached to this instruction. void clear_dbg_line_insts() { dbg_line_insts_.clear(); } @@ -278,6 +295,13 @@ // Sets DebugScope. inline void SetDebugScope(const DebugScope& scope); inline const DebugScope& GetDebugScope() const { return dbg_scope_; } + // Updates DebugInlinedAt of DebugScope and OpLine. + inline void UpdateDebugInlinedAt(uint32_t new_inlined_at); + inline uint32_t GetDebugInlinedAt() const { + return dbg_scope_.GetInlinedAt(); + } + // Updates OpLine and DebugScope based on the information of |from|. + inline void UpdateDebugInfo(const Instruction* from); // Remove the |index|-th operand void RemoveOperand(uint32_t index) { operands_.erase(operands_.begin() + index); @@ -351,6 +375,10 @@ inline bool WhileEachInOperand( const std::function<bool(const uint32_t*)>& f) const; + // Returns true if it's an OpBranchConditional instruction + // with branch weights. + bool HasBranchWeights() const; + // Returns true if any operands can be labels inline bool HasLabels() const; @@ -383,8 +411,14 @@ // Memory-to-memory instructions are not considered loads. inline bool IsLoad() const; - // Returns true if the instruction declares a variable that is read-only. - bool IsReadOnlyVariable() const; + // Returns true if the instruction generates a pointer that is definitely + // read-only. This is determined by analysing the pointer type's storage + // class and decorations that target the pointer's id. It does not analyse + // other instructions that the pointer may be derived from. Thus if 'true' is + // returned, the pointer is definitely read-only, while if 'false' is returned + // it is possible that the pointer may actually be read-only if it is derived + // from another pointer that is decorated as read-only. + bool IsReadOnlyPointer() const; // The following functions check for the various descriptor types defined in // the Vulkan specification section 13.1. @@ -496,6 +530,11 @@ // rules for physical addressing. bool IsValidBasePointer() const; + // Returns debug opcode of an OpenCL.100.DebugInfo instruction. If + // it is not an OpenCL.100.DebugInfo instruction, just returns + // OpenCLDebugInfo100InstructionsMax. + OpenCLDebugInfo100Instructions GetOpenCL100DebugOpcode() const; + // Dump this instruction on stderr. Useful when running interactive // debuggers. void Dump() const; @@ -508,11 +547,12 @@ return 0; } - // Returns true if the instruction declares a variable that is read-only. The - // first version assumes the module is a shader module. The second assumes a + // Returns true if the instruction generates a read-only pointer, with the + // same caveats documented in the comment for IsReadOnlyPointer. The first + // version assumes the module is a shader module. The second assumes a // kernel. - bool IsReadOnlyVariableShaders() const; - bool IsReadOnlyVariableKernel() const; + bool IsReadOnlyPointerShaders() const; + bool IsReadOnlyPointerKernel() const; // Returns true if the result of |inst| can be used as the base image for an // instruction that samples a image, reads an image, or writes to an image. @@ -611,6 +651,21 @@ } } +inline void Instruction::UpdateDebugInlinedAt(uint32_t new_inlined_at) { + dbg_scope_.SetInlinedAt(new_inlined_at); + for (auto& i : dbg_line_insts_) { + i.dbg_scope_.SetInlinedAt(new_inlined_at); + } +} + +inline void Instruction::UpdateDebugInfo(const Instruction* from) { + if (from == nullptr) return; + clear_dbg_line_insts(); + if (!from->dbg_line_insts().empty()) + dbg_line_insts().push_back(from->dbg_line_insts()[0]); + SetDebugScope(from->GetDebugScope()); +} + inline void Instruction::SetResultType(uint32_t ty_id) { // TODO(dsinclair): Allow setting a type id if there wasn't one // previously. Need to make room in the operands_ array to place the result,
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp index c8c6c21..4210ad5 100644 --- a/source/opt/instrument_pass.cpp +++ b/source/opt/instrument_pass.cpp
@@ -885,14 +885,6 @@ } bool InstrumentPass::InstProcessEntryPointCallTree(InstProcessFunction& pfn) { - // Check that format version 2 requested - if (version_ != 2u) { - if (consumer()) { - std::string message = "Unsupported instrumentation format requested"; - consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str()); - } - return false; - } // Make sure all entry points have the same execution model. Do not // instrument if they do not. // TODO(greg-lunarg): Handle mixed stages. Technically, a shader module
diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h index 11afdce..f6884d2 100644 --- a/source/opt/instrument_pass.h +++ b/source/opt/instrument_pass.h
@@ -87,18 +87,7 @@ : Pass(), desc_set_(desc_set), shader_id_(shader_id), - validation_id_(validation_id), - version_(2u) {} - // Create instrumentation pass for |validation_id| which utilizes descriptor - // set |desc_set| for debug input and output buffers and writes |shader_id| - // into debug output records with format |version|. Deprecated. - InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id, - uint32_t version) - : Pass(), - desc_set_(desc_set), - shader_id_(shader_id), - validation_id_(validation_id), - version_(version) {} + validation_id_(validation_id) {} // Initialize state for instrumentation of module. void InitializeInstrument(); @@ -425,9 +414,6 @@ // id for void type uint32_t void_id_; - // Record format version - uint32_t version_; - // boolean to remember storage buffer extension bool storage_buffer_ext_defined_;
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index 72993fd..df04066 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp
@@ -16,6 +16,7 @@ #include <cstring> +#include "OpenCLDebugInfo100.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/opt/log.h" #include "source/opt/mem_pass.h" @@ -29,6 +30,10 @@ static const int kEntryPointInterfaceInIdx = 3; static const int kEntryPointFunctionIdInIdx = 1; +// Constants for OpenCL.DebugInfo.100 extension instructions. +static const uint32_t kDebugFunctionOperandFunctionIndex = 13; +static const uint32_t kDebugGlobalVariableOperandVariableIndex = 11; + } // anonymous namespace namespace spvtools { @@ -80,6 +85,9 @@ if (set & kAnalysisTypes) { BuildTypeManager(); } + if (set & kAnalysisDebugInfo) { + BuildDebugInfoManager(); + } } void IRContext::InvalidateAnalysesExceptFor( @@ -93,6 +101,7 @@ // away, the ConstantManager has to go away. if (analyses_to_invalidate & kAnalysisTypes) { analyses_to_invalidate |= kAnalysisConstants; + analyses_to_invalidate |= kAnalysisDebugInfo; } // The dominator analysis hold the psuedo entry and exit nodes from the CFG. @@ -143,6 +152,10 @@ type_mgr_.reset(nullptr); } + if (analyses_to_invalidate & kAnalysisDebugInfo) { + debug_info_mgr_.reset(nullptr); + } + valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate); } @@ -153,6 +166,8 @@ KillNamesAndDecorates(inst); + KillOperandFromDebugInstructions(inst); + if (AreAnalysesValid(kAnalysisDefUse)) { get_def_use_mgr()->ClearInst(inst); } @@ -265,7 +280,7 @@ bool IRContext::IsConsistent() { #ifndef SPIRV_CHECK_CONTEXT return true; -#endif +#else if (AreAnalysesValid(kAnalysisDefUse)) { analysis::DefUseManager new_def_use(module()); if (*get_def_use_mgr() != new_def_use) { @@ -317,6 +332,7 @@ } } return true; +#endif } void IRContext::ForgetUses(Instruction* inst) { @@ -365,6 +381,42 @@ KillNamesAndDecorates(rId); } +void IRContext::KillOperandFromDebugInstructions(Instruction* inst) { + const auto opcode = inst->opcode(); + const uint32_t id = inst->result_id(); + // Kill id of OpFunction from DebugFunction. + if (opcode == SpvOpFunction) { + for (auto it = module()->ext_inst_debuginfo_begin(); + it != module()->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() != OpenCLDebugInfo100DebugFunction) + continue; + auto& operand = it->GetOperand(kDebugFunctionOperandFunctionIndex); + if (operand.words[0] == id) { + operand.words[0] = + get_debug_info_mgr()->GetDebugInfoNone()->result_id(); + } + } + } + // Kill id of OpVariable for global variable from DebugGlobalVariable. + if (opcode == SpvOpVariable || IsConstantInst(opcode)) { + for (auto it = module()->ext_inst_debuginfo_begin(); + it != module()->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() != + OpenCLDebugInfo100DebugGlobalVariable) + continue; + auto& operand = it->GetOperand(kDebugGlobalVariableOperandVariableIndex); + if (operand.words[0] == id) { + operand.words[0] = + get_debug_info_mgr()->GetDebugInfoNone()->result_id(); + } + } + } + // Notice that we do not need anythings to do for local variables. + // DebugLocalVariable does not have an OpVariable operand. Instead, + // DebugDeclare/DebugValue has an OpVariable operand for a local + // variable. The function inlining pass handles it properly. +} + void IRContext::AddCombinatorsForCapability(uint32_t capability) { if (capability == SpvCapabilityShader) { combinator_ops_[0].insert({SpvOpNop,
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 723a2bb..a1b63ff 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h
@@ -29,6 +29,7 @@ #include "source/assembly_grammar.h" #include "source/opt/cfg.h" #include "source/opt/constants.h" +#include "source/opt/debug_info_manager.h" #include "source/opt/decoration_manager.h" #include "source/opt/def_use_manager.h" #include "source/opt/dominator_analysis.h" @@ -78,7 +79,8 @@ kAnalysisIdToFuncMapping = 1 << 13, kAnalysisConstants = 1 << 14, kAnalysisTypes = 1 << 15, - kAnalysisEnd = 1 << 16 + kAnalysisDebugInfo = 1 << 16, + kAnalysisEnd = 1 << 17 }; using ProcessFunction = std::function<bool(Function*)>; @@ -326,6 +328,17 @@ return type_mgr_.get(); } + // Returns a pointer to the debug information manager. If no debug + // information manager has been created yet, it creates one. + // NOTE: Once created, the debug information manager remains active + // it is never re-built. + analysis::DebugInfoManager* get_debug_info_mgr() { + if (!AreAnalysesValid(kAnalysisDebugInfo)) { + BuildDebugInfoManager(); + } + return debug_info_mgr_.get(); + } + // Returns a pointer to the scalar evolution analysis. If it is invalid it // will be rebuilt first. ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() { @@ -426,6 +439,9 @@ // Kill all name and decorate ops targeting the result id of |inst|. void KillNamesAndDecorates(Instruction* inst); + // Change operands of debug instruction to DebugInfoNone. + void KillOperandFromDebugInstructions(Instruction* inst); + // Returns the next unique id for use by an instruction. inline uint32_t TakeNextUniqueId() { assert(unique_id_ != std::numeric_limits<uint32_t>::max()); @@ -652,6 +668,13 @@ valid_analyses_ = valid_analyses_ | kAnalysisTypes; } + // Builds the debug information manager from scratch, even if it was + // already valid. + void BuildDebugInfoManager() { + debug_info_mgr_ = MakeUnique<analysis::DebugInfoManager>(this); + valid_analyses_ = valid_analyses_ | kAnalysisDebugInfo; + } + // Removes all computed dominator and post-dominator trees. This will force // the context to rebuild the trees on demand. void ResetDominatorAnalysis() { @@ -774,6 +797,9 @@ // Type manager for |module_|. std::unique_ptr<analysis::TypeManager> type_mgr_; + // Debug information manager for |module_|. + std::unique_ptr<analysis::DebugInfoManager> debug_info_mgr_; + // A map from an id to its corresponding OpName and OpMemberName instructions. std::unique_ptr<std::multimap<uint32_t, Instruction*>> id_to_name_;
diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp index fcde079..acd41cd 100644 --- a/source/opt/ir_loader.cpp +++ b/source/opt/ir_loader.cpp
@@ -135,6 +135,8 @@ Error(consumer_, src, loc, "terminator instruction outside basic block"); return false; } + if (last_dbg_scope_.GetLexicalScope() != kNoDebugScope) + spv_inst->SetDebugScope(last_dbg_scope_); block_->AddInstruction(std::move(spv_inst)); function_->AddBasicBlock(std::move(block_)); block_ = nullptr;
diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 04e2e8a..d23d679 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp
@@ -97,6 +97,11 @@ Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId); Instruction* varInst; + if (ptrInst->opcode() == SpvOpConstantNull) { + *varId = 0; + return ptrInst; + } + if (ptrInst->opcode() != SpvOpVariable && ptrInst->opcode() != SpvOpFunctionParameter) { varInst = ptrInst->GetBaseAddress();
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index bbac4bb..8cb4299 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp
@@ -39,8 +39,11 @@ if (!is_shader || return_blocks.size() == 0) { return false; } - if (context()->GetStructuredCFGAnalysis()->ContainingConstruct( - return_blocks[0]->id()) == 0) { + bool isInConstruct = + context()->GetStructuredCFGAnalysis()->ContainingConstruct( + return_blocks[0]->id()) != 0; + bool endsWithReturn = return_blocks[0] == function->tail(); + if (!isInConstruct && endsWithReturn) { return false; } } @@ -421,7 +424,6 @@ auto old_body_id = TakeNextId(); BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter); predicated->insert(old_body); - cfg()->AddEdges(old_body); // If a return block is being split, mark the new body block also as a return // block.
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 0a937e8..25adee9 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp
@@ -175,9 +175,18 @@ .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateCCPPass()) .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateLoopUnrollPass(true)) + .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateCombineAccessChainsPass()) .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateScalarReplacementPass()) + .RegisterPass(CreateLocalAccessChainConvertPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateSSARewritePass()) + .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateDeadBranchElimPass()) @@ -407,19 +416,19 @@ } else if (pass_name == "replace-invalid-opcode") { RegisterPass(CreateReplaceInvalidOpcodePass()); } else if (pass_name == "inst-bindless-check") { - RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false, 2)); + RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false)); RegisterPass(CreateSimplificationPass()); RegisterPass(CreateDeadBranchElimPass()); RegisterPass(CreateBlockMergePass()); RegisterPass(CreateAggressiveDCEPass()); } else if (pass_name == "inst-desc-idx-check") { - RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true, 2)); + RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true)); RegisterPass(CreateSimplificationPass()); RegisterPass(CreateDeadBranchElimPass()); RegisterPass(CreateBlockMergePass()); RegisterPass(CreateAggressiveDCEPass()); } else if (pass_name == "inst-buff-addr-check") { - RegisterPass(CreateInstBuffAddrCheckPass(7, 23, 2)); + RegisterPass(CreateInstBuffAddrCheckPass(7, 23)); RegisterPass(CreateAggressiveDCEPass()); } else if (pass_name == "convert-relaxed-to-half") { RegisterPass(CreateConvertRelaxedToHalfPass()); @@ -498,7 +507,7 @@ } else if (pass_name == "legalize-vector-shuffle") { RegisterPass(CreateLegalizeVectorShufflePass()); } else if (pass_name == "split-invalid-unreachable") { - RegisterPass(CreateLegalizeVectorShufflePass()); + RegisterPass(CreateSplitInvalidUnreachablePass()); } else if (pass_name == "decompose-initialized-variables") { RegisterPass(CreateDecomposeInitializedVariablesPass()); } else if (pass_name == "graphics-robust-access") { @@ -885,12 +894,10 @@ Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id, bool input_length_enable, - bool input_init_enable, - uint32_t version) { + bool input_init_enable) { return MakeUnique<Optimizer::PassToken::Impl>( - MakeUnique<opt::InstBindlessCheckPass>(desc_set, shader_id, - input_length_enable, - input_init_enable, version)); + MakeUnique<opt::InstBindlessCheckPass>( + desc_set, shader_id, input_length_enable, input_init_enable)); } Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set, @@ -900,10 +907,9 @@ } Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set, - uint32_t shader_id, - uint32_t version) { + uint32_t shader_id) { return MakeUnique<Optimizer::PassToken::Impl>( - MakeUnique<opt::InstBuffAddrCheckPass>(desc_set, shader_id, version)); + MakeUnique<opt::InstBuffAddrCheckPass>(desc_set, shader_id)); } Optimizer::PassToken CreateConvertRelaxedToHalfPass() {
diff --git a/source/opt/struct_cfg_analysis.cpp b/source/opt/struct_cfg_analysis.cpp index b16322c..57fc49c 100644 --- a/source/opt/struct_cfg_analysis.cpp +++ b/source/opt/struct_cfg_analysis.cpp
@@ -85,9 +85,14 @@ if (merge_inst->opcode() == SpvOpLoopMerge) { new_state.cinfo.containing_loop = block->id(); new_state.cinfo.containing_switch = 0; - new_state.cinfo.in_continue = false; new_state.continue_node = merge_inst->GetSingleWordInOperand(kContinueNodeIndex); + if (block->id() == new_state.continue_node) { + new_state.cinfo.in_continue = true; + bb_to_construct_[block->id()].in_continue = true; + } else { + new_state.cinfo.in_continue = false; + } } else { new_state.cinfo.containing_loop = state.back().cinfo.containing_loop; new_state.cinfo.in_continue = state.back().cinfo.in_continue;
diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h index 8fcf8aa..ce9d83d 100644 --- a/source/opt/type_manager.h +++ b/source/opt/type_manager.h
@@ -194,6 +194,13 @@ uint32_t GetBoolTypeId() { return GetTypeInstruction(GetBoolType()); } + Type* GetVoidType() { + Void void_type; + return GetRegisteredType(&void_type); + } + + uint32_t GetVoidTypeId() { return GetTypeInstruction(GetVoidType()); } + private: using TypeToIdMap = std::unordered_map<const Type*, uint32_t, HashTypePointer, CompareTypePointers>;
diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp index ffd7a10..3c8bae6 100644 --- a/source/opt/wrap_opkill.cpp +++ b/source/opt/wrap_opkill.cpp
@@ -59,9 +59,12 @@ if (func_id == 0) { return false; } - if (ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}) == nullptr) { + Instruction* call_inst = + ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}); + if (call_inst == nullptr) { return false; } + call_inst->UpdateDebugInfo(inst); Instruction* return_inst = nullptr; uint32_t return_type_id = GetOwningFunctionsReturnType(inst); @@ -147,6 +150,7 @@ bb->AddInstruction(std::move(kill_inst)); // Add the bb to the function + bb->SetParent(opkill_function_.get()); opkill_function_->AddBasicBlock(std::move(bb)); // Add the function to the module.
diff --git a/source/reduce/CMakeLists.txt b/source/reduce/CMakeLists.txt index 51e9b1d..d945bd2 100644 --- a/source/reduce/CMakeLists.txt +++ b/source/reduce/CMakeLists.txt
@@ -26,12 +26,14 @@ reduction_util.h remove_block_reduction_opportunity.h remove_block_reduction_opportunity_finder.h - remove_instruction_reduction_opportunity.h remove_function_reduction_opportunity.h remove_function_reduction_opportunity_finder.h + remove_instruction_reduction_opportunity.h remove_selection_reduction_opportunity.h remove_selection_reduction_opportunity_finder.h - remove_unreferenced_instruction_reduction_opportunity_finder.h + remove_struct_member_reduction_opportunity.h + remove_unused_instruction_reduction_opportunity_finder.h + remove_unused_struct_member_reduction_opportunity_finder.h structured_loop_to_selection_reduction_opportunity.h structured_loop_to_selection_reduction_opportunity_finder.h conditional_branch_to_simple_conditional_branch_opportunity_finder.h @@ -57,7 +59,9 @@ remove_instruction_reduction_opportunity.cpp remove_selection_reduction_opportunity.cpp remove_selection_reduction_opportunity_finder.cpp - remove_unreferenced_instruction_reduction_opportunity_finder.cpp + remove_struct_member_reduction_opportunity.cpp + remove_unused_instruction_reduction_opportunity_finder.cpp + remove_unused_struct_member_reduction_opportunity_finder.cpp structured_loop_to_selection_reduction_opportunity.cpp structured_loop_to_selection_reduction_opportunity_finder.cpp conditional_branch_to_simple_conditional_branch_opportunity_finder.cpp
diff --git a/source/reduce/pch_source_reduce.h b/source/reduce/pch_source_reduce.h index 6c0da0c..81bed20 100644 --- a/source/reduce/pch_source_reduce.h +++ b/source/reduce/pch_source_reduce.h
@@ -20,4 +20,4 @@ #include "source/reduce/reduction_opportunity.h" #include "source/reduce/reduction_pass.h" #include "source/reduce/remove_instruction_reduction_opportunity.h" -#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h"
diff --git a/source/reduce/reducer.cpp b/source/reduce/reducer.cpp index bda41ce..092d409 100644 --- a/source/reduce/reducer.cpp +++ b/source/reduce/reducer.cpp
@@ -25,7 +25,8 @@ #include "source/reduce/remove_block_reduction_opportunity_finder.h" #include "source/reduce/remove_function_reduction_opportunity_finder.h" #include "source/reduce/remove_selection_reduction_opportunity_finder.h" -#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h" #include "source/reduce/simple_conditional_branch_to_branch_opportunity_finder.h" #include "source/reduce/structured_loop_to_selection_reduction_opportunity_finder.h" #include "source/spirv_reducer_options.h" @@ -103,8 +104,8 @@ void Reducer::AddDefaultReductionPasses() { AddReductionPass( - spvtools::MakeUnique< - RemoveUnreferencedInstructionReductionOpportunityFinder>(false)); + spvtools::MakeUnique<RemoveUnusedInstructionReductionOpportunityFinder>( + false)); AddReductionPass( spvtools::MakeUnique<OperandToUndefReductionOpportunityFinder>()); AddReductionPass( @@ -126,12 +127,14 @@ ConditionalBranchToSimpleConditionalBranchOpportunityFinder>()); AddReductionPass( spvtools::MakeUnique<SimpleConditionalBranchToBranchOpportunityFinder>()); + AddReductionPass(spvtools::MakeUnique< + RemoveUnusedStructMemberReductionOpportunityFinder>()); // Cleanup passes. AddCleanupReductionPass( - spvtools::MakeUnique< - RemoveUnreferencedInstructionReductionOpportunityFinder>(true)); + spvtools::MakeUnique<RemoveUnusedInstructionReductionOpportunityFinder>( + true)); } void Reducer::AddReductionPass(
diff --git a/source/reduce/remove_instruction_reduction_opportunity.cpp b/source/reduce/remove_instruction_reduction_opportunity.cpp index 9ca093b..8026204 100644 --- a/source/reduce/remove_instruction_reduction_opportunity.cpp +++ b/source/reduce/remove_instruction_reduction_opportunity.cpp
@@ -22,6 +22,18 @@ bool RemoveInstructionReductionOpportunity::PreconditionHolds() { return true; } void RemoveInstructionReductionOpportunity::Apply() { + const uint32_t kNumEntryPointInOperandsBeforeInterfaceIds = 3; + for (auto& entry_point : inst_->context()->module()->entry_points()) { + opt::Instruction::OperandList new_entry_point_in_operands; + for (uint32_t index = 0; index < entry_point.NumInOperands(); index++) { + if (index >= kNumEntryPointInOperandsBeforeInterfaceIds && + entry_point.GetSingleWordInOperand(index) == inst_->result_id()) { + continue; + } + new_entry_point_in_operands.push_back(entry_point.GetInOperand(index)); + } + entry_point.SetInOperands(std::move(new_entry_point_in_operands)); + } inst_->context()->KillInst(inst_); }
diff --git a/source/reduce/remove_struct_member_reduction_opportunity.cpp b/source/reduce/remove_struct_member_reduction_opportunity.cpp new file mode 100644 index 0000000..787c629 --- /dev/null +++ b/source/reduce/remove_struct_member_reduction_opportunity.cpp
@@ -0,0 +1,208 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/remove_struct_member_reduction_opportunity.h" + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace reduce { + +bool RemoveStructMemberReductionOpportunity::PreconditionHolds() { + return struct_type_->NumInOperands() == original_number_of_members_; +} + +void RemoveStructMemberReductionOpportunity::Apply() { + std::set<opt::Instruction*> decorations_to_kill; + + // We need to remove decorations that target the removed struct member, and + // adapt decorations that target later struct members by decrementing the + // member identifier. We also need to adapt composite construction + // instructions so that no id is provided for the member being removed. + // + // To do this, we consider every use of the struct type. + struct_type_->context()->get_def_use_mgr()->ForEachUse( + struct_type_, [this, &decorations_to_kill](opt::Instruction* user, + uint32_t /*operand_index*/) { + switch (user->opcode()) { + case SpvOpCompositeConstruct: + case SpvOpConstantComposite: + // This use is constructing a composite of the struct type, so we + // must remove the id that was provided for the member we are + // removing. + user->RemoveInOperand(member_index_); + break; + case SpvOpMemberDecorate: + // This use is decorating a member of the struct. + if (user->GetSingleWordInOperand(1) == member_index_) { + // The member we are removing is being decorated, so we record + // that we need to get rid of the decoration. + decorations_to_kill.insert(user); + } else if (user->GetSingleWordInOperand(1) > member_index_) { + // A member beyond the one we are removing is being decorated, so + // we adjust the index that identifies the member. + user->SetInOperand(1, {user->GetSingleWordInOperand(1) - 1}); + } + break; + default: + break; + } + }); + + // Get rid of all the decorations that were found to target the member being + // removed. + for (auto decoration_to_kill : decorations_to_kill) { + decoration_to_kill->context()->KillInst(decoration_to_kill); + } + + // We now look through all instructions that access composites via sequences + // of indices. Every time we find an index into the struct whose member is + // being removed, and if the member being accessed comes after the member + // being removed, we need to adjust the index accordingly. + // + // We go through every relevant instruction in every block of every function, + // and invoke a helper to adjust it. + auto context = struct_type_->context(); + for (auto& function : *context->module()) { + for (auto& block : function) { + for (auto& inst : block) { + switch (inst.opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + // These access chain instructions take sequences of ids for + // indexing, starting from input operand 1. + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(0)) + ->type_id()) + ->GetSingleWordInOperand(1); + AdjustAccessedIndices(composite_type_id, 1, false, context, &inst); + } break; + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: { + // These access chain instructions take sequences of ids for + // indexing, starting from input operand 2. + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(1)) + ->type_id()) + ->GetSingleWordInOperand(1); + AdjustAccessedIndices(composite_type_id, 2, false, context, &inst); + } break; + case SpvOpCompositeExtract: { + // OpCompositeExtract uses literals for indexing, starting at input + // operand 1. + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(0)) + ->type_id(); + AdjustAccessedIndices(composite_type_id, 1, true, context, &inst); + } break; + case SpvOpCompositeInsert: { + // OpCompositeInsert uses literals for indexing, starting at input + // operand 2. + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(1)) + ->type_id(); + AdjustAccessedIndices(composite_type_id, 2, true, context, &inst); + } break; + default: + break; + } + } + } + } + + // Remove the member from the struct type. + struct_type_->RemoveInOperand(member_index_); +} + +void RemoveStructMemberReductionOpportunity::AdjustAccessedIndices( + uint32_t composite_type_id, uint32_t first_index_input_operand, + bool literal_indices, opt::IRContext* context, + opt::Instruction* composite_access_instruction) const { + // Walk the series of types that are encountered by following the + // instruction's sequence of indices. For all types except structs, this is + // routine: the type of the composite dictates what the next type will be + // regardless of the specific index value. + uint32_t next_type = composite_type_id; + for (uint32_t i = first_index_input_operand; + i < composite_access_instruction->NumInOperands(); i++) { + auto type_inst = context->get_def_use_mgr()->GetDef(next_type); + switch (type_inst->opcode()) { + case SpvOpTypeArray: + case SpvOpTypeMatrix: + case SpvOpTypeRuntimeArray: + case SpvOpTypeVector: + next_type = type_inst->GetSingleWordInOperand(0); + break; + case SpvOpTypeStruct: { + // Struct types are special becuase (a) we may need to adjust the index + // being used, if the struct type is the one from which we are removing + // a member, and (b) the type encountered by following the current index + // is dependent on the value of the index. + + // Work out the member being accessed. If literal indexing is used this + // is simple; otherwise we need to look up the id of the constant + // instruction being used as an index and get the value of the constant. + uint32_t index_operand = + composite_access_instruction->GetSingleWordInOperand(i); + uint32_t member = literal_indices ? index_operand + : context->get_def_use_mgr() + ->GetDef(index_operand) + ->GetSingleWordInOperand(0); + + // The next type we will consider is obtained by looking up the struct + // type at |member|. + next_type = type_inst->GetSingleWordInOperand(member); + + if (type_inst == struct_type_ && member > member_index_) { + // The struct type is the struct from which we are removing a member, + // and the member being accessed is beyond the member we are removing. + // We thus need to decrement the index by 1. + uint32_t new_in_operand; + if (literal_indices) { + // With literal indexing this is straightforward. + new_in_operand = member - 1; + } else { + // With id-based indexing this is more tricky: we need to find or + // create a constant instruction whose value is one less than + // |member|, and use the id of this constant as the replacement + // input operand. + auto constant_inst = + context->get_def_use_mgr()->GetDef(index_operand); + auto int_type = context->get_type_mgr() + ->GetType(constant_inst->type_id()) + ->AsInteger(); + auto new_index_constant = + opt::analysis::IntConstant(int_type, {member - 1}); + new_in_operand = context->get_constant_mgr() + ->GetDefiningInstruction(&new_index_constant) + ->result_id(); + } + composite_access_instruction->SetInOperand(i, {new_in_operand}); + } + } break; + default: + assert(0 && "Unknown composite type."); + break; + } + } +} + +} // namespace reduce +} // namespace spvtools
diff --git a/source/reduce/remove_struct_member_reduction_opportunity.h b/source/reduce/remove_struct_member_reduction_opportunity.h new file mode 100644 index 0000000..899e5ea --- /dev/null +++ b/source/reduce/remove_struct_member_reduction_opportunity.h
@@ -0,0 +1,84 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_REMOVE_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_H_ + +#include "source/reduce/reduction_opportunity.h" + +#include "source/opt/instruction.h" + +namespace spvtools { +namespace reduce { + +// An opportunity for removing a member from a struct type, adjusting all uses +// of the struct accordingly. +class RemoveStructMemberReductionOpportunity : public ReductionOpportunity { + public: + // Constructs a reduction opportunity from the struct type |struct_type|, for + // removal of member |member_index|. + RemoveStructMemberReductionOpportunity(opt::Instruction* struct_type, + uint32_t member_index) + : struct_type_(struct_type), + member_index_(member_index), + original_number_of_members_(struct_type->NumInOperands()) {} + + // Opportunities to remove fields from a common struct type mutually + // invalidate each other. We guard against this by requiring that the struct + // still has the number of members it had when the opportunity was created. + bool PreconditionHolds() override; + + protected: + void Apply() override; + + private: + // |composite_access_instruction| is an instruction that accesses a composite + // id using either a series of literal indices (e.g. in the case of + // OpCompositeInsert) or a series of index ids (e.g. in the case of + // OpAccessChain). + // + // This function adjusts the indices that are used by + // |composite_access_instruction| to that whenever an index is accessing a + // member of |struct_type_|, it is decremented if the member is beyond + // |member_index_|, to account for the removal of the |member_index_|-th + // member. + // + // |composite_type_id| is the id of the composite type that the series of + // indices is to be applied to. + // + // |first_index_input_operand| specifies the first input operand that is an + // index. + // + // |literal_indices| specifies whether indices are given as literals (true), + // or as ids (false). + // + // If id-based indexing is used, this function will add a constant for + // |member_index_| - 1 to the module if needed. + void AdjustAccessedIndices( + uint32_t composite_type_id, uint32_t first_index_input_operand, + bool literal_indices, opt::IRContext* context, + opt::Instruction* composite_access_instruction) const; + + // The struct type from which a member is to be removed. + opt::Instruction* struct_type_; + + uint32_t member_index_; + + uint32_t original_number_of_members_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_H_
diff --git a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h b/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h deleted file mode 100644 index bc4f137..0000000 --- a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h +++ /dev/null
@@ -1,48 +0,0 @@ -// Copyright (c) 2018 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_ -#define SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_ - -#include "source/reduce/reduction_opportunity_finder.h" - -namespace spvtools { -namespace reduce { - -// A finder for opportunities to remove non-control-flow instructions in blocks -// in cases where the instruction's id is not referenced. As well as making the -// module smaller, removing an instruction that references particular ids may -// create opportunities for subsequently removing the instructions that -// generated those ids. -class RemoveUnreferencedInstructionReductionOpportunityFinder - : public ReductionOpportunityFinder { - public: - explicit RemoveUnreferencedInstructionReductionOpportunityFinder( - bool remove_constants_and_undefs); - - ~RemoveUnreferencedInstructionReductionOpportunityFinder() override = default; - - std::string GetName() const final; - - std::vector<std::unique_ptr<ReductionOpportunity>> GetAvailableOpportunities( - opt::IRContext* context) const final; - - private: - bool remove_constants_and_undefs_; -}; - -} // namespace reduce -} // namespace spvtools - -#endif // SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_
diff --git a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp b/source/reduce/remove_unused_instruction_reduction_opportunity_finder.cpp similarity index 60% rename from source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp rename to source/reduce/remove_unused_instruction_reduction_opportunity_finder.cpp index ce66691..91ec542 100644 --- a/source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.cpp +++ b/source/reduce/remove_unused_instruction_reduction_opportunity_finder.cpp
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h" #include "source/opcode.h" #include "source/opt/instruction.h" @@ -21,14 +21,14 @@ namespace spvtools { namespace reduce { -RemoveUnreferencedInstructionReductionOpportunityFinder:: - RemoveUnreferencedInstructionReductionOpportunityFinder( +RemoveUnusedInstructionReductionOpportunityFinder:: + RemoveUnusedInstructionReductionOpportunityFinder( bool remove_constants_and_undefs) : remove_constants_and_undefs_(remove_constants_and_undefs) {} std::vector<std::unique_ptr<ReductionOpportunity>> -RemoveUnreferencedInstructionReductionOpportunityFinder:: - GetAvailableOpportunities(opt::IRContext* context) const { +RemoveUnusedInstructionReductionOpportunityFinder::GetAvailableOpportunities( + opt::IRContext* context) const { std::vector<std::unique_ptr<ReductionOpportunity>> result; for (auto& inst : context->module()->debugs1()) { @@ -60,13 +60,14 @@ } for (auto& inst : context->module()->types_values()) { - if (context->get_def_use_mgr()->NumUsers(&inst) > 0) { - continue; - } if (!remove_constants_and_undefs_ && spvOpcodeIsConstantOrUndef(inst.opcode())) { continue; } + if (!OnlyReferencedByIntimateDecorationOrEntryPointInterface(context, + inst)) { + continue; + } result.push_back(MakeUnique<RemoveInstructionReductionOpportunity>(&inst)); } @@ -74,38 +75,9 @@ if (context->get_def_use_mgr()->NumUsers(&inst) > 0) { continue; } - - uint32_t decoration = SpvDecorationMax; - switch (inst.opcode()) { - case SpvOpDecorate: - case SpvOpDecorateId: - case SpvOpDecorateString: - decoration = inst.GetSingleWordInOperand(1u); - break; - case SpvOpMemberDecorate: - case SpvOpMemberDecorateString: - decoration = inst.GetSingleWordInOperand(2u); - break; - default: - break; + if (!IsIndependentlyRemovableDecoration(inst)) { + continue; } - - // We conservatively only remove specific decorations that we believe will - // not change the shader interface, will not make the shader invalid, will - // actually be found in practice, etc. - - switch (decoration) { - case SpvDecorationRelaxedPrecision: - case SpvDecorationNoSignedWrap: - case SpvDecorationNoContraction: - case SpvDecorationNoUnsignedWrap: - case SpvDecorationUserSemantic: - break; - default: - // Give up. - continue; - } - result.push_back(MakeUnique<RemoveInstructionReductionOpportunity>(&inst)); } @@ -139,9 +111,54 @@ return result; } -std::string RemoveUnreferencedInstructionReductionOpportunityFinder::GetName() - const { - return "RemoveUnreferencedInstructionReductionOpportunityFinder"; +std::string RemoveUnusedInstructionReductionOpportunityFinder::GetName() const { + return "RemoveUnusedInstructionReductionOpportunityFinder"; +} + +bool RemoveUnusedInstructionReductionOpportunityFinder:: + OnlyReferencedByIntimateDecorationOrEntryPointInterface( + opt::IRContext* context, const opt::Instruction& inst) const { + return context->get_def_use_mgr()->WhileEachUse( + &inst, [this](opt::Instruction* user, uint32_t use_index) -> bool { + return (user->IsDecoration() && + !IsIndependentlyRemovableDecoration(*user)) || + (user->opcode() == SpvOpEntryPoint && use_index > 2); + }); +} + +bool RemoveUnusedInstructionReductionOpportunityFinder:: + IsIndependentlyRemovableDecoration(const opt::Instruction& inst) const { + uint32_t decoration; + switch (inst.opcode()) { + case SpvOpDecorate: + case SpvOpDecorateId: + case SpvOpDecorateString: + decoration = inst.GetSingleWordInOperand(1u); + break; + case SpvOpMemberDecorate: + case SpvOpMemberDecorateString: + decoration = inst.GetSingleWordInOperand(2u); + break; + default: + // The instruction is not a decoration. It is legitimate for this to be + // reached: it allows the method to be invoked on arbitrary instructions. + return false; + } + + // We conservatively only remove specific decorations that we believe will + // not change the shader interface, will not make the shader invalid, will + // actually be found in practice, etc. + + switch (decoration) { + case SpvDecorationRelaxedPrecision: + case SpvDecorationNoSignedWrap: + case SpvDecorationNoContraction: + case SpvDecorationNoUnsignedWrap: + case SpvDecorationUserSemantic: + return true; + default: + return false; + } } } // namespace reduce
diff --git a/source/reduce/remove_unused_instruction_reduction_opportunity_finder.h b/source/reduce/remove_unused_instruction_reduction_opportunity_finder.h new file mode 100644 index 0000000..cbf6a5b --- /dev/null +++ b/source/reduce/remove_unused_instruction_reduction_opportunity_finder.h
@@ -0,0 +1,61 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_ +#define SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_ + +#include "source/reduce/reduction_opportunity_finder.h" + +namespace spvtools { +namespace reduce { + +// A finder for opportunities to remove non-control-flow instructions in blocks +// in cases where the instruction's id is either not referenced at all, or +// referenced only in a trivial manner (for example, we regard a struct type as +// unused if it is referenced only by struct layout decorations). As well as +// making the module smaller, removing an instruction that references particular +// ids may create opportunities for subsequently removing the instructions that +// generated those ids. +class RemoveUnusedInstructionReductionOpportunityFinder + : public ReductionOpportunityFinder { + public: + explicit RemoveUnusedInstructionReductionOpportunityFinder( + bool remove_constants_and_undefs); + + ~RemoveUnusedInstructionReductionOpportunityFinder() override = default; + + std::string GetName() const final; + + std::vector<std::unique_ptr<ReductionOpportunity>> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: + // Returns true if and only if the only uses of |inst| are by decorations that + // relate intimately to the instruction (as opposed to decorations that could + // be removed independently), or by interface ids in OpEntryPoint. + bool OnlyReferencedByIntimateDecorationOrEntryPointInterface( + opt::IRContext* context, const opt::Instruction& inst) const; + + // Returns true if and only if |inst| is a decoration instruction that can + // legitimately be removed on its own (rather than one that has to be removed + // simultaneously with other instructions). + bool IsIndependentlyRemovableDecoration(const opt::Instruction& inst) const; + + bool remove_constants_and_undefs_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_OPPORTUNITY_FINDER_H_
diff --git a/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.cpp b/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.cpp new file mode 100644 index 0000000..39ce47f --- /dev/null +++ b/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.cpp
@@ -0,0 +1,193 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h" + +#include <map> +#include <set> + +#include "source/reduce/remove_struct_member_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +std::vector<std::unique_ptr<ReductionOpportunity>> +RemoveUnusedStructMemberReductionOpportunityFinder::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector<std::unique_ptr<ReductionOpportunity>> result; + + // We track those struct members that are never accessed. We do this by + // associating a member index to all the structs that have this member index + // but do not use it. This representation is designed to allow reduction + // opportunities to be provided in a useful manner, so that opportunities + // associated with the same struct are unlikely to be adjacent. + std::map<uint32_t, std::set<opt::Instruction*>> unused_member_to_structs; + + // Consider every struct type in the module. + for (auto& type_or_value : context->types_values()) { + if (type_or_value.opcode() != SpvOpTypeStruct) { + continue; + } + + // Initially, we assume that *every* member of the struct is unused. We + // then refine this based on observed uses. + std::set<uint32_t> unused_members; + for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { + unused_members.insert(i); + } + + // A separate reduction pass deals with removal of names. If a struct + // member is still named, we treat it as being used. + context->get_def_use_mgr()->ForEachUse( + &type_or_value, + [&unused_members](opt::Instruction* user, uint32_t /*operand_index*/) { + switch (user->opcode()) { + case SpvOpMemberName: + unused_members.erase(user->GetSingleWordInOperand(1)); + break; + default: + break; + } + }); + + for (uint32_t member : unused_members) { + if (!unused_member_to_structs.count(member)) { + unused_member_to_structs.insert( + {member, std::set<opt::Instruction*>()}); + } + unused_member_to_structs.at(member).insert(&type_or_value); + } + } + + // We now go through every instruction that might index into a struct, and + // refine our tracking of which struct members are used based on the struct + // indexing we observe. We cannot just go through all uses of a struct type + // because the type is not necessarily even referenced, e.g. when walking + // arrays of structs. + for (auto& function : *context->module()) { + for (auto& block : function) { + for (auto& inst : block) { + switch (inst.opcode()) { + // For each indexing operation we observe, we invoke a helper to + // remove from our map those struct indices that are found to be used. + // The way the helper is invoked depends on whether the instruction + // uses literal or id indices, and the offset into the instruction's + // input operands from which index operands are provided. + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(0)) + ->type_id()) + ->GetSingleWordInOperand(1); + MarkAccessedMembersAsUsed(context, composite_type_id, 1, false, + inst, &unused_member_to_structs); + } break; + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: { + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(1)) + ->type_id()) + ->GetSingleWordInOperand(1); + MarkAccessedMembersAsUsed(context, composite_type_id, 2, false, + inst, &unused_member_to_structs); + } break; + case SpvOpCompositeExtract: { + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(0)) + ->type_id(); + MarkAccessedMembersAsUsed(context, composite_type_id, 1, true, inst, + &unused_member_to_structs); + } break; + case SpvOpCompositeInsert: { + auto composite_type_id = + context->get_def_use_mgr() + ->GetDef(inst.GetSingleWordInOperand(1)) + ->type_id(); + MarkAccessedMembersAsUsed(context, composite_type_id, 2, true, inst, + &unused_member_to_structs); + } break; + default: + break; + } + } + } + } + + // We now know those struct indices that are unsed, and we make a reduction + // opportunity for each of them. By mapping each relevant member index to the + // structs in which it is unsed, we will group all opportunities to remove + // member k of a struct (for some k) together. This reduces the likelihood + // that opportunities to remove members from the same struct will be adjacent, + // which is good because such opportunities mutually disable one another. + for (auto& entry : unused_member_to_structs) { + for (auto struct_type : entry.second) { + result.push_back(MakeUnique<RemoveStructMemberReductionOpportunity>( + struct_type, entry.first)); + } + } + return result; +} + +void RemoveUnusedStructMemberReductionOpportunityFinder:: + MarkAccessedMembersAsUsed( + opt::IRContext* context, uint32_t composite_type_id, + uint32_t first_index_in_operand, bool literal_indices, + const opt::Instruction& composite_access_instruction, + std::map<uint32_t, std::set<opt::Instruction*>>* + unused_member_to_structs) const { + uint32_t next_type = composite_type_id; + for (uint32_t i = first_index_in_operand; + i < composite_access_instruction.NumInOperands(); i++) { + auto type_inst = context->get_def_use_mgr()->GetDef(next_type); + switch (type_inst->opcode()) { + case SpvOpTypeArray: + case SpvOpTypeMatrix: + case SpvOpTypeRuntimeArray: + case SpvOpTypeVector: + next_type = type_inst->GetSingleWordInOperand(0); + break; + case SpvOpTypeStruct: { + uint32_t index_operand = + composite_access_instruction.GetSingleWordInOperand(i); + uint32_t member = literal_indices ? index_operand + : context->get_def_use_mgr() + ->GetDef(index_operand) + ->GetSingleWordInOperand(0); + // Remove the struct type from the struct types associated with this + // member index, but only if a set of struct types is known to be + // associated with this member index. + if (unused_member_to_structs->count(member)) { + unused_member_to_structs->at(member).erase(type_inst); + } + next_type = type_inst->GetSingleWordInOperand(member); + } break; + default: + assert(0 && "Unknown composite type."); + break; + } + } +} + +std::string RemoveUnusedStructMemberReductionOpportunityFinder::GetName() + const { + return "RemoveUnusedStructMemberReductionOpportunityFinder"; +} + +} // namespace reduce +} // namespace spvtools
diff --git a/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h b/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h new file mode 100644 index 0000000..13f4017 --- /dev/null +++ b/source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h
@@ -0,0 +1,61 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_UNUSED_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_FINDER_H_ +#define SOURCE_REDUCE_REMOVE_UNUSED_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_FINDER_H_ + +#include "source/reduce/reduction_opportunity_finder.h" + +namespace spvtools { +namespace reduce { + +// A finder for opportunities to remove struct members that are not explicitly +// used by extract, insert or access chain instructions. +class RemoveUnusedStructMemberReductionOpportunityFinder + : public ReductionOpportunityFinder { + public: + RemoveUnusedStructMemberReductionOpportunityFinder() = default; + + ~RemoveUnusedStructMemberReductionOpportunityFinder() override = default; + + std::string GetName() const final; + + std::vector<std::unique_ptr<ReductionOpportunity>> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: + // A helper method to update |unused_members_to_structs| by removing from it + // all struct member accesses that take place in + // |composite_access_instruction|. + // + // |composite_type_id| is the type of the root object indexed into by the + // instruction. + // + // |first_index_in_operand| provides indicates where in the input operands the + // sequence of indices begins. + // + // |literal_indices| indicates whether indices are literals (true) or ids + // (false). + void MarkAccessedMembersAsUsed( + opt::IRContext* context, uint32_t composite_type_id, + uint32_t first_index_in_operand, bool literal_indices, + const opt::Instruction& composite_access_instruction, + std::map<uint32_t, std::set<opt::Instruction*>>* unused_member_to_structs) + const; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_UNUSED_STRUCT_MEMBER_REDUCTION_OPPORTUNITY_FINDER_H_
diff --git a/source/spirv_reducer_options.cpp b/source/spirv_reducer_options.cpp index 5801d0a..e807875 100644 --- a/source/spirv_reducer_options.cpp +++ b/source/spirv_reducer_options.cpp
@@ -19,7 +19,7 @@ namespace { // The default maximum number of steps the reducer will take before giving up. -const uint32_t kDefaultStepLimit = 250; +const uint32_t kDefaultStepLimit = 2500; } // namespace spv_reducer_options_t::spv_reducer_options_t()
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp index 1c279f6..1e33e51 100644 --- a/source/val/validate_cfg.cpp +++ b/source/val/validate_cfg.cpp
@@ -1090,8 +1090,9 @@ return _.diag(SPV_ERROR_INVALID_CFG, inst) << "OpReturn can only be called from a function with void " << "return type."; + _.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode); + break; } - // Fallthrough. case SpvOpKill: case SpvOpReturnValue: case SpvOpUnreachable:
diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp index 3b44833..ce09e18 100644 --- a/source/val/validate_decorations.cpp +++ b/source/val/validate_decorations.cpp
@@ -1524,6 +1524,22 @@ return SPV_SUCCESS; } +// Returns SPV_SUCCESS if validation rules are satisfied for the Block +// decoration. Otherwise emits a diagnostic and returns something other than +// SPV_SUCCESS. +spv_result_t CheckBlockDecoration(ValidationState_t& vstate, + const Instruction& inst, + const Decoration& decoration) { + assert(inst.id() && "Parser ensures the target of the decoration has an ID"); + if (inst.opcode() != SpvOpTypeStruct) { + const char* const dec_name = + decoration.dec_type() == SpvDecorationBlock ? "Block" : "BufferBlock"; + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << dec_name << " decoration on a non-struct type."; + } + return SPV_SUCCESS; +} + #define PASS_OR_BAIL_AT_LINE(X, LINE) \ { \ spv_result_t e##LINE = (X); \ @@ -1570,6 +1586,10 @@ case SpvDecorationNoUnsignedWrap: PASS_OR_BAIL(CheckIntegerWrapDecoration(vstate, *inst, decoration)); break; + case SpvDecorationBlock: + case SpvDecorationBufferBlock: + PASS_OR_BAIL(CheckBlockDecoration(vstate, *inst, decoration)); + break; default: break; }
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index 1e311c1..7ce681c 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp
@@ -2300,7 +2300,14 @@ ValidateOperandLexicalScope(_, "Parent", inst, 10, ext_inst_name); if (validate_parent != SPV_SUCCESS) return validate_parent; CHECK_OPERAND("Linkage Name", SpvOpString, 11); - CHECK_OPERAND("Size", SpvOpConstant, 12); + if (!DoesDebugInfoOperandMatchExpectation( + _, + [](OpenCLDebugInfo100Instructions dbg_inst) { + return dbg_inst == OpenCLDebugInfo100DebugInfoNone; + }, + inst, 12)) { + CHECK_OPERAND("Size", SpvOpConstant, 12); + } for (uint32_t word_index = 14; word_index < num_words; ++word_index) { if (!DoesDebugInfoOperandMatchExpectation( _,
diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index f130eac..596186b 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp
@@ -71,6 +71,7 @@ } const std::vector<SpvOp> acceptable = { + SpvOpGroupDecorate, SpvOpDecorate, SpvOpEnqueueKernel, SpvOpEntryPoint,
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp index 5b77058..9ce74a3 100644 --- a/source/val/validate_image.cpp +++ b/source/val/validate_image.cpp
@@ -160,6 +160,17 @@ } } +bool IsValidGatherLodBiasAMD(const ValidationState_t& _, SpvOp opcode) { + switch (opcode) { + case SpvOpImageGather: + case SpvOpImageSparseGather: + return _.HasCapability(SpvCapabilityImageGatherBiasLodAMD); + default: + break; + } + return false; +} + // Returns true if the opcode is a Image instruction which applies // homogenous projection to the coordinates. bool IsProj(SpvOp opcode) { @@ -260,11 +271,12 @@ const bool is_implicit_lod = IsImplicitLod(opcode); const bool is_explicit_lod = IsExplicitLod(opcode); const bool is_valid_lod_operand = IsValidLodOperand(_, opcode); + const bool is_valid_gather_lod_bias_amd = IsValidGatherLodBiasAMD(_, opcode); // The checks should be done in the order of definition of OperandImage. if (mask & SpvImageOperandsBiasMask) { - if (!is_implicit_lod) { + if (!is_implicit_lod && !is_valid_gather_lod_bias_amd) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image Operand Bias can only be used with ImplicitLod opcodes"; } @@ -290,7 +302,7 @@ if (mask & SpvImageOperandsLodMask) { if (!is_valid_lod_operand && opcode != SpvOpImageFetch && - opcode != SpvOpImageSparseFetch) { + opcode != SpvOpImageSparseFetch && !is_valid_gather_lod_bias_amd) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image Operand Lod can only be used with ExplicitLod opcodes " << "and OpImageFetch"; @@ -303,7 +315,7 @@ } const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); - if (is_explicit_lod) { + if (is_explicit_lod || is_valid_gather_lod_bias_amd) { if (!_.IsFloatScalarType(type_id)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Image Operand Lod to be float scalar when used "
diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp index ea3ebcb..a6fb26d 100644 --- a/source/val/validate_scopes.cpp +++ b/source/val/validate_scopes.cpp
@@ -230,11 +230,33 @@ if ((_.context()->target_env == SPV_ENV_VULKAN_1_1 || _.context()->target_env == SPV_ENV_VULKAN_1_2) && value != SpvScopeDevice && value != SpvScopeWorkgroup && - value != SpvScopeSubgroup && value != SpvScopeInvocation) { + value != SpvScopeSubgroup && value != SpvScopeInvocation && + value != SpvScopeShaderCallKHR) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": in Vulkan 1.1 and 1.2 environment Memory Scope is limited " - << "to Device, Workgroup and Invocation"; + << "to Device, Workgroup, Invocation, and ShaderCall"; + } + + if (value == SpvScopeShaderCallKHR) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + [](SpvExecutionModel model, std::string* message) { + if (model != SpvExecutionModelRayGenerationKHR && + model != SpvExecutionModelIntersectionKHR && + model != SpvExecutionModelAnyHitKHR && + model != SpvExecutionModelClosestHitKHR && + model != SpvExecutionModelMissKHR && + model != SpvExecutionModelCallableKHR) { + if (message) { + *message = + "ShaderCallKHR Memory Scope requires a ray tracing " + "execution model"; + } + return false; + } + return true; + }); } }
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index 99a78fd..dca142a 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt
@@ -21,12 +21,13 @@ equivalence_relation_test.cpp fact_manager_test.cpp fuzz_test_util.cpp - fuzzer_pass_add_useful_constructs_test.cpp + fuzzer_pass_construct_composites_test.cpp fuzzer_pass_donate_modules_test.cpp instruction_descriptor_test.cpp transformation_access_chain_test.cpp transformation_add_constant_boolean_test.cpp transformation_add_constant_composite_test.cpp + transformation_add_constant_null_test.cpp transformation_add_constant_scalar_test.cpp transformation_add_dead_block_test.cpp transformation_add_dead_break_test.cpp @@ -45,8 +46,10 @@ transformation_add_type_pointer_test.cpp transformation_add_type_struct_test.cpp transformation_add_type_vector_test.cpp + transformation_adjust_branch_weights_test.cpp transformation_composite_construct_test.cpp transformation_composite_extract_test.cpp + transformation_compute_data_synonym_fact_closure_test.cpp transformation_copy_object_test.cpp transformation_equation_instruction_test.cpp transformation_function_call_test.cpp
diff --git a/test/fuzz/data_synonym_transformation_test.cpp b/test/fuzz/data_synonym_transformation_test.cpp index 21ea068..66ce769 100644 --- a/test/fuzz/data_synonym_transformation_test.cpp +++ b/test/fuzz/data_synonym_transformation_test.cpp
@@ -123,13 +123,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFact(MakeSynonymFact(12, {}, 100, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(13, {}, 100, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(22, {}, 100, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(28, {}, 101, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(23, {}, 101, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(32, {}, 101, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(23, {}, 101, {3}), context.get()); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(12, {}, 100, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(13, {}, 100, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(22, {}, 100, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(28, {}, 101, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(23, {}, 101, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(32, {}, 101, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(23, {}, 101, {3}), context.get()); // Replace %12 with %100[0] in '%25 = OpAccessChain %24 %20 %12' auto instruction_descriptor_1 = @@ -139,13 +150,16 @@ // Bad: id already in use auto bad_extract_1 = TransformationCompositeExtract( MakeInstructionDescriptor(25, SpvOpAccessChain, 0), 25, 100, {0}); - ASSERT_TRUE(good_extract_1.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_extract_1.IsApplicable(context.get(), fact_manager)); - good_extract_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_1.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_extract_1.IsApplicable(context.get(), transformation_context)); + good_extract_1.Apply(context.get(), &transformation_context); auto replacement_1 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(12, instruction_descriptor_1, 1), 102); - ASSERT_TRUE(replacement_1.IsApplicable(context.get(), fact_manager)); - replacement_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_1.IsApplicable(context.get(), transformation_context)); + replacement_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %13 with %100[1] in 'OpStore %15 %13' @@ -153,12 +167,14 @@ auto good_extract_2 = TransformationCompositeExtract(instruction_descriptor_2, 103, 100, {1}); // No bad example provided here. - ASSERT_TRUE(good_extract_2.IsApplicable(context.get(), fact_manager)); - good_extract_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_2.IsApplicable(context.get(), transformation_context)); + good_extract_2.Apply(context.get(), &transformation_context); auto replacement_2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(13, instruction_descriptor_2, 1), 103); - ASSERT_TRUE(replacement_2.IsApplicable(context.get(), fact_manager)); - replacement_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_2.IsApplicable(context.get(), transformation_context)); + replacement_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %22 with %100[2] in '%23 = OpConvertSToF %16 %22' @@ -166,16 +182,19 @@ MakeInstructionDescriptor(23, SpvOpConvertSToF, 0); auto good_extract_3 = TransformationCompositeExtract(instruction_descriptor_3, 104, 100, {2}); - ASSERT_TRUE(good_extract_3.IsApplicable(context.get(), fact_manager)); - good_extract_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_3.IsApplicable(context.get(), transformation_context)); + good_extract_3.Apply(context.get(), &transformation_context); auto replacement_3 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(22, instruction_descriptor_3, 0), 104); // Bad: wrong input operand index auto bad_replacement_3 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(22, instruction_descriptor_3, 1), 104); - ASSERT_TRUE(replacement_3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_replacement_3.IsApplicable(context.get(), fact_manager)); - replacement_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_3.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_replacement_3.IsApplicable(context.get(), transformation_context)); + replacement_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %28 with %101[0] in 'OpStore %33 %28' @@ -185,13 +204,16 @@ // Bad: instruction descriptor does not identify an appropriate instruction auto bad_extract_4 = TransformationCompositeExtract( MakeInstructionDescriptor(33, SpvOpCopyObject, 0), 105, 101, {0}); - ASSERT_TRUE(good_extract_4.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_extract_4.IsApplicable(context.get(), fact_manager)); - good_extract_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_4.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_extract_4.IsApplicable(context.get(), transformation_context)); + good_extract_4.Apply(context.get(), &transformation_context); auto replacement_4 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(28, instruction_descriptor_4, 1), 105); - ASSERT_TRUE(replacement_4.IsApplicable(context.get(), fact_manager)); - replacement_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_4.IsApplicable(context.get(), transformation_context)); + replacement_4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %23 with %101[1] in '%50 = OpCopyObject %16 %23' @@ -199,16 +221,19 @@ MakeInstructionDescriptor(50, SpvOpCopyObject, 0); auto good_extract_5 = TransformationCompositeExtract(instruction_descriptor_5, 106, 101, {1}); - ASSERT_TRUE(good_extract_5.IsApplicable(context.get(), fact_manager)); - good_extract_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_5.IsApplicable(context.get(), transformation_context)); + good_extract_5.Apply(context.get(), &transformation_context); auto replacement_5 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(23, instruction_descriptor_5, 0), 106); // Bad: wrong synonym fact being used auto bad_replacement_5 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(23, instruction_descriptor_5, 0), 105); - ASSERT_TRUE(replacement_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_replacement_5.IsApplicable(context.get(), fact_manager)); - replacement_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_replacement_5.IsApplicable(context.get(), transformation_context)); + replacement_5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %32 with %101[2] in 'OpStore %33 %32' @@ -218,13 +243,16 @@ // Bad: id 1001 does not exist auto bad_extract_6 = TransformationCompositeExtract(instruction_descriptor_6, 107, 1001, {2}); - ASSERT_TRUE(good_extract_6.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_extract_6.IsApplicable(context.get(), fact_manager)); - good_extract_6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_6.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_extract_6.IsApplicable(context.get(), transformation_context)); + good_extract_6.Apply(context.get(), &transformation_context); auto replacement_6 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(32, instruction_descriptor_6, 1), 107); - ASSERT_TRUE(replacement_6.IsApplicable(context.get(), fact_manager)); - replacement_6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_6.IsApplicable(context.get(), transformation_context)); + replacement_6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %23 with %101[3] in '%51 = OpCopyObject %16 %23' @@ -232,16 +260,19 @@ MakeInstructionDescriptor(51, SpvOpCopyObject, 0); auto good_extract_7 = TransformationCompositeExtract(instruction_descriptor_7, 108, 101, {3}); - ASSERT_TRUE(good_extract_7.IsApplicable(context.get(), fact_manager)); - good_extract_7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + good_extract_7.IsApplicable(context.get(), transformation_context)); + good_extract_7.Apply(context.get(), &transformation_context); auto replacement_7 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(23, instruction_descriptor_7, 0), 108); // Bad: use id 0 is invalid auto bad_replacement_7 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(0, instruction_descriptor_7, 0), 108); - ASSERT_TRUE(replacement_7.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(bad_replacement_7.IsApplicable(context.get(), fact_manager)); - replacement_7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_7.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + bad_replacement_7.IsApplicable(context.get(), transformation_context)); + replacement_7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -380,32 +411,41 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFact(MakeSynonymFact(23, {}, 100, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(25, {}, 100, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(50, {}, 100, {2}), context.get()); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(23, {}, 100, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(25, {}, 100, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(50, {}, 100, {2}), context.get()); // Replace %23 with %100[0] in '%26 = OpFAdd %7 %23 %25' auto instruction_descriptor_1 = MakeInstructionDescriptor(26, SpvOpFAdd, 0); auto extract_1 = TransformationCompositeExtract(instruction_descriptor_1, 101, 100, {0}); - ASSERT_TRUE(extract_1.IsApplicable(context.get(), fact_manager)); - extract_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_1.IsApplicable(context.get(), transformation_context)); + extract_1.Apply(context.get(), &transformation_context); auto replacement_1 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(23, instruction_descriptor_1, 0), 101); - ASSERT_TRUE(replacement_1.IsApplicable(context.get(), fact_manager)); - replacement_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_1.IsApplicable(context.get(), transformation_context)); + replacement_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %25 with %100[1] in '%26 = OpFAdd %7 %23 %25' auto instruction_descriptor_2 = MakeInstructionDescriptor(26, SpvOpFAdd, 0); auto extract_2 = TransformationCompositeExtract(instruction_descriptor_2, 102, 100, {1}); - ASSERT_TRUE(extract_2.IsApplicable(context.get(), fact_manager)); - extract_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_2.IsApplicable(context.get(), transformation_context)); + extract_2.Apply(context.get(), &transformation_context); auto replacement_2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(25, instruction_descriptor_2, 1), 102); - ASSERT_TRUE(replacement_2.IsApplicable(context.get(), fact_manager)); - replacement_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_2.IsApplicable(context.get(), transformation_context)); + replacement_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -541,26 +581,37 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFact(MakeSynonymFact(16, {}, 100, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(45, {}, 100, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(27, {}, 101, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(36, {}, 101, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(27, {}, 101, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(22, {}, 102, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(15, {}, 102, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(16, {}, 100, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(45, {}, 100, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(27, {}, 101, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(36, {}, 101, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(27, {}, 101, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(22, {}, 102, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(15, {}, 102, {1}), context.get()); // Replace %45 with %100[1] in '%46 = OpCompositeConstruct %32 %35 %45' auto instruction_descriptor_1 = MakeInstructionDescriptor(46, SpvOpCompositeConstruct, 0); auto extract_1 = TransformationCompositeExtract(instruction_descriptor_1, 201, 100, {1}); - ASSERT_TRUE(extract_1.IsApplicable(context.get(), fact_manager)); - extract_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_1.IsApplicable(context.get(), transformation_context)); + extract_1.Apply(context.get(), &transformation_context); auto replacement_1 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(45, instruction_descriptor_1, 1), 201); - ASSERT_TRUE(replacement_1.IsApplicable(context.get(), fact_manager)); - replacement_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_1.IsApplicable(context.get(), transformation_context)); + replacement_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace second occurrence of %27 with %101[0] in '%28 = @@ -569,12 +620,13 @@ MakeInstructionDescriptor(28, SpvOpCompositeConstruct, 0); auto extract_2 = TransformationCompositeExtract(instruction_descriptor_2, 202, 101, {0}); - ASSERT_TRUE(extract_2.IsApplicable(context.get(), fact_manager)); - extract_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_2.IsApplicable(context.get(), transformation_context)); + extract_2.Apply(context.get(), &transformation_context); auto replacement_2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(27, instruction_descriptor_2, 1), 202); - ASSERT_TRUE(replacement_2.IsApplicable(context.get(), fact_manager)); - replacement_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_2.IsApplicable(context.get(), transformation_context)); + replacement_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %36 with %101[1] in '%45 = OpCompositeConstruct %31 %36 %41 %44' @@ -582,12 +634,13 @@ MakeInstructionDescriptor(45, SpvOpCompositeConstruct, 0); auto extract_3 = TransformationCompositeExtract(instruction_descriptor_3, 203, 101, {1}); - ASSERT_TRUE(extract_3.IsApplicable(context.get(), fact_manager)); - extract_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_3.IsApplicable(context.get(), transformation_context)); + extract_3.Apply(context.get(), &transformation_context); auto replacement_3 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(36, instruction_descriptor_3, 0), 203); - ASSERT_TRUE(replacement_3.IsApplicable(context.get(), fact_manager)); - replacement_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_3.IsApplicable(context.get(), transformation_context)); + replacement_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace first occurrence of %27 with %101[2] in '%28 = OpCompositeConstruct @@ -596,24 +649,26 @@ MakeInstructionDescriptor(28, SpvOpCompositeConstruct, 0); auto extract_4 = TransformationCompositeExtract(instruction_descriptor_4, 204, 101, {2}); - ASSERT_TRUE(extract_4.IsApplicable(context.get(), fact_manager)); - extract_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_4.IsApplicable(context.get(), transformation_context)); + extract_4.Apply(context.get(), &transformation_context); auto replacement_4 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(27, instruction_descriptor_4, 0), 204); - ASSERT_TRUE(replacement_4.IsApplicable(context.get(), fact_manager)); - replacement_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_4.IsApplicable(context.get(), transformation_context)); + replacement_4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %22 with %102[0] in 'OpStore %23 %22' auto instruction_descriptor_5 = MakeInstructionDescriptor(23, SpvOpStore, 0); auto extract_5 = TransformationCompositeExtract(instruction_descriptor_5, 205, 102, {0}); - ASSERT_TRUE(extract_5.IsApplicable(context.get(), fact_manager)); - extract_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_5.IsApplicable(context.get(), transformation_context)); + extract_5.Apply(context.get(), &transformation_context); auto replacement_5 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(22, instruction_descriptor_5, 1), 205); - ASSERT_TRUE(replacement_5.IsApplicable(context.get(), fact_manager)); - replacement_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_5.IsApplicable(context.get(), transformation_context)); + replacement_5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -816,38 +871,65 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFact(MakeSynonymFact(20, {0}, 100, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(20, {1}, 100, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(20, {2}, 100, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(54, {}, 100, {3}), context.get()); - fact_manager.AddFact(MakeSynonymFact(15, {0}, 101, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(15, {1}, 101, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(19, {0}, 101, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(19, {1}, 101, {3}), context.get()); - fact_manager.AddFact(MakeSynonymFact(27, {}, 102, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(15, {0}, 102, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(15, {1}, 102, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(33, {}, 103, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(47, {0}, 103, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(47, {1}, 103, {2}), context.get()); - fact_manager.AddFact(MakeSynonymFact(47, {2}, 103, {3}), context.get()); - fact_manager.AddFact(MakeSynonymFact(42, {}, 104, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(45, {}, 104, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(38, {0}, 105, {0}), context.get()); - fact_manager.AddFact(MakeSynonymFact(38, {1}, 105, {1}), context.get()); - fact_manager.AddFact(MakeSynonymFact(46, {}, 105, {2}), context.get()); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(20, {0}, 100, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(20, {1}, 100, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(20, {2}, 100, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(54, {}, 100, {3}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(15, {0}, 101, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(15, {1}, 101, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(19, {0}, 101, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(19, {1}, 101, {3}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(27, {}, 102, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(15, {0}, 102, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(15, {1}, 102, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(33, {}, 103, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(47, {0}, 103, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(47, {1}, 103, {2}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(47, {2}, 103, {3}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(42, {}, 104, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(45, {}, 104, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(38, {0}, 105, {0}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(38, {1}, 105, {1}), context.get()); + transformation_context.GetFactManager()->AddFact( + MakeSynonymFact(46, {}, 105, {2}), context.get()); // Replace %20 with %100[0:2] in '%80 = OpCopyObject %16 %20' auto instruction_descriptor_1 = MakeInstructionDescriptor(80, SpvOpCopyObject, 0); auto shuffle_1 = TransformationVectorShuffle(instruction_descriptor_1, 200, 100, 100, {0, 1, 2}); - ASSERT_TRUE(shuffle_1.IsApplicable(context.get(), fact_manager)); - shuffle_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_1.IsApplicable(context.get(), transformation_context)); + shuffle_1.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_1 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(20, instruction_descriptor_1, 0), 200); - ASSERT_TRUE(replacement_1.IsApplicable(context.get(), fact_manager)); - replacement_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_1.IsApplicable(context.get(), transformation_context)); + replacement_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %54 with %100[3] in '%56 = OpFOrdNotEqual %30 %54 %55' @@ -856,24 +938,28 @@ auto extract_2 = TransformationCompositeExtract(instruction_descriptor_2, 201, 100, {3}); - ASSERT_TRUE(extract_2.IsApplicable(context.get(), fact_manager)); - extract_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_2.IsApplicable(context.get(), transformation_context)); + extract_2.Apply(context.get(), &transformation_context); auto replacement_2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(54, instruction_descriptor_2, 0), 201); - ASSERT_TRUE(replacement_2.IsApplicable(context.get(), fact_manager)); - replacement_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_2.IsApplicable(context.get(), transformation_context)); + replacement_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %15 with %101[0:1] in 'OpStore %12 %15' auto instruction_descriptor_3 = MakeInstructionDescriptor(64, SpvOpStore, 0); auto shuffle_3 = TransformationVectorShuffle(instruction_descriptor_3, 202, 101, 101, {0, 1}); - ASSERT_TRUE(shuffle_3.IsApplicable(context.get(), fact_manager)); - shuffle_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_3.IsApplicable(context.get(), transformation_context)); + shuffle_3.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_3 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(15, instruction_descriptor_3, 1), 202); - ASSERT_TRUE(replacement_3.IsApplicable(context.get(), fact_manager)); - replacement_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_3.IsApplicable(context.get(), transformation_context)); + replacement_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %19 with %101[2:3] in '%81 = OpVectorShuffle %16 %19 %19 0 0 1' @@ -881,12 +967,15 @@ MakeInstructionDescriptor(81, SpvOpVectorShuffle, 0); auto shuffle_4 = TransformationVectorShuffle(instruction_descriptor_4, 203, 101, 101, {2, 3}); - ASSERT_TRUE(shuffle_4.IsApplicable(context.get(), fact_manager)); - shuffle_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_4.IsApplicable(context.get(), transformation_context)); + shuffle_4.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_4 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(19, instruction_descriptor_4, 0), 203); - ASSERT_TRUE(replacement_4.IsApplicable(context.get(), fact_manager)); - replacement_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_4.IsApplicable(context.get(), transformation_context)); + replacement_4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %27 with %102[0] in '%82 = OpCompositeConstruct %21 %26 %27 %28 @@ -896,12 +985,13 @@ auto extract_5 = TransformationCompositeExtract(instruction_descriptor_5, 204, 102, {0}); - ASSERT_TRUE(extract_5.IsApplicable(context.get(), fact_manager)); - extract_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_5.IsApplicable(context.get(), transformation_context)); + extract_5.Apply(context.get(), &transformation_context); auto replacement_5 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(27, instruction_descriptor_5, 1), 204); - ASSERT_TRUE(replacement_5.IsApplicable(context.get(), fact_manager)); - replacement_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_5.IsApplicable(context.get(), transformation_context)); + replacement_5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %15 with %102[1:2] in '%83 = OpCopyObject %10 %15' @@ -909,12 +999,15 @@ MakeInstructionDescriptor(83, SpvOpCopyObject, 0); auto shuffle_6 = TransformationVectorShuffle(instruction_descriptor_6, 205, 102, 102, {1, 2}); - ASSERT_TRUE(shuffle_6.IsApplicable(context.get(), fact_manager)); - shuffle_6.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_6.IsApplicable(context.get(), transformation_context)); + shuffle_6.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_6 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(15, instruction_descriptor_6, 0), 205); - ASSERT_TRUE(replacement_6.IsApplicable(context.get(), fact_manager)); - replacement_6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_6.IsApplicable(context.get(), transformation_context)); + replacement_6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %33 with %103[0] in '%86 = OpCopyObject %30 %33' @@ -922,12 +1015,13 @@ MakeInstructionDescriptor(86, SpvOpCopyObject, 0); auto extract_7 = TransformationCompositeExtract(instruction_descriptor_7, 206, 103, {0}); - ASSERT_TRUE(extract_7.IsApplicable(context.get(), fact_manager)); - extract_7.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_7.IsApplicable(context.get(), transformation_context)); + extract_7.Apply(context.get(), &transformation_context); auto replacement_7 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(33, instruction_descriptor_7, 0), 206); - ASSERT_TRUE(replacement_7.IsApplicable(context.get(), fact_manager)); - replacement_7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_7.IsApplicable(context.get(), transformation_context)); + replacement_7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %47 with %103[1:3] in '%84 = OpCopyObject %39 %47' @@ -935,12 +1029,15 @@ MakeInstructionDescriptor(84, SpvOpCopyObject, 0); auto shuffle_8 = TransformationVectorShuffle(instruction_descriptor_8, 207, 103, 103, {1, 2, 3}); - ASSERT_TRUE(shuffle_8.IsApplicable(context.get(), fact_manager)); - shuffle_8.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_8.IsApplicable(context.get(), transformation_context)); + shuffle_8.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_8 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(47, instruction_descriptor_8, 0), 207); - ASSERT_TRUE(replacement_8.IsApplicable(context.get(), fact_manager)); - replacement_8.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_8.IsApplicable(context.get(), transformation_context)); + replacement_8.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %42 with %104[0] in '%85 = OpCopyObject %30 %42' @@ -948,12 +1045,13 @@ MakeInstructionDescriptor(85, SpvOpCopyObject, 0); auto extract_9 = TransformationCompositeExtract(instruction_descriptor_9, 208, 104, {0}); - ASSERT_TRUE(extract_9.IsApplicable(context.get(), fact_manager)); - extract_9.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_9.IsApplicable(context.get(), transformation_context)); + extract_9.Apply(context.get(), &transformation_context); auto replacement_9 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(42, instruction_descriptor_9, 0), 208); - ASSERT_TRUE(replacement_9.IsApplicable(context.get(), fact_manager)); - replacement_9.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_9.IsApplicable(context.get(), transformation_context)); + replacement_9.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %45 with %104[1] in '%63 = OpLogicalOr %30 %45 %46' @@ -961,24 +1059,28 @@ MakeInstructionDescriptor(63, SpvOpLogicalOr, 0); auto extract_10 = TransformationCompositeExtract(instruction_descriptor_10, 209, 104, {1}); - ASSERT_TRUE(extract_10.IsApplicable(context.get(), fact_manager)); - extract_10.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_10.IsApplicable(context.get(), transformation_context)); + extract_10.Apply(context.get(), &transformation_context); auto replacement_10 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(45, instruction_descriptor_10, 0), 209); - ASSERT_TRUE(replacement_10.IsApplicable(context.get(), fact_manager)); - replacement_10.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_10.IsApplicable(context.get(), transformation_context)); + replacement_10.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %38 with %105[0:1] in 'OpStore %36 %38' auto instruction_descriptor_11 = MakeInstructionDescriptor(85, SpvOpStore, 0); auto shuffle_11 = TransformationVectorShuffle(instruction_descriptor_11, 210, 105, 105, {0, 1}); - ASSERT_TRUE(shuffle_11.IsApplicable(context.get(), fact_manager)); - shuffle_11.Apply(context.get(), &fact_manager); + ASSERT_TRUE(shuffle_11.IsApplicable(context.get(), transformation_context)); + shuffle_11.Apply(context.get(), &transformation_context); + fact_manager.ComputeClosureOfFacts(context.get(), 100); + auto replacement_11 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(38, instruction_descriptor_11, 1), 210); - ASSERT_TRUE(replacement_11.IsApplicable(context.get(), fact_manager)); - replacement_11.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_11.IsApplicable(context.get(), transformation_context)); + replacement_11.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %46 with %105[2] in '%62 = OpLogicalAnd %30 %45 %46' @@ -986,12 +1088,13 @@ MakeInstructionDescriptor(62, SpvOpLogicalAnd, 0); auto extract_12 = TransformationCompositeExtract(instruction_descriptor_12, 211, 105, {2}); - ASSERT_TRUE(extract_12.IsApplicable(context.get(), fact_manager)); - extract_12.Apply(context.get(), &fact_manager); + ASSERT_TRUE(extract_12.IsApplicable(context.get(), transformation_context)); + extract_12.Apply(context.get(), &transformation_context); auto replacement_12 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(46, instruction_descriptor_12, 1), 211); - ASSERT_TRUE(replacement_12.IsApplicable(context.get(), fact_manager)); - replacement_12.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_12.IsApplicable(context.get(), transformation_context)); + replacement_12.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"(
diff --git a/test/fuzz/equivalence_relation_test.cpp b/test/fuzz/equivalence_relation_test.cpp index 3f2ea58..280aa3a 100644 --- a/test/fuzz/equivalence_relation_test.cpp +++ b/test/fuzz/equivalence_relation_test.cpp
@@ -47,6 +47,10 @@ EquivalenceRelation<uint32_t, UInt32Hash, UInt32Equals> relation; ASSERT_TRUE(relation.GetAllKnownValues().empty()); + for (uint32_t element = 0; element < 100; element++) { + relation.Register(element); + } + for (uint32_t element = 2; element < 80; element += 2) { relation.MakeEquivalent(0, element); relation.MakeEquivalent(element - 1, element + 1); @@ -123,6 +127,11 @@ EquivalenceRelation<uint32_t, UInt32Hash, UInt32Equals> relation2; for (uint32_t i = 0; i < 1000; ++i) { + relation1.Register(i); + relation2.Register(i); + } + + for (uint32_t i = 0; i < 1000; ++i) { if (i >= 10) { relation1.MakeEquivalent(i, i - 10); relation2.MakeEquivalent(i, i - 10);
diff --git a/test/fuzz/fact_manager_test.cpp b/test/fuzz/fact_manager_test.cpp index 2c79f12..8b1e0c4 100644 --- a/test/fuzz/fact_manager_test.cpp +++ b/test/fuzz/fact_manager_test.cpp
@@ -738,393 +738,6 @@ uniform_buffer_element_descriptor)); } -TEST(FactManagerTest, DataSynonymFacts) { - // The SPIR-V types and constants come from the following code. The body of - // the SPIR-V function then constructs a composite that is synonymous with - // myT. - // - // #version 310 es - // - // precision highp float; - // - // struct S { - // int a; - // uvec2 b; - // }; - // - // struct T { - // bool c[5]; - // mat4x2 d; - // S e; - // }; - // - // void main() { - // T myT = T(bool[5](true, false, true, false, true), - // mat4x2(vec2(1.0, 2.0), vec2(3.0, 4.0), - // vec2(5.0, 6.0), vec2(7.0, 8.0)), - // S(10, uvec2(100u, 200u))); - // } - - std::string shader = R"( - OpCapability Shader - %1 = OpExtInstImport "GLSL.std.450" - OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %4 "main" - OpExecutionMode %4 OriginUpperLeft - OpSource ESSL 310 - OpName %4 "main" - OpName %15 "S" - OpMemberName %15 0 "a" - OpMemberName %15 1 "b" - OpName %16 "T" - OpMemberName %16 0 "c" - OpMemberName %16 1 "d" - OpMemberName %16 2 "e" - OpName %18 "myT" - OpMemberDecorate %15 0 RelaxedPrecision - OpMemberDecorate %15 1 RelaxedPrecision - %2 = OpTypeVoid - %3 = OpTypeFunction %2 - %6 = OpTypeBool - %7 = OpTypeInt 32 0 - %8 = OpConstant %7 5 - %9 = OpTypeArray %6 %8 - %10 = OpTypeFloat 32 - %11 = OpTypeVector %10 2 - %12 = OpTypeMatrix %11 4 - %13 = OpTypeInt 32 1 - %14 = OpTypeVector %7 2 - %15 = OpTypeStruct %13 %14 - %16 = OpTypeStruct %9 %12 %15 - %17 = OpTypePointer Function %16 - %19 = OpConstantTrue %6 - %20 = OpConstantFalse %6 - %21 = OpConstantComposite %9 %19 %20 %19 %20 %19 - %22 = OpConstant %10 1 - %23 = OpConstant %10 2 - %24 = OpConstantComposite %11 %22 %23 - %25 = OpConstant %10 3 - %26 = OpConstant %10 4 - %27 = OpConstantComposite %11 %25 %26 - %28 = OpConstant %10 5 - %29 = OpConstant %10 6 - %30 = OpConstantComposite %11 %28 %29 - %31 = OpConstant %10 7 - %32 = OpConstant %10 8 - %33 = OpConstantComposite %11 %31 %32 - %34 = OpConstantComposite %12 %24 %27 %30 %33 - %35 = OpConstant %13 10 - %36 = OpConstant %7 100 - %37 = OpConstant %7 200 - %38 = OpConstantComposite %14 %36 %37 - %39 = OpConstantComposite %15 %35 %38 - %40 = OpConstantComposite %16 %21 %34 %39 - %4 = OpFunction %2 None %3 - %5 = OpLabel - %18 = OpVariable %17 Function - OpStore %18 %40 - %100 = OpCompositeConstruct %9 %19 %20 %19 %20 %19 - %101 = OpCompositeConstruct %11 %22 %23 - %102 = OpCompositeConstruct %11 %25 %26 - %103 = OpCompositeConstruct %11 %28 %29 - %104 = OpCompositeConstruct %11 %31 %32 - %105 = OpCompositeConstruct %12 %101 %102 %103 %104 - %106 = OpCompositeConstruct %14 %36 %37 - %107 = OpCompositeConstruct %15 %35 %106 - %108 = OpCompositeConstruct %16 %100 %105 %107 - OpReturn - OpFunctionEnd - )"; - - const auto env = SPV_ENV_UNIVERSAL_1_3; - const auto consumer = nullptr; - const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); - ASSERT_TRUE(IsValid(env, context.get())); - - FactManager fact_manager; - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(24, {}), MakeDataDescriptor(101, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), - MakeDataDescriptor(101, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {1}), - MakeDataDescriptor(101, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), - MakeDataDescriptor(101, {1}), - context.get())); - - fact_manager.AddFactDataSynonym(MakeDataDescriptor(24, {}), - MakeDataDescriptor(101, {}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(24, {}), MakeDataDescriptor(101, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), - MakeDataDescriptor(101, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {1}), - MakeDataDescriptor(101, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), - MakeDataDescriptor(101, {1}), - context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(27, {}), MakeDataDescriptor(102, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), - MakeDataDescriptor(102, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), - MakeDataDescriptor(102, {1}), - context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(27, {0}), - MakeDataDescriptor(102, {0}), context.get()); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(27, {}), MakeDataDescriptor(102, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), - MakeDataDescriptor(102, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), - MakeDataDescriptor(102, {1}), - context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(27, {1}), - MakeDataDescriptor(102, {1}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(27, {}), MakeDataDescriptor(102, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), - MakeDataDescriptor(102, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), - MakeDataDescriptor(102, {1}), - context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(30, {}), MakeDataDescriptor(103, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {0}), - MakeDataDescriptor(103, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {1}), - MakeDataDescriptor(103, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(33, {}), MakeDataDescriptor(104, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), - MakeDataDescriptor(104, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {1}), - MakeDataDescriptor(104, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(34, {}), MakeDataDescriptor(105, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {0}), - MakeDataDescriptor(105, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {1}), - MakeDataDescriptor(105, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {2}), - MakeDataDescriptor(105, {2}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), - MakeDataDescriptor(105, {3}), - context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(30, {}), - MakeDataDescriptor(103, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(33, {}), - MakeDataDescriptor(104, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {0}), - MakeDataDescriptor(105, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {1}), - MakeDataDescriptor(105, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {2}), - MakeDataDescriptor(105, {2}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(30, {}), MakeDataDescriptor(103, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {0}), - MakeDataDescriptor(103, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {1}), - MakeDataDescriptor(103, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(33, {}), MakeDataDescriptor(104, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), - MakeDataDescriptor(104, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {1}), - MakeDataDescriptor(104, {1}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(34, {}), MakeDataDescriptor(105, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {0}), - MakeDataDescriptor(105, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {1}), - MakeDataDescriptor(105, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {2}), - MakeDataDescriptor(105, {2}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), - MakeDataDescriptor(105, {3}), - context.get())); - - fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {3}), - MakeDataDescriptor(105, {3}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), - MakeDataDescriptor(104, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), - MakeDataDescriptor(105, {3}), - context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(21, {}), MakeDataDescriptor(100, {}), context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {0}), - MakeDataDescriptor(100, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {1}), - MakeDataDescriptor(100, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {2}), - MakeDataDescriptor(100, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {3}), - MakeDataDescriptor(100, {3}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {4}), - MakeDataDescriptor(100, {4}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(21, {}), MakeDataDescriptor(100, {}), context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(39, {0}), - MakeDataDescriptor(107, {0}), - context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(35, {}), MakeDataDescriptor(39, {0}), context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(39, {0}), - MakeDataDescriptor(35, {}), context.get()); - ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(39, {0}), - MakeDataDescriptor(107, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(35, {}), MakeDataDescriptor(39, {0}), context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {0}), MakeDataDescriptor(36, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {1}), MakeDataDescriptor(37, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(106, {0}), MakeDataDescriptor(36, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(106, {1}), MakeDataDescriptor(37, {}), context.get())); - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {}), MakeDataDescriptor(106, {}), context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(38, {0}), - MakeDataDescriptor(36, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(106, {0}), - MakeDataDescriptor(36, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(38, {1}), - MakeDataDescriptor(37, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(106, {1}), - MakeDataDescriptor(37, {}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {0}), MakeDataDescriptor(36, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {1}), MakeDataDescriptor(37, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(106, {0}), MakeDataDescriptor(36, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(106, {1}), MakeDataDescriptor(37, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {}), MakeDataDescriptor(106, {}), context.get())); - - ASSERT_FALSE(fact_manager.IsSynonymous( - MakeDataDescriptor(40, {}), MakeDataDescriptor(108, {}), context.get())); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(107, {0}), - MakeDataDescriptor(35, {}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {0}), - MakeDataDescriptor(108, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {1}), - MakeDataDescriptor(108, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {2}), - MakeDataDescriptor(108, {2}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(40, {}), MakeDataDescriptor(108, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0}), - MakeDataDescriptor(108, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1}), - MakeDataDescriptor(108, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2}), - MakeDataDescriptor(108, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 0}), - MakeDataDescriptor(108, {0, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 1}), - MakeDataDescriptor(108, {0, 1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 2}), - MakeDataDescriptor(108, {0, 2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 3}), - MakeDataDescriptor(108, {0, 3}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 4}), - MakeDataDescriptor(108, {0, 4}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0}), - MakeDataDescriptor(108, {1, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1}), - MakeDataDescriptor(108, {1, 1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2}), - MakeDataDescriptor(108, {1, 2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3}), - MakeDataDescriptor(108, {1, 3}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0, 0}), - MakeDataDescriptor(108, {1, 0, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1, 0}), - MakeDataDescriptor(108, {1, 1, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2, 0}), - MakeDataDescriptor(108, {1, 2, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3, 0}), - MakeDataDescriptor(108, {1, 3, 0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0, 1}), - MakeDataDescriptor(108, {1, 0, 1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1, 1}), - MakeDataDescriptor(108, {1, 1, 1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2, 1}), - MakeDataDescriptor(108, {1, 2, 1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3, 1}), - MakeDataDescriptor(108, {1, 3, 1}), - context.get())); - - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 0}), - MakeDataDescriptor(108, {2, 0}), - context.get())); - - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1}), - MakeDataDescriptor(108, {2, 1}), - context.get())); - - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1, 0}), - MakeDataDescriptor(108, {2, 1, 0}), - context.get())); - - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1, 1}), - MakeDataDescriptor(108, {2, 1, 1}), - context.get())); -} - TEST(FactManagerTest, RecursiveAdditionOfFacts) { std::string shader = R"( OpCapability Shader @@ -1157,20 +770,16 @@ fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), MakeDataDescriptor(11, {2}), context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(10, {}), MakeDataDescriptor(11, {2}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {}), + MakeDataDescriptor(11, {2}))); ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {0}), - MakeDataDescriptor(11, {2, 0}), - context.get())); + MakeDataDescriptor(11, {2, 0}))); ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {1}), - MakeDataDescriptor(11, {2, 1}), - context.get())); + MakeDataDescriptor(11, {2, 1}))); ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {2}), - MakeDataDescriptor(11, {2, 2}), - context.get())); + MakeDataDescriptor(11, {2, 2}))); ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {3}), - MakeDataDescriptor(11, {2, 3}), - context.get())); + MakeDataDescriptor(11, {2, 3}))); } TEST(FactManagerTest, LogicalNotEquationFacts) { @@ -1209,14 +818,14 @@ fact_manager.AddFactIdEquation(14, SpvOpLogicalNot, {7}, context.get()); fact_manager.AddFactIdEquation(17, SpvOpLogicalNot, {16}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(15, {}), MakeDataDescriptor(7, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(7, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(15, {}), MakeDataDescriptor(17, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(16, {}), MakeDataDescriptor(14, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(15, {}), + MakeDataDescriptor(7, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(17, {}), + MakeDataDescriptor(7, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(15, {}), + MakeDataDescriptor(17, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(16, {}), + MakeDataDescriptor(14, {}))); } TEST(FactManagerTest, SignedNegateEquationFacts) { @@ -1249,8 +858,8 @@ fact_manager.AddFactIdEquation(14, SpvOpSNegate, {7}, context.get()); fact_manager.AddFactIdEquation(15, SpvOpSNegate, {14}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(7, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(7, {}), + MakeDataDescriptor(15, {}))); } TEST(FactManagerTest, AddSubNegateFacts1) { @@ -1302,12 +911,12 @@ MakeDataDescriptor(22, {}), context.get()); fact_manager.AddFactIdEquation(24, SpvOpSNegate, {23}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(19, {}), MakeDataDescriptor(15, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(24, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(19, {}), + MakeDataDescriptor(15, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(20, {}), + MakeDataDescriptor(16, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {}), + MakeDataDescriptor(15, {}))); } TEST(FactManagerTest, AddSubNegateFacts2) { @@ -1347,30 +956,158 @@ fact_manager.AddFactIdEquation(14, SpvOpISub, {15, 16}, context.get()); fact_manager.AddFactIdEquation(17, SpvOpIAdd, {14, 16}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(17, {}), + MakeDataDescriptor(15, {}))); fact_manager.AddFactIdEquation(18, SpvOpIAdd, {16, 14}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(18, {}), MakeDataDescriptor(15, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(18, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(18, {}), + MakeDataDescriptor(15, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(17, {}), + MakeDataDescriptor(18, {}))); fact_manager.AddFactIdEquation(19, SpvOpISub, {14, 15}, context.get()); fact_manager.AddFactIdEquation(20, SpvOpSNegate, {19}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(20, {}), + MakeDataDescriptor(16, {}))); fact_manager.AddFactIdEquation(21, SpvOpISub, {14, 19}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(21, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(21, {}), + MakeDataDescriptor(15, {}))); fact_manager.AddFactIdEquation(22, SpvOpISub, {14, 18}, context.get()); fact_manager.AddFactIdEquation(23, SpvOpSNegate, {22}, context.get()); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(23, {}), MakeDataDescriptor(16, {}), context.get())); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(23, {}), + MakeDataDescriptor(16, {}))); +} + +TEST(FactManagerTest, EquationAndEquivalenceFacts) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %12 "main" + OpExecutionMode %12 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %15 = OpConstant %6 24 + %16 = OpConstant %6 37 + %12 = OpFunction %2 None %3 + %13 = OpLabel + %14 = OpISub %6 %15 %16 + %114 = OpCopyObject %6 %14 + %17 = OpIAdd %6 %114 %16 ; ==> synonymous(%17, %15) + %18 = OpIAdd %6 %16 %114 ; ==> synonymous(%17, %18, %15) + %19 = OpISub %6 %114 %15 + %119 = OpCopyObject %6 %19 + %20 = OpSNegate %6 %119 ; ==> synonymous(%20, %16) + %21 = OpISub %6 %14 %19 ; ==> synonymous(%21, %15) + %22 = OpISub %6 %14 %18 + %220 = OpCopyObject %6 %22 + %23 = OpSNegate %6 %220 ; ==> synonymous(%23, %16) + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + fact_manager.AddFactIdEquation(14, SpvOpISub, {15, 16}, context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(114, {}), + MakeDataDescriptor(14, {}), context.get()); + fact_manager.AddFactIdEquation(17, SpvOpIAdd, {114, 16}, context.get()); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(17, {}), + MakeDataDescriptor(15, {}))); + + fact_manager.AddFactIdEquation(18, SpvOpIAdd, {16, 114}, context.get()); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(18, {}), + MakeDataDescriptor(15, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(17, {}), + MakeDataDescriptor(18, {}))); + + fact_manager.AddFactIdEquation(19, SpvOpISub, {14, 15}, context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(119, {}), + MakeDataDescriptor(19, {}), context.get()); + fact_manager.AddFactIdEquation(20, SpvOpSNegate, {119}, context.get()); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(20, {}), + MakeDataDescriptor(16, {}))); + + fact_manager.AddFactIdEquation(21, SpvOpISub, {14, 19}, context.get()); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(21, {}), + MakeDataDescriptor(15, {}))); + + fact_manager.AddFactIdEquation(22, SpvOpISub, {14, 18}, context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(22, {}), + MakeDataDescriptor(220, {}), context.get()); + fact_manager.AddFactIdEquation(23, SpvOpSNegate, {220}, context.get()); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(23, {}), + MakeDataDescriptor(16, {}))); +} + +TEST(FactManagerTest, CheckingFactsDoesNotAddConstants) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 Block + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeStruct %6 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %12 = OpConstant %6 0 + %13 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %14 = OpAccessChain %13 %11 %12 + %15 = OpLoad %6 %14 + OpStore %8 %15 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + // 8[0] == int(1) + ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), {1}, + MakeUniformBufferElementDescriptor(0, 0, {0}))); + + // Although 8[0] has the value 1, we do not have the constant 1 in the module. + // We thus should not find any constants available from uniforms for int type. + // Furthermore, the act of looking for appropriate constants should not change + // which constants are known to the constant manager. + auto int_type = context->get_type_mgr()->GetType(6)->AsInteger(); + opt::analysis::IntConstant constant_one(int_type, {1}); + ASSERT_FALSE(context->get_constant_mgr()->FindConstant(&constant_one)); + auto available_constants = + fact_manager.GetConstantsAvailableFromUniformsForType(context.get(), 6); + ASSERT_EQ(0, available_constants.size()); + ASSERT_TRUE(IsEqual(env, shader, context.get())); + ASSERT_FALSE(context->get_constant_mgr()->FindConstant(&constant_one)); } } // namespace
diff --git a/test/fuzz/fuzzer_pass_add_useful_constructs_test.cpp b/test/fuzz/fuzzer_pass_add_useful_constructs_test.cpp deleted file mode 100644 index 89f006e..0000000 --- a/test/fuzz/fuzzer_pass_add_useful_constructs_test.cpp +++ /dev/null
@@ -1,393 +0,0 @@ -// Copyright (c) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "source/fuzz/fuzzer_pass_add_useful_constructs.h" -#include "source/fuzz/pseudo_random_generator.h" -#include "source/fuzz/uniform_buffer_element_descriptor.h" -#include "test/fuzz/fuzz_test_util.h" - -namespace spvtools { -namespace fuzz { -namespace { - -bool AddFactHelper( - FactManager* fact_manager, opt::IRContext* context, uint32_t word, - const protobufs::UniformBufferElementDescriptor& descriptor) { - protobufs::FactConstantUniform constant_uniform_fact; - constant_uniform_fact.add_constant_word(word); - *constant_uniform_fact.mutable_uniform_buffer_element_descriptor() = - descriptor; - protobufs::Fact fact; - *fact.mutable_constant_uniform_fact() = constant_uniform_fact; - return fact_manager->AddFact(fact, context); -} - -TEST(FuzzerPassAddUsefulConstructsTest, CheckBasicStuffIsAdded) { - // The SPIR-V came from the following empty GLSL shader: - // - // #version 450 - // - // void main() - // { - // } - - std::string shader = R"( - OpCapability Shader - %1 = OpExtInstImport "GLSL.std.450" - OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %4 "main" - OpExecutionMode %4 OriginUpperLeft - OpSource GLSL 450 - OpName %4 "main" - %2 = OpTypeVoid - %3 = OpTypeFunction %2 - %4 = OpFunction %2 None %3 - %5 = OpLabel - OpReturn - OpFunctionEnd - )"; - - const auto env = SPV_ENV_UNIVERSAL_1_3; - const auto consumer = nullptr; - const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); - ASSERT_TRUE(IsValid(env, context.get())); - - FactManager fact_manager; - FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); - protobufs::TransformationSequence transformation_sequence; - - FuzzerPassAddUsefulConstructs pass(context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence); - pass.Apply(); - ASSERT_TRUE(IsValid(env, context.get())); - - std::string after = R"( - OpCapability Shader - %1 = OpExtInstImport "GLSL.std.450" - OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %4 "main" - OpExecutionMode %4 OriginUpperLeft - OpSource GLSL 450 - OpName %4 "main" - %2 = OpTypeVoid - %3 = OpTypeFunction %2 - %100 = OpTypeBool - %101 = OpTypeInt 32 1 - %102 = OpTypeInt 32 0 - %103 = OpTypeFloat 32 - %104 = OpConstantTrue %100 - %105 = OpConstantFalse %100 - %106 = OpConstant %101 0 - %107 = OpConstant %101 1 - %108 = OpConstant %102 0 - %109 = OpConstant %102 1 - %110 = OpConstant %103 0 - %111 = OpConstant %103 1 - %4 = OpFunction %2 None %3 - %5 = OpLabel - OpReturn - OpFunctionEnd - )"; - ASSERT_TRUE(IsEqual(env, after, context.get())); -} - -TEST(FuzzerPassAddUsefulConstructsTest, - CheckTypesIndicesAndConstantsAddedForUniformFacts) { - // The SPIR-V came from the following GLSL shader: - // - // #version 450 - // - // struct S { - // int x; - // float y; - // int z; - // int w; - // }; - // - // uniform buf { - // S s; - // uint w[10]; - // }; - // - // void main() { - // } - - std::string shader = R"( - OpCapability Shader - %1 = OpExtInstImport "GLSL.std.450" - OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %4 "main" - OpExecutionMode %4 OriginUpperLeft - OpSource GLSL 450 - OpName %4 "main" - OpName %8 "S" - OpMemberName %8 0 "x" - OpMemberName %8 1 "y" - OpMemberName %8 2 "z" - OpMemberName %8 3 "w" - OpName %12 "buf" - OpMemberName %12 0 "s" - OpMemberName %12 1 "w" - OpName %14 "" - OpMemberDecorate %8 0 Offset 0 - OpMemberDecorate %8 1 Offset 4 - OpMemberDecorate %8 2 Offset 8 - OpMemberDecorate %8 3 Offset 12 - OpDecorate %11 ArrayStride 16 - OpMemberDecorate %12 0 Offset 0 - OpMemberDecorate %12 1 Offset 16 - OpDecorate %12 Block - OpDecorate %14 DescriptorSet 0 - OpDecorate %14 Binding 0 - %2 = OpTypeVoid - %3 = OpTypeFunction %2 - %6 = OpTypeInt 32 1 - %7 = OpTypeFloat 32 - %8 = OpTypeStruct %6 %7 %6 %6 - %9 = OpTypeInt 32 0 - %10 = OpConstant %9 10 - %11 = OpTypeArray %9 %10 - %12 = OpTypeStruct %8 %11 - %13 = OpTypePointer Uniform %12 - %14 = OpVariable %13 Uniform - %4 = OpFunction %2 None %3 - %5 = OpLabel - OpReturn - OpFunctionEnd - )"; - - const auto env = SPV_ENV_UNIVERSAL_1_3; - const auto consumer = nullptr; - const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); - ASSERT_TRUE(IsValid(env, context.get())); - - FactManager fact_manager; - FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); - protobufs::TransformationSequence transformation_sequence; - - // Add some uniform facts. - - // buf.s.x == 200 - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 200, - MakeUniformBufferElementDescriptor(0, 0, {0, 0}))); - - // buf.s.y == 0.5 - const float float_value = 0.5; - uint32_t float_value_as_uint; - memcpy(&float_value_as_uint, &float_value, sizeof(float_value)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_value_as_uint, - MakeUniformBufferElementDescriptor(0, 0, {0, 1}))); - - // buf.s.z == 300 - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 300, - MakeUniformBufferElementDescriptor(0, 0, {0, 2}))); - - // buf.s.w == 400 - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 400, - MakeUniformBufferElementDescriptor(0, 0, {0, 3}))); - - // buf.w[6] = 22 - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 22, - MakeUniformBufferElementDescriptor(0, 0, {1, 6}))); - - // buf.w[8] = 23 - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 23, - MakeUniformBufferElementDescriptor(0, 0, {1, 8}))); - - // Assert some things about the module that are not true prior to adding the - // pass - - { - // No uniform int pointer - opt::analysis::Integer temp_type_signed_int(32, true); - opt::analysis::Integer* registered_type_signed_int = - context->get_type_mgr() - ->GetRegisteredType(&temp_type_signed_int) - ->AsInteger(); - opt::analysis::Pointer type_pointer_uniform_signed_int( - registered_type_signed_int, SpvStorageClassUniform); - ASSERT_EQ(0, - context->get_type_mgr()->GetId(&type_pointer_uniform_signed_int)); - - // No uniform uint pointer - opt::analysis::Integer temp_type_unsigned_int(32, false); - opt::analysis::Integer* registered_type_unsigned_int = - context->get_type_mgr() - ->GetRegisteredType(&temp_type_unsigned_int) - ->AsInteger(); - opt::analysis::Pointer type_pointer_uniform_unsigned_int( - registered_type_unsigned_int, SpvStorageClassUniform); - ASSERT_EQ( - 0, context->get_type_mgr()->GetId(&type_pointer_uniform_unsigned_int)); - - // No uniform float pointer - opt::analysis::Float temp_type_float(32); - opt::analysis::Float* registered_type_float = - context->get_type_mgr()->GetRegisteredType(&temp_type_float)->AsFloat(); - opt::analysis::Pointer type_pointer_uniform_float(registered_type_float, - SpvStorageClassUniform); - ASSERT_EQ(0, context->get_type_mgr()->GetId(&type_pointer_uniform_float)); - - // No int constants 200, 300 nor 400 - opt::analysis::IntConstant int_constant_200(registered_type_signed_int, - {200}); - opt::analysis::IntConstant int_constant_300(registered_type_signed_int, - {300}); - opt::analysis::IntConstant int_constant_400(registered_type_signed_int, - {400}); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_200)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_300)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_400)); - - // No float constant 0.5 - opt::analysis::FloatConstant float_constant_zero_point_five( - registered_type_float, {float_value_as_uint}); - ASSERT_EQ(nullptr, context->get_constant_mgr()->FindConstant( - &float_constant_zero_point_five)); - - // No uint constant 22 - opt::analysis::IntConstant uint_constant_22(registered_type_unsigned_int, - {22}); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&uint_constant_22)); - - // No uint constant 23 - opt::analysis::IntConstant uint_constant_23(registered_type_unsigned_int, - {23}); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&uint_constant_23)); - - // No int constants 0, 1, 2, 3, 6, 8 - opt::analysis::IntConstant int_constant_0(registered_type_signed_int, {0}); - opt::analysis::IntConstant int_constant_1(registered_type_signed_int, {1}); - opt::analysis::IntConstant int_constant_2(registered_type_signed_int, {2}); - opt::analysis::IntConstant int_constant_3(registered_type_signed_int, {3}); - opt::analysis::IntConstant int_constant_6(registered_type_signed_int, {6}); - opt::analysis::IntConstant int_constant_8(registered_type_signed_int, {8}); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_0)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_1)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_2)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_3)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_6)); - ASSERT_EQ(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_8)); - } - - FuzzerPassAddUsefulConstructs pass(context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence); - pass.Apply(); - ASSERT_TRUE(IsValid(env, context.get())); - - // Now assert some things about the module that should be true following the - // pass. - - // We reconstruct all necessary types and constants to guard against the type - // and constant managers for the module having been invalidated. - - { - // Uniform int pointer now present - opt::analysis::Integer temp_type_signed_int(32, true); - opt::analysis::Integer* registered_type_signed_int = - context->get_type_mgr() - ->GetRegisteredType(&temp_type_signed_int) - ->AsInteger(); - opt::analysis::Pointer type_pointer_uniform_signed_int( - registered_type_signed_int, SpvStorageClassUniform); - ASSERT_NE(0, - context->get_type_mgr()->GetId(&type_pointer_uniform_signed_int)); - - // Uniform uint pointer now present - opt::analysis::Integer temp_type_unsigned_int(32, false); - opt::analysis::Integer* registered_type_unsigned_int = - context->get_type_mgr() - ->GetRegisteredType(&temp_type_unsigned_int) - ->AsInteger(); - opt::analysis::Pointer type_pointer_uniform_unsigned_int( - registered_type_unsigned_int, SpvStorageClassUniform); - ASSERT_NE( - 0, context->get_type_mgr()->GetId(&type_pointer_uniform_unsigned_int)); - - // Uniform float pointer now present - opt::analysis::Float temp_type_float(32); - opt::analysis::Float* registered_type_float = - context->get_type_mgr()->GetRegisteredType(&temp_type_float)->AsFloat(); - opt::analysis::Pointer type_pointer_uniform_float(registered_type_float, - SpvStorageClassUniform); - ASSERT_NE(0, context->get_type_mgr()->GetId(&type_pointer_uniform_float)); - - // int constants 200, 300, 400 now present - opt::analysis::IntConstant int_constant_200(registered_type_signed_int, - {200}); - opt::analysis::IntConstant int_constant_300(registered_type_signed_int, - {300}); - opt::analysis::IntConstant int_constant_400(registered_type_signed_int, - {400}); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_200)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_300)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_400)); - - // float constant 0.5 now present - opt::analysis::FloatConstant float_constant_zero_point_five( - registered_type_float, {float_value_as_uint}); - ASSERT_NE(nullptr, context->get_constant_mgr()->FindConstant( - &float_constant_zero_point_five)); - - // uint constant 22 now present - opt::analysis::IntConstant uint_constant_22(registered_type_unsigned_int, - {22}); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&uint_constant_22)); - - // uint constant 23 now present - opt::analysis::IntConstant uint_constant_23(registered_type_unsigned_int, - {23}); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&uint_constant_23)); - - // int constants 0, 1, 2, 3, 6, 8 now present - opt::analysis::IntConstant int_constant_0(registered_type_signed_int, {0}); - opt::analysis::IntConstant int_constant_1(registered_type_signed_int, {1}); - opt::analysis::IntConstant int_constant_2(registered_type_signed_int, {2}); - opt::analysis::IntConstant int_constant_3(registered_type_signed_int, {3}); - opt::analysis::IntConstant int_constant_6(registered_type_signed_int, {6}); - opt::analysis::IntConstant int_constant_8(registered_type_signed_int, {8}); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_0)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_1)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_2)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_3)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_6)); - ASSERT_NE(nullptr, - context->get_constant_mgr()->FindConstant(&int_constant_8)); - } -} - -} // namespace -} // namespace fuzz -} // namespace spvtools
diff --git a/test/fuzz/fuzzer_pass_construct_composites_test.cpp b/test/fuzz/fuzzer_pass_construct_composites_test.cpp new file mode 100644 index 0000000..cc21f74 --- /dev/null +++ b/test/fuzz/fuzzer_pass_construct_composites_test.cpp
@@ -0,0 +1,187 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/fuzzer_pass_construct_composites.h" +#include "source/fuzz/pseudo_random_generator.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(FuzzerPassConstructCompositesTest, IsomorphicStructs) { + // This test declares various isomorphic structs, and a struct that is made up + // of these isomorphic structs. The pass to construct composites is then + // applied several times to check that no issues arise related to using a + // value of one struct type when a value of an isomorphic struct type is + // required. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpConstant %6 0 + %8 = OpTypeStruct %6 %6 %6 + %9 = OpTypeStruct %6 %6 %6 + %10 = OpTypeStruct %6 %6 %6 + %11 = OpTypeStruct %6 %6 %6 + %12 = OpTypeStruct %6 %6 %6 + %13 = OpTypeStruct %8 %9 %10 %11 %12 + %14 = OpConstantComposite %8 %7 %7 %7 + %15 = OpConstantComposite %9 %7 %7 %7 + %16 = OpConstantComposite %10 %7 %7 %7 + %17 = OpConstantComposite %11 %7 %7 %7 + %18 = OpConstantComposite %12 %7 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + + auto prng = MakeUnique<PseudoRandomGenerator>(0); + + for (uint32_t i = 0; i < 10; i++) { + const auto context = + BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + FuzzerContext fuzzer_context(prng.get(), 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassConstructComposites fuzzer_pass( + context.get(), &transformation_context, &fuzzer_context, + &transformation_sequence); + + fuzzer_pass.Apply(); + + // We just check that the result is valid. + ASSERT_TRUE(IsValid(env, context.get())); + } +} + +TEST(FuzzerPassConstructCompositesTest, IsomorphicArrays) { + // This test declares various isomorphic arrays, and a struct that is made up + // of these isomorphic arrays. The pass to construct composites is then + // applied several times to check that no issues arise related to using a + // value of one array type when a value of an isomorphic array type is + // required. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 3 + %7 = OpConstant %6 0 + %8 = OpTypeArray %6 %51 + %9 = OpTypeArray %6 %51 + %10 = OpTypeArray %6 %51 + %11 = OpTypeArray %6 %51 + %12 = OpTypeArray %6 %51 + %13 = OpTypeStruct %8 %9 %10 %11 %12 + %14 = OpConstantComposite %8 %7 %7 %7 + %15 = OpConstantComposite %9 %7 %7 %7 + %16 = OpConstantComposite %10 %7 %7 %7 + %17 = OpConstantComposite %11 %7 %7 %7 + %18 = OpConstantComposite %12 %7 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpNop + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + + auto prng = MakeUnique<PseudoRandomGenerator>(0); + + for (uint32_t i = 0; i < 10; i++) { + const auto context = + BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + FuzzerContext fuzzer_context(prng.get(), 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassConstructComposites fuzzer_pass( + context.get(), &transformation_context, &fuzzer_context, + &transformation_sequence); + + fuzzer_pass.Apply(); + + // We just check that the result is valid. + ASSERT_TRUE(IsValid(env, context.get())); + } +} + +} // namespace +} // namespace fuzz +} // namespace spvtools
diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp index dc7ba3a..0833c1d 100644 --- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp +++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
@@ -194,14 +194,17 @@ ASSERT_TRUE(IsValid(env, donor_context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - auto prng = MakeUnique<PseudoRandomGenerator>(0); - FuzzerContext fuzzer_context(prng.get(), 100); + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); protobufs::TransformationSequence transformation_sequence; - FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence, - {}); + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); fuzzer_pass.DonateSingleModule(donor_context.get(), false); @@ -269,13 +272,17 @@ ASSERT_TRUE(IsValid(env, donor_context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); protobufs::TransformationSequence transformation_sequence; - FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence, - {}); + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); fuzzer_pass.DonateSingleModule(donor_context.get(), false); @@ -393,13 +400,17 @@ ASSERT_TRUE(IsValid(env, donor_context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); protobufs::TransformationSequence transformation_sequence; - FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence, - {}); + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); fuzzer_pass.DonateSingleModule(donor_context.get(), false); @@ -481,13 +492,17 @@ ASSERT_TRUE(IsValid(env, donor_context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); protobufs::TransformationSequence transformation_sequence; - FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence, - {}); + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); fuzzer_pass.DonateSingleModule(donor_context.get(), false); @@ -496,6 +511,993 @@ ASSERT_TRUE(IsValid(env, recipient_context.get())); } +TEST(FuzzerPassDonateModulesTest, DonateOpConstantNull) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Private %6 + %8 = OpConstantNull %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateCodeThatUsesImages) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + OpName %4 "main" + OpName %10 "mySampler" + OpName %21 "myTexture" + OpName %33 "v" + OpDecorate %10 RelaxedPrecision + OpDecorate %10 DescriptorSet 0 + OpDecorate %10 Binding 0 + OpDecorate %11 RelaxedPrecision + OpDecorate %21 RelaxedPrecision + OpDecorate %21 DescriptorSet 0 + OpDecorate %21 Binding 1 + OpDecorate %22 RelaxedPrecision + OpDecorate %34 RelaxedPrecision + OpDecorate %40 RelaxedPrecision + OpDecorate %42 RelaxedPrecision + OpDecorate %43 RelaxedPrecision + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 2D 0 0 0 1 Unknown + %8 = OpTypeSampledImage %7 + %9 = OpTypePointer UniformConstant %8 + %10 = OpVariable %9 UniformConstant + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 2 + %15 = OpTypeVector %12 2 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %20 = OpTypePointer UniformConstant %7 + %21 = OpVariable %20 UniformConstant + %23 = OpConstant %12 1 + %25 = OpConstant %17 1 + %27 = OpTypeBool + %31 = OpTypeVector %6 4 + %32 = OpTypePointer Function %31 + %35 = OpConstantComposite %15 %23 %23 + %36 = OpConstant %12 3 + %37 = OpConstant %12 4 + %38 = OpConstantComposite %15 %36 %37 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %33 = OpVariable %32 Function + %11 = OpLoad %8 %10 + %14 = OpImage %7 %11 + %16 = OpImageQuerySizeLod %15 %14 %13 + %19 = OpCompositeExtract %12 %16 0 + %22 = OpLoad %7 %21 + %24 = OpImageQuerySizeLod %15 %22 %23 + %26 = OpCompositeExtract %12 %24 1 + %28 = OpSGreaterThan %27 %19 %26 + OpSelectionMerge %30 None + OpBranchConditional %28 %29 %41 + %29 = OpLabel + %34 = OpLoad %8 %10 + %39 = OpImage %7 %34 + %40 = OpImageFetch %31 %39 %35 Lod|ConstOffset %13 %38 + OpStore %33 %40 + OpBranch %30 + %41 = OpLabel + %42 = OpLoad %7 %21 + %43 = OpImageFetch %31 %42 %35 Lod|ConstOffset %13 %38 + OpStore %33 %43 + OpBranch %30 + %30 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateCodeThatUsesSampler) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpDecorate %16 DescriptorSet 0 + OpDecorate %16 Binding 0 + OpDecorate %12 DescriptorSet 0 + OpDecorate %12 Binding 64 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %23 = OpTypeFloat 32 + %6 = OpTypeImage %23 2D 2 0 0 1 Unknown + %47 = OpTypePointer UniformConstant %6 + %12 = OpVariable %47 UniformConstant + %15 = OpTypeSampler + %55 = OpTypePointer UniformConstant %15 + %17 = OpTypeSampledImage %6 + %16 = OpVariable %55 UniformConstant + %37 = OpTypeVector %23 4 + %109 = OpConstant %23 0 + %66 = OpConstantComposite %37 %109 %109 %109 %109 + %56 = OpTypeBool + %54 = OpConstantTrue %56 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %50 + %50 = OpLabel + %51 = OpPhi %37 %66 %5 %111 %53 + OpLoopMerge %52 %53 None + OpBranchConditional %54 %53 %52 + %53 = OpLabel + %106 = OpLoad %6 %12 + %107 = OpLoad %15 %16 + %110 = OpSampledImage %17 %106 %107 + %111 = OpImageSampleImplicitLod %37 %110 %66 Bias %109 + OpBranch %50 + %52 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateCodeThatUsesImageStructField) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + OpName %4 "main" + OpName %10 "mySampler" + OpName %21 "myTexture" + OpName %33 "v" + OpDecorate %10 RelaxedPrecision + OpDecorate %10 DescriptorSet 0 + OpDecorate %10 Binding 0 + OpDecorate %11 RelaxedPrecision + OpDecorate %21 RelaxedPrecision + OpDecorate %21 DescriptorSet 0 + OpDecorate %21 Binding 1 + OpDecorate %22 RelaxedPrecision + OpDecorate %34 RelaxedPrecision + OpDecorate %40 RelaxedPrecision + OpDecorate %42 RelaxedPrecision + OpDecorate %43 RelaxedPrecision + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 2D 0 0 0 1 Unknown + %8 = OpTypeSampledImage %7 + %9 = OpTypePointer UniformConstant %8 + %10 = OpVariable %9 UniformConstant + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 2 + %15 = OpTypeVector %12 2 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %20 = OpTypePointer UniformConstant %7 + %21 = OpVariable %20 UniformConstant + %23 = OpConstant %12 1 + %25 = OpConstant %17 1 + %27 = OpTypeBool + %31 = OpTypeVector %6 4 + %32 = OpTypePointer Function %31 + %35 = OpConstantComposite %15 %23 %23 + %36 = OpConstant %12 3 + %37 = OpConstant %12 4 + %38 = OpConstantComposite %15 %36 %37 + %201 = OpTypeStruct %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %33 = OpVariable %32 Function + %11 = OpLoad %8 %10 + %14 = OpImage %7 %11 + %22 = OpLoad %7 %21 + %200 = OpCompositeConstruct %201 %14 %22 + %202 = OpCompositeExtract %7 %200 0 + %203 = OpCompositeExtract %7 %200 1 + %24 = OpImageQuerySizeLod %15 %203 %23 + %16 = OpImageQuerySizeLod %15 %202 %13 + %26 = OpCompositeExtract %12 %24 1 + %19 = OpCompositeExtract %12 %16 0 + %28 = OpSGreaterThan %27 %19 %26 + OpSelectionMerge %30 None + OpBranchConditional %28 %29 %41 + %29 = OpLabel + %34 = OpLoad %8 %10 + %39 = OpImage %7 %34 + %40 = OpImageFetch %31 %39 %35 Lod|ConstOffset %13 %38 + OpStore %33 %40 + OpBranch %30 + %41 = OpLabel + %42 = OpLoad %7 %21 + %43 = OpImageFetch %31 %42 %35 Lod|ConstOffset %13 %38 + OpStore %33 %43 + OpBranch %30 + %30 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateCodeThatUsesImageFunctionParameter) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + OpName %4 "main" + OpName %10 "mySampler" + OpName %21 "myTexture" + OpName %33 "v" + OpDecorate %10 RelaxedPrecision + OpDecorate %10 DescriptorSet 0 + OpDecorate %10 Binding 0 + OpDecorate %11 RelaxedPrecision + OpDecorate %21 RelaxedPrecision + OpDecorate %21 DescriptorSet 0 + OpDecorate %21 Binding 1 + OpDecorate %22 RelaxedPrecision + OpDecorate %34 RelaxedPrecision + OpDecorate %40 RelaxedPrecision + OpDecorate %42 RelaxedPrecision + OpDecorate %43 RelaxedPrecision + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 2D 0 0 0 1 Unknown + %8 = OpTypeSampledImage %7 + %9 = OpTypePointer UniformConstant %8 + %10 = OpVariable %9 UniformConstant + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 2 + %15 = OpTypeVector %12 2 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %20 = OpTypePointer UniformConstant %7 + %21 = OpVariable %20 UniformConstant + %23 = OpConstant %12 1 + %25 = OpConstant %17 1 + %27 = OpTypeBool + %31 = OpTypeVector %6 4 + %32 = OpTypePointer Function %31 + %35 = OpConstantComposite %15 %23 %23 + %36 = OpConstant %12 3 + %37 = OpConstant %12 4 + %38 = OpConstantComposite %15 %36 %37 + %201 = OpTypeFunction %15 %7 %12 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %33 = OpVariable %32 Function + %11 = OpLoad %8 %10 + %14 = OpImage %7 %11 + %16 = OpFunctionCall %15 %200 %14 %13 + %19 = OpCompositeExtract %12 %16 0 + %22 = OpLoad %7 %21 + %24 = OpImageQuerySizeLod %15 %22 %23 + %26 = OpCompositeExtract %12 %24 1 + %28 = OpSGreaterThan %27 %19 %26 + OpSelectionMerge %30 None + OpBranchConditional %28 %29 %41 + %29 = OpLabel + %34 = OpLoad %8 %10 + %39 = OpImage %7 %34 + %40 = OpImageFetch %31 %39 %35 Lod|ConstOffset %13 %38 + OpStore %33 %40 + OpBranch %30 + %41 = OpLabel + %42 = OpLoad %7 %21 + %43 = OpImageFetch %31 %42 %35 Lod|ConstOffset %13 %38 + OpStore %33 %43 + OpBranch %30 + %30 = OpLabel + OpReturn + OpFunctionEnd + %200 = OpFunction %15 None %201 + %202 = OpFunctionParameter %7 + %203 = OpFunctionParameter %12 + %204 = OpLabel + %205 = OpImageQuerySizeLod %15 %202 %203 + OpReturnValue %205 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateShaderWithImageStorageClass) { + std::string recipient_shader = R"( + OpCapability Shader + OpCapability ImageQuery + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpSourceExtension "GL_EXT_samplerless_texture_functions" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + OpCapability SampledBuffer + OpCapability ImageBuffer + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "MainPSPacked" + OpExecutionMode %2 OriginUpperLeft + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 128 + %49 = OpTypeInt 32 0 + %50 = OpTypeFloat 32 + %58 = OpConstant %50 1 + %66 = OpConstant %49 0 + %87 = OpTypeVector %50 2 + %88 = OpConstantComposite %87 %58 %58 + %17 = OpTypeImage %49 2D 2 0 0 2 R32ui + %118 = OpTypePointer UniformConstant %17 + %123 = OpTypeVector %49 2 + %132 = OpTypeVoid + %133 = OpTypeFunction %132 + %142 = OpTypePointer Image %49 + %18 = OpVariable %118 UniformConstant + %2 = OpFunction %132 None %133 + %153 = OpLabel + %495 = OpConvertFToU %123 %88 + %501 = OpImageTexelPointer %142 %18 %495 %66 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), true); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithRuntimeArray) { + std::string recipient_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + OpDecorate %9 ArrayStride 4 + OpMemberDecorate %10 0 Offset 0 + OpDecorate %10 BufferBlock + OpDecorate %12 DescriptorSet 0 + OpDecorate %12 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeRuntimeArray %6 + %10 = OpTypeStruct %9 + %11 = OpTypePointer Uniform %10 + %12 = OpVariable %11 Uniform + %13 = OpTypeInt 32 0 + %16 = OpConstant %6 0 + %18 = OpConstant %6 1 + %20 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %14 = OpArrayLength %13 %12 0 + %15 = OpBitcast %6 %14 + OpStore %8 %15 + %17 = OpLoad %6 %8 + %19 = OpISub %6 %17 %18 + %21 = OpAccessChain %20 %12 %16 %19 + OpStore %21 %16 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), false); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithRuntimeArrayLivesafe) { + std::string recipient_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + OpDecorate %16 ArrayStride 4 + OpMemberDecorate %17 0 Offset 0 + OpDecorate %17 BufferBlock + OpDecorate %19 DescriptorSet 0 + OpDecorate %19 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpTypeRuntimeArray %6 + %17 = OpTypeStruct %16 + %18 = OpTypePointer Uniform %17 + %19 = OpVariable %18 Uniform + %20 = OpTypeInt 32 0 + %23 = OpTypeBool + %26 = OpConstant %6 32 + %27 = OpTypePointer Uniform %6 + %30 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %21 = OpArrayLength %20 %19 0 + %22 = OpBitcast %6 %21 + %24 = OpSLessThan %23 %15 %22 + OpBranchConditional %24 %11 %12 + %11 = OpLabel + %25 = OpLoad %6 %8 + %28 = OpAccessChain %27 %19 %9 %25 + OpStore %28 %26 + OpBranch %13 + %13 = OpLabel + %29 = OpLoad %6 %8 + %31 = OpIAdd %6 %29 %30 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), true); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithWorkgroupVariables) { + std::string recipient_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Workgroup %6 + %8 = OpVariable %7 Workgroup + %9 = OpConstant %6 2 + %10 = OpVariable %7 Workgroup + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpStore %8 %9 + %11 = OpLoad %6 %8 + OpStore %10 %11 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), true); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + +TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithAtomics) { + std::string recipient_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 BufferBlock + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %9 = OpTypeStruct %6 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 0 + %14 = OpTypePointer Uniform %6 + %16 = OpConstant %6 1 + %17 = OpConstant %6 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %15 = OpAccessChain %14 %11 %13 + %18 = OpAtomicIAdd %6 %15 %16 %17 %16 + OpStore %8 %18 + %19 = OpAccessChain %14 %11 %13 + %20 = OpLoad %6 %8 + %21 = OpAtomicUMin %6 %19 %16 %17 %20 + OpStore %8 %21 + %22 = OpAccessChain %14 %11 %13 + %23 = OpLoad %6 %8 + %24 = OpAtomicUMax %6 %22 %16 %17 %23 + OpStore %8 %24 + %25 = OpAccessChain %14 %11 %13 + %26 = OpLoad %6 %8 + %27 = OpAtomicAnd %6 %25 %16 %17 %26 + OpStore %8 %27 + %28 = OpAccessChain %14 %11 %13 + %29 = OpLoad %6 %8 + %30 = OpAtomicOr %6 %28 %16 %17 %29 + OpStore %8 %30 + %31 = OpAccessChain %14 %11 %13 + %32 = OpLoad %6 %8 + %33 = OpAtomicXor %6 %31 %16 %17 %32 + OpStore %8 %33 + %34 = OpAccessChain %14 %11 %13 + %35 = OpLoad %6 %8 + %36 = OpAtomicExchange %6 %34 %16 %17 %35 + OpStore %8 %36 + %37 = OpAccessChain %14 %11 %13 + %38 = OpLoad %6 %8 + %39 = OpAtomicCompareExchange %6 %37 %16 %17 %17 %16 %38 + OpStore %8 %39 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto recipient_context = + BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = + BuildModule(env, consumer, donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + PseudoRandomGenerator prng(0); + FuzzerContext fuzzer_context(&prng, 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); + + fuzzer_pass.DonateSingleModule(donor_context.get(), true); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + TEST(FuzzerPassDonateModulesTest, Miscellaneous1) { std::string recipient_shader = R"( OpCapability Shader @@ -658,13 +1660,16 @@ ASSERT_TRUE(IsValid(env, donor_context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100); protobufs::TransformationSequence transformation_sequence; - FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, - &fuzzer_context, &transformation_sequence, - {}); + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), + &transformation_context, &fuzzer_context, + &transformation_sequence, {}); fuzzer_pass.DonateSingleModule(donor_context.get(), false);
diff --git a/test/fuzz/fuzzer_replayer_test.cpp b/test/fuzz/fuzzer_replayer_test.cpp index b91393e..1e7c643 100644 --- a/test/fuzz/fuzzer_replayer_test.cpp +++ b/test/fuzz/fuzzer_replayer_test.cpp
@@ -1553,6 +1553,47 @@ OpFunctionEnd )"; +// Some miscellaneous SPIR-V. + +const std::string kTestShader6 = R"( + OpCapability Shader + OpCapability SampledBuffer + OpCapability ImageBuffer + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %40 %41 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 450 + OpDecorate %40 DescriptorSet 0 + OpDecorate %40 Binding 69 + OpDecorate %41 DescriptorSet 0 + OpDecorate %41 Binding 1 + %54 = OpTypeFloat 32 + %76 = OpTypeVector %54 4 + %55 = OpConstant %54 0 + %56 = OpTypeVector %54 3 + %94 = OpTypeVector %54 2 + %112 = OpConstantComposite %94 %55 %55 + %57 = OpConstantComposite %56 %55 %55 %55 + %15 = OpTypeImage %54 2D 2 0 0 1 Unknown + %114 = OpTypePointer UniformConstant %15 + %38 = OpTypeSampler + %125 = OpTypePointer UniformConstant %38 + %132 = OpTypeVoid + %133 = OpTypeFunction %132 + %45 = OpTypeSampledImage %15 + %40 = OpVariable %114 UniformConstant + %41 = OpVariable %125 UniformConstant + %2 = OpFunction %132 None %133 + %164 = OpLabel + %184 = OpLoad %15 %40 + %213 = OpLoad %38 %41 + %216 = OpSampledImage %45 %184 %213 + %217 = OpImageSampleImplicitLod %76 %216 %112 Bias %55 + OpReturn + OpFunctionEnd + )"; + void AddConstantUniformFact(protobufs::FactSequence* facts, uint32_t descriptor_set, uint32_t binding, std::vector<uint32_t>&& indices, uint32_t value) { @@ -1591,7 +1632,7 @@ std::vector<fuzzerutil::ModuleSupplier> donor_suppliers; for (auto donor : {&kTestShader1, &kTestShader2, &kTestShader3, &kTestShader4, - &kTestShader5}) { + &kTestShader5, &kTestShader6}) { donor_suppliers.emplace_back([donor]() { return BuildModule(env, kConsoleMessageConsumer, *donor, kFuzzAssembleOption); @@ -1602,8 +1643,9 @@ std::vector<uint32_t> fuzzer_binary_out; protobufs::TransformationSequence fuzzer_transformation_sequence_out; - Fuzzer fuzzer(env, seed, true); - fuzzer.SetMessageConsumer(kSilentConsumer); + spvtools::ValidatorOptions validator_options; + Fuzzer fuzzer(env, seed, true, validator_options); + fuzzer.SetMessageConsumer(kConsoleMessageConsumer); auto fuzzer_result_status = fuzzer.Run(binary_in, initial_facts, donor_suppliers, &fuzzer_binary_out, &fuzzer_transformation_sequence_out); @@ -1613,8 +1655,8 @@ std::vector<uint32_t> replayer_binary_out; protobufs::TransformationSequence replayer_transformation_sequence_out; - Replayer replayer(env, false); - replayer.SetMessageConsumer(kSilentConsumer); + Replayer replayer(env, false, validator_options); + replayer.SetMessageConsumer(kConsoleMessageConsumer); auto replayer_result_status = replayer.Run( binary_in, initial_facts, fuzzer_transformation_sequence_out, &replayer_binary_out, &replayer_transformation_sequence_out); @@ -1681,6 +1723,13 @@ kNumFuzzerRuns); } +TEST(FuzzerReplayerTest, Miscellaneous6) { + // Do some fuzzer runs, starting from an initial seed of 57 (seed value chosen + // arbitrarily). + RunFuzzerAndReplayer(kTestShader6, protobufs::FactSequence(), 57, + kNumFuzzerRuns); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/fuzzer_shrinker_test.cpp b/test/fuzz/fuzzer_shrinker_test.cpp index c906a1e..24b4460 100644 --- a/test/fuzz/fuzzer_shrinker_test.cpp +++ b/test/fuzz/fuzzer_shrinker_test.cpp
@@ -979,15 +979,19 @@ // The |step_limit| parameter restricts the number of steps that the shrinker // will try; it can be set to something small for a faster (but less thorough) // test. +// +// The |validator_options| parameter provides validator options that should be +// used during shrinking. void RunAndCheckShrinker( const spv_target_env& target_env, const std::vector<uint32_t>& binary_in, const protobufs::FactSequence& initial_facts, const protobufs::TransformationSequence& transformation_sequence_in, const Shrinker::InterestingnessFunction& interestingness_function, const std::vector<uint32_t>& expected_binary_out, - uint32_t expected_transformations_out_size, uint32_t step_limit) { + uint32_t expected_transformations_out_size, uint32_t step_limit, + spv_validator_options validator_options) { // Run the shrinker. - Shrinker shrinker(target_env, step_limit, false); + Shrinker shrinker(target_env, step_limit, false, validator_options); shrinker.SetMessageConsumer(kSilentConsumer); std::vector<uint32_t> binary_out; @@ -1035,7 +1039,8 @@ // Run the fuzzer and check that it successfully yields a valid binary. std::vector<uint32_t> fuzzer_binary_out; protobufs::TransformationSequence fuzzer_transformation_sequence_out; - Fuzzer fuzzer(env, seed, true); + spvtools::ValidatorOptions validator_options; + Fuzzer fuzzer(env, seed, true, validator_options); fuzzer.SetMessageConsumer(kSilentConsumer); auto fuzzer_result_status = fuzzer.Run(binary_in, initial_facts, donor_suppliers, &fuzzer_binary_out, @@ -1048,9 +1053,10 @@ // With the AlwaysInteresting test, we should quickly shrink to the original // binary with no transformations remaining. - RunAndCheckShrinker( - env, binary_in, initial_facts, fuzzer_transformation_sequence_out, - AlwaysInteresting().AsFunction(), binary_in, 0, kReasonableStepLimit); + RunAndCheckShrinker(env, binary_in, initial_facts, + fuzzer_transformation_sequence_out, + AlwaysInteresting().AsFunction(), binary_in, 0, + kReasonableStepLimit, validator_options); // With the OnlyInterestingFirstTime test, no shrinking should be achieved. RunAndCheckShrinker( @@ -1058,14 +1064,14 @@ OnlyInterestingFirstTime().AsFunction(), fuzzer_binary_out, static_cast<uint32_t>( fuzzer_transformation_sequence_out.transformation_size()), - kReasonableStepLimit); + kReasonableStepLimit, validator_options); // The PingPong test is unpredictable; passing an empty expected binary // means that we don't check anything beyond that shrinking completes // successfully. - RunAndCheckShrinker(env, binary_in, initial_facts, - fuzzer_transformation_sequence_out, - PingPong().AsFunction(), {}, 0, kSmallStepLimit); + RunAndCheckShrinker( + env, binary_in, initial_facts, fuzzer_transformation_sequence_out, + PingPong().AsFunction(), {}, 0, kSmallStepLimit, validator_options); // The InterestingThenRandom test is unpredictable; passing an empty // expected binary means that we do not check anything about shrinking @@ -1073,7 +1079,7 @@ RunAndCheckShrinker( env, binary_in, initial_facts, fuzzer_transformation_sequence_out, InterestingThenRandom(PseudoRandomGenerator(seed)).AsFunction(), {}, 0, - kSmallStepLimit); + kSmallStepLimit, validator_options); } TEST(FuzzerShrinkerTest, Miscellaneous1) {
diff --git a/test/fuzz/transformation_access_chain_test.cpp b/test/fuzz/transformation_access_chain_test.cpp index 516d371..443c31c 100644 --- a/test/fuzz/transformation_access_chain_test.cpp +++ b/test/fuzz/transformation_access_chain_test.cpp
@@ -118,169 +118,194 @@ // Indices 0-5 are in ids 80-85 FactManager fact_manager; - fact_manager.AddFactValueOfPointeeIsIrrelevant(54); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 54); // Bad: id is not fresh ASSERT_FALSE(TransformationAccessChain( 43, 43, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id does not exist ASSERT_FALSE(TransformationAccessChain( 100, 1000, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id is not a type ASSERT_FALSE(TransformationAccessChain( 100, 5, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id is not a pointer ASSERT_FALSE(TransformationAccessChain( 100, 23, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: index id does not exist ASSERT_FALSE(TransformationAccessChain( 100, 43, {1000}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: index id is not a constant ASSERT_FALSE(TransformationAccessChain( 100, 43, {24}, MakeInstructionDescriptor(25, SpvOpIAdd, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: too many indices ASSERT_FALSE( TransformationAccessChain(100, 43, {80, 80, 80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: index id is out of bounds ASSERT_FALSE( TransformationAccessChain(100, 43, {80, 83}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to insert before variable ASSERT_FALSE(TransformationAccessChain( 100, 34, {}, MakeInstructionDescriptor(36, SpvOpVariable, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer not available ASSERT_FALSE( TransformationAccessChain( 100, 43, {80}, MakeInstructionDescriptor(21, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: instruction descriptor does not identify anything ASSERT_FALSE(TransformationAccessChain( 100, 43, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 100)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer is null ASSERT_FALSE(TransformationAccessChain( 100, 45, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer is undef ASSERT_FALSE(TransformationAccessChain( 100, 46, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer to result type does not exist ASSERT_FALSE(TransformationAccessChain( 100, 52, {0}, MakeInstructionDescriptor(24, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); { TransformationAccessChain transformation( 100, 43, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(100)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); } { TransformationAccessChain transformation( 101, 28, {81}, MakeInstructionDescriptor(42, SpvOpReturn, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(101)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(101)); } { TransformationAccessChain transformation( 102, 36, {80, 81}, MakeInstructionDescriptor(37, SpvOpStore, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(102)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(102)); } { TransformationAccessChain transformation( 103, 44, {}, MakeInstructionDescriptor(44, SpvOpStore, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(103)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(103)); } { TransformationAccessChain transformation( 104, 13, {80}, MakeInstructionDescriptor(21, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(104)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(104)); } { TransformationAccessChain transformation( 105, 34, {}, MakeInstructionDescriptor(44, SpvOpStore, 1)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(105)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(105)); } { TransformationAccessChain transformation( 106, 38, {}, MakeInstructionDescriptor(40, SpvOpFunctionCall, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(106)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(106)); } { TransformationAccessChain transformation( 107, 14, {}, MakeInstructionDescriptor(24, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(107)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(107)); } { TransformationAccessChain transformation( 108, 54, {85, 81, 81}, MakeInstructionDescriptor(24, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(108)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(108)); } { TransformationAccessChain transformation( 109, 48, {80, 80}, MakeInstructionDescriptor(24, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(109)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(109)); } std::string after_transformation = R"( @@ -401,19 +426,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); { TransformationAccessChain transformation( 100, 11, {}, MakeInstructionDescriptor(5, SpvOpReturn, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { TransformationAccessChain transformation( 101, 12, {}, MakeInstructionDescriptor(5, SpvOpReturn, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); }
diff --git a/test/fuzz/transformation_add_constant_boolean_test.cpp b/test/fuzz/transformation_add_constant_boolean_test.cpp index f51c46b..c603333 100644 --- a/test/fuzz/transformation_add_constant_boolean_test.cpp +++ b/test/fuzz/transformation_add_constant_boolean_test.cpp
@@ -43,42 +43,47 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // True and false can both be added as neither is present. ASSERT_TRUE(TransformationAddConstantBoolean(7, true).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); ASSERT_TRUE(TransformationAddConstantBoolean(7, false).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // Id 5 is already taken. ASSERT_FALSE(TransformationAddConstantBoolean(5, true).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); auto add_true = TransformationAddConstantBoolean(7, true); auto add_false = TransformationAddConstantBoolean(8, false); - ASSERT_TRUE(add_true.IsApplicable(context.get(), fact_manager)); - add_true.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_true.IsApplicable(context.get(), transformation_context)); + add_true.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Having added true, we cannot add it again with the same id. - ASSERT_FALSE(add_true.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(add_true.IsApplicable(context.get(), transformation_context)); // But we can add it with a different id. auto add_true_again = TransformationAddConstantBoolean(100, true); - ASSERT_TRUE(add_true_again.IsApplicable(context.get(), fact_manager)); - add_true_again.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_true_again.IsApplicable(context.get(), transformation_context)); + add_true_again.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_false.IsApplicable(context.get(), fact_manager)); - add_false.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_false.IsApplicable(context.get(), transformation_context)); + add_false.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Having added false, we cannot add it again with the same id. - ASSERT_FALSE(add_false.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(add_false.IsApplicable(context.get(), transformation_context)); // But we can add it with a different id. auto add_false_again = TransformationAddConstantBoolean(101, false); - ASSERT_TRUE(add_false_again.IsApplicable(context.get(), fact_manager)); - add_false_again.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_false_again.IsApplicable(context.get(), transformation_context)); + add_false_again.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -128,12 +133,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Neither true nor false can be added as OpTypeBool is not present. ASSERT_FALSE(TransformationAddConstantBoolean(6, true).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); ASSERT_FALSE(TransformationAddConstantBoolean(6, false).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_add_constant_composite_test.cpp b/test/fuzz/transformation_add_constant_composite_test.cpp index 5ce171b..021bf58 100644 --- a/test/fuzz/transformation_add_constant_composite_test.cpp +++ b/test/fuzz/transformation_add_constant_composite_test.cpp
@@ -64,19 +64,22 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Too few ids ASSERT_FALSE(TransformationAddConstantComposite(103, 8, {100, 101}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too many ids ASSERT_FALSE(TransformationAddConstantComposite(101, 7, {14, 15, 14}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Id already in use ASSERT_FALSE(TransformationAddConstantComposite(40, 7, {11, 12}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %39 is not a type ASSERT_FALSE(TransformationAddConstantComposite(100, 39, {11, 12}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationAddConstantComposite transformations[] = { // %100 = OpConstantComposite %7 %11 %12 @@ -101,8 +104,9 @@ TransformationAddConstantComposite(106, 35, {38, 39, 40})}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_constant_null_test.cpp b/test/fuzz/transformation_add_constant_null_test.cpp new file mode 100644 index 0000000..0bfee34 --- /dev/null +++ b/test/fuzz/transformation_add_constant_null_test.cpp
@@ -0,0 +1,140 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_add_constant_null.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(TransformationAddConstantNullTest, BasicTest) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeInt 32 1 + %8 = OpTypeVector %6 2 + %9 = OpTypeVector %6 3 + %10 = OpTypeVector %6 4 + %11 = OpTypeVector %7 2 + %20 = OpTypeSampler + %21 = OpTypeImage %6 2D 0 0 0 0 Rgba32f + %22 = OpTypeSampledImage %21 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + // Id already in use + ASSERT_FALSE(TransformationAddConstantNull(4, 11).IsApplicable( + context.get(), transformation_context)); + // %1 is not a type + ASSERT_FALSE(TransformationAddConstantNull(100, 1).IsApplicable( + context.get(), transformation_context)); + + // %3 is a function type + ASSERT_FALSE(TransformationAddConstantNull(100, 3).IsApplicable( + context.get(), transformation_context)); + + // %20 is a sampler type + ASSERT_FALSE(TransformationAddConstantNull(100, 20).IsApplicable( + context.get(), transformation_context)); + + // %21 is an image type + ASSERT_FALSE(TransformationAddConstantNull(100, 21).IsApplicable( + context.get(), transformation_context)); + + // %22 is a sampled image type + ASSERT_FALSE(TransformationAddConstantNull(100, 22).IsApplicable( + context.get(), transformation_context)); + + TransformationAddConstantNull transformations[] = { + // %100 = OpConstantNull %6 + TransformationAddConstantNull(100, 6), + + // %101 = OpConstantNull %7 + TransformationAddConstantNull(101, 7), + + // %102 = OpConstantNull %8 + TransformationAddConstantNull(102, 8), + + // %103 = OpConstantNull %9 + TransformationAddConstantNull(103, 9), + + // %104 = OpConstantNull %10 + TransformationAddConstantNull(104, 10), + + // %105 = OpConstantNull %11 + TransformationAddConstantNull(105, 11)}; + + for (auto& transformation : transformations) { + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); + } + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeInt 32 1 + %8 = OpTypeVector %6 2 + %9 = OpTypeVector %6 3 + %10 = OpTypeVector %6 4 + %11 = OpTypeVector %7 2 + %20 = OpTypeSampler + %21 = OpTypeImage %6 2D 0 0 0 0 Rgba32f + %22 = OpTypeSampledImage %21 + %100 = OpConstantNull %6 + %101 = OpConstantNull %7 + %102 = OpConstantNull %8 + %103 = OpConstantNull %9 + %104 = OpConstantNull %10 + %105 = OpConstantNull %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +} // namespace +} // namespace fuzz +} // namespace spvtools
diff --git a/test/fuzz/transformation_add_constant_scalar_test.cpp b/test/fuzz/transformation_add_constant_scalar_test.cpp index b156111..5124b7d 100644 --- a/test/fuzz/transformation_add_constant_scalar_test.cpp +++ b/test/fuzz/transformation_add_constant_scalar_test.cpp
@@ -62,6 +62,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const float float_values[2] = {3.0, 30.0}; uint32_t uint_for_float[2]; @@ -87,55 +90,62 @@ auto bad_type_id_is_pointer = TransformationAddConstantScalar(111, 11, {0}); // Id is already in use. - ASSERT_FALSE(bad_id_already_used.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_id_already_used.IsApplicable(context.get(), transformation_context)); // At least one word of data must be provided. - ASSERT_FALSE(bad_no_data.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_no_data.IsApplicable(context.get(), transformation_context)); // Cannot give two data words for a 32-bit type. - ASSERT_FALSE(bad_too_much_data.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_too_much_data.IsApplicable(context.get(), transformation_context)); // Type id does not exist - ASSERT_FALSE( - bad_type_id_does_not_exist.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_type_id_does_not_exist.IsApplicable(context.get(), + transformation_context)); // Type id is not a type - ASSERT_FALSE( - bad_type_id_is_not_a_type.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_type_id_is_not_a_type.IsApplicable(context.get(), + transformation_context)); // Type id is void - ASSERT_FALSE(bad_type_id_is_void.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_type_id_is_void.IsApplicable(context.get(), transformation_context)); // Type id is pointer - ASSERT_FALSE( - bad_type_id_is_pointer.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_type_id_is_pointer.IsApplicable(context.get(), + transformation_context)); - ASSERT_TRUE(add_signed_int_1.IsApplicable(context.get(), fact_manager)); - add_signed_int_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_signed_int_1.IsApplicable(context.get(), transformation_context)); + add_signed_int_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_signed_int_10.IsApplicable(context.get(), fact_manager)); - add_signed_int_10.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_signed_int_10.IsApplicable(context.get(), transformation_context)); + add_signed_int_10.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_unsigned_int_2.IsApplicable(context.get(), fact_manager)); - add_unsigned_int_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_unsigned_int_2.IsApplicable(context.get(), transformation_context)); + add_unsigned_int_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_unsigned_int_20.IsApplicable(context.get(), fact_manager)); - add_unsigned_int_20.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_unsigned_int_20.IsApplicable(context.get(), transformation_context)); + add_unsigned_int_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_float_3.IsApplicable(context.get(), fact_manager)); - add_float_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_float_3.IsApplicable(context.get(), transformation_context)); + add_float_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(add_float_30.IsApplicable(context.get(), fact_manager)); - add_float_30.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_float_30.IsApplicable(context.get(), transformation_context)); + add_float_30.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(bad_add_float_30_id_already_used.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(bad_add_float_30_id_already_used.IsApplicable( + context.get(), transformation_context)); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_add_dead_block_test.cpp b/test/fuzz/transformation_add_dead_block_test.cpp index f89140f..c9be520 100644 --- a/test/fuzz/transformation_add_dead_block_test.cpp +++ b/test/fuzz/transformation_add_dead_block_test.cpp
@@ -46,21 +46,25 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id 4 is already in use ASSERT_FALSE(TransformationAddDeadBlock(4, 5, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Id 7 is not a block ASSERT_FALSE(TransformationAddDeadBlock(100, 7, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationAddDeadBlock transformation(100, 5, true); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.BlockIsDead(100)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(100)); std::string after_transformation = R"( OpCapability Shader @@ -119,9 +123,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); ASSERT_FALSE(TransformationAddDeadBlock(100, 9, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBlockTest, TargetBlockMustNotBeLoopMergeOrContinue) { @@ -160,13 +167,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Bad because 9's successor is the loop continue target. ASSERT_FALSE(TransformationAddDeadBlock(100, 9, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad because 10's successor is the loop merge. ASSERT_FALSE(TransformationAddDeadBlock(100, 10, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBlockTest, SourceBlockMustNotBeLoopHead) { @@ -203,10 +213,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Bad because 8 is a loop head. ASSERT_FALSE(TransformationAddDeadBlock(100, 8, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBlockTest, OpPhiInTarget) { @@ -240,13 +253,17 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationAddDeadBlock transformation(100, 5, true); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.BlockIsDead(100)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(100)); std::string after_transformation = R"( OpCapability Shader @@ -309,11 +326,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // 9 is a back edge block, so it would not be OK to add a dead block here, // as then both 9 and the dead block would branch to the loop header, 8. ASSERT_FALSE(TransformationAddDeadBlock(100, 9, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_add_dead_break_test.cpp b/test/fuzz/transformation_add_dead_break_test.cpp index d60fc1f..8400b0c 100644 --- a/test/fuzz/transformation_add_dead_break_test.cpp +++ b/test/fuzz/transformation_add_dead_break_test.cpp
@@ -100,44 +100,47 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const uint32_t merge_block = 16; // These are all possibilities. ASSERT_TRUE(TransformationAddDeadBreak(15, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(15, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(21, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(21, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(22, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(22, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(19, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(19, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(23, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(23, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(24, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(24, merge_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 100 is not a block id. ASSERT_FALSE(TransformationAddDeadBreak(100, merge_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(15, 100, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 24 is not a merge block. ASSERT_FALSE(TransformationAddDeadBreak(15, 24, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // These are the transformations we will apply. auto transformation1 = TransformationAddDeadBreak(15, merge_block, true, {}); @@ -147,28 +150,34 @@ auto transformation5 = TransformationAddDeadBreak(23, merge_block, true, {}); auto transformation6 = TransformationAddDeadBreak(24, merge_block, false, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -333,6 +342,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // The header and merge blocks const uint32_t header_inner = 34; @@ -354,53 +366,53 @@ // Fine to break from a construct to its merge ASSERT_TRUE(TransformationAddDeadBreak(inner_block_1, merge_inner, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(inner_block_2, merge_inner, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(outer_block_1, merge_outer, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(outer_block_2, merge_outer, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(outer_block_3, merge_outer, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(outer_block_4, merge_outer, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(after_block_1, merge_after, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(after_block_2, merge_after, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break to the wrong merge (whether enclosing or not) ASSERT_FALSE(TransformationAddDeadBreak(inner_block_1, merge_outer, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(inner_block_2, merge_after, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(outer_block_1, merge_inner, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(outer_block_2, merge_after, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(after_block_1, merge_inner, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(after_block_2, merge_outer, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break from header (as it does not branch unconditionally) ASSERT_FALSE(TransformationAddDeadBreak(header_inner, merge_inner, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_outer, merge_outer, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_after, merge_after, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break to non-merge ASSERT_FALSE( TransformationAddDeadBreak(inner_block_1, inner_block_2, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(outer_block_2, after_block_1, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(outer_block_1, header_after, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationAddDeadBreak(inner_block_1, merge_inner, true, {}); @@ -419,36 +431,44 @@ auto transformation8 = TransformationAddDeadBreak(after_block_2, merge_after, false, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation7.IsApplicable(context.get(), fact_manager)); - transformation7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation7.IsApplicable(context.get(), transformation_context)); + transformation7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation8.IsApplicable(context.get(), fact_manager)); - transformation8.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation8.IsApplicable(context.get(), transformation_context)); + transformation8.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -685,6 +705,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // The header and merge blocks const uint32_t header_outer_if = 5; @@ -715,63 +738,63 @@ // Fine to branch straight to direct merge block for a construct ASSERT_TRUE(TransformationAddDeadBreak(then_outer_switch_block_1, merge_then_outer_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(then_inner_switch_block_1, merge_then_inner_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(then_inner_switch_block_2, merge_then_inner_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(then_inner_switch_block_3, merge_then_inner_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(else_switch_block_1, merge_else_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(else_switch_block_2, merge_else_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(else_switch_block_3, merge_else_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationAddDeadBreak(inner_if_1_block_1, merge_inner_if_1, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(inner_if_1_block_2, merge_inner_if_1, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationAddDeadBreak(inner_if_2_block_1, merge_inner_if_2, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break out of a switch from a selection construct inside the // switch. ASSERT_FALSE(TransformationAddDeadBreak(inner_if_1_block_1, merge_then_outer_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(inner_if_1_block_2, merge_then_outer_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(inner_if_2_block_1, merge_then_outer_switch, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Some miscellaneous inapplicable cases. ASSERT_FALSE( TransformationAddDeadBreak(header_outer_if, merge_outer_if, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_inner_if_1, inner_if_1_block_2, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_then_inner_switch, header_then_outer_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_else_switch, then_inner_switch_block_3, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(header_inner_if_2, header_inner_if_2, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationAddDeadBreak( then_outer_switch_block_1, merge_then_outer_switch, true, {}); @@ -794,44 +817,54 @@ auto transformation10 = TransformationAddDeadBreak( inner_if_2_block_1, merge_inner_if_2, true, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation7.IsApplicable(context.get(), fact_manager)); - transformation7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation7.IsApplicable(context.get(), transformation_context)); + transformation7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation8.IsApplicable(context.get(), fact_manager)); - transformation8.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation8.IsApplicable(context.get(), transformation_context)); + transformation8.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation9.IsApplicable(context.get(), fact_manager)); - transformation9.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation9.IsApplicable(context.get(), transformation_context)); + transformation9.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation10.IsApplicable(context.get(), fact_manager)); - transformation10.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation10.IsApplicable(context.get(), transformation_context)); + transformation10.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1094,6 +1127,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // The header and merge blocks const uint32_t header_do_while = 6; @@ -1123,75 +1159,75 @@ // Fine to break from any loop header to its merge ASSERT_TRUE( TransformationAddDeadBreak(header_do_while, merge_do_while, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(header_for_i, merge_for_i, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadBreak(header_for_j, merge_for_j, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Fine to break from any of the blocks in constructs in the "for j" loop to // that loop's merge ASSERT_TRUE( TransformationAddDeadBreak(block_in_inner_if, merge_for_j, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationAddDeadBreak(block_switch_case, merge_for_j, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationAddDeadBreak(block_switch_default, merge_for_j, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Fine to break from the body of the "for i" loop to that loop's merge ASSERT_TRUE( TransformationAddDeadBreak(block_in_for_i_loop, merge_for_i, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break from multiple loops ASSERT_FALSE( TransformationAddDeadBreak(block_in_inner_if, merge_do_while, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(block_switch_case, merge_do_while, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(block_switch_default, merge_do_while, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(header_for_j, merge_do_while, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break loop from its continue construct ASSERT_FALSE( TransformationAddDeadBreak(continue_do_while, merge_do_while, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(continue_for_j, merge_for_j, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(continue_for_i, merge_for_i, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to break out of multiple non-loop constructs if not breaking to a // loop merge ASSERT_FALSE( TransformationAddDeadBreak(block_in_inner_if, merge_if_x_eq_y, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(block_switch_case, merge_if_x_eq_y, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(block_switch_default, merge_if_x_eq_y, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Some miscellaneous inapplicable transformations ASSERT_FALSE( TransformationAddDeadBreak(header_if_x_eq_2, header_if_x_eq_y, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(merge_if_x_eq_2, merge_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(header_switch, header_switch, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationAddDeadBreak(header_do_while, merge_do_while, true, {}); @@ -1208,32 +1244,39 @@ auto transformation7 = TransformationAddDeadBreak(block_in_for_i_loop, merge_for_i, true, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation7.IsApplicable(context.get(), fact_manager)); - transformation7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation7.IsApplicable(context.get(), transformation_context)); + transformation7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1421,12 +1464,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Not OK to break loop from its continue construct ASSERT_FALSE(TransformationAddDeadBreak(13, 12, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(23, 12, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, SelectionInContinueConstruct) { @@ -1509,6 +1555,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const uint32_t loop_merge = 12; const uint32_t selection_merge = 24; @@ -1520,13 +1569,13 @@ // Not OK to jump from the selection to the loop merge, as this would break // from the loop's continue construct. ASSERT_FALSE(TransformationAddDeadBreak(in_selection_1, loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(in_selection_2, loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(in_selection_3, loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationAddDeadBreak(in_selection_4, loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // But fine to jump from the selection to its merge. @@ -1539,20 +1588,24 @@ auto transformation4 = TransformationAddDeadBreak(in_selection_4, selection_merge, true, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1720,6 +1773,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const uint32_t outer_loop_merge = 34; const uint32_t outer_loop_block = 33; @@ -1729,22 +1785,24 @@ // Some inapplicable cases ASSERT_FALSE( TransformationAddDeadBreak(inner_loop_block, outer_loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationAddDeadBreak(outer_loop_block, inner_loop_merge, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationAddDeadBreak(inner_loop_block, inner_loop_merge, true, {}); auto transformation2 = TransformationAddDeadBreak(outer_loop_block, outer_loop_merge, true, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1936,36 +1994,39 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Some inapplicable transformations // Not applicable because there is already an edge 19->20, so the OpPhis at 20 // do not need to be updated ASSERT_FALSE(TransformationAddDeadBreak(19, 20, true, {13, 21}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because two OpPhis (not zero) need to be updated at 20 ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because two OpPhis (not just one) need to be updated at 20 ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {13}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because the given ids do not have types that match the // OpPhis at 20, in order ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {21, 13}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because id 23 is a label ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {21, 23}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because 101 is not an id ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {21, 101}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because ids 51 and 47 are not available at the end of block // 23 ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {51, 47}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not applicable because OpConstantFalse is not present in the module ASSERT_FALSE(TransformationAddDeadBreak(19, 20, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationAddDeadBreak(19, 20, true, {}); auto transformation2 = TransformationAddDeadBreak(23, 20, true, {13, 21}); @@ -1973,24 +2034,29 @@ auto transformation4 = TransformationAddDeadBreak(30, 31, true, {21, 13}); auto transformation5 = TransformationAddDeadBreak(75, 31, true, {47, 51}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -2119,9 +2185,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(100, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, RespectDominanceRules2) { @@ -2172,9 +2242,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(102, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, RespectDominanceRules3) { @@ -2219,11 +2293,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto good_transformation = TransformationAddDeadBreak(100, 101, false, {11}); - ASSERT_TRUE(good_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_TRUE( + good_transformation.IsApplicable(context.get(), transformation_context)); - good_transformation.Apply(context.get(), &fact_manager); + good_transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -2307,11 +2385,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto good_transformation = TransformationAddDeadBreak(102, 101, false, {11}); - ASSERT_TRUE(good_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_TRUE( + good_transformation.IsApplicable(context.get(), transformation_context)); - good_transformation.Apply(context.get(), &fact_manager); + good_transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -2389,9 +2471,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(100, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, RespectDominanceRules6) { @@ -2446,9 +2532,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(102, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, RespectDominanceRules7) { @@ -2505,9 +2595,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(102, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, RespectDominanceRules8) { @@ -2551,9 +2645,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadBreak(102, 101, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadBreakTest, @@ -2597,12 +2695,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Bad because 14 comes before 12 in the module, and 14 has no predecessors. // This means that an edge from 12 to 14 will lead to 12 dominating 14, which // is illegal if 12 appears after 14. auto bad_transformation = TransformationAddDeadBreak(12, 14, true, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_add_dead_continue_test.cpp b/test/fuzz/transformation_add_dead_continue_test.cpp index ff93da8..07ee3b1 100644 --- a/test/fuzz/transformation_add_dead_continue_test.cpp +++ b/test/fuzz/transformation_add_dead_continue_test.cpp
@@ -97,57 +97,63 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // These are all possibilities. ASSERT_TRUE(TransformationAddDeadContinue(11, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadContinue(11, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadContinue(12, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadContinue(12, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadContinue(40, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationAddDeadContinue(40, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 100 is not a block id. ASSERT_FALSE(TransformationAddDeadContinue(100, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 10 is not in a loop. ASSERT_FALSE(TransformationAddDeadContinue(10, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 15 does not branch unconditionally to a single successor. ASSERT_FALSE(TransformationAddDeadContinue(15, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 13 is not in a loop and has no successor. ASSERT_FALSE(TransformationAddDeadContinue(13, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable: 14 is the loop continue target, so it's not OK to jump to // the loop continue from there. ASSERT_FALSE(TransformationAddDeadContinue(14, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // These are the transformations we will apply. auto transformation1 = TransformationAddDeadContinue(11, true, {}); auto transformation2 = TransformationAddDeadContinue(12, false, {}); auto transformation3 = TransformationAddDeadContinue(40, true, {}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -365,19 +371,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); std::vector<uint32_t> good = {6, 7, 18, 20, 34, 40, 45, 46, 47, 56, 57}; std::vector<uint32_t> bad = {5, 8, 9, 19, 21, 22, 33, 41, 58, 59, 60}; for (uint32_t from_block : bad) { ASSERT_FALSE(TransformationAddDeadContinue(from_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } for (uint32_t from_block : good) { const TransformationAddDeadContinue transformation(from_block, true, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } std::string after_transformation = R"( @@ -600,19 +611,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); std::vector<uint32_t> good = {32, 33, 46, 52, 101}; std::vector<uint32_t> bad = {5, 34, 36, 35, 47, 49, 48}; for (uint32_t from_block : bad) { ASSERT_FALSE(TransformationAddDeadContinue(from_block, false, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } for (uint32_t from_block : good) { const TransformationAddDeadContinue transformation(from_block, false, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } std::string after_transformation = R"( @@ -806,6 +822,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); std::vector<uint32_t> bad = {5, 19, 20, 23, 31, 32, 33, 70}; @@ -813,24 +832,28 @@ for (uint32_t from_block : bad) { ASSERT_FALSE(TransformationAddDeadContinue(from_block, true, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } auto transformation1 = TransformationAddDeadContinue(29, true, {13, 21}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); auto transformation2 = TransformationAddDeadContinue(30, true, {22, 46}); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); // 75 already has the continue block as a successor, so we should not provide // phi ids. auto transformationBad = TransformationAddDeadContinue(75, true, {27, 46}); - ASSERT_FALSE(transformationBad.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformationBad.IsApplicable(context.get(), transformation_context)); auto transformation3 = TransformationAddDeadContinue(75, true, {}); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader @@ -974,26 +997,33 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // This transformation is not applicable because the dead continue from the // loop body prevents the definition of %23 later in the loop body from // dominating its use in the loop's continue target. auto bad_transformation = TransformationAddDeadContinue(13, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); auto good_transformation_1 = TransformationAddDeadContinue(7, false, {}); - ASSERT_TRUE(good_transformation_1.IsApplicable(context.get(), fact_manager)); - good_transformation_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(good_transformation_1.IsApplicable(context.get(), + transformation_context)); + good_transformation_1.Apply(context.get(), &transformation_context); auto good_transformation_2 = TransformationAddDeadContinue(22, false, {}); - ASSERT_TRUE(good_transformation_2.IsApplicable(context.get(), fact_manager)); - good_transformation_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(good_transformation_2.IsApplicable(context.get(), + transformation_context)); + good_transformation_2.Apply(context.get(), &transformation_context); // This transformation is OK, because the definition of %21 in the loop body // is only used in an OpPhi in the loop's continue target. auto good_transformation_3 = TransformationAddDeadContinue(6, false, {11}); - ASSERT_TRUE(good_transformation_3.IsApplicable(context.get(), fact_manager)); - good_transformation_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(good_transformation_3.IsApplicable(context.get(), + transformation_context)); + good_transformation_3.Apply(context.get(), &transformation_context); std::string after_transformations = R"( OpCapability Shader @@ -1083,11 +1113,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // This transformation would shortcut the part of the loop body that defines // an id used after the loop. auto bad_transformation = TransformationAddDeadContinue(100, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, RespectDominanceRules3) { @@ -1131,11 +1165,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // This transformation would shortcut the part of the loop body that defines // an id used after the loop. auto bad_transformation = TransformationAddDeadContinue(100, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous1) { @@ -1270,11 +1308,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // This transformation would shortcut the part of the loop body that defines // an id used in the continue target. auto bad_transformation = TransformationAddDeadContinue(165, false, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous2) { @@ -1336,11 +1378,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // This transformation would introduce a branch from a continue target to // itself. auto bad_transformation = TransformationAddDeadContinue(1554, true, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous3) { @@ -1394,13 +1440,17 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadContinue(299, false, {}); // The continue edge would connect %299 to the previously-unreachable %236, // making %299 dominate %236, and breaking the rule that block ordering must // respect dominance. - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous4) { @@ -1454,13 +1504,17 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadContinue(10, false, {}); // The continue edge would connect %10 to the previously-unreachable %13, // making %10 dominate %13, and breaking the rule that block ordering must // respect dominance. - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous5) { @@ -1506,12 +1560,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadContinue(110, true, {}); // The continue edge would lead to the use of %200 in block %101 no longer // being dominated by its definition in block %111. - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddDeadContinueTest, Miscellaneous6) { @@ -1551,10 +1609,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_transformation = TransformationAddDeadContinue(10, true, {}); - ASSERT_FALSE(bad_transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_transformation.IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_add_function_test.cpp b/test/fuzz/transformation_add_function_test.cpp index aed12dc..bbd915b 100644 --- a/test/fuzz/transformation_add_function_test.cpp +++ b/test/fuzz/transformation_add_function_test.cpp
@@ -59,13 +59,14 @@ } // Returns true if and only if every pointer parameter and variable associated -// with |function_id| in |context| is known by |fact_manager| to be irrelevant, -// with the exception of |loop_limiter_id|, which must not be irrelevant. (It -// can be 0 if no loop limiter is expected, and 0 should not be deemed -// irrelevant). +// with |function_id| in |context| is known by |transformation_context| to be +// irrelevant, with the exception of |loop_limiter_id|, which must not be +// irrelevant. (It can be 0 if no loop limiter is expected, and 0 should not be +// deemed irrelevant). bool AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - opt::IRContext* context, const FactManager& fact_manager, - uint32_t function_id, uint32_t loop_limiter_id) { + opt::IRContext* context, + const TransformationContext& transformation_context, uint32_t function_id, + uint32_t loop_limiter_id) { // Look at all the functions until the function of interest is found. for (auto& function : *context->module()) { if (function.result_id() != function_id) { @@ -73,15 +74,16 @@ } // Check that the parameters are all irrelevant. bool found_non_irrelevant_parameter = false; - function.ForEachParam( - [context, &fact_manager, - &found_non_irrelevant_parameter](opt::Instruction* inst) { - if (context->get_def_use_mgr()->GetDef(inst->type_id())->opcode() == - SpvOpTypePointer && - !fact_manager.PointeeValueIsIrrelevant(inst->result_id())) { - found_non_irrelevant_parameter = true; - } - }); + function.ForEachParam([context, &transformation_context, + &found_non_irrelevant_parameter]( + opt::Instruction* inst) { + if (context->get_def_use_mgr()->GetDef(inst->type_id())->opcode() == + SpvOpTypePointer && + !transformation_context.GetFactManager()->PointeeValueIsIrrelevant( + inst->result_id())) { + found_non_irrelevant_parameter = true; + } + }); if (found_non_irrelevant_parameter) { // A non-irrelevant parameter was found. return false; @@ -96,7 +98,8 @@ // The variable should be irrelevant if and only if it is not the loop // limiter. if ((inst.result_id() == loop_limiter_id) == - fact_manager.PointeeValueIsIrrelevant(inst.result_id())) { + transformation_context.GetFactManager()->PointeeValueIsIrrelevant( + inst.result_id())) { return false; } } @@ -142,6 +145,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationAddFunction transformation1(std::vector<protobufs::Instruction>( {MakeInstructionMessage( @@ -212,8 +218,9 @@ {{SPV_OPERAND_TYPE_ID, {39}}}), MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation1 = R"( @@ -278,12 +285,12 @@ OpFunctionEnd )"; ASSERT_TRUE(IsEqual(env, after_transformation1, context.get())); - ASSERT_TRUE(fact_manager.BlockIsDead(14)); - ASSERT_TRUE(fact_manager.BlockIsDead(21)); - ASSERT_TRUE(fact_manager.BlockIsDead(22)); - ASSERT_TRUE(fact_manager.BlockIsDead(23)); - ASSERT_TRUE(fact_manager.BlockIsDead(24)); - ASSERT_TRUE(fact_manager.BlockIsDead(25)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(14)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(21)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(22)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(23)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(24)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(25)); TransformationAddFunction transformation2(std::vector<protobufs::Instruction>( {MakeInstructionMessage( @@ -332,8 +339,9 @@ MakeInstructionMessage(SpvOpReturn, 0, 0, {}), MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation2 = R"( @@ -414,7 +422,7 @@ OpFunctionEnd )"; ASSERT_TRUE(IsEqual(env, after_transformation2, context.get())); - ASSERT_TRUE(fact_manager.BlockIsDead(16)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(16)); } TEST(TransformationAddFunctionTest, InapplicableTransformations) { @@ -486,11 +494,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // No instructions ASSERT_FALSE( TransformationAddFunction(std::vector<protobufs::Instruction>({})) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No function begin ASSERT_FALSE( @@ -499,7 +510,7 @@ {MakeInstructionMessage(SpvOpFunctionParameter, 7, 11, {}), MakeInstructionMessage(SpvOpFunctionParameter, 9, 12, {}), MakeInstructionMessage(SpvOpLabel, 0, 14, {})})) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No OpLabel ASSERT_FALSE( @@ -512,7 +523,7 @@ MakeInstructionMessage(SpvOpReturnValue, 0, 0, {{SPV_OPERAND_TYPE_ID, {39}}}), MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Abrupt end of instructions ASSERT_FALSE(TransformationAddFunction( @@ -521,7 +532,7 @@ {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, {SPV_OPERAND_TYPE_ID, {10}}})})) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No function end ASSERT_FALSE( @@ -534,7 +545,7 @@ MakeInstructionMessage(SpvOpLabel, 0, 14, {}), MakeInstructionMessage(SpvOpReturnValue, 0, 0, {{SPV_OPERAND_TYPE_ID, {39}}})})) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationAddFunctionTest, LoopLimiters) { @@ -622,20 +633,27 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The added function should not be deemed livesafe. - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(30)); + ASSERT_FALSE( + transformation_context1.GetFactManager()->FunctionIsLivesafe(30)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 30, 0)); + context1.get(), transformation_context1, 30, 0)); std::string added_as_dead_code = R"( OpCapability Shader @@ -711,16 +729,16 @@ TransformationAddFunction add_livesafe_function(instructions, 100, 10, loop_limiters, 0, {}); - ASSERT_TRUE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); - add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); + add_livesafe_function.Apply(context2.get(), &transformation_context2); ASSERT_TRUE(IsValid(env, context2.get())); // The added function should indeed be deemed livesafe. - ASSERT_TRUE(fact_manager2.FunctionIsLivesafe(30)); + ASSERT_TRUE(transformation_context2.GetFactManager()->FunctionIsLivesafe(30)); // All variables/parameters in the function should be deemed irrelevant, // except the loop limiter. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context2.get(), fact_manager2, 30, 100)); + context2.get(), transformation_context2, 30, 100)); std::string added_as_livesafe_code = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -837,20 +855,27 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The added function should not be deemed livesafe. - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(10)); + ASSERT_FALSE( + transformation_context1.GetFactManager()->FunctionIsLivesafe(10)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 10, 0)); + context1.get(), transformation_context1, 10, 0)); std::string added_as_dead_code = R"( OpCapability Shader @@ -887,15 +912,15 @@ TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, {}); - ASSERT_TRUE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); - add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); + add_livesafe_function.Apply(context2.get(), &transformation_context2); ASSERT_TRUE(IsValid(env, context2.get())); // The added function should indeed be deemed livesafe. - ASSERT_TRUE(fact_manager2.FunctionIsLivesafe(10)); + ASSERT_TRUE(transformation_context2.GetFactManager()->FunctionIsLivesafe(10)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context2.get(), fact_manager2, 10, 0)); + context2.get(), transformation_context2, 10, 0)); std::string added_as_livesafe_code = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -985,20 +1010,27 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The added function should not be deemed livesafe. - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(10)); + ASSERT_FALSE( + transformation_context1.GetFactManager()->FunctionIsLivesafe(10)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 10, 0)); + context1.get(), transformation_context1, 10, 0)); std::string added_as_dead_code = R"( OpCapability Shader @@ -1036,15 +1068,15 @@ TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 13, {}); - ASSERT_TRUE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); - add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); + add_livesafe_function.Apply(context2.get(), &transformation_context2); ASSERT_TRUE(IsValid(env, context2.get())); // The added function should indeed be deemed livesafe. - ASSERT_TRUE(fact_manager2.FunctionIsLivesafe(10)); + ASSERT_TRUE(transformation_context2.GetFactManager()->FunctionIsLivesafe(10)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context2.get(), fact_manager2, 10, 0)); + context2.get(), transformation_context2, 10, 0)); std::string added_as_livesafe_code = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1265,20 +1297,27 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The function should not be deemed livesafe - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(12)); + ASSERT_FALSE( + transformation_context1.GetFactManager()->FunctionIsLivesafe(12)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 12, 0)); + context1.get(), transformation_context1, 12, 0)); std::string added_as_dead_code = R"( OpCapability Shader @@ -1409,15 +1448,15 @@ TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 13, access_chain_clamping_info); - ASSERT_TRUE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); - add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); + add_livesafe_function.Apply(context2.get(), &transformation_context2); ASSERT_TRUE(IsValid(env, context2.get())); // The function should be deemed livesafe - ASSERT_TRUE(fact_manager2.FunctionIsLivesafe(12)); + ASSERT_TRUE(transformation_context2.GetFactManager()->FunctionIsLivesafe(12)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context2.get(), fact_manager2, 12, 0)); + context2.get(), transformation_context2, 12, 0)); std::string added_as_livesafe_code = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1585,23 +1624,29 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); // Mark function 6 as livesafe. - fact_manager2.AddFactFunctionIsLivesafe(6); + transformation_context2.GetFactManager()->AddFactFunctionIsLivesafe(6); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The function should not be deemed livesafe - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(8)); + ASSERT_FALSE(transformation_context1.GetFactManager()->FunctionIsLivesafe(8)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 8, 0)); + context1.get(), transformation_context1, 8, 0)); std::string added_as_live_or_dead_code = R"( OpCapability Shader @@ -1630,15 +1675,15 @@ TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, {}); - ASSERT_TRUE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); - add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); + add_livesafe_function.Apply(context2.get(), &transformation_context2); ASSERT_TRUE(IsValid(env, context2.get())); // The function should be deemed livesafe - ASSERT_TRUE(fact_manager2.FunctionIsLivesafe(8)); + ASSERT_TRUE(transformation_context2.GetFactManager()->FunctionIsLivesafe(8)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context2.get(), fact_manager2, 8, 0)); + context2.get(), transformation_context2, 8, 0)); ASSERT_TRUE(IsEqual(env, added_as_live_or_dead_code, context2.get())); } @@ -1679,20 +1724,26 @@ FactManager fact_manager1; FactManager fact_manager2; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context1(&fact_manager1, + validator_options); + TransformationContext transformation_context2(&fact_manager2, + validator_options); const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context1.get())); TransformationAddFunction add_dead_function(instructions); - ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); - add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE( + add_dead_function.IsApplicable(context1.get(), transformation_context1)); + add_dead_function.Apply(context1.get(), &transformation_context1); ASSERT_TRUE(IsValid(env, context1.get())); // The function should not be deemed livesafe - ASSERT_FALSE(fact_manager1.FunctionIsLivesafe(8)); + ASSERT_FALSE(transformation_context1.GetFactManager()->FunctionIsLivesafe(8)); // All variables/parameters in the function should be deemed irrelevant. ASSERT_TRUE(AllVariablesAndParametersExceptLoopLimiterAreIrrelevant( - context1.get(), fact_manager1, 8, 0)); + context1.get(), transformation_context1, 8, 0)); std::string added_as_dead_code = R"( OpCapability Shader @@ -1721,8 +1772,8 @@ TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, {}); - ASSERT_FALSE( - add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); + ASSERT_FALSE(add_livesafe_function.IsApplicable(context2.get(), + transformation_context2)); } TEST(TransformationAddFunctionTest, @@ -1804,11 +1855,14 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - // Make a sequence of instruction messages corresponding to function %8 in + // Make a sequence of instruction messages corresponding to function %6 in // |donor|. std::vector<protobufs::Instruction> instructions = GetInstructionsForFunction(env, consumer, donor, 6); @@ -1821,8 +1875,9 @@ loop_limiter_info.set_logical_op_id(105); TransformationAddFunction add_livesafe_function(instructions, 100, 32, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); - add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -1958,11 +2013,14 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - // Make a sequence of instruction messages corresponding to function %8 in + // Make a sequence of instruction messages corresponding to function %6 in // |donor|. std::vector<protobufs::Instruction> instructions = GetInstructionsForFunction(env, consumer, donor, 6); @@ -1975,8 +2033,9 @@ loop_limiter_info.set_logical_op_id(105); TransformationAddFunction add_livesafe_function(instructions, 100, 32, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); - add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2110,11 +2169,14 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - // Make a sequence of instruction messages corresponding to function %8 in + // Make a sequence of instruction messages corresponding to function %6 in // |donor|. std::vector<protobufs::Instruction> instructions = GetInstructionsForFunction(env, consumer, donor, 6); @@ -2127,8 +2189,9 @@ loop_limiter_info.set_logical_op_id(105); TransformationAddFunction add_livesafe_function(instructions, 100, 32, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); - add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2254,11 +2317,14 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - // Make a sequence of instruction messages corresponding to function %8 in + // Make a sequence of instruction messages corresponding to function %6 in // |donor|. std::vector<protobufs::Instruction> instructions = GetInstructionsForFunction(env, consumer, donor, 6); @@ -2271,8 +2337,9 @@ loop_limiter_info.set_logical_op_id(105); TransformationAddFunction add_livesafe_function(instructions, 100, 32, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); - add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2408,11 +2475,14 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - // Make a sequence of instruction messages corresponding to function %8 in + // Make a sequence of instruction messages corresponding to function %6 in // |donor|. std::vector<protobufs::Instruction> instructions = GetInstructionsForFunction(env, consumer, donor, 6); @@ -2425,8 +2495,9 @@ loop_limiter_info.set_logical_op_id(105); TransformationAddFunction add_livesafe_function(instructions, 100, 32, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); - add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2570,6 +2641,9 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); @@ -2590,15 +2664,17 @@ {loop_limiter_info}, 0, {}); // The loop limiter info is not good enough; it does not include ids to patch // up the OpPhi at the loop merge. - ASSERT_FALSE(no_op_phi_data.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + no_op_phi_data.IsApplicable(context.get(), transformation_context)); // Add a phi id for the new edge from the loop back edge block to the loop // merge. loop_limiter_info.add_phi_id(28); TransformationAddFunction with_op_phi_data(instructions, 100, 28, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(with_op_phi_data.IsApplicable(context.get(), fact_manager)); - with_op_phi_data.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + with_op_phi_data.IsApplicable(context.get(), transformation_context)); + with_op_phi_data.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2758,6 +2834,9 @@ const auto consumer = nullptr; FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); @@ -2776,8 +2855,9 @@ TransformationAddFunction transformation(instructions, 100, 28, {loop_limiter_info}, 0, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string expected = R"( OpCapability Shader @@ -2845,6 +2925,120 @@ ASSERT_TRUE(IsEqual(env, expected, context.get())); } +TEST(TransformationAddFunctionTest, StaticallyOutOfBoundsArrayAccess) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypeInt 32 0 + %10 = OpConstant %9 3 + %11 = OpTypeArray %8 %10 + %12 = OpTypePointer Private %11 + %13 = OpVariable %12 Private + %14 = OpConstant %8 3 + %20 = OpConstant %8 2 + %15 = OpConstant %8 1 + %21 = OpTypeBool + %16 = OpTypePointer Private %8 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypeInt 32 0 + %10 = OpConstant %9 3 + %11 = OpTypeArray %8 %10 + %12 = OpTypePointer Private %11 + %13 = OpVariable %12 Private + %14 = OpConstant %8 3 + %15 = OpConstant %8 1 + %16 = OpTypePointer Private %8 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %17 = OpAccessChain %16 %13 %14 + OpStore %17 %15 + OpReturn + OpFunctionEnd + )"; + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %6 in + // |donor|. + std::vector<protobufs::Instruction> instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + TransformationAddFunction add_livesafe_function( + instructions, 0, 0, {}, 0, {MakeAccessClampingInfo(17, {{100, 101}})}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), + transformation_context)); + add_livesafe_function.Apply(context.get(), &transformation_context); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypeInt 32 0 + %10 = OpConstant %9 3 + %11 = OpTypeArray %8 %10 + %12 = OpTypePointer Private %11 + %13 = OpVariable %12 Private + %14 = OpConstant %8 3 + %20 = OpConstant %8 2 + %15 = OpConstant %8 1 + %21 = OpTypeBool + %16 = OpTypePointer Private %8 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpULessThanEqual %21 %14 %20 + %101 = OpSelect %8 %100 %14 %20 + %17 = OpAccessChain %16 %13 %101 + OpStore %17 %15 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_add_global_undef_test.cpp b/test/fuzz/transformation_add_global_undef_test.cpp index c14f7e9..8c06db0 100644 --- a/test/fuzz/transformation_add_global_undef_test.cpp +++ b/test/fuzz/transformation_add_global_undef_test.cpp
@@ -47,17 +47,20 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use - ASSERT_FALSE(TransformationAddGlobalUndef(4, 11).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddGlobalUndef(4, 11).IsApplicable( + context.get(), transformation_context)); // %1 is not a type - ASSERT_FALSE(TransformationAddGlobalUndef(100, 1).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddGlobalUndef(100, 1).IsApplicable( + context.get(), transformation_context)); // %3 is a function type - ASSERT_FALSE(TransformationAddGlobalUndef(100, 3).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddGlobalUndef(100, 3).IsApplicable( + context.get(), transformation_context)); TransformationAddGlobalUndef transformations[] = { // %100 = OpUndef %6 @@ -79,8 +82,9 @@ TransformationAddGlobalUndef(105, 11)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_global_variable_test.cpp b/test/fuzz/transformation_add_global_variable_test.cpp index 619f068..5c74ca0 100644 --- a/test/fuzz/transformation_add_global_variable_test.cpp +++ b/test/fuzz/transformation_add_global_variable_test.cpp
@@ -60,79 +60,105 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use - ASSERT_FALSE(TransformationAddGlobalVariable(4, 10, 0, true) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(4, 10, SpvStorageClassPrivate, 0, true) + .IsApplicable(context.get(), transformation_context)); // %1 is not a type - ASSERT_FALSE(TransformationAddGlobalVariable(100, 1, 0, false) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(100, 1, SpvStorageClassPrivate, 0, false) + .IsApplicable(context.get(), transformation_context)); // %7 is not a pointer type - ASSERT_FALSE(TransformationAddGlobalVariable(100, 7, 0, true) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(100, 7, SpvStorageClassPrivate, 0, true) + .IsApplicable(context.get(), transformation_context)); // %9 does not have Private storage class - ASSERT_FALSE(TransformationAddGlobalVariable(100, 9, 0, false) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(100, 9, SpvStorageClassPrivate, 0, false) + .IsApplicable(context.get(), transformation_context)); // %15 does not have Private storage class - ASSERT_FALSE(TransformationAddGlobalVariable(100, 15, 0, true) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(100, 15, SpvStorageClassPrivate, 0, true) + .IsApplicable(context.get(), transformation_context)); // %10 is a pointer to float, while %16 is an int constant - ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, 16, false) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, SpvStorageClassPrivate, + 16, false) + .IsApplicable(context.get(), transformation_context)); // %10 is a Private pointer to float, while %15 is a variable with type // Uniform float pointer - ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, 15, true) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(100, 10, SpvStorageClassPrivate, 15, true) + .IsApplicable(context.get(), transformation_context)); // %12 is a Private pointer to int, while %10 is a variable with type // Private float pointer - ASSERT_FALSE(TransformationAddGlobalVariable(100, 12, 10, false) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate, + 10, false) + .IsApplicable(context.get(), transformation_context)); // %10 is pointer-to-float, and %14 has type pointer-to-float; that's not OK // since the initializer's type should be the *pointee* type. - ASSERT_FALSE(TransformationAddGlobalVariable(104, 10, 14, true) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + TransformationAddGlobalVariable(104, 10, SpvStorageClassPrivate, 14, true) + .IsApplicable(context.get(), transformation_context)); // This would work in principle, but logical addressing does not allow // a pointer to a pointer. - ASSERT_FALSE(TransformationAddGlobalVariable(104, 17, 14, false) - .IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationAddGlobalVariable(104, 17, SpvStorageClassPrivate, + 14, false) + .IsApplicable(context.get(), transformation_context)); TransformationAddGlobalVariable transformations[] = { // %100 = OpVariable %12 Private - TransformationAddGlobalVariable(100, 12, 16, true), + TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate, 16, + true), // %101 = OpVariable %10 Private - TransformationAddGlobalVariable(101, 10, 40, false), + TransformationAddGlobalVariable(101, 10, SpvStorageClassPrivate, 40, + false), // %102 = OpVariable %13 Private - TransformationAddGlobalVariable(102, 13, 41, true), + TransformationAddGlobalVariable(102, 13, SpvStorageClassPrivate, 41, + true), // %103 = OpVariable %12 Private %16 - TransformationAddGlobalVariable(103, 12, 16, false), + TransformationAddGlobalVariable(103, 12, SpvStorageClassPrivate, 16, + false), // %104 = OpVariable %19 Private %21 - TransformationAddGlobalVariable(104, 19, 21, true), + TransformationAddGlobalVariable(104, 19, SpvStorageClassPrivate, 21, + true), // %105 = OpVariable %19 Private %22 - TransformationAddGlobalVariable(105, 19, 22, false)}; + TransformationAddGlobalVariable(105, 19, SpvStorageClassPrivate, 22, + false)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(100)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(102)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(104)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(101)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(103)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(105)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(102)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(104)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(101)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(103)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(105)); ASSERT_TRUE(IsValid(env, context.get())); @@ -223,24 +249,34 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationAddGlobalVariable transformations[] = { // %100 = OpVariable %12 Private - TransformationAddGlobalVariable(100, 12, 16, true), + TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate, 16, + true), // %101 = OpVariable %12 Private %16 - TransformationAddGlobalVariable(101, 12, 16, false), + TransformationAddGlobalVariable(101, 12, SpvStorageClassPrivate, 16, + false), // %102 = OpVariable %19 Private %21 - TransformationAddGlobalVariable(102, 19, 21, true)}; + TransformationAddGlobalVariable(102, 19, SpvStorageClassPrivate, 21, + true)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(100)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(102)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(101)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(102)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(101)); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -284,6 +320,85 @@ ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationAddGlobalVariableTest, TestAddingWorkgroupGlobals) { + // This checks that workgroup globals can be added to a compute shader. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Workgroup %6 + %50 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + +#ifndef NDEBUG + ASSERT_DEATH( + TransformationAddGlobalVariable(8, 7, SpvStorageClassWorkgroup, 50, true) + .IsApplicable(context.get(), transformation_context), + "By construction this transformation should not have an.*initializer " + "when Workgroup storage class is used"); +#endif + + TransformationAddGlobalVariable transformations[] = { + // %8 = OpVariable %7 Workgroup + TransformationAddGlobalVariable(8, 7, SpvStorageClassWorkgroup, 0, true), + + // %10 = OpVariable %7 Workgroup + TransformationAddGlobalVariable(10, 7, SpvStorageClassWorkgroup, 0, + false)}; + + for (auto& transformation : transformations) { + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); + } + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(8)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(10)); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" %8 %10 + OpExecutionMode %4 LocalSize 1 1 1 + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Workgroup %6 + %50 = OpConstant %6 2 + %8 = OpVariable %7 Workgroup + %10 = OpVariable %7 Workgroup + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_add_local_variable_test.cpp b/test/fuzz/transformation_add_local_variable_test.cpp index fd7047f..e989b33 100644 --- a/test/fuzz/transformation_add_local_variable_test.cpp +++ b/test/fuzz/transformation_add_local_variable_test.cpp
@@ -79,66 +79,81 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // A few cases of inapplicable transformations: // Id 4 is already in use ASSERT_FALSE(TransformationAddLocalVariable(4, 50, 4, 51, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Type mismatch between initializer and pointer ASSERT_FALSE(TransformationAddLocalVariable(105, 46, 4, 51, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Id 5 is not a function ASSERT_FALSE(TransformationAddLocalVariable(105, 50, 5, 51, true) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %105 = OpVariable %50 Function %51 { TransformationAddLocalVariable transformation(105, 50, 4, 51, true); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } // %104 = OpVariable %41 Function %46 { TransformationAddLocalVariable transformation(104, 41, 4, 46, false); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } // %103 = OpVariable %35 Function %38 { TransformationAddLocalVariable transformation(103, 35, 4, 38, true); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } // %102 = OpVariable %31 Function %33 { TransformationAddLocalVariable transformation(102, 31, 4, 33, false); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } // %101 = OpVariable %19 Function %29 { TransformationAddLocalVariable transformation(101, 19, 4, 29, true); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } // %100 = OpVariable %8 Function %12 { TransformationAddLocalVariable transformation(100, 8, 4, 12, false); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(100)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(101)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(102)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(103)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(104)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(105)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(101)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(102)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(103)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(104)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(105)); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_add_no_contraction_decoration_test.cpp b/test/fuzz/transformation_add_no_contraction_decoration_test.cpp index b1a87ea..46841a5 100644 --- a/test/fuzz/transformation_add_no_contraction_decoration_test.cpp +++ b/test/fuzz/transformation_add_no_contraction_decoration_test.cpp
@@ -94,23 +94,27 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Invalid: 200 is not an id ASSERT_FALSE(TransformationAddNoContractionDecoration(200).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // Invalid: 17 is a block id ASSERT_FALSE(TransformationAddNoContractionDecoration(17).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // Invalid: 24 is not arithmetic ASSERT_FALSE(TransformationAddNoContractionDecoration(24).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // It is valid to add NoContraction to each of these ids (and it's fine to // have duplicates of the decoration, in the case of 32). for (uint32_t result_id : {32u, 32u, 27u, 29u, 39u}) { TransformationAddNoContractionDecoration transformation(result_id); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); }
diff --git a/test/fuzz/transformation_add_type_array_test.cpp b/test/fuzz/transformation_add_type_array_test.cpp index 2bcbe73..4392f99 100644 --- a/test/fuzz/transformation_add_type_array_test.cpp +++ b/test/fuzz/transformation_add_type_array_test.cpp
@@ -54,37 +54,40 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use ASSERT_FALSE(TransformationAddTypeArray(4, 10, 16).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // %1 is not a type ASSERT_FALSE(TransformationAddTypeArray(100, 1, 16) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %3 is a function type ASSERT_FALSE(TransformationAddTypeArray(100, 3, 16) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %2 is not a constant ASSERT_FALSE(TransformationAddTypeArray(100, 11, 2) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %18 is not an integer ASSERT_FALSE(TransformationAddTypeArray(100, 11, 18) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %13 is signed 0 ASSERT_FALSE(TransformationAddTypeArray(100, 11, 13) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %14 is negative ASSERT_FALSE(TransformationAddTypeArray(100, 11, 14) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %17 is unsigned 0 ASSERT_FALSE(TransformationAddTypeArray(100, 11, 17) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationAddTypeArray transformations[] = { // %100 = OpTypeArray %10 %16 @@ -94,8 +97,9 @@ TransformationAddTypeArray(101, 7, 12)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_type_boolean_test.cpp b/test/fuzz/transformation_add_type_boolean_test.cpp index 9975953..60eabd9 100644 --- a/test/fuzz/transformation_add_type_boolean_test.cpp +++ b/test/fuzz/transformation_add_type_boolean_test.cpp
@@ -42,19 +42,23 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Not applicable because id 1 is already in use. - ASSERT_FALSE(TransformationAddTypeBoolean(1).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeBoolean(1).IsApplicable( + context.get(), transformation_context)); auto add_type_bool = TransformationAddTypeBoolean(100); - ASSERT_TRUE(add_type_bool.IsApplicable(context.get(), fact_manager)); - add_type_bool.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_type_bool.IsApplicable(context.get(), transformation_context)); + add_type_bool.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Not applicable as we already have this type now. - ASSERT_FALSE(TransformationAddTypeBoolean(101).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeBoolean(101).IsApplicable( + context.get(), transformation_context)); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_add_type_float_test.cpp b/test/fuzz/transformation_add_type_float_test.cpp index 67408da..7d17266 100644 --- a/test/fuzz/transformation_add_type_float_test.cpp +++ b/test/fuzz/transformation_add_type_float_test.cpp
@@ -42,19 +42,23 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Not applicable because id 1 is already in use. - ASSERT_FALSE(TransformationAddTypeFloat(1, 32).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeFloat(1, 32).IsApplicable( + context.get(), transformation_context)); auto add_type_float_32 = TransformationAddTypeFloat(100, 32); - ASSERT_TRUE(add_type_float_32.IsApplicable(context.get(), fact_manager)); - add_type_float_32.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + add_type_float_32.IsApplicable(context.get(), transformation_context)); + add_type_float_32.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Not applicable as we already have this type now. - ASSERT_FALSE(TransformationAddTypeFloat(101, 32).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeFloat(101, 32).IsApplicable( + context.get(), transformation_context)); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_add_type_function_test.cpp b/test/fuzz/transformation_add_type_function_test.cpp index 46bd436..1557bb8 100644 --- a/test/fuzz/transformation_add_type_function_test.cpp +++ b/test/fuzz/transformation_add_type_function_test.cpp
@@ -59,21 +59,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use ASSERT_FALSE(TransformationAddTypeFunction(4, 12, {12, 16, 14}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %1 is not a type ASSERT_FALSE(TransformationAddTypeFunction(100, 1, {12, 16, 14}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %18 is a function type ASSERT_FALSE(TransformationAddTypeFunction(100, 12, {18}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // A function of this signature already exists ASSERT_FALSE(TransformationAddTypeFunction(100, 17, {14, 16}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationAddTypeFunction transformations[] = { // %100 = OpTypeFunction %12 %12 %16 %14 @@ -86,8 +89,9 @@ TransformationAddTypeFunction(102, 17, {200, 16})}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_type_int_test.cpp b/test/fuzz/transformation_add_type_int_test.cpp index c6f884c..63b17c2 100644 --- a/test/fuzz/transformation_add_type_int_test.cpp +++ b/test/fuzz/transformation_add_type_int_test.cpp
@@ -42,10 +42,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Not applicable because id 1 is already in use. ASSERT_FALSE(TransformationAddTypeInt(1, 32, false) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto add_type_signed_int_32 = TransformationAddTypeInt(100, 32, true); auto add_type_unsigned_int_32 = TransformationAddTypeInt(101, 32, false); @@ -53,20 +56,21 @@ auto add_type_unsigned_int_32_again = TransformationAddTypeInt(103, 32, false); - ASSERT_TRUE(add_type_signed_int_32.IsApplicable(context.get(), fact_manager)); - add_type_signed_int_32.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_type_signed_int_32.IsApplicable(context.get(), + transformation_context)); + add_type_signed_int_32.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE( - add_type_unsigned_int_32.IsApplicable(context.get(), fact_manager)); - add_type_unsigned_int_32.Apply(context.get(), &fact_manager); + ASSERT_TRUE(add_type_unsigned_int_32.IsApplicable(context.get(), + transformation_context)); + add_type_unsigned_int_32.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Not applicable as we already have these types now. - ASSERT_FALSE( - add_type_signed_int_32_again.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - add_type_unsigned_int_32_again.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(add_type_signed_int_32_again.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(add_type_unsigned_int_32_again.IsApplicable( + context.get(), transformation_context)); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_add_type_matrix_test.cpp b/test/fuzz/transformation_add_type_matrix_test.cpp index 84f27e9..e925012 100644 --- a/test/fuzz/transformation_add_type_matrix_test.cpp +++ b/test/fuzz/transformation_add_type_matrix_test.cpp
@@ -47,17 +47,20 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use - ASSERT_FALSE(TransformationAddTypeMatrix(4, 9, 2).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeMatrix(4, 9, 2).IsApplicable( + context.get(), transformation_context)); // %1 is not a type ASSERT_FALSE(TransformationAddTypeMatrix(100, 1, 2).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // %11 is not a floating-point vector ASSERT_FALSE(TransformationAddTypeMatrix(100, 11, 2) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationAddTypeMatrix transformations[] = { // %100 = OpTypeMatrix %8 2 @@ -88,8 +91,9 @@ TransformationAddTypeMatrix(108, 10, 4)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_type_pointer_test.cpp b/test/fuzz/transformation_add_type_pointer_test.cpp index e36707f..35303e4 100644 --- a/test/fuzz/transformation_add_type_pointer_test.cpp +++ b/test/fuzz/transformation_add_type_pointer_test.cpp
@@ -97,6 +97,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto bad_type_id_does_not_exist = TransformationAddTypePointer(100, SpvStorageClassFunction, 101); @@ -122,12 +125,12 @@ auto good_new_private_pointer_to_uniform_pointer_to_vec2 = TransformationAddTypePointer(108, SpvStorageClassPrivate, 107); - ASSERT_FALSE( - bad_type_id_does_not_exist.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - bad_type_id_is_not_type.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - bad_result_id_is_not_fresh.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_type_id_does_not_exist.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(bad_type_id_is_not_type.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(bad_result_id_is_not_fresh.IsApplicable(context.get(), + transformation_context)); for (auto& transformation : {good_new_private_pointer_to_t, good_new_uniform_pointer_to_t, @@ -136,8 +139,9 @@ good_new_private_pointer_to_private_pointer_to_float, good_new_uniform_pointer_to_vec2, good_new_private_pointer_to_uniform_pointer_to_vec2}) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); }
diff --git a/test/fuzz/transformation_add_type_struct_test.cpp b/test/fuzz/transformation_add_type_struct_test.cpp index ae68c9a..06f78cd 100644 --- a/test/fuzz/transformation_add_type_struct_test.cpp +++ b/test/fuzz/transformation_add_type_struct_test.cpp
@@ -47,17 +47,20 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use - ASSERT_FALSE(TransformationAddTypeStruct(4, {}).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeStruct(4, {}).IsApplicable( + context.get(), transformation_context)); // %1 is not a type ASSERT_FALSE(TransformationAddTypeStruct(100, {1}).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // %3 is a function type ASSERT_FALSE(TransformationAddTypeStruct(100, {3}).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); TransformationAddTypeStruct transformations[] = { // %100 = OpTypeStruct %6 %7 %8 %9 %10 %11 @@ -73,8 +76,9 @@ TransformationAddTypeStruct(103, {6, 6})}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_add_type_vector_test.cpp b/test/fuzz/transformation_add_type_vector_test.cpp index 6ac4498..f1252a3 100644 --- a/test/fuzz/transformation_add_type_vector_test.cpp +++ b/test/fuzz/transformation_add_type_vector_test.cpp
@@ -45,13 +45,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Id already in use - ASSERT_FALSE(TransformationAddTypeVector(4, 6, 2).IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(TransformationAddTypeVector(4, 6, 2).IsApplicable( + context.get(), transformation_context)); // %1 is not a type ASSERT_FALSE(TransformationAddTypeVector(100, 1, 2).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); TransformationAddTypeVector transformations[] = { // %100 = OpTypeVector %6 2 @@ -67,8 +70,9 @@ TransformationAddTypeVector(103, 9, 2)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get()));
diff --git a/test/fuzz/transformation_adjust_branch_weights_test.cpp b/test/fuzz/transformation_adjust_branch_weights_test.cpp new file mode 100644 index 0000000..7f8ba31 --- /dev/null +++ b/test/fuzz/transformation_adjust_branch_weights_test.cpp
@@ -0,0 +1,349 @@ +// Copyright (c) 2020 André Perez Maselco +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_adjust_branch_weights.h" +#include "source/fuzz/instruction_descriptor.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(TransformationAdjustBranchWeightsTest, IsApplicableTest) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %51 %27 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %25 "buf" + OpMemberName %25 0 "value" + OpName %27 "" + OpName %51 "color" + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 0 + OpDecorate %51 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %150 = OpTypeVector %6 2 + %10 = OpConstant %6 0.300000012 + %11 = OpConstant %6 0.400000006 + %12 = OpConstant %6 0.5 + %13 = OpConstant %6 1 + %14 = OpConstantComposite %7 %10 %11 %12 %13 + %15 = OpTypeInt 32 1 + %18 = OpConstant %15 0 + %25 = OpTypeStruct %6 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %6 + %32 = OpTypeBool + %103 = OpConstantTrue %32 + %34 = OpConstant %6 0.100000001 + %48 = OpConstant %15 1 + %50 = OpTypePointer Output %7 + %51 = OpVariable %50 Output + %100 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %101 = OpVariable %100 Function + %102 = OpVariable %100 Function + OpBranch %19 + %19 = OpLabel + %60 = OpPhi %7 %14 %5 %58 %20 + %59 = OpPhi %15 %18 %5 %49 %20 + %29 = OpAccessChain %28 %27 %18 + %30 = OpLoad %6 %29 + %31 = OpConvertFToS %15 %30 + %33 = OpSLessThan %32 %59 %31 + OpLoopMerge %21 %20 None + OpBranchConditional %33 %20 %21 1 2 + %20 = OpLabel + %39 = OpCompositeExtract %6 %60 0 + %40 = OpFAdd %6 %39 %34 + %55 = OpCompositeInsert %7 %40 %60 0 + %44 = OpCompositeExtract %6 %60 1 + %45 = OpFSub %6 %44 %34 + %58 = OpCompositeInsert %7 %45 %55 1 + %49 = OpIAdd %15 %59 %48 + OpBranch %19 + %21 = OpLabel + OpStore %51 %60 + OpSelectionMerge %105 None + OpBranchConditional %103 %104 %105 + %104 = OpLabel + OpBranch %105 + %105 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + // Tests OpBranchConditional instruction with weigths. + auto instruction_descriptor = + MakeInstructionDescriptor(33, SpvOpBranchConditional, 0); + auto transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {0, 1}); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + + // Tests the two branch weights equal to 0. + instruction_descriptor = + MakeInstructionDescriptor(33, SpvOpBranchConditional, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {0, 0}); +#ifndef NDEBUG + ASSERT_DEATH( + transformation.IsApplicable(context.get(), transformation_context), + "At least one weight must be non-zero"); +#endif + + // Tests 32-bit unsigned integer overflow. + instruction_descriptor = + MakeInstructionDescriptor(33, SpvOpBranchConditional, 0); + transformation = TransformationAdjustBranchWeights(instruction_descriptor, + {UINT32_MAX, 0}); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + + instruction_descriptor = + MakeInstructionDescriptor(33, SpvOpBranchConditional, 0); + transformation = TransformationAdjustBranchWeights(instruction_descriptor, + {1, UINT32_MAX}); +#ifndef NDEBUG + ASSERT_DEATH( + transformation.IsApplicable(context.get(), transformation_context), + "The sum of the two weights must not be greater than UINT32_MAX"); +#endif + + // Tests OpBranchConditional instruction with no weights. + instruction_descriptor = + MakeInstructionDescriptor(21, SpvOpBranchConditional, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {0, 1}); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + + // Tests non-OpBranchConditional instructions. + instruction_descriptor = MakeInstructionDescriptor(2, SpvOpTypeVoid, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {5, 6}); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); + + instruction_descriptor = MakeInstructionDescriptor(20, SpvOpLabel, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {1, 2}); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); + + instruction_descriptor = MakeInstructionDescriptor(49, SpvOpIAdd, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {1, 2}); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); +} + +TEST(TransformationAdjustBranchWeightsTest, ApplyTest) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %51 %27 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %25 "buf" + OpMemberName %25 0 "value" + OpName %27 "" + OpName %51 "color" + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 0 + OpDecorate %51 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %150 = OpTypeVector %6 2 + %10 = OpConstant %6 0.300000012 + %11 = OpConstant %6 0.400000006 + %12 = OpConstant %6 0.5 + %13 = OpConstant %6 1 + %14 = OpConstantComposite %7 %10 %11 %12 %13 + %15 = OpTypeInt 32 1 + %18 = OpConstant %15 0 + %25 = OpTypeStruct %6 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %6 + %32 = OpTypeBool + %103 = OpConstantTrue %32 + %34 = OpConstant %6 0.100000001 + %48 = OpConstant %15 1 + %50 = OpTypePointer Output %7 + %51 = OpVariable %50 Output + %100 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %101 = OpVariable %100 Function + %102 = OpVariable %100 Function + OpBranch %19 + %19 = OpLabel + %60 = OpPhi %7 %14 %5 %58 %20 + %59 = OpPhi %15 %18 %5 %49 %20 + %29 = OpAccessChain %28 %27 %18 + %30 = OpLoad %6 %29 + %31 = OpConvertFToS %15 %30 + %33 = OpSLessThan %32 %59 %31 + OpLoopMerge %21 %20 None + OpBranchConditional %33 %20 %21 1 2 + %20 = OpLabel + %39 = OpCompositeExtract %6 %60 0 + %40 = OpFAdd %6 %39 %34 + %55 = OpCompositeInsert %7 %40 %60 0 + %44 = OpCompositeExtract %6 %60 1 + %45 = OpFSub %6 %44 %34 + %58 = OpCompositeInsert %7 %45 %55 1 + %49 = OpIAdd %15 %59 %48 + OpBranch %19 + %21 = OpLabel + OpStore %51 %60 + OpSelectionMerge %105 None + OpBranchConditional %103 %104 %105 + %104 = OpLabel + OpBranch %105 + %105 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + auto instruction_descriptor = + MakeInstructionDescriptor(33, SpvOpBranchConditional, 0); + auto transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {5, 6}); + transformation.Apply(context.get(), &transformation_context); + + instruction_descriptor = + MakeInstructionDescriptor(21, SpvOpBranchConditional, 0); + transformation = + TransformationAdjustBranchWeights(instruction_descriptor, {7, 8}); + transformation.Apply(context.get(), &transformation_context); + + std::string variant_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %51 %27 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %25 "buf" + OpMemberName %25 0 "value" + OpName %27 "" + OpName %51 "color" + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 0 + OpDecorate %51 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %150 = OpTypeVector %6 2 + %10 = OpConstant %6 0.300000012 + %11 = OpConstant %6 0.400000006 + %12 = OpConstant %6 0.5 + %13 = OpConstant %6 1 + %14 = OpConstantComposite %7 %10 %11 %12 %13 + %15 = OpTypeInt 32 1 + %18 = OpConstant %15 0 + %25 = OpTypeStruct %6 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %6 + %32 = OpTypeBool + %103 = OpConstantTrue %32 + %34 = OpConstant %6 0.100000001 + %48 = OpConstant %15 1 + %50 = OpTypePointer Output %7 + %51 = OpVariable %50 Output + %100 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %101 = OpVariable %100 Function + %102 = OpVariable %100 Function + OpBranch %19 + %19 = OpLabel + %60 = OpPhi %7 %14 %5 %58 %20 + %59 = OpPhi %15 %18 %5 %49 %20 + %29 = OpAccessChain %28 %27 %18 + %30 = OpLoad %6 %29 + %31 = OpConvertFToS %15 %30 + %33 = OpSLessThan %32 %59 %31 + OpLoopMerge %21 %20 None + OpBranchConditional %33 %20 %21 5 6 + %20 = OpLabel + %39 = OpCompositeExtract %6 %60 0 + %40 = OpFAdd %6 %39 %34 + %55 = OpCompositeInsert %7 %40 %60 0 + %44 = OpCompositeExtract %6 %60 1 + %45 = OpFSub %6 %44 %34 + %58 = OpCompositeInsert %7 %45 %55 1 + %49 = OpIAdd %15 %59 %48 + OpBranch %19 + %21 = OpLabel + OpStore %51 %60 + OpSelectionMerge %105 None + OpBranchConditional %103 %104 %105 7 8 + %104 = OpLabel + OpBranch %105 + %105 = OpLabel + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(IsEqual(env, variant_shader, context.get())); +} + +} // namespace +} // namespace fuzz +} // namespace spvtools
diff --git a/test/fuzz/transformation_composite_construct_test.cpp b/test/fuzz/transformation_composite_construct_test.cpp index d303368..b663866 100644 --- a/test/fuzz/transformation_composite_construct_test.cpp +++ b/test/fuzz/transformation_composite_construct_test.cpp
@@ -129,6 +129,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Make a vec2[3] TransformationCompositeConstruct make_vec2_array_length_3( @@ -138,18 +141,18 @@ TransformationCompositeConstruct make_vec2_array_length_3_bad( 37, {41, 45, 27, 27}, MakeInstructionDescriptor(46, SpvOpAccessChain, 0), 200); - ASSERT_TRUE( - make_vec2_array_length_3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - make_vec2_array_length_3_bad.IsApplicable(context.get(), fact_manager)); - make_vec2_array_length_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_vec2_array_length_3.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(make_vec2_array_length_3_bad.IsApplicable( + context.get(), transformation_context)); + make_vec2_array_length_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(41, {}), MakeDataDescriptor(200, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(45, {}), MakeDataDescriptor(200, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(27, {}), MakeDataDescriptor(200, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), MakeDataDescriptor(200, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(45, {}), MakeDataDescriptor(200, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(27, {}), MakeDataDescriptor(200, {2}))); // Make a float[2] TransformationCompositeConstruct make_float_array_length_2( @@ -157,16 +160,16 @@ // Bad: %41 does not have type float TransformationCompositeConstruct make_float_array_length_2_bad( 9, {41, 40}, MakeInstructionDescriptor(71, SpvOpStore, 0), 201); - ASSERT_TRUE( - make_float_array_length_2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - make_float_array_length_2_bad.IsApplicable(context.get(), fact_manager)); - make_float_array_length_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_float_array_length_2.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(make_float_array_length_2_bad.IsApplicable( + context.get(), transformation_context)); + make_float_array_length_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(24, {}), MakeDataDescriptor(201, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(40, {}), MakeDataDescriptor(201, {1}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(24, {}), MakeDataDescriptor(201, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), MakeDataDescriptor(201, {1}))); // Make a bool[3] TransformationCompositeConstruct make_bool_array_length_3( @@ -176,18 +179,18 @@ TransformationCompositeConstruct make_bool_array_length_3_bad( 47, {33, 54, 50}, MakeInstructionDescriptor(33, SpvOpSelectionMerge, 0), 202); - ASSERT_TRUE( - make_bool_array_length_3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - make_bool_array_length_3_bad.IsApplicable(context.get(), fact_manager)); - make_bool_array_length_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_bool_array_length_3.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(make_bool_array_length_3_bad.IsApplicable( + context.get(), transformation_context)); + make_bool_array_length_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(33, {}), MakeDataDescriptor(202, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(50, {}), MakeDataDescriptor(202, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(50, {}), MakeDataDescriptor(202, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(33, {}), MakeDataDescriptor(202, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(50, {}), MakeDataDescriptor(202, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(50, {}), MakeDataDescriptor(202, {2}))); // make a uvec3[2][2] TransformationCompositeConstruct make_uvec3_array_length_2_2( @@ -195,17 +198,16 @@ // Bad: Skip count 100 is too large. TransformationCompositeConstruct make_uvec3_array_length_2_2_bad( 58, {33, 54}, MakeInstructionDescriptor(64, SpvOpStore, 100), 203); - ASSERT_TRUE( - make_uvec3_array_length_2_2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_uvec3_array_length_2_2_bad.IsApplicable(context.get(), - fact_manager)); - make_uvec3_array_length_2_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_uvec3_array_length_2_2.IsApplicable(context.get(), + transformation_context)); + ASSERT_FALSE(make_uvec3_array_length_2_2_bad.IsApplicable( + context.get(), transformation_context)); + make_uvec3_array_length_2_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(69, {}), MakeDataDescriptor(203, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(100, {}), - MakeDataDescriptor(203, {1}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(69, {}), MakeDataDescriptor(203, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(100, {}), MakeDataDescriptor(203, {1}))); std::string after_transformation = R"( OpCapability Shader @@ -393,6 +395,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // make a mat3x4 TransformationCompositeConstruct make_mat34( @@ -400,16 +405,17 @@ // Bad: %35 is mat4x3, not mat3x4. TransformationCompositeConstruct make_mat34_bad( 35, {25, 28, 31}, MakeInstructionDescriptor(31, SpvOpReturn, 0), 200); - ASSERT_TRUE(make_mat34.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_mat34_bad.IsApplicable(context.get(), fact_manager)); - make_mat34.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_mat34.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_mat34_bad.IsApplicable(context.get(), transformation_context)); + make_mat34.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(25, {}), MakeDataDescriptor(200, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(28, {}), MakeDataDescriptor(200, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(31, {}), MakeDataDescriptor(200, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(25, {}), MakeDataDescriptor(200, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(28, {}), MakeDataDescriptor(200, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(31, {}), MakeDataDescriptor(200, {2}))); // make a mat4x3 TransformationCompositeConstruct make_mat43( @@ -417,19 +423,19 @@ // Bad: %25 does not match the matrix's column type. TransformationCompositeConstruct make_mat43_bad( 35, {25, 13, 16, 100}, MakeInstructionDescriptor(31, SpvOpStore, 0), 201); - ASSERT_TRUE(make_mat43.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_mat43_bad.IsApplicable(context.get(), fact_manager)); - make_mat43.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_mat43.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_mat43_bad.IsApplicable(context.get(), transformation_context)); + make_mat43.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(11, {}), MakeDataDescriptor(201, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(13, {}), MakeDataDescriptor(201, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(16, {}), MakeDataDescriptor(201, {2}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(100, {}), - MakeDataDescriptor(201, {3}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), MakeDataDescriptor(201, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(13, {}), MakeDataDescriptor(201, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(16, {}), MakeDataDescriptor(201, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(100, {}), MakeDataDescriptor(201, {3}))); std::string after_transformation = R"( OpCapability Shader @@ -602,6 +608,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // make an Inner TransformationCompositeConstruct make_inner( @@ -609,14 +618,15 @@ // Bad: Too few fields to make the struct. TransformationCompositeConstruct make_inner_bad( 9, {25}, MakeInstructionDescriptor(57, SpvOpAccessChain, 0), 200); - ASSERT_TRUE(make_inner.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_inner_bad.IsApplicable(context.get(), fact_manager)); - make_inner.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_inner.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_inner_bad.IsApplicable(context.get(), transformation_context)); + make_inner.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(25, {}), MakeDataDescriptor(200, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(19, {}), MakeDataDescriptor(200, {1}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(25, {}), MakeDataDescriptor(200, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(19, {}), MakeDataDescriptor(200, {1}))); // make an Outer TransformationCompositeConstruct make_outer( @@ -626,17 +636,17 @@ TransformationCompositeConstruct make_outer_bad( 33, {46, 200, 56}, MakeInstructionDescriptor(200, SpvOpCompositeConstruct, 0), 201); - ASSERT_TRUE(make_outer.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_outer_bad.IsApplicable(context.get(), fact_manager)); - make_outer.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_outer.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_outer_bad.IsApplicable(context.get(), transformation_context)); + make_outer.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(46, {}), MakeDataDescriptor(201, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(200, {}), - MakeDataDescriptor(201, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(56, {}), MakeDataDescriptor(201, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(46, {}), MakeDataDescriptor(201, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(200, {}), MakeDataDescriptor(201, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(56, {}), MakeDataDescriptor(201, {2}))); std::string after_transformation = R"( OpCapability Shader @@ -922,20 +932,24 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationCompositeConstruct make_vec2( 7, {17, 11}, MakeInstructionDescriptor(100, SpvOpStore, 0), 200); // Bad: not enough data for a vec2 TransformationCompositeConstruct make_vec2_bad( 7, {11}, MakeInstructionDescriptor(100, SpvOpStore, 0), 200); - ASSERT_TRUE(make_vec2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_vec2_bad.IsApplicable(context.get(), fact_manager)); - make_vec2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_vec2.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_vec2_bad.IsApplicable(context.get(), transformation_context)); + make_vec2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(200, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(11, {}), MakeDataDescriptor(200, {1}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(17, {}), MakeDataDescriptor(200, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), MakeDataDescriptor(200, {1}))); TransformationCompositeConstruct make_vec3( 25, {12, 32}, MakeInstructionDescriptor(35, SpvOpCompositeConstruct, 0), @@ -944,18 +958,17 @@ TransformationCompositeConstruct make_vec3_bad( 25, {12, 32, 32}, MakeInstructionDescriptor(35, SpvOpCompositeConstruct, 0), 201); - ASSERT_TRUE(make_vec3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_vec3_bad.IsApplicable(context.get(), fact_manager)); - make_vec3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_vec3.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_vec3_bad.IsApplicable(context.get(), transformation_context)); + make_vec3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(12, {0}), - MakeDataDescriptor(201, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(12, {1}), - MakeDataDescriptor(201, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(32, {}), MakeDataDescriptor(201, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(12, {0}), MakeDataDescriptor(201, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(12, {1}), MakeDataDescriptor(201, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(32, {}), MakeDataDescriptor(201, {2}))); TransformationCompositeConstruct make_vec4( 44, {32, 32, 10, 11}, MakeInstructionDescriptor(75, SpvOpAccessChain, 0), @@ -964,34 +977,34 @@ TransformationCompositeConstruct make_vec4_bad( 44, {48, 32, 10, 11}, MakeInstructionDescriptor(75, SpvOpAccessChain, 0), 202); - ASSERT_TRUE(make_vec4.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_vec4_bad.IsApplicable(context.get(), fact_manager)); - make_vec4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_vec4.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_vec4_bad.IsApplicable(context.get(), transformation_context)); + make_vec4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(32, {}), MakeDataDescriptor(202, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(32, {}), MakeDataDescriptor(202, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(10, {}), MakeDataDescriptor(202, {2}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(11, {}), MakeDataDescriptor(202, {3}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(32, {}), MakeDataDescriptor(202, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(32, {}), MakeDataDescriptor(202, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(10, {}), MakeDataDescriptor(202, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), MakeDataDescriptor(202, {3}))); TransformationCompositeConstruct make_ivec2( 51, {126, 120}, MakeInstructionDescriptor(128, SpvOpLoad, 0), 203); // Bad: if 128 is not available at the instruction that defines 128 TransformationCompositeConstruct make_ivec2_bad( 51, {128, 120}, MakeInstructionDescriptor(128, SpvOpLoad, 0), 203); - ASSERT_TRUE(make_ivec2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_ivec2_bad.IsApplicable(context.get(), fact_manager)); - make_ivec2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_ivec2.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_ivec2_bad.IsApplicable(context.get(), transformation_context)); + make_ivec2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(126, {}), - MakeDataDescriptor(203, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(120, {}), - MakeDataDescriptor(203, {1}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(126, {}), MakeDataDescriptor(203, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(120, {}), MakeDataDescriptor(203, {1}))); TransformationCompositeConstruct make_ivec3( 114, {56, 117, 56}, MakeInstructionDescriptor(66, SpvOpAccessChain, 0), @@ -1000,17 +1013,17 @@ TransformationCompositeConstruct make_ivec3_bad( 114, {56, 117, 1300}, MakeInstructionDescriptor(66, SpvOpAccessChain, 0), 204); - ASSERT_TRUE(make_ivec3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_ivec3_bad.IsApplicable(context.get(), fact_manager)); - make_ivec3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_ivec3.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_ivec3_bad.IsApplicable(context.get(), transformation_context)); + make_ivec3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(56, {}), MakeDataDescriptor(204, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(117, {}), - MakeDataDescriptor(204, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(56, {}), MakeDataDescriptor(204, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(56, {}), MakeDataDescriptor(204, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(117, {}), MakeDataDescriptor(204, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(56, {}), MakeDataDescriptor(204, {2}))); TransformationCompositeConstruct make_ivec4( 122, {56, 117, 117, 117}, MakeInstructionDescriptor(66, SpvOpIAdd, 0), @@ -1019,51 +1032,50 @@ TransformationCompositeConstruct make_ivec4_bad( 86, {56, 117, 117, 117}, MakeInstructionDescriptor(66, SpvOpIAdd, 0), 205); - ASSERT_TRUE(make_ivec4.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_ivec4_bad.IsApplicable(context.get(), fact_manager)); - make_ivec4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_ivec4.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_ivec4_bad.IsApplicable(context.get(), transformation_context)); + make_ivec4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(56, {}), MakeDataDescriptor(205, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(117, {}), - MakeDataDescriptor(205, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(117, {}), - MakeDataDescriptor(205, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(117, {}), - MakeDataDescriptor(205, {3}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(56, {}), MakeDataDescriptor(205, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(117, {}), MakeDataDescriptor(205, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(117, {}), MakeDataDescriptor(205, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(117, {}), MakeDataDescriptor(205, {3}))); TransformationCompositeConstruct make_uvec2( 86, {18, 38}, MakeInstructionDescriptor(133, SpvOpAccessChain, 0), 206); TransformationCompositeConstruct make_uvec2_bad( 86, {18, 38}, MakeInstructionDescriptor(133, SpvOpAccessChain, 200), 206); - ASSERT_TRUE(make_uvec2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_uvec2_bad.IsApplicable(context.get(), fact_manager)); - make_uvec2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_uvec2.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_uvec2_bad.IsApplicable(context.get(), transformation_context)); + make_uvec2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(18, {}), MakeDataDescriptor(206, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(38, {}), MakeDataDescriptor(206, {1}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(18, {}), MakeDataDescriptor(206, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(38, {}), MakeDataDescriptor(206, {1}))); TransformationCompositeConstruct make_uvec3( 59, {14, 18, 136}, MakeInstructionDescriptor(137, SpvOpReturn, 0), 207); // Bad because 1300 is not an id TransformationCompositeConstruct make_uvec3_bad( 59, {14, 18, 1300}, MakeInstructionDescriptor(137, SpvOpReturn, 0), 207); - ASSERT_TRUE(make_uvec3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_uvec3_bad.IsApplicable(context.get(), fact_manager)); - make_uvec3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_uvec3.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_uvec3_bad.IsApplicable(context.get(), transformation_context)); + make_uvec3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(14, {}), MakeDataDescriptor(207, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(18, {}), MakeDataDescriptor(207, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(136, {}), - MakeDataDescriptor(207, {2}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(14, {}), MakeDataDescriptor(207, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(18, {}), MakeDataDescriptor(207, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(136, {}), MakeDataDescriptor(207, {2}))); TransformationCompositeConstruct make_uvec4( 131, {14, 18, 136, 136}, @@ -1072,20 +1084,19 @@ TransformationCompositeConstruct make_uvec4_bad( 86, {14, 18, 136, 136}, MakeInstructionDescriptor(137, SpvOpAccessChain, 0), 208); - ASSERT_TRUE(make_uvec4.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_uvec4_bad.IsApplicable(context.get(), fact_manager)); - make_uvec4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_uvec4.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_uvec4_bad.IsApplicable(context.get(), transformation_context)); + make_uvec4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(14, {}), MakeDataDescriptor(208, {0}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(18, {}), MakeDataDescriptor(208, {1}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(136, {}), - MakeDataDescriptor(208, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(136, {}), - MakeDataDescriptor(208, {3}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(14, {}), MakeDataDescriptor(208, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(18, {}), MakeDataDescriptor(208, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(136, {}), MakeDataDescriptor(208, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(136, {}), MakeDataDescriptor(208, {3}))); TransformationCompositeConstruct make_bvec2( 102, @@ -1102,55 +1113,51 @@ 41, }, MakeInstructionDescriptor(0, SpvOpExtInstImport, 0), 209); - ASSERT_TRUE(make_bvec2.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_bvec2_bad.IsApplicable(context.get(), fact_manager)); - make_bvec2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_bvec2.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_bvec2_bad.IsApplicable(context.get(), transformation_context)); + make_bvec2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(111, {}), - MakeDataDescriptor(209, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(41, {}), MakeDataDescriptor(209, {1}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(111, {}), MakeDataDescriptor(209, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), MakeDataDescriptor(209, {1}))); TransformationCompositeConstruct make_bvec3( 93, {108, 73}, MakeInstructionDescriptor(108, SpvOpStore, 0), 210); // Bad because there are too many components for a bvec3 TransformationCompositeConstruct make_bvec3_bad( 93, {108, 108}, MakeInstructionDescriptor(108, SpvOpStore, 0), 210); - ASSERT_TRUE(make_bvec3.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_bvec3_bad.IsApplicable(context.get(), fact_manager)); - make_bvec3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_bvec3.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_bvec3_bad.IsApplicable(context.get(), transformation_context)); + make_bvec3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {0}), - MakeDataDescriptor(210, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {1}), - MakeDataDescriptor(210, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(73, {}), MakeDataDescriptor(210, {2}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {0}), MakeDataDescriptor(210, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {1}), MakeDataDescriptor(210, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(73, {}), MakeDataDescriptor(210, {2}))); TransformationCompositeConstruct make_bvec4( 70, {108, 108}, MakeInstructionDescriptor(108, SpvOpBranch, 0), 211); // Bad because 21 is a type, not a result id TransformationCompositeConstruct make_bvec4_bad( 70, {21, 108}, MakeInstructionDescriptor(108, SpvOpBranch, 0), 211); - ASSERT_TRUE(make_bvec4.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(make_bvec4_bad.IsApplicable(context.get(), fact_manager)); - make_bvec4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(make_bvec4.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + make_bvec4_bad.IsApplicable(context.get(), transformation_context)); + make_bvec4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {0}), - MakeDataDescriptor(211, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {1}), - MakeDataDescriptor(211, {1}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {0}), - MakeDataDescriptor(211, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(108, {1}), - MakeDataDescriptor(211, {3}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {0}), MakeDataDescriptor(211, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {1}), MakeDataDescriptor(211, {1}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {0}), MakeDataDescriptor(211, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(108, {1}), MakeDataDescriptor(211, {3}))); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_composite_extract_test.cpp b/test/fuzz/transformation_composite_extract_test.cpp index 5cc2115..a7674a6 100644 --- a/test/fuzz/transformation_composite_extract_test.cpp +++ b/test/fuzz/transformation_composite_extract_test.cpp
@@ -96,100 +96,103 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Instruction does not exist. ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(36, SpvOpIAdd, 0), 200, 101, {0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Id for composite is not a composite. ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(36, SpvOpIAdd, 0), 200, 27, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Composite does not dominate instruction being inserted before. ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(37, SpvOpAccessChain, 0), 200, 101, {0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too many indices for extraction from struct composite. ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(24, SpvOpAccessChain, 0), 200, 101, {0, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too many indices for extraction from struct composite. ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(13, SpvOpIEqual, 0), 200, 104, {0, 0, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Out of bounds index for extraction from struct composite. ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(13, SpvOpIEqual, 0), 200, 104, {0, 3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Result id already used. ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(35, SpvOpFAdd, 0), 80, 103, {0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationCompositeExtract transformation_1( MakeInstructionDescriptor(36, SpvOpConvertFToS, 0), 201, 100, {2}); - ASSERT_TRUE(transformation_1.IsApplicable(context.get(), fact_manager)); - transformation_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_1.IsApplicable(context.get(), transformation_context)); + transformation_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); TransformationCompositeExtract transformation_2( MakeInstructionDescriptor(37, SpvOpAccessChain, 0), 202, 104, {0, 2}); - ASSERT_TRUE(transformation_2.IsApplicable(context.get(), fact_manager)); - transformation_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_2.IsApplicable(context.get(), transformation_context)); + transformation_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); TransformationCompositeExtract transformation_3( MakeInstructionDescriptor(29, SpvOpAccessChain, 0), 203, 104, {0}); - ASSERT_TRUE(transformation_3.IsApplicable(context.get(), fact_manager)); - transformation_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_3.IsApplicable(context.get(), transformation_context)); + transformation_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); TransformationCompositeExtract transformation_4( MakeInstructionDescriptor(24, SpvOpStore, 0), 204, 101, {0}); - ASSERT_TRUE(transformation_4.IsApplicable(context.get(), fact_manager)); - transformation_4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_4.IsApplicable(context.get(), transformation_context)); + transformation_4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); TransformationCompositeExtract transformation_5( MakeInstructionDescriptor(29, SpvOpBranch, 0), 205, 102, {2}); - ASSERT_TRUE(transformation_5.IsApplicable(context.get(), fact_manager)); - transformation_5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_5.IsApplicable(context.get(), transformation_context)); + transformation_5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); TransformationCompositeExtract transformation_6( MakeInstructionDescriptor(37, SpvOpReturn, 0), 206, 103, {1}); - ASSERT_TRUE(transformation_6.IsApplicable(context.get(), fact_manager)); - transformation_6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation_6.IsApplicable(context.get(), transformation_context)); + transformation_6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(201, {}), - MakeDataDescriptor(100, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(202, {}), - MakeDataDescriptor(104, {0, 2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(203, {}), - MakeDataDescriptor(104, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(204, {}), - MakeDataDescriptor(101, {0}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(205, {}), - MakeDataDescriptor(102, {2}), - context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(206, {}), - MakeDataDescriptor(103, {1}), - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(201, {}), MakeDataDescriptor(100, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(202, {}), MakeDataDescriptor(104, {0, 2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(203, {}), MakeDataDescriptor(104, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(204, {}), MakeDataDescriptor(101, {0}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(205, {}), MakeDataDescriptor(102, {2}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(206, {}), MakeDataDescriptor(103, {1}))); std::string after_transformation = R"( OpCapability Shader @@ -348,49 +351,52 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Cannot insert before the OpVariables of a function. ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(101, SpvOpVariable, 0), 200, 14, {0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(101, SpvOpVariable, 1), 200, 14, {1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationCompositeExtract( MakeInstructionDescriptor(102, SpvOpVariable, 0), 200, 14, {1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK to insert right after the OpVariables. ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(102, SpvOpBranch, 1), 200, 14, {1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before the OpPhis of a block. ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(60, SpvOpPhi, 0), 200, 14, {2}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(59, SpvOpPhi, 0), 200, 14, {3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK to insert after the OpPhis. ASSERT_TRUE( TransformationCompositeExtract( MakeInstructionDescriptor(59, SpvOpAccessChain, 0), 200, 14, {3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before OpLoopMerge ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(33, SpvOpBranchConditional, 0), 200, 14, {3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before OpSelectionMerge ASSERT_FALSE(TransformationCompositeExtract( MakeInstructionDescriptor(21, SpvOpBranchConditional, 0), 200, 14, {2}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_compute_data_synonym_fact_closure_test.cpp b/test/fuzz/transformation_compute_data_synonym_fact_closure_test.cpp new file mode 100644 index 0000000..5fa74b7 --- /dev/null +++ b/test/fuzz/transformation_compute_data_synonym_fact_closure_test.cpp
@@ -0,0 +1,377 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_compute_data_synonym_fact_closure.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(TransformationComputeDataSynonymFactClosureTest, DataSynonymFacts) { + // The SPIR-V types and constants come from the following code. The body of + // the SPIR-V function then constructs a composite that is synonymous with + // myT. + // + // #version 310 es + // + // precision highp float; + // + // struct S { + // int a; + // uvec2 b; + // }; + // + // struct T { + // bool c[5]; + // mat4x2 d; + // S e; + // }; + // + // void main() { + // T myT = T(bool[5](true, false, true, false, true), + // mat4x2(vec2(1.0, 2.0), vec2(3.0, 4.0), + // vec2(5.0, 6.0), vec2(7.0, 8.0)), + // S(10, uvec2(100u, 200u))); + // } + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %15 "S" + OpMemberName %15 0 "a" + OpMemberName %15 1 "b" + OpName %16 "T" + OpMemberName %16 0 "c" + OpMemberName %16 1 "d" + OpMemberName %16 2 "e" + OpName %18 "myT" + OpMemberDecorate %15 0 RelaxedPrecision + OpMemberDecorate %15 1 RelaxedPrecision + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpTypeInt 32 0 + %8 = OpConstant %7 5 + %9 = OpTypeArray %6 %8 + %10 = OpTypeFloat 32 + %11 = OpTypeVector %10 2 + %12 = OpTypeMatrix %11 4 + %13 = OpTypeInt 32 1 + %14 = OpTypeVector %7 2 + %15 = OpTypeStruct %13 %14 + %16 = OpTypeStruct %9 %12 %15 + %17 = OpTypePointer Function %16 + %19 = OpConstantTrue %6 + %20 = OpConstantFalse %6 + %21 = OpConstantComposite %9 %19 %20 %19 %20 %19 + %22 = OpConstant %10 1 + %23 = OpConstant %10 2 + %24 = OpConstantComposite %11 %22 %23 + %25 = OpConstant %10 3 + %26 = OpConstant %10 4 + %27 = OpConstantComposite %11 %25 %26 + %28 = OpConstant %10 5 + %29 = OpConstant %10 6 + %30 = OpConstantComposite %11 %28 %29 + %31 = OpConstant %10 7 + %32 = OpConstant %10 8 + %33 = OpConstantComposite %11 %31 %32 + %34 = OpConstantComposite %12 %24 %27 %30 %33 + %35 = OpConstant %13 10 + %36 = OpConstant %7 100 + %37 = OpConstant %7 200 + %38 = OpConstantComposite %14 %36 %37 + %39 = OpConstantComposite %15 %35 %38 + %40 = OpConstantComposite %16 %21 %34 %39 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %18 = OpVariable %17 Function + OpStore %18 %40 + %100 = OpCompositeConstruct %9 %19 %20 %19 %20 %19 + %101 = OpCompositeConstruct %11 %22 %23 + %102 = OpCompositeConstruct %11 %25 %26 + %103 = OpCompositeConstruct %11 %28 %29 + %104 = OpCompositeConstruct %11 %31 %32 + %105 = OpCompositeConstruct %12 %101 %102 %103 %104 + %106 = OpCompositeConstruct %14 %36 %37 + %107 = OpCompositeConstruct %15 %35 %106 + %108 = OpCompositeConstruct %16 %100 %105 %107 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + ASSERT_TRUE(TransformationComputeDataSynonymFactClosure(100).IsApplicable( + context.get(), transformation_context)); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {}), + MakeDataDescriptor(101, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), + MakeDataDescriptor(101, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {1}), + MakeDataDescriptor(101, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), + MakeDataDescriptor(101, {1}))); + + fact_manager.AddFactDataSynonym(MakeDataDescriptor(24, {}), + MakeDataDescriptor(101, {}), context.get()); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {}), + MakeDataDescriptor(101, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), + MakeDataDescriptor(101, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {1}), + MakeDataDescriptor(101, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(24, {0}), + MakeDataDescriptor(101, {1}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {}), + MakeDataDescriptor(102, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), + MakeDataDescriptor(102, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), + MakeDataDescriptor(102, {1}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(27, {0}), + MakeDataDescriptor(102, {0}), context.get()); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {}), + MakeDataDescriptor(102, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), + MakeDataDescriptor(102, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), + MakeDataDescriptor(102, {1}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(27, {1}), + MakeDataDescriptor(102, {1}), context.get()); + + TransformationComputeDataSynonymFactClosure(100).Apply( + context.get(), &transformation_context); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {}), + MakeDataDescriptor(102, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {0}), + MakeDataDescriptor(102, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(27, {1}), + MakeDataDescriptor(102, {1}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {}), + MakeDataDescriptor(103, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {0}), + MakeDataDescriptor(103, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {1}), + MakeDataDescriptor(103, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {}), + MakeDataDescriptor(104, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), + MakeDataDescriptor(104, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {1}), + MakeDataDescriptor(104, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {}), + MakeDataDescriptor(105, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {0}), + MakeDataDescriptor(105, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {1}), + MakeDataDescriptor(105, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {2}), + MakeDataDescriptor(105, {2}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), + MakeDataDescriptor(105, {3}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(30, {}), + MakeDataDescriptor(103, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(33, {}), + MakeDataDescriptor(104, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {0}), + MakeDataDescriptor(105, {0}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {1}), + MakeDataDescriptor(105, {1}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {2}), + MakeDataDescriptor(105, {2}), context.get()); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {}), + MakeDataDescriptor(103, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {0}), + MakeDataDescriptor(103, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(30, {1}), + MakeDataDescriptor(103, {1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {}), + MakeDataDescriptor(104, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), + MakeDataDescriptor(104, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {1}), + MakeDataDescriptor(104, {1}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {}), + MakeDataDescriptor(105, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {0}), + MakeDataDescriptor(105, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {1}), + MakeDataDescriptor(105, {1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {2}), + MakeDataDescriptor(105, {2}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), + MakeDataDescriptor(105, {3}))); + + fact_manager.AddFactDataSynonym(MakeDataDescriptor(34, {3}), + MakeDataDescriptor(105, {3}), context.get()); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(33, {0}), + MakeDataDescriptor(104, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(34, {3}), + MakeDataDescriptor(105, {3}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(21, {}), + MakeDataDescriptor(100, {}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {0}), + MakeDataDescriptor(100, {0}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {1}), + MakeDataDescriptor(100, {1}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {2}), + MakeDataDescriptor(100, {2}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {3}), + MakeDataDescriptor(100, {3}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(21, {4}), + MakeDataDescriptor(100, {4}), context.get()); + + TransformationComputeDataSynonymFactClosure(100).Apply( + context.get(), &transformation_context); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(21, {}), + MakeDataDescriptor(100, {}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(39, {0}), + MakeDataDescriptor(107, {0}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(35, {}), + MakeDataDescriptor(39, {0}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(39, {0}), + MakeDataDescriptor(35, {}), context.get()); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(39, {0}), + MakeDataDescriptor(107, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(35, {}), + MakeDataDescriptor(39, {0}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {0}), + MakeDataDescriptor(36, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {1}), + MakeDataDescriptor(37, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(106, {0}), + MakeDataDescriptor(36, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(106, {1}), + MakeDataDescriptor(37, {}))); + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {}), + MakeDataDescriptor(106, {}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(38, {0}), + MakeDataDescriptor(36, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(106, {0}), + MakeDataDescriptor(36, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(38, {1}), + MakeDataDescriptor(37, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(106, {1}), + MakeDataDescriptor(37, {}), context.get()); + + TransformationComputeDataSynonymFactClosure(100).Apply( + context.get(), &transformation_context); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {0}), + MakeDataDescriptor(36, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {1}), + MakeDataDescriptor(37, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(106, {0}), + MakeDataDescriptor(36, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(106, {1}), + MakeDataDescriptor(37, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(38, {}), + MakeDataDescriptor(106, {}))); + + ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), + MakeDataDescriptor(108, {}))); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(107, {0}), + MakeDataDescriptor(35, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {0}), + MakeDataDescriptor(108, {0}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {1}), + MakeDataDescriptor(108, {1}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {2}), + MakeDataDescriptor(108, {2}), context.get()); + + TransformationComputeDataSynonymFactClosure(100).Apply( + context.get(), &transformation_context); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), + MakeDataDescriptor(108, {}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0}), + MakeDataDescriptor(108, {0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1}), + MakeDataDescriptor(108, {1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2}), + MakeDataDescriptor(108, {2}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 0}), + MakeDataDescriptor(108, {0, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 1}), + MakeDataDescriptor(108, {0, 1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 2}), + MakeDataDescriptor(108, {0, 2}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 3}), + MakeDataDescriptor(108, {0, 3}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {0, 4}), + MakeDataDescriptor(108, {0, 4}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0}), + MakeDataDescriptor(108, {1, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1}), + MakeDataDescriptor(108, {1, 1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2}), + MakeDataDescriptor(108, {1, 2}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3}), + MakeDataDescriptor(108, {1, 3}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0, 0}), + MakeDataDescriptor(108, {1, 0, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1, 0}), + MakeDataDescriptor(108, {1, 1, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2, 0}), + MakeDataDescriptor(108, {1, 2, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3, 0}), + MakeDataDescriptor(108, {1, 3, 0}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 0, 1}), + MakeDataDescriptor(108, {1, 0, 1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 1, 1}), + MakeDataDescriptor(108, {1, 1, 1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 2, 1}), + MakeDataDescriptor(108, {1, 2, 1}))); + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {1, 3, 1}), + MakeDataDescriptor(108, {1, 3, 1}))); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 0}), + MakeDataDescriptor(108, {2, 0}))); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1}), + MakeDataDescriptor(108, {2, 1}))); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1, 0}), + MakeDataDescriptor(108, {2, 1, 0}))); + + ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {2, 1, 1}), + MakeDataDescriptor(108, {2, 1, 1}))); +} + +} // namespace +} // namespace fuzz +} // namespace spvtools
diff --git a/test/fuzz/transformation_copy_object_test.cpp b/test/fuzz/transformation_copy_object_test.cpp index b85f75b..cf9d135 100644 --- a/test/fuzz/transformation_copy_object_test.cpp +++ b/test/fuzz/transformation_copy_object_test.cpp
@@ -51,77 +51,92 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - ASSERT_EQ(0, - fact_manager.GetIdsForWhichSynonymsAreKnown(context.get()).size()); + ASSERT_EQ(0, transformation_context.GetFactManager() + ->GetIdsForWhichSynonymsAreKnown() + .size()); { TransformationCopyObject copy_true( 7, MakeInstructionDescriptor(5, SpvOpReturn, 0), 100); - ASSERT_TRUE(copy_true.IsApplicable(context.get(), fact_manager)); - copy_true.Apply(context.get(), &fact_manager); + ASSERT_TRUE(copy_true.IsApplicable(context.get(), transformation_context)); + copy_true.Apply(context.get(), &transformation_context); std::vector<uint32_t> ids_for_which_synonyms_are_known = - fact_manager.GetIdsForWhichSynonymsAreKnown(context.get()); + transformation_context.GetFactManager() + ->GetIdsForWhichSynonymsAreKnown(); ASSERT_EQ(2, ids_for_which_synonyms_are_known.size()); ASSERT_TRUE(std::find(ids_for_which_synonyms_are_known.begin(), ids_for_which_synonyms_are_known.end(), 7) != ids_for_which_synonyms_are_known.end()); - ASSERT_EQ(2, fact_manager.GetSynonymsForId(7, context.get()).size()); + ASSERT_EQ( + 2, transformation_context.GetFactManager()->GetSynonymsForId(7).size()); protobufs::DataDescriptor descriptor_100 = MakeDataDescriptor(100, {}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(7, {}), - descriptor_100, context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(7, {}), descriptor_100)); } { TransformationCopyObject copy_false( 8, MakeInstructionDescriptor(100, SpvOpReturn, 0), 101); - ASSERT_TRUE(copy_false.IsApplicable(context.get(), fact_manager)); - copy_false.Apply(context.get(), &fact_manager); + ASSERT_TRUE(copy_false.IsApplicable(context.get(), transformation_context)); + copy_false.Apply(context.get(), &transformation_context); std::vector<uint32_t> ids_for_which_synonyms_are_known = - fact_manager.GetIdsForWhichSynonymsAreKnown(context.get()); + transformation_context.GetFactManager() + ->GetIdsForWhichSynonymsAreKnown(); ASSERT_EQ(4, ids_for_which_synonyms_are_known.size()); ASSERT_TRUE(std::find(ids_for_which_synonyms_are_known.begin(), ids_for_which_synonyms_are_known.end(), 8) != ids_for_which_synonyms_are_known.end()); - ASSERT_EQ(2, fact_manager.GetSynonymsForId(8, context.get()).size()); + ASSERT_EQ( + 2, transformation_context.GetFactManager()->GetSynonymsForId(8).size()); protobufs::DataDescriptor descriptor_101 = MakeDataDescriptor(101, {}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(8, {}), - descriptor_101, context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(8, {}), descriptor_101)); } { TransformationCopyObject copy_false_again( 101, MakeInstructionDescriptor(5, SpvOpReturn, 0), 102); - ASSERT_TRUE(copy_false_again.IsApplicable(context.get(), fact_manager)); - copy_false_again.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + copy_false_again.IsApplicable(context.get(), transformation_context)); + copy_false_again.Apply(context.get(), &transformation_context); std::vector<uint32_t> ids_for_which_synonyms_are_known = - fact_manager.GetIdsForWhichSynonymsAreKnown(context.get()); + transformation_context.GetFactManager() + ->GetIdsForWhichSynonymsAreKnown(); ASSERT_EQ(5, ids_for_which_synonyms_are_known.size()); ASSERT_TRUE(std::find(ids_for_which_synonyms_are_known.begin(), ids_for_which_synonyms_are_known.end(), 101) != ids_for_which_synonyms_are_known.end()); - ASSERT_EQ(3, fact_manager.GetSynonymsForId(101, context.get()).size()); + ASSERT_EQ( + 3, + transformation_context.GetFactManager()->GetSynonymsForId(101).size()); protobufs::DataDescriptor descriptor_102 = MakeDataDescriptor(102, {}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(101, {}), - descriptor_102, context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(101, {}), descriptor_102)); } { TransformationCopyObject copy_true_again( 7, MakeInstructionDescriptor(102, SpvOpReturn, 0), 103); - ASSERT_TRUE(copy_true_again.IsApplicable(context.get(), fact_manager)); - copy_true_again.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + copy_true_again.IsApplicable(context.get(), transformation_context)); + copy_true_again.Apply(context.get(), &transformation_context); std::vector<uint32_t> ids_for_which_synonyms_are_known = - fact_manager.GetIdsForWhichSynonymsAreKnown(context.get()); + transformation_context.GetFactManager() + ->GetIdsForWhichSynonymsAreKnown(); ASSERT_EQ(6, ids_for_which_synonyms_are_known.size()); ASSERT_TRUE(std::find(ids_for_which_synonyms_are_known.begin(), ids_for_which_synonyms_are_known.end(), 7) != ids_for_which_synonyms_are_known.end()); - ASSERT_EQ(3, fact_manager.GetSynonymsForId(7, context.get()).size()); + ASSERT_EQ( + 3, transformation_context.GetFactManager()->GetSynonymsForId(7).size()); protobufs::DataDescriptor descriptor_103 = MakeDataDescriptor(103, {}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(7, {}), - descriptor_103, context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(7, {}), descriptor_103)); } std::string after_transformation = R"( @@ -340,116 +355,119 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Inapplicable because %18 is decorated. ASSERT_FALSE(TransformationCopyObject( 18, MakeInstructionDescriptor(21, SpvOpAccessChain, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because %77 is decorated. ASSERT_FALSE(TransformationCopyObject( 77, MakeInstructionDescriptor(77, SpvOpBranch, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because %80 is decorated. ASSERT_FALSE(TransformationCopyObject( 80, MakeInstructionDescriptor(77, SpvOpIAdd, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because %84 is not available at the requested point ASSERT_FALSE( TransformationCopyObject( 84, MakeInstructionDescriptor(32, SpvOpCompositeExtract, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Fine because %84 is available at the requested point ASSERT_TRUE( TransformationCopyObject( 84, MakeInstructionDescriptor(32, SpvOpCompositeConstruct, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because id %9 is already in use ASSERT_FALSE( TransformationCopyObject( 84, MakeInstructionDescriptor(32, SpvOpCompositeConstruct, 0), 9) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the requested point does not exist ASSERT_FALSE(TransformationCopyObject( 84, MakeInstructionDescriptor(86, SpvOpReturn, 2), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because %9 is not in a function ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(9, SpvOpTypeInt, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the insert point is right before, or inside, a chunk // of OpPhis ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(30, SpvOpPhi, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(99, SpvOpPhi, 1), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK, because the insert point is just after a chunk of OpPhis. ASSERT_TRUE(TransformationCopyObject( 9, MakeInstructionDescriptor(96, SpvOpAccessChain, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the insert point is right after an OpSelectionMerge ASSERT_FALSE( TransformationCopyObject( 9, MakeInstructionDescriptor(58, SpvOpBranchConditional, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK, because the insert point is right before the OpSelectionMerge ASSERT_TRUE(TransformationCopyObject( 9, MakeInstructionDescriptor(58, SpvOpSelectionMerge, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the insert point is right after an OpSelectionMerge ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(43, SpvOpSwitch, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK, because the insert point is right before the OpSelectionMerge ASSERT_TRUE(TransformationCopyObject( 9, MakeInstructionDescriptor(43, SpvOpSelectionMerge, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the insert point is right after an OpLoopMerge ASSERT_FALSE( TransformationCopyObject( 9, MakeInstructionDescriptor(40, SpvOpBranchConditional, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK, because the insert point is right before the OpLoopMerge ASSERT_TRUE(TransformationCopyObject( 9, MakeInstructionDescriptor(40, SpvOpLoopMerge, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because id %300 does not exist ASSERT_FALSE(TransformationCopyObject( 300, MakeInstructionDescriptor(40, SpvOpLoopMerge, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Inapplicable because the following instruction is OpVariable ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(180, SpvOpVariable, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(181, SpvOpVariable, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(182, SpvOpVariable, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK, because this is just past the group of OpVariable instructions. ASSERT_TRUE(TransformationCopyObject( 9, MakeInstructionDescriptor(182, SpvOpAccessChain, 0), 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationCopyObjectTest, MiscellaneousCopies) { @@ -515,6 +533,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); std::vector<TransformationCopyObject> transformations = { TransformationCopyObject(19, MakeInstructionDescriptor(22, SpvOpStore, 0), @@ -533,8 +554,9 @@ 17, MakeInstructionDescriptor(22, SpvOpCopyObject, 0), 106)}; for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); } ASSERT_TRUE(IsValid(env, context.get())); @@ -614,16 +636,19 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Illegal to copy null. ASSERT_FALSE(TransformationCopyObject( 8, MakeInstructionDescriptor(5, SpvOpReturn, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Illegal to copy an OpUndef of pointer type. ASSERT_FALSE(TransformationCopyObject( 9, MakeInstructionDescriptor(5, SpvOpReturn, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationCopyObjectTest, PropagateIrrelevantPointeeFact) { @@ -655,7 +680,11 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFactValueOfPointeeIsIrrelevant(8); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant(8); TransformationCopyObject transformation1( 8, MakeInstructionDescriptor(9, SpvOpReturn, 0), 100); @@ -664,18 +693,84 @@ TransformationCopyObject transformation3( 100, MakeInstructionDescriptor(9, SpvOpReturn, 0), 102); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(8)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(100)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(102)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(9)); - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(101)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(8)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(102)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(9)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(101)); +} + +TEST(TransformationCopyObject, DoNotCopyOpSampledImage) { + // This checks that we do not try to copy the result id of an OpSampledImage + // instruction. + std::string shader = R"( + OpCapability Shader + OpCapability SampledBuffer + OpCapability ImageBuffer + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %40 %41 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 450 + OpDecorate %40 DescriptorSet 0 + OpDecorate %40 Binding 69 + OpDecorate %41 DescriptorSet 0 + OpDecorate %41 Binding 1 + %54 = OpTypeFloat 32 + %76 = OpTypeVector %54 4 + %55 = OpConstant %54 0 + %56 = OpTypeVector %54 3 + %94 = OpTypeVector %54 2 + %112 = OpConstantComposite %94 %55 %55 + %57 = OpConstantComposite %56 %55 %55 %55 + %15 = OpTypeImage %54 2D 2 0 0 1 Unknown + %114 = OpTypePointer UniformConstant %15 + %38 = OpTypeSampler + %125 = OpTypePointer UniformConstant %38 + %132 = OpTypeVoid + %133 = OpTypeFunction %132 + %45 = OpTypeSampledImage %15 + %40 = OpVariable %114 UniformConstant + %41 = OpVariable %125 UniformConstant + %2 = OpFunction %132 None %133 + %164 = OpLabel + %184 = OpLoad %15 %40 + %213 = OpLoad %38 %41 + %216 = OpSampledImage %45 %184 %213 + %217 = OpImageSampleImplicitLod %76 %216 %112 Bias %55 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + ASSERT_FALSE( + TransformationCopyObject( + 216, MakeInstructionDescriptor(217, SpvOpImageSampleImplicitLod, 0), + 500) + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_equation_instruction_test.cpp b/test/fuzz/transformation_equation_instruction_test.cpp index 81d849b..1e8aa7e 100644 --- a/test/fuzz/transformation_equation_instruction_test.cpp +++ b/test/fuzz/transformation_equation_instruction_test.cpp
@@ -48,6 +48,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); protobufs::InstructionDescriptor return_instruction = MakeInstructionDescriptor(13, SpvOpReturn, 0); @@ -55,59 +58,61 @@ // Bad: id already in use. ASSERT_FALSE(TransformationEquationInstruction(7, SpvOpSNegate, {7}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: identified instruction does not exist. ASSERT_FALSE( TransformationEquationInstruction( 14, SpvOpSNegate, {7}, MakeInstructionDescriptor(13, SpvOpLoad, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: id 100 does not exist ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpSNegate, {100}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: id 20 is an OpUndef ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpSNegate, {20}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: id 30 is not available right before its definition ASSERT_FALSE(TransformationEquationInstruction( 14, SpvOpSNegate, {30}, MakeInstructionDescriptor(30, SpvOpCopyObject, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: too many arguments to OpSNegate. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpSNegate, {7, 7}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: 40 is a type id. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpSNegate, {40}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: wrong type of argument to OpSNegate. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpSNegate, {41}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationEquationInstruction( 14, SpvOpSNegate, {7}, return_instruction); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation2 = TransformationEquationInstruction( 15, SpvOpSNegate, {14}, return_instruction); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(15, {}), MakeDataDescriptor(7, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(15, {}), MakeDataDescriptor(7, {}))); std::string after_transformation = R"( OpCapability Shader @@ -161,6 +166,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); protobufs::InstructionDescriptor return_instruction = MakeInstructionDescriptor(13, SpvOpReturn, 0); @@ -168,32 +176,34 @@ // Bad: too few arguments to OpLogicalNot. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpLogicalNot, {}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: 6 is a type id. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpLogicalNot, {6}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: wrong type of argument to OpLogicalNot. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpLogicalNot, {21}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationEquationInstruction( 14, SpvOpLogicalNot, {7}, return_instruction); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation2 = TransformationEquationInstruction( 15, SpvOpLogicalNot, {14}, return_instruction); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(15, {}), MakeDataDescriptor(7, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(15, {}), MakeDataDescriptor(7, {}))); std::string after_transformation = R"( OpCapability Shader @@ -248,6 +258,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); protobufs::InstructionDescriptor return_instruction = MakeInstructionDescriptor(13, SpvOpReturn, 0); @@ -255,59 +268,64 @@ // Bad: too many arguments to OpIAdd. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpIAdd, {15, 16, 16}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: boolean argument to OpIAdd. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpIAdd, {15, 32}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: type as argument to OpIAdd. ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpIAdd, {33, 16}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: arguments of mismatched widths ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpIAdd, {15, 31}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: arguments of mismatched widths ASSERT_FALSE(TransformationEquationInstruction(14, SpvOpIAdd, {31, 15}, return_instruction) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationEquationInstruction( 14, SpvOpIAdd, {15, 16}, return_instruction); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation2 = TransformationEquationInstruction( 19, SpvOpISub, {14, 16}, return_instruction); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(15, {}), MakeDataDescriptor(19, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(15, {}), MakeDataDescriptor(19, {}))); auto transformation3 = TransformationEquationInstruction( 20, SpvOpISub, {14, 15}, return_instruction); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}))); auto transformation4 = TransformationEquationInstruction( 22, SpvOpISub, {16, 14}, return_instruction); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation5 = TransformationEquationInstruction( 24, SpvOpSNegate, {22}, return_instruction); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(24, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(24, {}), MakeDataDescriptor(15, {}))); std::string after_transformation = R"( OpCapability Shader @@ -364,69 +382,80 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); protobufs::InstructionDescriptor return_instruction = MakeInstructionDescriptor(13, SpvOpReturn, 0); auto transformation1 = TransformationEquationInstruction( 14, SpvOpISub, {15, 16}, return_instruction); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation2 = TransformationEquationInstruction( 17, SpvOpIAdd, {14, 16}, return_instruction); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(17, {}), MakeDataDescriptor(15, {}))); auto transformation3 = TransformationEquationInstruction( 18, SpvOpIAdd, {16, 14}, return_instruction); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(17, {}), MakeDataDescriptor(18, {}), context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(18, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(17, {}), MakeDataDescriptor(18, {}))); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(18, {}), MakeDataDescriptor(15, {}))); auto transformation4 = TransformationEquationInstruction( 19, SpvOpISub, {14, 15}, return_instruction); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation5 = TransformationEquationInstruction( 20, SpvOpSNegate, {19}, return_instruction); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(20, {}), MakeDataDescriptor(16, {}))); auto transformation6 = TransformationEquationInstruction( 21, SpvOpISub, {14, 19}, return_instruction); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(21, {}), MakeDataDescriptor(15, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(21, {}), MakeDataDescriptor(15, {}))); auto transformation7 = TransformationEquationInstruction( 22, SpvOpISub, {14, 18}, return_instruction); - ASSERT_TRUE(transformation7.IsApplicable(context.get(), fact_manager)); - transformation7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation7.IsApplicable(context.get(), transformation_context)); + transformation7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto transformation8 = TransformationEquationInstruction( 23, SpvOpSNegate, {22}, return_instruction); - ASSERT_TRUE(transformation8.IsApplicable(context.get(), fact_manager)); - transformation8.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation8.IsApplicable(context.get(), transformation_context)); + transformation8.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.IsSynonymous( - MakeDataDescriptor(23, {}), MakeDataDescriptor(16, {}), context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(23, {}), MakeDataDescriptor(16, {}))); std::string after_transformation = R"( OpCapability Shader @@ -457,6 +486,146 @@ ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationEquationInstructionTest, Miscellaneous1) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %12 "main" + OpExecutionMode %12 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %113 = OpConstant %6 24 + %12 = OpFunction %2 None %3 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + protobufs::InstructionDescriptor return_instruction = + MakeInstructionDescriptor(13, SpvOpReturn, 0); + + auto transformation1 = TransformationEquationInstruction( + 522, SpvOpISub, {113, 113}, return_instruction); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); + ASSERT_TRUE(IsValid(env, context.get())); + + auto transformation2 = TransformationEquationInstruction( + 570, SpvOpIAdd, {522, 113}, return_instruction); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %12 "main" + OpExecutionMode %12 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %113 = OpConstant %6 24 + %12 = OpFunction %2 None %3 + %13 = OpLabel + %522 = OpISub %6 %113 %113 + %570 = OpIAdd %6 %522 %113 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(570, {}), MakeDataDescriptor(113, {}))); + + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationEquationInstructionTest, Miscellaneous2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %12 "main" + OpExecutionMode %12 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %113 = OpConstant %6 24 + %12 = OpFunction %2 None %3 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + protobufs::InstructionDescriptor return_instruction = + MakeInstructionDescriptor(13, SpvOpReturn, 0); + + auto transformation1 = TransformationEquationInstruction( + 522, SpvOpISub, {113, 113}, return_instruction); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); + ASSERT_TRUE(IsValid(env, context.get())); + + auto transformation2 = TransformationEquationInstruction( + 570, SpvOpIAdd, {522, 113}, return_instruction); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %12 "main" + OpExecutionMode %12 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %113 = OpConstant %6 24 + %12 = OpFunction %2 None %3 + %13 = OpLabel + %522 = OpISub %6 %113 %113 + %570 = OpIAdd %6 %522 %113 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(570, {}), MakeDataDescriptor(113, {}))); + + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_function_call_test.cpp b/test/fuzz/transformation_function_call_test.cpp index 9bd971e..d7305f8 100644 --- a/test/fuzz/transformation_function_call_test.cpp +++ b/test/fuzz/transformation_function_call_test.cpp
@@ -134,24 +134,36 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFactBlockIsDead(59); - fact_manager.AddFactBlockIsDead(11); - fact_manager.AddFactBlockIsDead(18); - fact_manager.AddFactBlockIsDead(25); - fact_manager.AddFactBlockIsDead(96); - fact_manager.AddFactBlockIsDead(205); - fact_manager.AddFactFunctionIsLivesafe(21); - fact_manager.AddFactFunctionIsLivesafe(200); - fact_manager.AddFactValueOfPointeeIsIrrelevant(71); - fact_manager.AddFactValueOfPointeeIsIrrelevant(72); - fact_manager.AddFactValueOfPointeeIsIrrelevant(19); - fact_manager.AddFactValueOfPointeeIsIrrelevant(20); - fact_manager.AddFactValueOfPointeeIsIrrelevant(23); - fact_manager.AddFactValueOfPointeeIsIrrelevant(44); - fact_manager.AddFactValueOfPointeeIsIrrelevant(46); - fact_manager.AddFactValueOfPointeeIsIrrelevant(51); - fact_manager.AddFactValueOfPointeeIsIrrelevant(52); + transformation_context.GetFactManager()->AddFactBlockIsDead(59); + transformation_context.GetFactManager()->AddFactBlockIsDead(11); + transformation_context.GetFactManager()->AddFactBlockIsDead(18); + transformation_context.GetFactManager()->AddFactBlockIsDead(25); + transformation_context.GetFactManager()->AddFactBlockIsDead(96); + transformation_context.GetFactManager()->AddFactBlockIsDead(205); + transformation_context.GetFactManager()->AddFactFunctionIsLivesafe(21); + transformation_context.GetFactManager()->AddFactFunctionIsLivesafe(200); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 71); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 72); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 19); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 20); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 23); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 44); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 46); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 51); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 52); // Livesafe functions with argument types: 21(7, 13), 200(7, 13) // Non-livesafe functions with argument types: 4(), 10(7), 17(7, 13), 24(7) @@ -164,127 +176,133 @@ ASSERT_FALSE( TransformationFunctionCall(100, 21, {71, 72, 71}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too few arguments ASSERT_FALSE(TransformationFunctionCall( 100, 21, {71}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Arguments are the wrong way around (types do not match) ASSERT_FALSE( TransformationFunctionCall(100, 21, {72, 71}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // 21 is not an appropriate argument ASSERT_FALSE( TransformationFunctionCall(100, 21, {21, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // 300 does not exist ASSERT_FALSE( TransformationFunctionCall(100, 21, {300, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // 71 is not a function ASSERT_FALSE( TransformationFunctionCall(100, 71, {71, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // 500 does not exist ASSERT_FALSE( TransformationFunctionCall(100, 500, {71, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Id is not fresh ASSERT_FALSE( TransformationFunctionCall(21, 21, {71, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Access chain as pointer parameter ASSERT_FALSE( TransformationFunctionCall(100, 21, {98, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Copied object as pointer parameter ASSERT_FALSE( TransformationFunctionCall(100, 21, {99, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Non-livesafe called from original live block ASSERT_FALSE( TransformationFunctionCall( 100, 10, {71}, MakeInstructionDescriptor(99, SpvOpSelectionMerge, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Non-livesafe called from livesafe function ASSERT_FALSE( TransformationFunctionCall( 100, 10, {19}, MakeInstructionDescriptor(38, SpvOpConvertFToS, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Livesafe function called with pointer to non-arbitrary local variable ASSERT_FALSE( TransformationFunctionCall( 100, 21, {61, 72}, MakeInstructionDescriptor(38, SpvOpConvertFToS, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Direct recursion ASSERT_FALSE(TransformationFunctionCall( 100, 4, {}, MakeInstructionDescriptor(59, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Indirect recursion ASSERT_FALSE(TransformationFunctionCall( 100, 24, {9}, MakeInstructionDescriptor(96, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Parameter 23 is not available at the call site ASSERT_FALSE( TransformationFunctionCall(104, 10, {23}, MakeInstructionDescriptor(205, SpvOpBranch, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Good transformations { // Livesafe called from dead block: fine TransformationFunctionCall transformation( 100, 21, {71, 72}, MakeInstructionDescriptor(59, SpvOpBranch, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { // Livesafe called from original live block: fine TransformationFunctionCall transformation( 101, 21, {71, 72}, MakeInstructionDescriptor(98, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { // Livesafe called from livesafe function: fine TransformationFunctionCall transformation( 102, 200, {19, 20}, MakeInstructionDescriptor(36, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { // Dead called from dead block in injected function: fine TransformationFunctionCall transformation( 103, 10, {23}, MakeInstructionDescriptor(45, SpvOpLoad, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { // Non-livesafe called from dead block in livesafe function: OK TransformationFunctionCall transformation( 104, 10, {201}, MakeInstructionDescriptor(205, SpvOpBranch, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { // Livesafe called from dead block with non-arbitrary parameter TransformationFunctionCall transformation( 105, 21, {62, 65}, MakeInstructionDescriptor(59, SpvOpBranch, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -429,13 +447,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFactBlockIsDead(11); + transformation_context.GetFactManager()->AddFactBlockIsDead(11); // 4 is an entry point, so it is not legal for it to be the target of a call. ASSERT_FALSE(TransformationFunctionCall( 100, 4, {}, MakeInstructionDescriptor(11, SpvOpReturn, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_load_test.cpp b/test/fuzz/transformation_load_test.cpp index 1f728ff..18ca195 100644 --- a/test/fuzz/transformation_load_test.cpp +++ b/test/fuzz/transformation_load_test.cpp
@@ -85,14 +85,22 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFactValueOfPointeeIsIrrelevant(27); - fact_manager.AddFactValueOfPointeeIsIrrelevant(11); - fact_manager.AddFactValueOfPointeeIsIrrelevant(46); - fact_manager.AddFactValueOfPointeeIsIrrelevant(16); - fact_manager.AddFactValueOfPointeeIsIrrelevant(52); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 27); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 11); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 46); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 16); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 52); - fact_manager.AddFactBlockIsDead(36); + transformation_context.GetFactManager()->AddFactBlockIsDead(36); // Variables with pointee types: // 52 - ptr_to(7) @@ -125,86 +133,90 @@ // Bad: id is not fresh ASSERT_FALSE(TransformationLoad( 33, 33, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to load from 11 from outside its function ASSERT_FALSE(TransformationLoad( 100, 11, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer is not available ASSERT_FALSE(TransformationLoad( 100, 33, MakeInstructionDescriptor(45, SpvOpCopyObject, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to insert before OpVariable ASSERT_FALSE(TransformationLoad( 100, 27, MakeInstructionDescriptor(27, SpvOpVariable, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id does not exist ASSERT_FALSE( TransformationLoad(100, 1000, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id exists but does not have a type ASSERT_FALSE(TransformationLoad( 100, 5, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id exists and has a type, but is not a pointer ASSERT_FALSE(TransformationLoad( 100, 24, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to load from null pointer ASSERT_FALSE(TransformationLoad( 100, 60, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to load from undefined pointer ASSERT_FALSE(TransformationLoad( 100, 61, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: %40 is not available at the program point ASSERT_FALSE( TransformationLoad(100, 40, MakeInstructionDescriptor(37, SpvOpReturn, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: The described instruction does not exist ASSERT_FALSE(TransformationLoad( 100, 33, MakeInstructionDescriptor(1000, SpvOpReturn, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); { TransformationLoad transformation( 100, 33, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { TransformationLoad transformation( 101, 46, MakeInstructionDescriptor(16, SpvOpReturnValue, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { TransformationLoad transformation( 102, 16, MakeInstructionDescriptor(16, SpvOpReturnValue, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { TransformationLoad transformation( 103, 40, MakeInstructionDescriptor(43, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); }
diff --git a/test/fuzz/transformation_merge_blocks_test.cpp b/test/fuzz/transformation_merge_blocks_test.cpp index e2b4aa6..4500445 100644 --- a/test/fuzz/transformation_merge_blocks_test.cpp +++ b/test/fuzz/transformation_merge_blocks_test.cpp
@@ -45,11 +45,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - ASSERT_FALSE( - TransformationMergeBlocks(3).IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE( - TransformationMergeBlocks(7).IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationMergeBlocks(3).IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(TransformationMergeBlocks(7).IsApplicable( + context.get(), transformation_context)); } TEST(TransformationMergeBlocksTest, DoNotMergeFirstBlockHasMultipleSuccessors) { @@ -84,9 +87,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - ASSERT_FALSE( - TransformationMergeBlocks(6).IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationMergeBlocks(6).IsApplicable( + context.get(), transformation_context)); } TEST(TransformationMergeBlocksTest, @@ -122,9 +128,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - ASSERT_FALSE( - TransformationMergeBlocks(10).IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(TransformationMergeBlocks(10).IsApplicable( + context.get(), transformation_context)); } TEST(TransformationMergeBlocksTest, MergeWhenSecondBlockIsSelectionMerge) { @@ -161,10 +170,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationMergeBlocks transformation(10); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -231,10 +244,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationMergeBlocks transformation(10); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -306,10 +323,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationMergeBlocks transformation(11); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -377,10 +398,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationMergeBlocks transformation(6); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -454,12 +479,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); for (auto& transformation : {TransformationMergeBlocks(100), TransformationMergeBlocks(101), TransformationMergeBlocks(102), TransformationMergeBlocks(103)}) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -542,11 +571,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); for (auto& transformation : {TransformationMergeBlocks(101), TransformationMergeBlocks(100)}) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -629,10 +662,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationMergeBlocks transformation(101); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"(
diff --git a/test/fuzz/transformation_move_block_down_test.cpp b/test/fuzz/transformation_move_block_down_test.cpp index 02761a2..662e88c 100644 --- a/test/fuzz/transformation_move_block_down_test.cpp +++ b/test/fuzz/transformation_move_block_down_test.cpp
@@ -53,9 +53,13 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto transformation = TransformationMoveBlockDown(11); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationMoveBlockDownTest, NoMovePossible2) { @@ -90,9 +94,13 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto transformation = TransformationMoveBlockDown(5); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationMoveBlockDownTest, NoMovePossible3) { @@ -129,9 +137,13 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto transformation = TransformationMoveBlockDown(100); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationMoveBlockDownTest, NoMovePossible4) { @@ -172,9 +184,13 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto transformation = TransformationMoveBlockDown(12); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationMoveBlockDownTest, ManyMovesPossible) { @@ -277,6 +293,9 @@ BuildModule(env, consumer, before_transformation, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // The block ids are: 5 14 20 23 21 25 29 32 30 15 // We make a transformation to move each of them down, plus a transformation @@ -306,110 +325,130 @@ // 15 dominates nothing // Current ordering: 5 14 20 23 21 25 29 32 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); // Let's bubble 20 all the way down. - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 20 21 25 29 32 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 21 20 25 29 32 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 21 25 20 29 32 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 21 25 29 20 32 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 21 25 29 32 20 30 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 23 21 25 29 32 30 20 15 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_20.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_20.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_15.IsApplicable(context.get(), transformation_context)); - move_down_20.Apply(context.get(), &fact_manager); + move_down_20.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_bubbling_20_down = R"( @@ -485,63 +524,72 @@ ASSERT_TRUE(IsEqual(env, after_bubbling_20_down, context.get())); // Current ordering: 5 14 23 21 25 29 32 30 15 20 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_15.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_15.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_20.IsApplicable(context.get(), transformation_context)); - move_down_23.Apply(context.get(), &fact_manager); + move_down_23.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 21 23 25 29 32 30 15 20 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_15.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_15.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_20.IsApplicable(context.get(), transformation_context)); - move_down_23.Apply(context.get(), &fact_manager); + move_down_23.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 21 25 23 29 32 30 15 20 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_15.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_15.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_20.IsApplicable(context.get(), transformation_context)); - move_down_21.Apply(context.get(), &fact_manager); + move_down_21.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Current ordering: 5 14 25 21 23 29 32 30 15 20 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_15.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_15.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_20.IsApplicable(context.get(), transformation_context)); - move_down_14.Apply(context.get(), &fact_manager); + move_down_14.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_more_shuffling = R"( @@ -617,16 +665,18 @@ ASSERT_TRUE(IsEqual(env, after_more_shuffling, context.get())); // Final ordering: 5 25 14 21 23 29 32 30 15 20 - ASSERT_FALSE(move_down_5.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_25.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_14.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_21.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_23.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_29.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_32.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_30.IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(move_down_15.IsApplicable(context.get(), fact_manager)); - ASSERT_FALSE(move_down_20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(move_down_5.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_25.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_14.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_21.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_23.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_29.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_32.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_30.IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE(move_down_15.IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + move_down_20.IsApplicable(context.get(), transformation_context)); } TEST(TransformationMoveBlockDownTest, DoNotMoveUnreachable) { @@ -660,9 +710,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto transformation = TransformationMoveBlockDown(6); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp index 40aaebc..ed4fd15 100644 --- a/test/fuzz/transformation_outline_function_test.cpp +++ b/test/fuzz/transformation_outline_function_test.cpp
@@ -44,12 +44,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(5, 5, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -105,11 +109,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(5, 5, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, OutlineInterestingControlFlowNoState) { @@ -158,12 +166,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 13, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -243,12 +255,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 6, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -317,11 +333,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 6, 99, 100, 101, 102, 103, 105, {}, {{9, 104}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -412,12 +432,16 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( 6, 80, 100, 101, 102, 103, 104, 105, {}, {{15, 106}, {9, 107}, {7, 108}, {8, 109}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -508,11 +532,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 6, 100, 101, 102, 103, 104, 105, {{7, 106}}, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -582,11 +610,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 6, 100, 101, 102, 103, 104, 105, {{13, 106}}, {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -666,11 +698,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(11, 11, 100, 101, 102, 103, 104, 105, {{9, 106}}, {{14, 107}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -752,10 +788,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 8, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, DoNotOutlineIfRegionInvolvesReturn) { @@ -798,11 +838,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, DoNotOutlineIfRegionInvolvesKill) { @@ -845,11 +889,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -893,11 +941,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, 100, 101, 102, 103, /* not relevant */ 201, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -933,10 +985,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 8, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, DoNotOutlineIfLoopHeadIsOutsideRegion) { @@ -973,10 +1029,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(7, 8, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1012,10 +1072,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 7, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1053,10 +1117,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(6, 7, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1094,10 +1162,14 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation(8, 11, 100, 101, 102, 103, 104, 105, {}, {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, OutlineRegionEndingWithReturnVoid) { @@ -1132,6 +1204,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 54, @@ -1145,8 +1220,9 @@ /*input_id_to_fresh_id*/ {{22, 206}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1219,6 +1295,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 9, @@ -1232,8 +1311,9 @@ /*input_id_to_fresh_id*/ {{31, 206}}, /*output_id_to_fresh_id*/ {{32, 207}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1310,6 +1390,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 54, @@ -1323,8 +1406,9 @@ /*input_id_to_fresh_id*/ {{}}, /*output_id_to_fresh_id*/ {{6, 206}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1396,6 +1480,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 54, @@ -1409,8 +1496,9 @@ /*input_id_to_fresh_id*/ {}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1478,6 +1566,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 21, @@ -1491,7 +1582,8 @@ /*input_id_to_fresh_id*/ {{22, 207}}, /*output_id_to_fresh_id*/ {{23, 208}}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1531,6 +1623,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 21, @@ -1544,7 +1639,8 @@ /*input_id_to_fresh_id*/ {}, /*output_id_to_fresh_id*/ {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1584,6 +1680,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 5, @@ -1597,7 +1696,8 @@ /*input_id_to_fresh_id*/ {}, /*output_id_to_fresh_id*/ {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, DoNotOutlineRegionThatUsesAccessChain) { @@ -1640,6 +1740,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 13, @@ -1653,7 +1756,8 @@ /*input_id_to_fresh_id*/ {{12, 207}}, /*output_id_to_fresh_id*/ {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1698,6 +1802,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 13, @@ -1711,7 +1818,8 @@ /*input_id_to_fresh_id*/ {{20, 207}}, /*output_id_to_fresh_id*/ {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, @@ -1761,6 +1869,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 11, @@ -1774,8 +1885,9 @@ /*input_id_to_fresh_id*/ {{9, 207}}, /*output_id_to_fresh_id*/ {{14, 208}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -1913,9 +2025,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFactFunctionIsLivesafe(30); - fact_manager.AddFactValueOfPointeeIsIrrelevant(200); - fact_manager.AddFactValueOfPointeeIsIrrelevant(201); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFactFunctionIsLivesafe(30); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 200); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 201); TransformationOutlineFunction transformation( /*entry_block*/ 198, @@ -1929,24 +2047,31 @@ /*input_id_to_fresh_id*/ {{100, 407}, {200, 408}, {201, 409}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // The original function should still be livesafe. - ASSERT_TRUE(fact_manager.FunctionIsLivesafe(30)); + ASSERT_TRUE(transformation_context.GetFactManager()->FunctionIsLivesafe(30)); // The outlined function should be livesafe. - ASSERT_TRUE(fact_manager.FunctionIsLivesafe(402)); + ASSERT_TRUE(transformation_context.GetFactManager()->FunctionIsLivesafe(402)); // The variable and parameter that were originally irrelevant should still be. - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(200)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(201)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(200)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(201)); // The loop limiter should still be non-irrelevant. - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(100)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(100)); // The parameters for the original irrelevant variables should be irrelevant. - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(408)); - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(409)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(408)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(409)); // The parameter for the loop limiter should not be irrelevant. - ASSERT_FALSE(fact_manager.PointeeValueIsIrrelevant(407)); + ASSERT_FALSE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant(407)); std::string after_transformation = R"( OpCapability Shader @@ -2129,8 +2254,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + for (uint32_t block_id : {16u, 23u, 24u, 26u, 27u, 34u, 35u, 50u}) { - fact_manager.AddFactBlockIsDead(block_id); + transformation_context.GetFactManager()->AddFactBlockIsDead(block_id); } TransformationOutlineFunction transformation( @@ -2145,12 +2274,13 @@ /*input_id_to_fresh_id*/ {{9, 206}, {12, 207}, {21, 208}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // All the original blocks, plus the new function entry block, should be dead. for (uint32_t block_id : {16u, 23u, 24u, 26u, 27u, 34u, 35u, 50u, 203u}) { - ASSERT_TRUE(fact_manager.BlockIsDead(block_id)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(block_id)); } } @@ -2208,8 +2338,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + for (uint32_t block_id : {32u, 34u, 35u}) { - fact_manager.AddFactBlockIsDead(block_id); + transformation_context.GetFactManager()->AddFactBlockIsDead(block_id); } TransformationOutlineFunction transformation( @@ -2224,15 +2358,17 @@ /*input_id_to_fresh_id*/ {{11, 206}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // The blocks that were originally dead, but not others, should be dead. for (uint32_t block_id : {32u, 34u, 35u}) { - ASSERT_TRUE(fact_manager.BlockIsDead(block_id)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(block_id)); } for (uint32_t block_id : {5u, 30u, 31u, 33u, 36u, 37u, 203u}) { - ASSERT_FALSE(fact_manager.BlockIsDead(block_id)); + ASSERT_FALSE( + transformation_context.GetFactManager()->BlockIsDead(block_id)); } } @@ -2287,8 +2423,13 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - fact_manager.AddFactValueOfPointeeIsIrrelevant(9); - fact_manager.AddFactValueOfPointeeIsIrrelevant(14); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant(9); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 14); TransformationOutlineFunction transformation( /*entry_block*/ 50, @@ -2302,19 +2443,141 @@ /*input_id_to_fresh_id*/ {{9, 206}, {10, 207}, {14, 208}, {20, 209}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // The variables that were originally irrelevant, plus input parameters // corresponding to them, should be irrelevant. The rest should not be. for (uint32_t variable_id : {9u, 14u, 206u, 208u}) { - ASSERT_TRUE(fact_manager.PointeeValueIsIrrelevant(variable_id)); + ASSERT_TRUE( + transformation_context.GetFactManager()->PointeeValueIsIrrelevant( + variable_id)); } for (uint32_t variable_id : {10u, 20u, 207u, 209u}) { - ASSERT_FALSE(fact_manager.BlockIsDead(variable_id)); + ASSERT_FALSE( + transformation_context.GetFactManager()->BlockIsDead(variable_id)); } } +TEST(TransformationOutlineFunctionTest, + DoNotOutlineCodeThatProducesUsedPointer) { + // This checks that we cannot outline a region of code if it produces a + // pointer result id that gets used outside the region. This avoids creating + // a struct with a pointer member. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %6 "main" + OpExecutionMode %6 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %100 = OpTypeInt 32 0 + %99 = OpConstant %100 0 + %101 = OpTypeVector %100 2 + %102 = OpTypePointer Function %100 + %103 = OpTypePointer Function %101 + %6 = OpFunction %2 None %3 + %7 = OpLabel + %104 = OpVariable %103 Function + OpBranch %80 + %80 = OpLabel + %105 = OpAccessChain %102 %104 %99 + OpBranch %106 + %106 = OpLabel + OpStore %105 %99 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + TransformationOutlineFunction transformation( + /*entry_block*/ 80, + /*exit_block*/ 80, + /*new_function_struct_return_type_id*/ 300, + /*new_function_type_id*/ 301, + /*new_function_id*/ 302, + /*new_function_region_entry_block*/ 304, + /*new_caller_result_id*/ 305, + /*new_callee_result_id*/ 306, + /*input_id_to_fresh_id*/ {{104, 307}}, + /*output_id_to_fresh_id*/ {{105, 308}}); + + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); +} + +TEST(TransformationOutlineFunctionTest, ExitBlockHeadsLoop) { + // This checks that it is not possible outline a region that ends in a loop + // head. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %15 = OpTypeInt 32 1 + %35 = OpTypeBool + %39 = OpConstant %15 1 + %40 = OpConstantTrue %35 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %22 + %22 = OpLabel + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %15 %39 %22 %39 %25 + OpLoopMerge %26 %25 None + OpBranchConditional %40 %25 %26 + %25 = OpLabel + OpBranch %23 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + TransformationOutlineFunction transformation( + /*entry_block*/ 22, + /*exit_block*/ 23, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); +} + TEST(TransformationOutlineFunctionTest, Miscellaneous1) { // This tests outlining of some non-trivial code. @@ -2423,6 +2686,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 150, @@ -2436,8 +2702,9 @@ /*input_id_to_fresh_id*/ {{102, 300}, {103, 301}, {40, 302}}, /*output_id_to_fresh_id*/ {{106, 400}, {107, 401}}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -2588,6 +2855,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 38, @@ -2601,7 +2871,8 @@ /*input_id_to_fresh_id*/ {}, /*output_id_to_fresh_id*/ {}); - ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationOutlineFunctionTest, Miscellaneous3) { @@ -2643,6 +2914,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 80, @@ -2656,8 +2930,9 @@ /*input_id_to_fresh_id*/ {}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"( @@ -2732,6 +3007,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationOutlineFunction transformation( /*entry_block*/ 80, @@ -2745,8 +3023,9 @@ /*input_id_to_fresh_id*/ {{104, 307}}, /*output_id_to_fresh_id*/ {}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_transformation = R"(
diff --git a/test/fuzz/transformation_permute_function_parameters_test.cpp b/test/fuzz/transformation_permute_function_parameters_test.cpp index 1af4699..a4a7c00 100644 --- a/test/fuzz/transformation_permute_function_parameters_test.cpp +++ b/test/fuzz/transformation_permute_function_parameters_test.cpp
@@ -200,52 +200,57 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Can't permute main function ASSERT_FALSE(TransformationPermuteFunctionParameters(4, 0, {}).IsApplicable( - context.get(), fact_manager)); + context.get(), transformation_context)); // Can't permute invalid instruction ASSERT_FALSE(TransformationPermuteFunctionParameters(101, 0, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Permutation has too many values ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 0, {2, 1, 0, 3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Permutation has too few values ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 0, {0, 1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Permutation has invalid values ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 0, {3, 1, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Type id is not an OpTypeFunction instruction ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 42, {2, 1, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Type id has incorrect number of operands ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 9, {2, 1, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OpTypeFunction has operands out of order ASSERT_FALSE(TransformationPermuteFunctionParameters(22, 18, {2, 1, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Successful transformations { // Function has two operands of the same type: // initial OpTypeFunction should be enough TransformationPermuteFunctionParameters transformation(12, 9, {1, 0}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } { TransformationPermuteFunctionParameters transformation(28, 105, {1, 0}); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); }
diff --git a/test/fuzz/transformation_replace_boolean_constant_with_constant_binary_test.cpp b/test/fuzz/transformation_replace_boolean_constant_with_constant_binary_test.cpp index 527a7b7..b320308 100644 --- a/test/fuzz/transformation_replace_boolean_constant_with_constant_binary_test.cpp +++ b/test/fuzz/transformation_replace_boolean_constant_with_constant_binary_test.cpp
@@ -163,6 +163,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); std::vector<protobufs::IdUseDescriptor> uses_of_true = { MakeIdUseDescriptor(41, MakeInstructionDescriptor(44, SpvOpStore, 12), 1), @@ -197,10 +200,10 @@ #define CHECK_OPERATOR(USE_DESCRIPTOR, LHS_ID, RHS_ID, OPCODE, FRESH_ID) \ ASSERT_TRUE(TransformationReplaceBooleanConstantWithConstantBinary( \ USE_DESCRIPTOR, LHS_ID, RHS_ID, OPCODE, FRESH_ID) \ - .IsApplicable(context.get(), fact_manager)); \ + .IsApplicable(context.get(), transformation_context)); \ ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( \ USE_DESCRIPTOR, RHS_ID, LHS_ID, OPCODE, FRESH_ID) \ - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); #define CHECK_TRANSFORMATION_APPLICABILITY(GT_OPCODES, LT_OPCODES, SMALL_ID, \ LARGE_ID) \ @@ -252,27 +255,27 @@ // Target id is not fresh ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 15, 17, SpvOpFOrdLessThan, 15) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // LHS id does not exist ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 300, 17, SpvOpFOrdLessThan, 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // RHS id does not exist ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 15, 300, SpvOpFOrdLessThan, 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // LHS and RHS ids do not match type ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 11, 17, SpvOpFOrdLessThan, 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Opcode not appropriate ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 15, 17, SpvOpFDiv, 200) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto replace_true_with_double_comparison = TransformationReplaceBooleanConstantWithConstantBinary( @@ -287,21 +290,25 @@ TransformationReplaceBooleanConstantWithConstantBinary( uses_of_false[1], 33, 31, SpvOpSLessThan, 103); - ASSERT_TRUE(replace_true_with_double_comparison.IsApplicable(context.get(), - fact_manager)); - replace_true_with_double_comparison.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replace_true_with_double_comparison.IsApplicable( + context.get(), transformation_context)); + replace_true_with_double_comparison.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(replace_true_with_uint32_comparison.IsApplicable(context.get(), - fact_manager)); - replace_true_with_uint32_comparison.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replace_true_with_uint32_comparison.IsApplicable( + context.get(), transformation_context)); + replace_true_with_uint32_comparison.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(replace_false_with_float_comparison.IsApplicable(context.get(), - fact_manager)); - replace_false_with_float_comparison.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replace_false_with_float_comparison.IsApplicable( + context.get(), transformation_context)); + replace_false_with_float_comparison.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(replace_false_with_sint64_comparison.IsApplicable(context.get(), - fact_manager)); - replace_false_with_sint64_comparison.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replace_false_with_sint64_comparison.IsApplicable( + context.get(), transformation_context)); + replace_false_with_sint64_comparison.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after = R"( @@ -419,7 +426,7 @@ // The transformation is not applicable because %200 is NaN. ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 11, 200, SpvOpFOrdLessThan, 300) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } if (std::numeric_limits<double>::has_infinity) { double positive_infinity_double = std::numeric_limits<double>::infinity(); @@ -436,7 +443,7 @@ // transformation is restricted to only apply to finite values. ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 11, 201, SpvOpFOrdLessThan, 300) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } if (std::numeric_limits<float>::has_infinity) { float positive_infinity_float = std::numeric_limits<float>::infinity(); @@ -461,7 +468,7 @@ // values. ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( uses_of_true[0], 203, 202, SpvOpFOrdLessThan, 300) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } @@ -531,6 +538,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto use_of_true_in_if = MakeIdUseDescriptor( 13, MakeInstructionDescriptor(10, SpvOpBranchConditional, 0), 0); @@ -542,12 +552,14 @@ auto replacement_2 = TransformationReplaceBooleanConstantWithConstantBinary( use_of_false_in_while, 9, 11, SpvOpSGreaterThanEqual, 101); - ASSERT_TRUE(replacement_1.IsApplicable(context.get(), fact_manager)); - replacement_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_1.IsApplicable(context.get(), transformation_context)); + replacement_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(replacement_2.IsApplicable(context.get(), fact_manager)); - replacement_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement_2.IsApplicable(context.get(), transformation_context)); + replacement_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after = R"( @@ -642,12 +654,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto replacement = TransformationReplaceBooleanConstantWithConstantBinary( MakeIdUseDescriptor(9, MakeInstructionDescriptor(23, SpvOpPhi, 0), 0), 13, 15, SpvOpSLessThan, 100); - ASSERT_FALSE(replacement.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(replacement.IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest, @@ -681,12 +696,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary( MakeIdUseDescriptor( 9, MakeInstructionDescriptor(50, SpvOpVariable, 0), 1), 13, 15, SpvOpSLessThan, 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp index 58d4a89..8cbba46 100644 --- a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp +++ b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
@@ -22,7 +22,8 @@ namespace { bool AddFactHelper( - FactManager* fact_manager, opt::IRContext* context, uint32_t word, + TransformationContext* transformation_context, opt::IRContext* context, + uint32_t word, const protobufs::UniformBufferElementDescriptor& descriptor) { protobufs::FactConstantUniform constant_uniform_fact; constant_uniform_fact.add_constant_word(word); @@ -30,7 +31,7 @@ descriptor; protobufs::Fact fact; *fact.mutable_constant_uniform_fact() = constant_uniform_fact; - return fact_manager->AddFact(fact, context); + return transformation_context->GetFactManager()->AddFact(fact, context); } TEST(TransformationReplaceConstantWithUniformTest, BasicReplacements) { @@ -104,6 +105,10 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_a = MakeUniformBufferElementDescriptor(0, 0, {0}); protobufs::UniformBufferElementDescriptor blockname_b = @@ -111,9 +116,12 @@ protobufs::UniformBufferElementDescriptor blockname_c = MakeUniformBufferElementDescriptor(0, 0, {2}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 1, blockname_a)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 2, blockname_b)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 3, blockname_c)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 1, blockname_a)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 2, blockname_b)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 3, blockname_c)); // The constant ids are 9, 11 and 14, for 1, 2 and 3 respectively. protobufs::IdUseDescriptor use_of_9_in_store = @@ -127,30 +135,30 @@ auto transformation_use_of_9_in_store = TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_a, 100, 101); - ASSERT_TRUE(transformation_use_of_9_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_9_in_store.IsApplicable( + context.get(), transformation_context)); auto transformation_use_of_11_in_add = TransformationReplaceConstantWithUniform(use_of_11_in_add, blockname_b, 102, 103); - ASSERT_TRUE(transformation_use_of_11_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_11_in_add.IsApplicable( + context.get(), transformation_context)); auto transformation_use_of_14_in_add = TransformationReplaceConstantWithUniform(use_of_14_in_add, blockname_c, 104, 105); - ASSERT_TRUE(transformation_use_of_14_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_14_in_add.IsApplicable( + context.get(), transformation_context)); // The transformations are not applicable if we change which uniforms are // applied to which constants. ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_b, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_11_in_add, blockname_c, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_14_in_add, blockname_a, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The following transformations do not apply because the uniform descriptors // are not sensible. @@ -160,10 +168,10 @@ MakeUniformBufferElementDescriptor(0, 0, {5}); ASSERT_FALSE(TransformationReplaceConstantWithUniform( use_of_9_in_store, nonsense_uniform_descriptor1, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationReplaceConstantWithUniform( use_of_9_in_store, nonsense_uniform_descriptor2, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The following transformation does not apply because the id descriptor is // not sensible. @@ -171,18 +179,19 @@ MakeIdUseDescriptor(9, MakeInstructionDescriptor(15, SpvOpIAdd, 0), 0); ASSERT_FALSE(TransformationReplaceConstantWithUniform( nonsense_id_use_descriptor, blockname_a, 101, 102) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The following transformations do not apply because the ids are not fresh. ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_11_in_add, blockname_b, 15, 103) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_11_in_add, blockname_b, 102, 15) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Apply the use of 9 in a store. - transformation_use_of_9_in_store.Apply(context.get(), &fact_manager); + transformation_use_of_9_in_store.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_replacing_use_of_9_in_store = R"( OpCapability Shader @@ -233,10 +242,10 @@ )"; ASSERT_TRUE(IsEqual(env, after_replacing_use_of_9_in_store, context.get())); - ASSERT_TRUE(transformation_use_of_11_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_11_in_add.IsApplicable( + context.get(), transformation_context)); // Apply the use of 11 in an add. - transformation_use_of_11_in_add.Apply(context.get(), &fact_manager); + transformation_use_of_11_in_add.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_replacing_use_of_11_in_add = R"( OpCapability Shader @@ -289,10 +298,10 @@ )"; ASSERT_TRUE(IsEqual(env, after_replacing_use_of_11_in_add, context.get())); - ASSERT_TRUE(transformation_use_of_14_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_14_in_add.IsApplicable( + context.get(), transformation_context)); // Apply the use of 15 in an add. - transformation_use_of_14_in_add.Apply(context.get(), &fact_manager); + transformation_use_of_14_in_add.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_replacing_use_of_14_in_add = R"( OpCapability Shader @@ -462,6 +471,10 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_1 = MakeUniformBufferElementDescriptor(0, 0, {0}); protobufs::UniformBufferElementDescriptor blockname_2 = @@ -471,10 +484,14 @@ protobufs::UniformBufferElementDescriptor blockname_4 = MakeUniformBufferElementDescriptor(0, 0, {1, 0, 1, 0}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 1, blockname_1)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 2, blockname_2)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 3, blockname_3)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 4, blockname_4)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 1, blockname_1)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 2, blockname_2)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 3, blockname_3)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 4, blockname_4)); // The constant ids are 13, 15, 17 and 20, for 1, 2, 3 and 4 respectively. protobufs::IdUseDescriptor use_of_13_in_store = @@ -490,76 +507,78 @@ auto transformation_use_of_13_in_store = TransformationReplaceConstantWithUniform(use_of_13_in_store, blockname_1, 100, 101); - ASSERT_TRUE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); auto transformation_use_of_15_in_add = TransformationReplaceConstantWithUniform(use_of_15_in_add, blockname_2, 102, 103); - ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); auto transformation_use_of_17_in_add = TransformationReplaceConstantWithUniform(use_of_17_in_add, blockname_3, 104, 105); - ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); auto transformation_use_of_20_in_store = TransformationReplaceConstantWithUniform(use_of_20_in_store, blockname_4, 106, 107); - ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); - ASSERT_TRUE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_TRUE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); - transformation_use_of_13_in_store.Apply(context.get(), &fact_manager); + transformation_use_of_13_in_store.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); - transformation_use_of_15_in_add.Apply(context.get(), &fact_manager); + transformation_use_of_15_in_add.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); - transformation_use_of_17_in_add.Apply(context.get(), &fact_manager); + transformation_use_of_17_in_add.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_TRUE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); - transformation_use_of_20_in_store.Apply(context.get(), &fact_manager); + transformation_use_of_20_in_store.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_17_in_add.IsApplicable(context.get(), - fact_manager)); - ASSERT_FALSE(transformation_use_of_20_in_store.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(transformation_use_of_13_in_store.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_15_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_17_in_add.IsApplicable( + context.get(), transformation_context)); + ASSERT_FALSE(transformation_use_of_20_in_store.IsApplicable( + context.get(), transformation_context)); std::string after = R"( OpCapability Shader @@ -697,10 +716,15 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_0 = MakeUniformBufferElementDescriptor(0, 0, {0}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 0, blockname_0)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 0, blockname_0)); // The constant id is 9 for 0. protobufs::IdUseDescriptor use_of_9_in_store = @@ -710,7 +734,7 @@ // type is present: ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_0, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceConstantWithUniformTest, NoConstantPresentForIndex) { @@ -770,12 +794,17 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_0 = MakeUniformBufferElementDescriptor(0, 0, {0}); protobufs::UniformBufferElementDescriptor blockname_9 = MakeUniformBufferElementDescriptor(0, 0, {1}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 9, blockname_9)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 9, blockname_9)); // The constant id is 9 for 9. protobufs::IdUseDescriptor use_of_9_in_store = @@ -785,7 +814,7 @@ // index 1 required to index into the uniform buffer: ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_9, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceConstantWithUniformTest, @@ -842,14 +871,18 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_3 = MakeUniformBufferElementDescriptor(0, 0, {0}); uint32_t float_data[1]; float temp = 3.0; memcpy(&float_data[0], &temp, sizeof(float)); - ASSERT_TRUE( - AddFactHelper(&fact_manager, context.get(), float_data[0], blockname_3)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_data[0], blockname_3)); // The constant id is 9 for 3.0. protobufs::IdUseDescriptor use_of_9_in_store = @@ -859,7 +892,7 @@ // allow a constant index to be expressed: ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_3, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceConstantWithUniformTest, @@ -928,13 +961,19 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_9 = MakeUniformBufferElementDescriptor(0, 0, {0}); protobufs::UniformBufferElementDescriptor blockname_10 = MakeUniformBufferElementDescriptor(0, 0, {1}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 9, blockname_9)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 10, blockname_10)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 9, blockname_9)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 10, blockname_10)); // The constant ids for 9 and 10 are 9 and 11 respectively protobufs::IdUseDescriptor use_of_9_in_store = @@ -945,19 +984,19 @@ // These are right: ASSERT_TRUE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_9, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationReplaceConstantWithUniform(use_of_11_in_store, blockname_10, 102, 103) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // These are wrong because the constants do not match the facts about // uniforms. ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_11_in_store, blockname_9, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationReplaceConstantWithUniform(use_of_9_in_store, blockname_10, 102, 103) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceConstantWithUniformTest, ComplexReplacements) { @@ -1141,6 +1180,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); const float float_array_values[5] = {1.0, 1.5, 1.75, 1.875, 1.9375}; uint32_t float_array_data[5]; @@ -1188,35 +1230,43 @@ protobufs::UniformBufferElementDescriptor uniform_h_y = MakeUniformBufferElementDescriptor(0, 0, {2, 1}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_array_data[0], - uniform_f_a_0)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_array_data[1], - uniform_f_a_1)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_array_data[2], - uniform_f_a_2)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_array_data[3], - uniform_f_a_3)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_array_data[4], - uniform_f_a_4)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_array_data[0], uniform_f_a_0)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_array_data[1], uniform_f_a_1)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_array_data[2], uniform_f_a_2)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_array_data[3], uniform_f_a_3)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_array_data[4], uniform_f_a_4)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 1, uniform_f_b_x)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 2, uniform_f_b_y)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 3, uniform_f_b_z)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 4, uniform_f_b_w)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 1, uniform_f_b_x)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 2, uniform_f_b_y)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 3, uniform_f_b_z)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 4, uniform_f_b_w)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_vector_data[0], - uniform_f_c_x)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_vector_data[1], - uniform_f_c_y)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), float_vector_data[2], - uniform_f_c_z)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_vector_data[0], uniform_f_c_x)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_vector_data[1], uniform_f_c_y)); + ASSERT_TRUE(AddFactHelper(&transformation_context, context.get(), + float_vector_data[2], uniform_f_c_z)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 42, uniform_f_d)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 42, uniform_f_d)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 22, uniform_g)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 22, uniform_g)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 100, uniform_h_x)); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 200, uniform_h_y)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 100, uniform_h_x)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 200, uniform_h_y)); std::vector<TransformationReplaceConstantWithUniform> transformations; @@ -1275,8 +1325,9 @@ uniform_g, 218, 219)); for (auto& transformation : transformations) { - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -1480,16 +1531,21 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + protobufs::UniformBufferElementDescriptor blockname_a = MakeUniformBufferElementDescriptor(0, 0, {0}); - ASSERT_TRUE(AddFactHelper(&fact_manager, context.get(), 0, blockname_a)); + ASSERT_TRUE( + AddFactHelper(&transformation_context, context.get(), 0, blockname_a)); ASSERT_FALSE(TransformationReplaceConstantWithUniform( MakeIdUseDescriptor( 50, MakeInstructionDescriptor(8, SpvOpVariable, 0), 1), blockname_a, 100, 101) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_replace_id_with_synonym_test.cpp b/test/fuzz/transformation_replace_id_with_synonym_test.cpp index 41b6116..37e9510 100644 --- a/test/fuzz/transformation_replace_id_with_synonym_test.cpp +++ b/test/fuzz/transformation_replace_id_with_synonym_test.cpp
@@ -220,15 +220,19 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - SetUpIdSynonyms(&fact_manager, context.get()); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + SetUpIdSynonyms(transformation_context.GetFactManager(), context.get()); // %202 cannot replace %15 as in-operand 0 of %300, since %202 does not // dominate %300. auto synonym_does_not_dominate_use = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(15, MakeInstructionDescriptor(300, SpvOpIAdd, 0), 0), 202); - ASSERT_FALSE( - synonym_does_not_dominate_use.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(synonym_does_not_dominate_use.IsApplicable( + context.get(), transformation_context)); // %202 cannot replace %15 as in-operand 2 of %301, since this is the OpPhi's // incoming value for block %72, and %202 does not dominate %72. @@ -237,22 +241,23 @@ MakeIdUseDescriptor(15, MakeInstructionDescriptor(301, SpvOpPhi, 0), 2), 202); - ASSERT_FALSE(synonym_does_not_dominate_use_op_phi.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(synonym_does_not_dominate_use_op_phi.IsApplicable( + context.get(), transformation_context)); // %200 is not a synonym for %84 auto id_in_use_is_not_synonymous = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor( 84, MakeInstructionDescriptor(67, SpvOpSGreaterThan, 0), 0), 200); - ASSERT_FALSE( - id_in_use_is_not_synonymous.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(id_in_use_is_not_synonymous.IsApplicable( + context.get(), transformation_context)); // %86 is not a synonym for anything (and in particular not for %74) auto id_has_no_synonyms = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(86, MakeInstructionDescriptor(84, SpvOpPhi, 0), 2), 74); - ASSERT_FALSE(id_has_no_synonyms.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + id_has_no_synonyms.IsApplicable(context.get(), transformation_context)); // This would lead to %207 = 'OpCopyObject %type %207' if it were allowed auto synonym_use_is_in_synonym_definition = @@ -260,8 +265,8 @@ MakeIdUseDescriptor( 84, MakeInstructionDescriptor(207, SpvOpCopyObject, 0), 0), 207); - ASSERT_FALSE(synonym_use_is_in_synonym_definition.IsApplicable(context.get(), - fact_manager)); + ASSERT_FALSE(synonym_use_is_in_synonym_definition.IsApplicable( + context.get(), transformation_context)); // The id use descriptor does not lead to a use (%84 is not used in the // definition of %207) @@ -269,7 +274,8 @@ MakeIdUseDescriptor( 84, MakeInstructionDescriptor(200, SpvOpCopyObject, 0), 0), 207); - ASSERT_FALSE(bad_id_use_descriptor.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(bad_id_use_descriptor.IsApplicable(context.get(), + transformation_context)); // This replacement would lead to an access chain into a struct using a // non-constant index. @@ -277,7 +283,8 @@ MakeIdUseDescriptor( 12, MakeInstructionDescriptor(14, SpvOpAccessChain, 0), 1), 209); - ASSERT_FALSE(bad_access_chain.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + bad_access_chain.IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceIdWithSynonymTest, LegalTransformations) { @@ -288,23 +295,28 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; - SetUpIdSynonyms(&fact_manager, context.get()); + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + SetUpIdSynonyms(transformation_context.GetFactManager(), context.get()); auto global_constant_synonym = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(19, MakeInstructionDescriptor(47, SpvOpStore, 0), 1), 210); - ASSERT_TRUE( - global_constant_synonym.IsApplicable(context.get(), fact_manager)); - global_constant_synonym.Apply(context.get(), &fact_manager); + ASSERT_TRUE(global_constant_synonym.IsApplicable(context.get(), + transformation_context)); + global_constant_synonym.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto replace_vector_access_chain_index = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor( 54, MakeInstructionDescriptor(55, SpvOpAccessChain, 0), 1), 204); - ASSERT_TRUE(replace_vector_access_chain_index.IsApplicable(context.get(), - fact_manager)); - replace_vector_access_chain_index.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replace_vector_access_chain_index.IsApplicable( + context.get(), transformation_context)); + replace_vector_access_chain_index.Apply(context.get(), + &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // This is an interesting case because it replaces something that is being @@ -313,22 +325,24 @@ MakeIdUseDescriptor( 15, MakeInstructionDescriptor(202, SpvOpCopyObject, 0), 0), 201); - ASSERT_TRUE(regular_replacement.IsApplicable(context.get(), fact_manager)); - regular_replacement.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + regular_replacement.IsApplicable(context.get(), transformation_context)); + regular_replacement.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto regular_replacement2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(55, MakeInstructionDescriptor(203, SpvOpStore, 0), 0), 203); - ASSERT_TRUE(regular_replacement2.IsApplicable(context.get(), fact_manager)); - regular_replacement2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + regular_replacement2.IsApplicable(context.get(), transformation_context)); + regular_replacement2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); auto good_op_phi = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(74, MakeInstructionDescriptor(86, SpvOpPhi, 0), 2), 205); - ASSERT_TRUE(good_op_phi.IsApplicable(context.get(), fact_manager)); - good_op_phi.Apply(context.get(), &fact_manager); + ASSERT_TRUE(good_op_phi.IsApplicable(context.get(), transformation_context)); + good_op_phi.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -504,17 +518,22 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFact(MakeSynonymFact(10, 100), context.get()); - fact_manager.AddFact(MakeSynonymFact(8, 101), context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(10, 100), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(8, 101), + context.get()); // Replace %10 with %100 in: // %11 = OpLoad %6 %10 auto replacement1 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(10, MakeInstructionDescriptor(11, SpvOpLoad, 0), 0), 100); - ASSERT_TRUE(replacement1.IsApplicable(context.get(), fact_manager)); - replacement1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement1.IsApplicable(context.get(), transformation_context)); + replacement1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %8 with %101 in: @@ -522,8 +541,8 @@ auto replacement2 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(8, MakeInstructionDescriptor(11, SpvOpStore, 0), 0), 101); - ASSERT_TRUE(replacement2.IsApplicable(context.get(), fact_manager)); - replacement2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement2.IsApplicable(context.get(), transformation_context)); + replacement2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %8 with %101 in: @@ -531,8 +550,8 @@ auto replacement3 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(8, MakeInstructionDescriptor(12, SpvOpLoad, 0), 0), 101); - ASSERT_TRUE(replacement3.IsApplicable(context.get(), fact_manager)); - replacement3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement3.IsApplicable(context.get(), transformation_context)); + replacement3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replace %10 with %100 in: @@ -540,8 +559,8 @@ auto replacement4 = TransformationReplaceIdWithSynonym( MakeIdUseDescriptor(10, MakeInstructionDescriptor(12, SpvOpStore, 0), 0), 100); - ASSERT_TRUE(replacement4.IsApplicable(context.get(), fact_manager)); - replacement4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement4.IsApplicable(context.get(), transformation_context)); + replacement4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -633,8 +652,12 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFact(MakeSynonymFact(14, 100), context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(14, 100), + context.get()); // Replace %14 with %100 in: // %16 = OpFunctionCall %2 %10 %14 @@ -642,7 +665,7 @@ MakeIdUseDescriptor( 14, MakeInstructionDescriptor(16, SpvOpFunctionCall, 0), 1), 100); - ASSERT_FALSE(replacement.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE(replacement.IsApplicable(context.get(), transformation_context)); } TEST(TransformationReplaceIdWithSynonymTest, SynonymsOfAccessChainIndices) { @@ -795,22 +818,38 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Add synonym facts corresponding to the OpCopyObject operations that have // been applied to all constants in the module. - fact_manager.AddFact(MakeSynonymFact(16, 100), context.get()); - fact_manager.AddFact(MakeSynonymFact(21, 101), context.get()); - fact_manager.AddFact(MakeSynonymFact(17, 102), context.get()); - fact_manager.AddFact(MakeSynonymFact(57, 103), context.get()); - fact_manager.AddFact(MakeSynonymFact(18, 104), context.get()); - fact_manager.AddFact(MakeSynonymFact(40, 105), context.get()); - fact_manager.AddFact(MakeSynonymFact(32, 106), context.get()); - fact_manager.AddFact(MakeSynonymFact(43, 107), context.get()); - fact_manager.AddFact(MakeSynonymFact(55, 108), context.get()); - fact_manager.AddFact(MakeSynonymFact(8, 109), context.get()); - fact_manager.AddFact(MakeSynonymFact(47, 110), context.get()); - fact_manager.AddFact(MakeSynonymFact(28, 111), context.get()); - fact_manager.AddFact(MakeSynonymFact(45, 112), context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(16, 100), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(21, 101), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(17, 102), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(57, 103), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(18, 104), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(40, 105), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(32, 106), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(43, 107), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(55, 108), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(8, 109), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(47, 110), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(28, 111), + context.get()); + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(45, 112), + context.get()); // Replacements of the form %16 -> %100 @@ -821,7 +860,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(20, SpvOpAccessChain, 0), 1), 100); - ASSERT_FALSE(replacement1.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement1.IsApplicable(context.get(), transformation_context)); // %39 = OpAccessChain %23 %37 *%16* // Corresponds to h.*f* @@ -830,7 +870,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(39, SpvOpAccessChain, 0), 1), 100); - ASSERT_FALSE(replacement2.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement2.IsApplicable(context.get(), transformation_context)); // %41 = OpAccessChain %19 %37 %21 *%16* %21 // Corresponds to h.g.*a*[1] @@ -839,7 +880,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(41, SpvOpAccessChain, 0), 2), 100); - ASSERT_FALSE(replacement3.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement3.IsApplicable(context.get(), transformation_context)); // %52 = OpAccessChain %23 %50 *%16* %16 // Corresponds to i[*0*].f @@ -848,8 +890,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(52, SpvOpAccessChain, 0), 1), 100); - ASSERT_TRUE(replacement4.IsApplicable(context.get(), fact_manager)); - replacement4.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement4.IsApplicable(context.get(), transformation_context)); + replacement4.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // %52 = OpAccessChain %23 %50 %16 *%16* @@ -859,7 +901,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(52, SpvOpAccessChain, 0), 2), 100); - ASSERT_FALSE(replacement5.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement5.IsApplicable(context.get(), transformation_context)); // %53 = OpAccessChain %19 %50 %21 %21 *%16* %16 // Corresponds to i[1].g.*a*[0] @@ -868,7 +911,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(53, SpvOpAccessChain, 0), 3), 100); - ASSERT_FALSE(replacement6.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement6.IsApplicable(context.get(), transformation_context)); // %53 = OpAccessChain %19 %50 %21 %21 %16 *%16* // Corresponds to i[1].g.a[*0*] @@ -877,8 +921,8 @@ MakeIdUseDescriptor( 16, MakeInstructionDescriptor(53, SpvOpAccessChain, 0), 4), 100); - ASSERT_TRUE(replacement7.IsApplicable(context.get(), fact_manager)); - replacement7.Apply(context.get(), &fact_manager); + ASSERT_TRUE(replacement7.IsApplicable(context.get(), transformation_context)); + replacement7.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replacements of the form %21 -> %101 @@ -890,7 +934,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(24, SpvOpAccessChain, 0), 1), 101); - ASSERT_FALSE(replacement8.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement8.IsApplicable(context.get(), transformation_context)); // %41 = OpAccessChain %19 %37 *%21* %16 %21 // Corresponds to h.*g*.a[1] @@ -899,7 +944,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(41, SpvOpAccessChain, 0), 1), 101); - ASSERT_FALSE(replacement9.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement9.IsApplicable(context.get(), transformation_context)); // %41 = OpAccessChain %19 %37 %21 %16 *%21* // Corresponds to h.g.a[*1*] @@ -908,8 +954,9 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(41, SpvOpAccessChain, 0), 3), 101); - ASSERT_TRUE(replacement10.IsApplicable(context.get(), fact_manager)); - replacement10.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement10.IsApplicable(context.get(), transformation_context)); + replacement10.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // %44 = OpAccessChain %23 %37 *%21* %21 %43 @@ -919,7 +966,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(44, SpvOpAccessChain, 0), 1), 101); - ASSERT_FALSE(replacement11.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement11.IsApplicable(context.get(), transformation_context)); // %44 = OpAccessChain %23 %37 %21 *%21* %43 // Corresponds to h.g.*b*[0] @@ -928,7 +976,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(44, SpvOpAccessChain, 0), 2), 101); - ASSERT_FALSE(replacement12.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement12.IsApplicable(context.get(), transformation_context)); // %46 = OpAccessChain %26 %37 *%21* %17 // Corresponds to h.*g*.c @@ -937,7 +986,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(46, SpvOpAccessChain, 0), 1), 101); - ASSERT_FALSE(replacement13.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement13.IsApplicable(context.get(), transformation_context)); // %53 = OpAccessChain %19 %50 *%21* %21 %16 %16 // Corresponds to i[*1*].g.a[0] @@ -946,8 +996,9 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(53, SpvOpAccessChain, 0), 1), 101); - ASSERT_TRUE(replacement14.IsApplicable(context.get(), fact_manager)); - replacement14.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement14.IsApplicable(context.get(), transformation_context)); + replacement14.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // %53 = OpAccessChain %19 %50 %21 *%21* %16 %16 @@ -957,7 +1008,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(53, SpvOpAccessChain, 0), 2), 101); - ASSERT_FALSE(replacement15.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement15.IsApplicable(context.get(), transformation_context)); // %56 = OpAccessChain %23 %50 %17 *%21* %21 %55 // Corresponds to i[2].*g*.b[1] @@ -966,7 +1018,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(56, SpvOpAccessChain, 0), 2), 101); - ASSERT_FALSE(replacement16.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement16.IsApplicable(context.get(), transformation_context)); // %56 = OpAccessChain %23 %50 %17 %21 *%21* %55 // Corresponds to i[2].g.*b*[1] @@ -975,7 +1028,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(56, SpvOpAccessChain, 0), 3), 101); - ASSERT_FALSE(replacement17.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement17.IsApplicable(context.get(), transformation_context)); // %58 = OpAccessChain %26 %50 %57 *%21* %17 // Corresponds to i[3].*g*.c @@ -984,7 +1038,8 @@ MakeIdUseDescriptor( 21, MakeInstructionDescriptor(58, SpvOpAccessChain, 0), 2), 101); - ASSERT_FALSE(replacement18.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement18.IsApplicable(context.get(), transformation_context)); // Replacements of the form %17 -> %102 @@ -995,8 +1050,9 @@ MakeIdUseDescriptor( 17, MakeInstructionDescriptor(20, SpvOpAccessChain, 0), 2), 102); - ASSERT_TRUE(replacement19.IsApplicable(context.get(), fact_manager)); - replacement19.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement19.IsApplicable(context.get(), transformation_context)); + replacement19.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // %27 = OpAccessChain %26 %15 %17 @@ -1006,7 +1062,8 @@ MakeIdUseDescriptor( 17, MakeInstructionDescriptor(27, SpvOpAccessChain, 0), 1), 102); - ASSERT_FALSE(replacement20.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement20.IsApplicable(context.get(), transformation_context)); // %46 = OpAccessChain %26 %37 %21 %17 // Corresponds to h.g.*c* @@ -1015,7 +1072,8 @@ MakeIdUseDescriptor( 17, MakeInstructionDescriptor(46, SpvOpAccessChain, 0), 2), 102); - ASSERT_FALSE(replacement21.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement21.IsApplicable(context.get(), transformation_context)); // %56 = OpAccessChain %23 %50 %17 %21 %21 %55 // Corresponds to i[*2*].g.b[1] @@ -1024,8 +1082,9 @@ MakeIdUseDescriptor( 17, MakeInstructionDescriptor(56, SpvOpAccessChain, 0), 1), 102); - ASSERT_TRUE(replacement22.IsApplicable(context.get(), fact_manager)); - replacement22.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement22.IsApplicable(context.get(), transformation_context)); + replacement22.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // %58 = OpAccessChain %26 %50 %57 %21 %17 @@ -1035,7 +1094,8 @@ MakeIdUseDescriptor( 17, MakeInstructionDescriptor(58, SpvOpAccessChain, 0), 3), 102); - ASSERT_FALSE(replacement23.IsApplicable(context.get(), fact_manager)); + ASSERT_FALSE( + replacement23.IsApplicable(context.get(), transformation_context)); // Replacements of the form %57 -> %103 @@ -1046,8 +1106,9 @@ MakeIdUseDescriptor( 57, MakeInstructionDescriptor(58, SpvOpAccessChain, 0), 1), 103); - ASSERT_TRUE(replacement24.IsApplicable(context.get(), fact_manager)); - replacement24.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement24.IsApplicable(context.get(), transformation_context)); + replacement24.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replacements of the form %32 -> %106 @@ -1059,8 +1120,9 @@ MakeIdUseDescriptor( 32, MakeInstructionDescriptor(34, SpvOpAccessChain, 0), 1), 106); - ASSERT_TRUE(replacement25.IsApplicable(context.get(), fact_manager)); - replacement25.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement25.IsApplicable(context.get(), transformation_context)); + replacement25.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replacements of the form %43 -> %107 @@ -1072,8 +1134,9 @@ MakeIdUseDescriptor( 43, MakeInstructionDescriptor(44, SpvOpAccessChain, 0), 3), 107); - ASSERT_TRUE(replacement26.IsApplicable(context.get(), fact_manager)); - replacement26.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement26.IsApplicable(context.get(), transformation_context)); + replacement26.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replacements of the form %55 -> %108 @@ -1085,8 +1148,9 @@ MakeIdUseDescriptor( 55, MakeInstructionDescriptor(56, SpvOpAccessChain, 0), 4), 108); - ASSERT_TRUE(replacement27.IsApplicable(context.get(), fact_manager)); - replacement27.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement27.IsApplicable(context.get(), transformation_context)); + replacement27.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); // Replacements of the form %8 -> %109 @@ -1098,8 +1162,9 @@ MakeIdUseDescriptor(8, MakeInstructionDescriptor(24, SpvOpAccessChain, 0), 2), 109); - ASSERT_TRUE(replacement28.IsApplicable(context.get(), fact_manager)); - replacement28.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + replacement28.IsApplicable(context.get(), transformation_context)); + replacement28.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); const std::string after_transformation = R"( @@ -1212,6 +1277,179 @@ ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationReplaceIdWithSynonymTest, RuntimeArrayTest) { + // This checks that OpRuntimeArray is correctly handled. + const std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpDecorate %8 ArrayStride 8 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 BufferBlock + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeVector %6 2 + %8 = OpTypeRuntimeArray %7 + %9 = OpTypeStruct %8 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %12 = OpConstant %6 0 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %50 = OpCopyObject %6 %12 + %51 = OpCopyObject %13 %14 + %16 = OpAccessChain %15 %11 %12 %12 %14 + OpStore %16 %12 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + // Add synonym fact relating %50 and %12. + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(50, 12), + context.get()); + // Add synonym fact relating %51 and %14. + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(51, 14), + context.get()); + + // Not legal because the index being replaced is a struct index. + ASSERT_FALSE( + TransformationReplaceIdWithSynonym( + MakeIdUseDescriptor( + 12, MakeInstructionDescriptor(16, SpvOpAccessChain, 0), 1), + 50) + .IsApplicable(context.get(), transformation_context)); + + // Fine to replace an index into a runtime array. + auto replacement1 = TransformationReplaceIdWithSynonym( + MakeIdUseDescriptor( + 12, MakeInstructionDescriptor(16, SpvOpAccessChain, 0), 2), + 50); + ASSERT_TRUE(replacement1.IsApplicable(context.get(), transformation_context)); + replacement1.Apply(context.get(), &transformation_context); + + // Fine to replace an index into a vector inside the runtime array. + auto replacement2 = TransformationReplaceIdWithSynonym( + MakeIdUseDescriptor( + 14, MakeInstructionDescriptor(16, SpvOpAccessChain, 0), 3), + 51); + ASSERT_TRUE(replacement2.IsApplicable(context.get(), transformation_context)); + replacement2.Apply(context.get(), &transformation_context); + + ASSERT_TRUE(IsValid(env, context.get())); + + const std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpDecorate %8 ArrayStride 8 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 BufferBlock + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeVector %6 2 + %8 = OpTypeRuntimeArray %7 + %9 = OpTypeStruct %8 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %12 = OpConstant %6 0 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %50 = OpCopyObject %6 %12 + %51 = OpCopyObject %13 %14 + %16 = OpAccessChain %15 %11 %12 %50 %51 + OpStore %16 %12 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationReplaceIdWithSynonymTest, + DoNotReplaceSampleParameterOfOpImageTexelPointer) { + const std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 2 + %9 = OpConstant %6 0 + %10 = OpConstant %6 10 + %11 = OpTypeBool + %12 = OpConstant %6 1 + %13 = OpTypeFloat 32 + %14 = OpTypePointer Image %13 + %15 = OpTypeImage %13 2D 0 0 0 0 Rgba8 + %16 = OpTypePointer Private %15 + %3 = OpVariable %16 Private + %17 = OpTypeVector %6 2 + %18 = OpConstantComposite %17 %9 %9 + %2 = OpFunction %4 None %5 + %19 = OpLabel + %100 = OpCopyObject %6 %9 + %20 = OpImageTexelPointer %14 %3 %18 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + // Add synonym fact relating %100 and %9. + transformation_context.GetFactManager()->AddFact(MakeSynonymFact(100, 9), + context.get()); + + // Not legal the Sample argument of OpImageTexelPointer needs to be a zero + // constant. + ASSERT_FALSE( + TransformationReplaceIdWithSynonym( + MakeIdUseDescriptor( + 9, MakeInstructionDescriptor(20, SpvOpImageTexelPointer, 0), 2), + 100) + .IsApplicable(context.get(), transformation_context)); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_set_function_control_test.cpp b/test/fuzz/transformation_set_function_control_test.cpp index 536e965..be7f2be 100644 --- a/test/fuzz/transformation_set_function_control_test.cpp +++ b/test/fuzz/transformation_set_function_control_test.cpp
@@ -118,41 +118,48 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // %36 is not a function ASSERT_FALSE(TransformationSetFunctionControl(36, SpvFunctionControlMaskNone) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot add the Pure function control to %4 as it did not already have it ASSERT_FALSE(TransformationSetFunctionControl(4, SpvFunctionControlPureMask) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot add the Const function control to %21 as it did not already // have it ASSERT_FALSE(TransformationSetFunctionControl(21, SpvFunctionControlConstMask) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Set to None, removing Const TransformationSetFunctionControl transformation1(11, SpvFunctionControlMaskNone); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); // Set to Inline; silly to do it on an entry point, but it is allowed TransformationSetFunctionControl transformation2( 4, SpvFunctionControlInlineMask); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); // Set to Pure, removing DontInline TransformationSetFunctionControl transformation3(17, SpvFunctionControlPureMask); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); // Change from Inline to DontInline TransformationSetFunctionControl transformation4( 13, SpvFunctionControlDontInlineMask); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_set_loop_control_test.cpp b/test/fuzz/transformation_set_loop_control_test.cpp index 83953ec..531aa7a 100644 --- a/test/fuzz/transformation_set_loop_control_test.cpp +++ b/test/fuzz/transformation_set_loop_control_test.cpp
@@ -256,6 +256,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // These are the loop headers together with the selection controls of their // merge instructions: @@ -275,310 +278,310 @@ // 2 5 90 4 7 14 ASSERT_TRUE(TransformationSetLoopControl(10, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(10, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(10, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 10, SpvLoopControlDependencyInfiniteMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(10, SpvLoopControlDependencyLengthMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(10, SpvLoopControlMinIterationsMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(10, SpvLoopControlMaxIterationsMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 10, SpvLoopControlIterationMultipleMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(10, SpvLoopControlPeelCountMask, 3, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(10, SpvLoopControlPeelCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(10, SpvLoopControlPartialCountMask, 0, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(10, SpvLoopControlPartialCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl( 10, SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(10, SpvLoopControlUnrollMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl(10, SpvLoopControlDontUnrollMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(23, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(23, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(23, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl( 23, SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 3, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(23, SpvLoopControlMaxIterationsMask, 2, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(33, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(33, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(33, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(33, SpvLoopControlMinIterationsMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 33, SpvLoopControlUnrollMask | SpvLoopControlPeelCountMask, 5, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl(33, SpvLoopControlDontUnrollMask | SpvLoopControlPartialCountMask, 0, 10) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(43, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(43, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(43, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl( 43, SpvLoopControlMaskNone | SpvLoopControlDependencyInfiniteMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 43, SpvLoopControlUnrollMask | SpvLoopControlDependencyInfiniteMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 43, SpvLoopControlDontUnrollMask | SpvLoopControlDependencyInfiniteMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(43, SpvLoopControlDependencyInfiniteMask | SpvLoopControlDependencyLengthMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 43, SpvLoopControlUnrollMask | SpvLoopControlPeelCountMask, 5, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(53, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(53, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(53, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(53, SpvLoopControlMaxIterationsMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 53, SpvLoopControlMaskNone | SpvLoopControlDependencyLengthMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl( 53, SpvLoopControlUnrollMask | SpvLoopControlDependencyInfiniteMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 53, SpvLoopControlDontUnrollMask | SpvLoopControlDependencyLengthMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(53, SpvLoopControlDependencyInfiniteMask | SpvLoopControlDependencyLengthMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 53, SpvLoopControlUnrollMask | SpvLoopControlDependencyLengthMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 5, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(63, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(63, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(63, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(63, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 5, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(63, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask, 23, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 63, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask, 2, 23) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(73, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(73, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(73, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 73, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 5, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(73, SpvLoopControlUnrollMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPeelCountMask, 23, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 73, SpvLoopControlUnrollMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPeelCountMask, 2, 23) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(83, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(83, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(83, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 83, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 5, 3) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(83, SpvLoopControlUnrollMask | SpvLoopControlIterationMultipleMask | SpvLoopControlPeelCountMask, 23, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(83, SpvLoopControlUnrollMask | SpvLoopControlIterationMultipleMask | SpvLoopControlPeelCountMask, 2, 23) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(93, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(93, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(93, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(93, SpvLoopControlPeelCountMask, 8, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl(93, SpvLoopControlPeelCountMask, 8, 8) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(93, SpvLoopControlPartialCountMask, 0, 8) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl( 93, SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 16, 8) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(103, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(103, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(103, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(103, SpvLoopControlPartialCountMask, 0, 60) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl(103, SpvLoopControlDontUnrollMask | SpvLoopControlPartialCountMask, 0, 60) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(113, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(113, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(113, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(113, SpvLoopControlPeelCountMask, 12, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl( 113, SpvLoopControlIterationMultipleMask | SpvLoopControlPeelCountMask, 12, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(123, SpvLoopControlMaskNone, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(123, SpvLoopControlUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl(123, SpvLoopControlDontUnrollMask, 0, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE( TransformationSetLoopControl( 123, @@ -586,72 +589,72 @@ SpvLoopControlIterationMultipleMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 7, 8) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_TRUE(TransformationSetLoopControl(123, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPartialCountMask, 0, 9) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSetLoopControl( 123, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPartialCountMask, 7, 9) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSetLoopControl( 123, SpvLoopControlDontUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPartialCountMask, 7, 9) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationSetLoopControl(10, SpvLoopControlUnrollMask | SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 3, 3) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(23, SpvLoopControlDontUnrollMask, 0, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(33, SpvLoopControlUnrollMask, 0, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl( 43, SpvLoopControlDontUnrollMask | SpvLoopControlDependencyInfiniteMask, 0, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(53, SpvLoopControlMaskNone, 0, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(63, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlPeelCountMask, 23, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(73, SpvLoopControlUnrollMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPeelCountMask, 23, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(83, SpvLoopControlDontUnrollMask, 0, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl( 93, SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 16, 8) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(103, SpvLoopControlPartialCountMask, 0, 60) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl(113, SpvLoopControlPeelCountMask, 12, 0) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); TransformationSetLoopControl( 123, SpvLoopControlUnrollMask | SpvLoopControlMinIterationsMask | SpvLoopControlMaxIterationsMask | SpvLoopControlPartialCountMask, 0, 9) - .Apply(context.get(), &fact_manager); + .Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader @@ -942,25 +945,28 @@ BuildModule(SPV_ENV_UNIVERSAL_1_5, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationSetLoopControl set_peel_and_partial( 10, SpvLoopControlPeelCountMask | SpvLoopControlPartialCountMask, 4, 4); // PeelCount and PartialCount were introduced in SPIRV 1.4, so are not valid // in the context of older versions. - ASSERT_FALSE( - set_peel_and_partial.IsApplicable(context_1_0.get(), fact_manager)); - ASSERT_FALSE( - set_peel_and_partial.IsApplicable(context_1_1.get(), fact_manager)); - ASSERT_FALSE( - set_peel_and_partial.IsApplicable(context_1_2.get(), fact_manager)); - ASSERT_FALSE( - set_peel_and_partial.IsApplicable(context_1_3.get(), fact_manager)); + ASSERT_FALSE(set_peel_and_partial.IsApplicable(context_1_0.get(), + transformation_context)); + ASSERT_FALSE(set_peel_and_partial.IsApplicable(context_1_1.get(), + transformation_context)); + ASSERT_FALSE(set_peel_and_partial.IsApplicable(context_1_2.get(), + transformation_context)); + ASSERT_FALSE(set_peel_and_partial.IsApplicable(context_1_3.get(), + transformation_context)); - ASSERT_TRUE( - set_peel_and_partial.IsApplicable(context_1_4.get(), fact_manager)); - ASSERT_TRUE( - set_peel_and_partial.IsApplicable(context_1_5.get(), fact_manager)); + ASSERT_TRUE(set_peel_and_partial.IsApplicable(context_1_4.get(), + transformation_context)); + ASSERT_TRUE(set_peel_and_partial.IsApplicable(context_1_5.get(), + transformation_context)); } } // namespace
diff --git a/test/fuzz/transformation_set_memory_operands_mask_test.cpp b/test/fuzz/transformation_set_memory_operands_mask_test.cpp index ad4dc25..c02d8d4 100644 --- a/test/fuzz/transformation_set_memory_operands_mask_test.cpp +++ b/test/fuzz/transformation_set_memory_operands_mask_test.cpp
@@ -92,37 +92,41 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Not OK: the instruction is not a memory access. ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpAccessChain, 0), SpvMemoryAccessMaskNone, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to remove Aligned ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(147, SpvOpLoad, 0), SpvMemoryAccessVolatileMask | SpvMemoryAccessNontemporalMask, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationSetMemoryOperandsMask transformation1( MakeInstructionDescriptor(147, SpvOpLoad, 0), SpvMemoryAccessAlignedMask | SpvMemoryAccessVolatileMask, 0); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); // Not OK to remove Aligned ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 0), SpvMemoryAccessMaskNone, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK: leaves the mask as is ASSERT_TRUE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 0), SpvMemoryAccessAlignedMask, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK: adds Nontemporal and Volatile TransformationSetMemoryOperandsMask transformation2( @@ -130,41 +134,45 @@ SpvMemoryAccessAlignedMask | SpvMemoryAccessNontemporalMask | SpvMemoryAccessVolatileMask, 0); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); // Not OK to remove Volatile ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 1), SpvMemoryAccessNontemporalMask, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Not OK to add Aligned ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 1), SpvMemoryAccessAlignedMask | SpvMemoryAccessVolatileMask, 0) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK: adds Nontemporal TransformationSetMemoryOperandsMask transformation3( MakeInstructionDescriptor(21, SpvOpCopyMemory, 1), SpvMemoryAccessNontemporalMask | SpvMemoryAccessVolatileMask, 0); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); // OK: adds Nontemporal and Volatile TransformationSetMemoryOperandsMask transformation4( MakeInstructionDescriptor(138, SpvOpCopyMemory, 0), SpvMemoryAccessNontemporalMask | SpvMemoryAccessVolatileMask, 0); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); // OK: removes Nontemporal, adds Volatile TransformationSetMemoryOperandsMask transformation5( MakeInstructionDescriptor(148, SpvOpStore, 0), SpvMemoryAccessVolatileMask, 0); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader @@ -306,6 +314,9 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); TransformationSetMemoryOperandsMask transformation1( MakeInstructionDescriptor(21, SpvOpCopyMemory, 0), @@ -314,9 +325,10 @@ ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 0), SpvMemoryAccessVolatileMask, 1) - .IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + .IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); TransformationSetMemoryOperandsMask transformation2( MakeInstructionDescriptor(21, SpvOpCopyMemory, 1), @@ -325,9 +337,10 @@ ASSERT_FALSE(TransformationSetMemoryOperandsMask( MakeInstructionDescriptor(21, SpvOpCopyMemory, 1), SpvMemoryAccessNontemporalMask, 0) - .IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + .IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); TransformationSetMemoryOperandsMask transformation3( MakeInstructionDescriptor(138, SpvOpCopyMemory, 0), @@ -337,27 +350,31 @@ MakeInstructionDescriptor(138, SpvOpCopyMemory, 0), SpvMemoryAccessAlignedMask | SpvMemoryAccessNontemporalMask, 0) - .IsApplicable(context.get(), fact_manager)); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + .IsApplicable(context.get(), transformation_context)); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); TransformationSetMemoryOperandsMask transformation4( MakeInstructionDescriptor(138, SpvOpCopyMemory, 1), SpvMemoryAccessVolatileMask, 1); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); TransformationSetMemoryOperandsMask transformation5( MakeInstructionDescriptor(147, SpvOpLoad, 0), SpvMemoryAccessVolatileMask | SpvMemoryAccessAlignedMask, 0); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); TransformationSetMemoryOperandsMask transformation6( MakeInstructionDescriptor(148, SpvOpStore, 0), SpvMemoryAccessMaskNone, 0); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_set_selection_control_test.cpp b/test/fuzz/transformation_set_selection_control_test.cpp index 9696417..9afb89d 100644 --- a/test/fuzz/transformation_set_selection_control_test.cpp +++ b/test/fuzz/transformation_set_selection_control_test.cpp
@@ -103,39 +103,46 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // %44 is not a block ASSERT_FALSE( TransformationSetSelectionControl(44, SpvSelectionControlFlattenMask) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %13 does not end with OpSelectionMerge ASSERT_FALSE( TransformationSetSelectionControl(13, SpvSelectionControlMaskNone) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // %10 ends in OpLoopMerge, not OpSelectionMerge ASSERT_FALSE( TransformationSetSelectionControl(10, SpvSelectionControlMaskNone) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); TransformationSetSelectionControl transformation1( 11, SpvSelectionControlDontFlattenMask); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); TransformationSetSelectionControl transformation2( 23, SpvSelectionControlFlattenMask); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); TransformationSetSelectionControl transformation3( 31, SpvSelectionControlMaskNone); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); TransformationSetSelectionControl transformation4( 31, SpvSelectionControlFlattenMask); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); std::string after_transformation = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_split_block_test.cpp b/test/fuzz/transformation_split_block_test.cpp index 09007a5..30bac02 100644 --- a/test/fuzz/transformation_split_block_test.cpp +++ b/test/fuzz/transformation_split_block_test.cpp
@@ -89,57 +89,60 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // No split before OpVariable ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(8, SpvOpVariable, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(8, SpvOpVariable, 1), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split before OpLabel ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(14, SpvOpLabel, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split if base instruction is outside a function ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(1, SpvOpLabel, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(1, SpvOpExecutionMode, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split if block is loop header ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(27, SpvOpPhi, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(27, SpvOpPhi, 1), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split if base instruction does not exist ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(88, SpvOpIAdd, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(88, SpvOpIMul, 22), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split if too many instructions with the desired opcode are skipped ASSERT_FALSE( TransformationSplitBlock( MakeInstructionDescriptor(18, SpvOpBranchConditional, 1), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // No split if id in use ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(18, SpvOpSLessThan, 0), 27) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(18, SpvOpSLessThan, 0), 14) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationSplitBlockTest, SplitBlockSeveralTimes) { @@ -199,11 +202,14 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto split_1 = TransformationSplitBlock( MakeInstructionDescriptor(5, SpvOpStore, 0), 100); - ASSERT_TRUE(split_1.IsApplicable(context.get(), fact_manager)); - split_1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split_1.IsApplicable(context.get(), transformation_context)); + split_1.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split_1 = R"( @@ -250,8 +256,8 @@ auto split_2 = TransformationSplitBlock( MakeInstructionDescriptor(11, SpvOpStore, 0), 101); - ASSERT_TRUE(split_2.IsApplicable(context.get(), fact_manager)); - split_2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split_2.IsApplicable(context.get(), transformation_context)); + split_2.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split_2 = R"( @@ -300,8 +306,8 @@ auto split_3 = TransformationSplitBlock( MakeInstructionDescriptor(14, SpvOpLoad, 0), 102); - ASSERT_TRUE(split_3.IsApplicable(context.get(), fact_manager)); - split_3.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split_3.IsApplicable(context.get(), transformation_context)); + split_3.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split_3 = R"( @@ -412,21 +418,24 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Illegal to split between the merge and the conditional branch. ASSERT_FALSE( TransformationSplitBlock( MakeInstructionDescriptor(14, SpvOpBranchConditional, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSplitBlock( MakeInstructionDescriptor(12, SpvOpBranchConditional, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto split = TransformationSplitBlock( MakeInstructionDescriptor(14, SpvOpSelectionMerge, 0), 100); - ASSERT_TRUE(split.IsApplicable(context.get(), fact_manager)); - split.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split.IsApplicable(context.get(), transformation_context)); + split.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split = R"( @@ -541,19 +550,22 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Illegal to split between the merge and the conditional branch. ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(9, SpvOpSwitch, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE(TransformationSplitBlock( MakeInstructionDescriptor(15, SpvOpSwitch, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); auto split = TransformationSplitBlock( MakeInstructionDescriptor(9, SpvOpSelectionMerge, 0), 100); - ASSERT_TRUE(split.IsApplicable(context.get(), fact_manager)); - split.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split.IsApplicable(context.get(), transformation_context)); + split.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split = R"( @@ -674,18 +686,21 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // We cannot split before OpPhi instructions, since the number of incoming // blocks may not appropriately match after splitting. ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(26, SpvOpPhi, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(27, SpvOpPhi, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationSplitBlock(MakeInstructionDescriptor(27, SpvOpPhi, 1), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } TEST(TransformationSplitBlockTest, SplitOpPhiWithSinglePredecessor) { @@ -726,16 +741,19 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); ASSERT_TRUE( TransformationSplitBlock(MakeInstructionDescriptor(21, SpvOpPhi, 0), 100) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // An equivalent transformation to the above, just described with respect to a // different base instruction. auto split = TransformationSplitBlock(MakeInstructionDescriptor(20, SpvOpPhi, 0), 100); - ASSERT_TRUE(split.IsApplicable(context.get(), fact_manager)); - split.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split.IsApplicable(context.get(), transformation_context)); + split.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); std::string after_split = R"( @@ -805,18 +823,21 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Record the fact that block 8 is dead. - fact_manager.AddFactBlockIsDead(8); + transformation_context.GetFactManager()->AddFactBlockIsDead(8); auto split = TransformationSplitBlock( MakeInstructionDescriptor(8, SpvOpBranch, 0), 100); - ASSERT_TRUE(split.IsApplicable(context.get(), fact_manager)); - split.Apply(context.get(), &fact_manager); + ASSERT_TRUE(split.IsApplicable(context.get(), transformation_context)); + split.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); - ASSERT_TRUE(fact_manager.BlockIsDead(8)); - ASSERT_TRUE(fact_manager.BlockIsDead(100)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(8)); + ASSERT_TRUE(transformation_context.GetFactManager()->BlockIsDead(100)); std::string after_split = R"( OpCapability Shader @@ -845,6 +866,62 @@ ASSERT_TRUE(IsEqual(env, after_split, context.get())); } +TEST(TransformationSplitBlockTest, DoNotSplitUseOfOpSampledImage) { + // This checks that we cannot split the definition of an OpSampledImage + // from its use. + std::string shader = R"( + OpCapability Shader + OpCapability SampledBuffer + OpCapability ImageBuffer + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %40 %41 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 450 + OpDecorate %40 DescriptorSet 0 + OpDecorate %40 Binding 69 + OpDecorate %41 DescriptorSet 0 + OpDecorate %41 Binding 1 + %54 = OpTypeFloat 32 + %76 = OpTypeVector %54 4 + %55 = OpConstant %54 0 + %56 = OpTypeVector %54 3 + %94 = OpTypeVector %54 2 + %112 = OpConstantComposite %94 %55 %55 + %57 = OpConstantComposite %56 %55 %55 %55 + %15 = OpTypeImage %54 2D 2 0 0 1 Unknown + %114 = OpTypePointer UniformConstant %15 + %38 = OpTypeSampler + %125 = OpTypePointer UniformConstant %38 + %132 = OpTypeVoid + %133 = OpTypeFunction %132 + %45 = OpTypeSampledImage %15 + %40 = OpVariable %114 UniformConstant + %41 = OpVariable %125 UniformConstant + %2 = OpFunction %132 None %133 + %164 = OpLabel + %184 = OpLoad %15 %40 + %213 = OpLoad %38 %41 + %216 = OpSampledImage %45 %184 %213 + %217 = OpImageSampleImplicitLod %76 %216 %112 Bias %55 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + auto split = TransformationSplitBlock( + MakeInstructionDescriptor(217, SpvOpImageSampleImplicitLod, 0), 500); + ASSERT_FALSE(split.IsApplicable(context.get(), transformation_context)); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_store_test.cpp b/test/fuzz/transformation_store_test.cpp index 3fb9b61..07d222f 100644 --- a/test/fuzz/transformation_store_test.cpp +++ b/test/fuzz/transformation_store_test.cpp
@@ -94,16 +94,26 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFactValueOfPointeeIsIrrelevant(27); - fact_manager.AddFactValueOfPointeeIsIrrelevant(11); - fact_manager.AddFactValueOfPointeeIsIrrelevant(46); - fact_manager.AddFactValueOfPointeeIsIrrelevant(16); - fact_manager.AddFactValueOfPointeeIsIrrelevant(52); - fact_manager.AddFactValueOfPointeeIsIrrelevant(81); - fact_manager.AddFactValueOfPointeeIsIrrelevant(82); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 27); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 11); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 46); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 16); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 52); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 81); + transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant( + 82); - fact_manager.AddFactBlockIsDead(36); + transformation_context.GetFactManager()->AddFactBlockIsDead(36); // Variables with pointee types: // 52 - ptr_to(7) @@ -139,90 +149,91 @@ // Bad: attempt to store to 11 from outside its function ASSERT_FALSE(TransformationStore( 11, 80, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer is not available ASSERT_FALSE(TransformationStore( 81, 80, MakeInstructionDescriptor(45, SpvOpCopyObject, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to insert before OpVariable ASSERT_FALSE(TransformationStore( 52, 24, MakeInstructionDescriptor(27, SpvOpVariable, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id does not exist ASSERT_FALSE(TransformationStore( 1000, 24, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id exists but does not have a type ASSERT_FALSE(TransformationStore( 5, 24, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: pointer id exists and has a type, but is not a pointer ASSERT_FALSE(TransformationStore( 24, 24, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to store to a null pointer ASSERT_FALSE(TransformationStore( 60, 24, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to store to an undefined pointer ASSERT_FALSE(TransformationStore( 61, 21, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: %82 is not available at the program point ASSERT_FALSE( TransformationStore(82, 80, MakeInstructionDescriptor(37, SpvOpReturn, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: value id does not exist ASSERT_FALSE(TransformationStore( 27, 1000, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: value id exists but does not have a type ASSERT_FALSE(TransformationStore( 27, 15, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: value id exists but has the wrong type ASSERT_FALSE(TransformationStore( 27, 14, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: attempt to store to read-only variable ASSERT_FALSE(TransformationStore( 92, 93, MakeInstructionDescriptor(40, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: value is not available ASSERT_FALSE(TransformationStore( 27, 95, MakeInstructionDescriptor(40, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Bad: variable being stored to does not have an irrelevant pointee value, // and the store is not in a dead block. ASSERT_FALSE(TransformationStore( 20, 95, MakeInstructionDescriptor(45, SpvOpCopyObject, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The described instruction does not exist. ASSERT_FALSE(TransformationStore( 27, 80, MakeInstructionDescriptor(1000, SpvOpAccessChain, 0)) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); { // Store to irrelevant variable from dead block. TransformationStore transformation( 27, 80, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -230,8 +241,9 @@ // Store to irrelevant variable from live block. TransformationStore transformation( 11, 95, MakeInstructionDescriptor(95, SpvOpReturnValue, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -239,8 +251,9 @@ // Store to irrelevant variable from live block. TransformationStore transformation( 46, 80, MakeInstructionDescriptor(95, SpvOpReturnValue, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -248,8 +261,9 @@ // Store to irrelevant variable from live block. TransformationStore transformation( 16, 21, MakeInstructionDescriptor(95, SpvOpReturnValue, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -257,8 +271,9 @@ // Store to non-irrelevant variable from dead block. TransformationStore transformation( 53, 21, MakeInstructionDescriptor(38, SpvOpAccessChain, 0)); - ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); - transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + transformation.Apply(context.get(), &transformation_context); ASSERT_TRUE(IsValid(env, context.get())); } @@ -336,6 +351,70 @@ ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationStoreTest, DoNotAllowStoresToReadOnlyMemory) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + OpMemberDecorate %10 0 Offset 0 + OpMemberDecorate %10 1 Offset 4 + OpDecorate %10 Block + OpMemberDecorate %23 0 Offset 0 + OpDecorate %23 Block + OpDecorate %25 DescriptorSet 0 + OpDecorate %25 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeStruct %6 %9 + %11 = OpTypePointer PushConstant %10 + %12 = OpVariable %11 PushConstant + %13 = OpConstant %6 0 + %14 = OpTypePointer PushConstant %6 + %17 = OpConstant %6 1 + %18 = OpTypePointer PushConstant %9 + %23 = OpTypeStruct %9 + %24 = OpTypePointer UniformConstant %23 + %25 = OpVariable %24 UniformConstant + %26 = OpTypePointer UniformConstant %9 + %50 = OpConstant %9 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %15 = OpAccessChain %14 %12 %13 + %19 = OpAccessChain %18 %12 %17 + %27 = OpAccessChain %26 %25 %13 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + fact_manager.AddFactBlockIsDead(5); + + ASSERT_FALSE( + TransformationStore(15, 13, MakeInstructionDescriptor(27, SpvOpReturn, 0)) + .IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + TransformationStore(19, 50, MakeInstructionDescriptor(27, SpvOpReturn, 0)) + .IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE( + TransformationStore(27, 50, MakeInstructionDescriptor(27, SpvOpReturn, 0)) + .IsApplicable(context.get(), transformation_context)); +} + } // namespace } // namespace fuzz } // namespace spvtools
diff --git a/test/fuzz/transformation_swap_commutable_operands_test.cpp b/test/fuzz/transformation_swap_commutable_operands_test.cpp index f0591cf..c213dfe 100644 --- a/test/fuzz/transformation_swap_commutable_operands_test.cpp +++ b/test/fuzz/transformation_swap_commutable_operands_test.cpp
@@ -111,113 +111,140 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - FactManager factManager; + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Tests existing commutative instructions auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0); auto transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests existing non-commutative instructions instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(8, SpvOpConstant, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(11, SpvOpVariable, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(14, SpvOpConstantComposite, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests the base instruction id not existing instructionDescriptor = MakeInstructionDescriptor(67, SpvOpIAddCarry, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(68, SpvOpIEqual, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(69, SpvOpINotEqual, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(70, SpvOpFOrdEqual, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(71, SpvOpPtrEqual, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests there being no instruction with the desired opcode after the base // instruction id instructionDescriptor = MakeInstructionDescriptor(24, SpvOpIAdd, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(38, SpvOpIMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(45, SpvOpFAdd, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(66, SpvOpFMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests there being an instruction with the desired opcode after the base // instruction id, but the skip count associated with the instruction // descriptor being so high. instructionDescriptor = MakeInstructionDescriptor(11, SpvOpIAdd, 100); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(16, SpvOpIMul, 100); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(23, SpvOpFAdd, 100); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(32, SpvOpFMul, 100); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(37, SpvOpDot, 100); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationSwapCommutableOperandsTest, ApplyTest) { @@ -311,28 +338,31 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - FactManager factManager; + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0); auto transformation = TransformationSwapCommutableOperands(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0); transformation = TransformationSwapCommutableOperands(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); std::string variantShader = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_toggle_access_chain_instruction_test.cpp b/test/fuzz/transformation_toggle_access_chain_instruction_test.cpp index 98e0a64..b20f59e 100644 --- a/test/fuzz/transformation_toggle_access_chain_instruction_test.cpp +++ b/test/fuzz/transformation_toggle_access_chain_instruction_test.cpp
@@ -111,78 +111,93 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - FactManager factManager; + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Tests existing access chain instructions auto instructionDescriptor = MakeInstructionDescriptor(18, SpvOpAccessChain, 0); auto transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(20, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(24, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(26, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_TRUE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests existing non-access chain instructions instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(14, SpvOpConstantComposite, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests the base instruction id not existing instructionDescriptor = MakeInstructionDescriptor(67, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(68, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(69, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests there being no instruction with the desired opcode after the base // instruction id instructionDescriptor = MakeInstructionDescriptor(65, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(66, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); // Tests there being an instruction with the desired opcode after the base // instruction id, but the skip count associated with the instruction @@ -190,13 +205,15 @@ instructionDescriptor = MakeInstructionDescriptor(11, SpvOpAccessChain, 100); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); instructionDescriptor = MakeInstructionDescriptor(16, SpvOpInBoundsAccessChain, 100); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - ASSERT_FALSE(transformation.IsApplicable(context.get(), factManager)); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); } TEST(TransformationToggleAccessChainInstructionTest, ApplyTest) { @@ -290,35 +307,38 @@ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); ASSERT_TRUE(IsValid(env, context.get())); - FactManager factManager; + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); auto instructionDescriptor = MakeInstructionDescriptor(18, SpvOpAccessChain, 0); auto transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(20, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(24, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(26, SpvOpInBoundsAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); instructionDescriptor = MakeInstructionDescriptor(38, SpvOpAccessChain, 0); transformation = TransformationToggleAccessChainInstruction(instructionDescriptor); - transformation.Apply(context.get(), &factManager); + transformation.Apply(context.get(), &transformation_context); std::string variantShader = R"( OpCapability Shader
diff --git a/test/fuzz/transformation_vector_shuffle_test.cpp b/test/fuzz/transformation_vector_shuffle_test.cpp index 385c38b..a29c511 100644 --- a/test/fuzz/transformation_vector_shuffle_test.cpp +++ b/test/fuzz/transformation_vector_shuffle_test.cpp
@@ -86,249 +86,259 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), - MakeDataDescriptor(12, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(11, {}), - MakeDataDescriptor(12, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(10, {}), MakeDataDescriptor(12, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(11, {}), MakeDataDescriptor(12, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), - MakeDataDescriptor(16, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(11, {}), - MakeDataDescriptor(16, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), - MakeDataDescriptor(16, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(10, {}), MakeDataDescriptor(16, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(11, {}), MakeDataDescriptor(16, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(10, {}), MakeDataDescriptor(16, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), - MakeDataDescriptor(20, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(11, {}), - MakeDataDescriptor(20, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(10, {}), - MakeDataDescriptor(20, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(11, {}), - MakeDataDescriptor(20, {3}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(10, {}), MakeDataDescriptor(20, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(11, {}), MakeDataDescriptor(20, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(10, {}), MakeDataDescriptor(20, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(11, {}), MakeDataDescriptor(20, {3}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(25, {}), - MakeDataDescriptor(27, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(26, {}), - MakeDataDescriptor(27, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(25, {}), MakeDataDescriptor(27, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(26, {}), MakeDataDescriptor(27, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(25, {}), - MakeDataDescriptor(31, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(26, {}), - MakeDataDescriptor(31, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(25, {}), - MakeDataDescriptor(31, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(25, {}), MakeDataDescriptor(31, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(26, {}), MakeDataDescriptor(31, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(25, {}), MakeDataDescriptor(31, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(25, {}), - MakeDataDescriptor(35, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(26, {}), - MakeDataDescriptor(35, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(25, {}), - MakeDataDescriptor(35, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(26, {}), - MakeDataDescriptor(35, {3}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(25, {}), MakeDataDescriptor(35, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(26, {}), MakeDataDescriptor(35, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(25, {}), MakeDataDescriptor(35, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(26, {}), MakeDataDescriptor(35, {3}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {}), - MakeDataDescriptor(42, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(41, {}), - MakeDataDescriptor(42, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(40, {}), MakeDataDescriptor(42, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(41, {}), MakeDataDescriptor(42, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {}), - MakeDataDescriptor(46, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(41, {}), - MakeDataDescriptor(46, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {}), - MakeDataDescriptor(46, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(40, {}), MakeDataDescriptor(46, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(41, {}), MakeDataDescriptor(46, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(40, {}), MakeDataDescriptor(46, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {}), - MakeDataDescriptor(50, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(41, {}), - MakeDataDescriptor(50, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(40, {}), - MakeDataDescriptor(50, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(41, {}), - MakeDataDescriptor(50, {3}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(40, {}), MakeDataDescriptor(50, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(41, {}), MakeDataDescriptor(50, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(40, {}), MakeDataDescriptor(50, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(41, {}), MakeDataDescriptor(50, {3}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(55, {}), - MakeDataDescriptor(61, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(56, {}), - MakeDataDescriptor(61, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(55, {}), - MakeDataDescriptor(61, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(55, {}), MakeDataDescriptor(61, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(56, {}), MakeDataDescriptor(61, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(55, {}), MakeDataDescriptor(61, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(55, {}), - MakeDataDescriptor(65, {0}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(56, {}), - MakeDataDescriptor(65, {1}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(55, {}), - MakeDataDescriptor(65, {2}), context.get()); - fact_manager.AddFactDataSynonym(MakeDataDescriptor(56, {}), - MakeDataDescriptor(65, {3}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(55, {}), MakeDataDescriptor(65, {0}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(56, {}), MakeDataDescriptor(65, {1}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(55, {}), MakeDataDescriptor(65, {2}), context.get()); + transformation_context.GetFactManager()->AddFactDataSynonym( + MakeDataDescriptor(56, {}), MakeDataDescriptor(65, {3}), context.get()); // %103 does not dominate the return instruction. ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 103, 65, {3, 5, 7}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Illegal to shuffle a bvec2 and a vec3 ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 112, 61, {0, 2, 4}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Illegal to shuffle an ivec2 and a uvec4 ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 27, 50, {1, 3, 5}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Vector 1 does not exist ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 300, 50, {1, 3, 5}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Vector 2 does not exist ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 27, 300, {1, 3, 5}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Index out of range ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 12, 112, {0, 20}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too many indices ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 12, 112, {0, 1, 0, 1, 0, 1, 0, 1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too few indices ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 12, 112, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Too few indices again ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 12, 112, {0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Indices define unknown type: we do not have vec2 ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 65, 65, {0, 1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The instruction to insert before does not exist ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpCompositeConstruct, 1), 201, 20, 12, {0xFFFFFFFF, 3, 5}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // The 'fresh' id is already in use ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(100, SpvOpReturn, 0), 12, 12, 112, {}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); protobufs::DataDescriptor temp_dd; TransformationVectorShuffle transformation1( MakeInstructionDescriptor(100, SpvOpReturn, 0), 200, 12, 112, {1, 0}); - ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); - transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation1.IsApplicable(context.get(), transformation_context)); + transformation1.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(200, {0}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(11, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), temp_dd)); temp_dd = MakeDataDescriptor(200, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(10, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(10, {}), temp_dd)); TransformationVectorShuffle transformation2( MakeInstructionDescriptor(100, SpvOpReturn, 0), 201, 20, 12, {0xFFFFFFFF, 3, 5}); - ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); - transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation2.IsApplicable(context.get(), transformation_context)); + transformation2.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(201, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(11, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), temp_dd)); temp_dd = MakeDataDescriptor(201, {2}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(11, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(11, {}), temp_dd)); TransformationVectorShuffle transformation3( MakeInstructionDescriptor(100, SpvOpReturn, 0), 202, 27, 35, {5, 4, 1}); - ASSERT_TRUE(transformation3.IsApplicable(context.get(), fact_manager)); - transformation3.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation3.IsApplicable(context.get(), transformation_context)); + transformation3.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(202, {0}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(26, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(26, {}), temp_dd)); temp_dd = MakeDataDescriptor(202, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(25, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(25, {}), temp_dd)); temp_dd = MakeDataDescriptor(202, {2}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(26, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(26, {}), temp_dd)); TransformationVectorShuffle transformation4( MakeInstructionDescriptor(100, SpvOpReturn, 0), 203, 42, 46, {0, 1}); - ASSERT_TRUE(transformation4.IsApplicable(context.get(), fact_manager)); - transformation4.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation4.IsApplicable(context.get(), transformation_context)); + transformation4.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(203, {0}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), temp_dd)); temp_dd = MakeDataDescriptor(203, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(41, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), temp_dd)); TransformationVectorShuffle transformation5( MakeInstructionDescriptor(100, SpvOpReturn, 0), 204, 42, 46, {2, 3, 4}); - ASSERT_TRUE(transformation5.IsApplicable(context.get(), fact_manager)); - transformation5.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation5.IsApplicable(context.get(), transformation_context)); + transformation5.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(204, {0}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), temp_dd)); temp_dd = MakeDataDescriptor(204, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(41, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), temp_dd)); temp_dd = MakeDataDescriptor(204, {2}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), temp_dd)); TransformationVectorShuffle transformation6( MakeInstructionDescriptor(100, SpvOpReturn, 0), 205, 42, 42, {0, 1, 2, 3}); - ASSERT_TRUE(transformation6.IsApplicable(context.get(), fact_manager)); - transformation6.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation6.IsApplicable(context.get(), transformation_context)); + transformation6.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(205, {0}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), temp_dd)); temp_dd = MakeDataDescriptor(205, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(41, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), temp_dd)); temp_dd = MakeDataDescriptor(205, {2}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(40, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(40, {}), temp_dd)); temp_dd = MakeDataDescriptor(205, {3}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(41, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(41, {}), temp_dd)); // swizzle vec4 from vec4 and vec4 using some undefs TransformationVectorShuffle transformation7( MakeInstructionDescriptor(100, SpvOpReturn, 0), 206, 65, 65, {0xFFFFFFFF, 3, 6, 0xFFFFFFFF}); - ASSERT_TRUE(transformation7.IsApplicable(context.get(), fact_manager)); - transformation7.Apply(context.get(), &fact_manager); + ASSERT_TRUE( + transformation7.IsApplicable(context.get(), transformation_context)); + transformation7.Apply(context.get(), &transformation_context); temp_dd = MakeDataDescriptor(206, {1}); - ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(56, {}), temp_dd, - context.get())); + ASSERT_TRUE(transformation_context.GetFactManager()->IsSynonymous( + MakeDataDescriptor(56, {}), temp_dd)); std::string after_transformation = R"( OpCapability Shader @@ -479,52 +489,55 @@ ASSERT_TRUE(IsValid(env, context.get())); FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); // Cannot insert before the OpVariables of a function. ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(101, SpvOpVariable, 0), 200, 14, 14, {0, 1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(101, SpvOpVariable, 1), 200, 14, 14, {1, 2}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(102, SpvOpVariable, 0), 200, 14, 14, {1, 2}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK to insert right after the OpVariables. ASSERT_FALSE( TransformationVectorShuffle( MakeInstructionDescriptor(102, SpvOpBranch, 1), 200, 14, 14, {1, 1}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before the OpPhis of a block. ASSERT_FALSE( TransformationVectorShuffle(MakeInstructionDescriptor(60, SpvOpPhi, 0), 200, 14, 14, {2, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); ASSERT_FALSE( TransformationVectorShuffle(MakeInstructionDescriptor(59, SpvOpPhi, 0), 200, 14, 14, {3, 0}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // OK to insert after the OpPhis. ASSERT_TRUE(TransformationVectorShuffle( MakeInstructionDescriptor(59, SpvOpAccessChain, 0), 200, 14, 14, {3, 4}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before OpLoopMerge ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(33, SpvOpBranchConditional, 0), 200, 14, 14, {3}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); // Cannot insert before OpSelectionMerge ASSERT_FALSE(TransformationVectorShuffle( MakeInstructionDescriptor(21, SpvOpBranchConditional, 0), 200, 14, 14, {2}) - .IsApplicable(context.get(), fact_manager)); + .IsApplicable(context.get(), transformation_context)); } } // namespace
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 3954338..21a6529 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt
@@ -33,6 +33,7 @@ dead_branch_elim_test.cpp dead_insert_elim_test.cpp dead_variable_elim_test.cpp + debug_info_manager_test.cpp decompose_initialized_variables_test.cpp decoration_manager_test.cpp def_use_test.cpp
diff --git a/test/opt/debug_info_manager_test.cpp b/test/opt/debug_info_manager_test.cpp new file mode 100644 index 0000000..f19737f --- /dev/null +++ b/test/opt/debug_info_manager_test.cpp
@@ -0,0 +1,437 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/debug_info_manager.h" + +#include <memory> +#include <string> +#include <vector> + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/instruction.h" +#include "spirv-tools/libspirv.hpp" + +// Constants for OpenCL.DebugInfo.100 extension instructions. + +static const uint32_t kDebugFunctionOperandFunctionIndex = 13; +static const uint32_t kDebugInlinedAtOperandLineIndex = 4; +static const uint32_t kDebugInlinedAtOperandScopeIndex = 5; +static const uint32_t kDebugInlinedAtOperandInlinedIndex = 6; + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +TEST(DebugInfoManager, GetDebugInlinedAt) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void main(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %main "main" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %main + %100 = OpExtInst %void %1 DebugInlinedAt 7 %22 + %main = OpFunction %void None %27 + %28 = OpLabel + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + EXPECT_EQ(manager.GetDebugInlinedAt(150), nullptr); + EXPECT_EQ(manager.GetDebugInlinedAt(31), nullptr); + EXPECT_EQ(manager.GetDebugInlinedAt(22), nullptr); + + auto* inst = manager.GetDebugInlinedAt(100); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), 7); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), 22); +} + +TEST(DebugInfoManager, CreateDebugInlinedAt) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void main(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %main "main" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %main + %100 = OpExtInst %void %1 DebugInlinedAt 7 %22 + %main = OpFunction %void None %27 + %28 = OpLabel + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + OpReturn + OpFunctionEnd + )"; + + DebugScope scope(22U, 0U); + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + uint32_t inlined_at_id = manager.CreateDebugInlinedAt(nullptr, scope); + auto* inlined_at = manager.GetDebugInlinedAt(inlined_at_id); + EXPECT_NE(inlined_at, nullptr); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), + 1); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), + 22); + EXPECT_EQ(inlined_at->NumOperands(), kDebugInlinedAtOperandScopeIndex + 1); + + const uint32_t line_number = 77U; + Instruction line(context.get(), SpvOpLine); + line.SetInOperands({ + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {5U}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {line_number}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0U}}, + }); + + inlined_at_id = manager.CreateDebugInlinedAt(&line, scope); + inlined_at = manager.GetDebugInlinedAt(inlined_at_id); + EXPECT_NE(inlined_at, nullptr); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), + line_number); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), + 22); + EXPECT_EQ(inlined_at->NumOperands(), kDebugInlinedAtOperandScopeIndex + 1); + + scope.SetInlinedAt(100U); + inlined_at_id = manager.CreateDebugInlinedAt(&line, scope); + inlined_at = manager.GetDebugInlinedAt(inlined_at_id); + EXPECT_NE(inlined_at, nullptr); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), + line_number); + EXPECT_EQ(inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), + 22); + EXPECT_EQ(inlined_at->NumOperands(), kDebugInlinedAtOperandInlinedIndex + 1); + EXPECT_EQ( + inlined_at->GetSingleWordOperand(kDebugInlinedAtOperandInlinedIndex), + 100U); +} + +TEST(DebugInfoManager, GetDebugInfoNone) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void main(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %main "main" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %main + %12 = OpExtInst %void %1 DebugInfoNone + %25 = OpExtInst %void %1 DebugLocalVariable %24 %18 %15 1 20 %22 FlagIsLocal 0 + %main = OpFunction %void None %27 + %28 = OpLabel + %100 = OpVariable %_ptr_Function_float Function + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + %36 = OpExtInst %void %1 DebugDeclare %25 %100 %13 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + Instruction* debug_info_none_inst = manager.GetDebugInfoNone(); + EXPECT_NE(debug_info_none_inst, nullptr); + EXPECT_EQ(debug_info_none_inst->GetOpenCL100DebugOpcode(), + OpenCLDebugInfo100DebugInfoNone); + EXPECT_EQ(debug_info_none_inst->PreviousNode(), nullptr); +} + +TEST(DebugInfoManager, CreateDebugInfoNone) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void main(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %main "main" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %main + %25 = OpExtInst %void %1 DebugLocalVariable %24 %18 %15 1 20 %22 FlagIsLocal 0 + %main = OpFunction %void None %27 + %28 = OpLabel + %100 = OpVariable %_ptr_Function_float Function + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + %36 = OpExtInst %void %1 DebugDeclare %25 %100 %13 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + Instruction* debug_info_none_inst = manager.GetDebugInfoNone(); + EXPECT_NE(debug_info_none_inst, nullptr); + EXPECT_EQ(debug_info_none_inst->GetOpenCL100DebugOpcode(), + OpenCLDebugInfo100DebugInfoNone); + EXPECT_EQ(debug_info_none_inst->PreviousNode(), nullptr); +} + +TEST(DebugInfoManager, GetDebugFunction) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %200 "200" %in_var_COLOR + OpExecutionMode %200 OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void 200(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "200" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %200 "200" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %200 + %25 = OpExtInst %void %1 DebugLocalVariable %24 %18 %15 1 20 %22 FlagIsLocal 0 + %200 = OpFunction %void None %27 + %28 = OpLabel + %100 = OpVariable %_ptr_Function_float Function + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + %36 = OpExtInst %void %1 DebugDeclare %25 %100 %13 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + EXPECT_EQ(manager.GetDebugFunction(100), nullptr); + EXPECT_EQ(manager.GetDebugFunction(150), nullptr); + + Instruction* dbg_fn = manager.GetDebugFunction(200); + + EXPECT_EQ(dbg_fn->GetOpenCL100DebugOpcode(), OpenCLDebugInfo100DebugFunction); + EXPECT_EQ(dbg_fn->GetSingleWordOperand(kDebugFunctionOperandFunctionIndex), + 200); +} + +TEST(DebugInfoManager, CloneDebugInlinedAt) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + %14 = OpString "#line 1 \"ps.hlsl\" +void main(float in_var_color : COLOR) { + float color = in_var_color; +} +" + %17 = OpString "float" + %21 = OpString "main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %main "main" + OpDecorate %in_var_COLOR Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_float = OpTypePointer Function %float +%in_var_COLOR = OpVariable %_ptr_Input_float Input + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %18 %18 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %main + %100 = OpExtInst %void %1 DebugInlinedAt 7 %22 + %main = OpFunction %void None %27 + %28 = OpLabel + %31 = OpLoad %float %in_var_COLOR + OpStore %100 %31 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + DebugInfoManager manager(context.get()); + + EXPECT_EQ(manager.CloneDebugInlinedAt(150), nullptr); + EXPECT_EQ(manager.CloneDebugInlinedAt(22), nullptr); + + auto* inst = manager.CloneDebugInlinedAt(100); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), 7); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), 22); + EXPECT_EQ(inst->NumOperands(), kDebugInlinedAtOperandScopeIndex + 1); + + Instruction* before_100 = nullptr; + for (auto it = context->module()->ext_inst_debuginfo_begin(); + it != context->module()->ext_inst_debuginfo_end(); ++it) { + if (it->result_id() == 100) break; + before_100 = &*it; + } + EXPECT_NE(inst, before_100); + + inst = manager.CloneDebugInlinedAt(100, manager.GetDebugInlinedAt(100)); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandLineIndex), 7); + EXPECT_EQ(inst->GetSingleWordOperand(kDebugInlinedAtOperandScopeIndex), 22); + EXPECT_EQ(inst->NumOperands(), kDebugInlinedAtOperandScopeIndex + 1); + + before_100 = nullptr; + for (auto it = context->module()->ext_inst_debuginfo_begin(); + it != context->module()->ext_inst_debuginfo_end(); ++it) { + if (it->result_id() == 100) break; + before_100 = &*it; + } + EXPECT_EQ(inst, before_100); +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools
diff --git a/test/opt/eliminate_dead_functions_test.cpp b/test/opt/eliminate_dead_functions_test.cpp index 0a3d490..2f8fa9a 100644 --- a/test/opt/eliminate_dead_functions_test.cpp +++ b/test/opt/eliminate_dead_functions_test.cpp
@@ -204,6 +204,146 @@ /* skip_nop = */ true); } +TEST_F(EliminateDeadFunctionsBasicTest, DebugRemoveFunctionFromDebugFunction) { + // We want to remove id of OpFunction from DebugFunction. + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 %4 +OpExecutionMode %2 OriginUpperLeft +%5 = OpString "ps.hlsl" +OpSource HLSL 600 %5 "float4 foo() { + return 1; +} +float4 main(float4 color : COLOR) : SV_TARGET { + return foo() + color; +} +" +%6 = OpString "float" +%7 = OpString "main" +%8 = OpString "foo" +; CHECK: [[foo:%\d+]] = OpString "foo" +OpDecorate %3 Location 0 +OpDecorate %4 Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%v4float = OpTypeVector %float 4 +%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%void = OpTypeVoid +%18 = OpTypeFunction %void +%19 = OpTypeFunction %v4float +%3 = OpVariable %_ptr_Input_v4float Input +%4 = OpVariable %_ptr_Output_v4float Output +%_ptr_Function_v4float = OpTypePointer Function %v4float +; CHECK: [[info_none:%\d+]] = OpExtInst %void %1 DebugInfoNone +%20 = OpExtInst %void %1 DebugSource %5 +%21 = OpExtInst %void %1 DebugCompilationUnit 1 4 %20 HLSL +%22 = OpExtInst %void %1 DebugTypeBasic %6 %uint_32 Float +%23 = OpExtInst %void %1 DebugTypeVector %22 4 +%24 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 %23 +%25 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 +%26 = OpExtInst %void %1 DebugFunction %7 %24 %20 4 1 %21 %7 FlagIsProtected|FlagIsPrivate 4 %2 +%27 = OpExtInst %void %1 DebugFunction %8 %25 %20 1 1 %21 %8 FlagIsProtected|FlagIsPrivate 1 %28 +; CHECK: {{%\d+}} = OpExtInst %void %1 DebugFunction [[foo]] {{%\d+}} {{%\d+}} 1 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 1 [[info_none]] +%29 = OpExtInst %void %1 DebugLexicalBlock %20 1 14 %27 +%40 = OpExtInst %void %1 DebugInlinedAt 4 %26 +%2 = OpFunction %void None %18 +%30 = OpLabel +%39 = OpVariable %_ptr_Function_v4float Function +%41 = OpExtInst %void %1 DebugScope %27 %40 +OpStore %39 %14 +%32 = OpLoad %v4float %39 +%42 = OpExtInst %void %1 DebugScope %26 +%33 = OpLoad %v4float %3 +%34 = OpFAdd %v4float %32 %33 +OpStore %4 %34 +%43 = OpExtInst %void %1 DebugNoScope +OpReturn +OpFunctionEnd +%28 = OpFunction %v4float None %19 +%36 = OpLabel +OpReturnValue %14 +OpFunctionEnd +)"; + + SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, false); +} + +TEST_F(EliminateDeadFunctionsBasicTest, + DebugRemoveFunctionUsingExistingDebugInfoNone) { + // We want to remove id of OpFunction from DebugFunction. + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 %4 +OpExecutionMode %2 OriginUpperLeft +%5 = OpString "ps.hlsl" +OpSource HLSL 600 %5 "float4 foo() { + return 1; +} +float4 main(float4 color : COLOR) : SV_TARGET { + return foo() + color; +} +" +%6 = OpString "float" +%7 = OpString "main" +%8 = OpString "foo" +; CHECK: [[foo:%\d+]] = OpString "foo" +OpDecorate %3 Location 0 +OpDecorate %4 Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%v4float = OpTypeVector %float 4 +%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%void = OpTypeVoid +%18 = OpTypeFunction %void +%19 = OpTypeFunction %v4float +%3 = OpVariable %_ptr_Input_v4float Input +%4 = OpVariable %_ptr_Output_v4float Output +%_ptr_Function_v4float = OpTypePointer Function %v4float +; CHECK: [[info_none:%\d+]] = OpExtInst %void %1 DebugInfoNone +%20 = OpExtInst %void %1 DebugSource %5 +%21 = OpExtInst %void %1 DebugCompilationUnit 1 4 %20 HLSL +%22 = OpExtInst %void %1 DebugTypeBasic %6 %uint_32 Float +%23 = OpExtInst %void %1 DebugTypeVector %22 4 +%24 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 %23 +%25 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 +%26 = OpExtInst %void %1 DebugFunction %7 %24 %20 4 1 %21 %7 FlagIsProtected|FlagIsPrivate 4 %2 +%27 = OpExtInst %void %1 DebugFunction %8 %25 %20 1 1 %21 %8 FlagIsProtected|FlagIsPrivate 1 %28 +; CHECK: {{%\d+}} = OpExtInst %void %1 DebugFunction [[foo]] {{%\d+}} {{%\d+}} 1 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 1 [[info_none]] +%29 = OpExtInst %void %1 DebugLexicalBlock %20 1 14 %27 +%35 = OpExtInst %void %1 DebugInfoNone +%40 = OpExtInst %void %1 DebugInlinedAt 4 %26 +%2 = OpFunction %void None %18 +%30 = OpLabel +%39 = OpVariable %_ptr_Function_v4float Function +%41 = OpExtInst %void %1 DebugScope %27 %40 +OpStore %39 %14 +%32 = OpLoad %v4float %39 +%42 = OpExtInst %void %1 DebugScope %26 +%33 = OpLoad %v4float %3 +%34 = OpFAdd %v4float %32 %33 +OpStore %4 %34 +%43 = OpExtInst %void %1 DebugNoScope +OpReturn +OpFunctionEnd +%28 = OpFunction %v4float None %19 +%36 = OpLabel +OpReturnValue %14 +OpFunctionEnd +)"; + + SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, false); +} + } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp index b6925d7..a9b0f28 100644 --- a/test/opt/eliminate_dead_member_test.cpp +++ b/test/opt/eliminate_dead_member_test.cpp
@@ -1085,4 +1085,103 @@ EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); } +TEST_F(EliminateDeadMemberTest, UpdateSpecConstOpExtract) { + // Test that an extract in an OpSpecConstantOp is correctly updated. + const std::string text = R"( +; CHECK: OpName +; CHECK-NEXT: OpMemberName %type__Globals 0 "y" +; CHECK-NOT: OpMemberName +; CHECK: OpDecorate [[spec_const:%\w+]] SpecId 1 +; CHECK: OpMemberDecorate %type__Globals 0 Offset 4 +; CHECK: %type__Globals = OpTypeStruct %uint +; CHECK: [[struct:%\w+]] = OpSpecConstantComposite %type__Globals [[spec_const]] +; CHECK: OpSpecConstantOp %uint CompositeExtract [[struct]] 0 + OpCapability Shader + OpCapability Addresses + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource HLSL 600 + OpName %type__Globals "type.$Globals" + OpMemberName %type__Globals 0 "x" + OpMemberName %type__Globals 1 "y" + OpMemberName %type__Globals 2 "z" + OpName %main "main" + OpDecorate %c_0 SpecId 0 + OpDecorate %c_1 SpecId 1 + OpDecorate %c_2 SpecId 2 + OpMemberDecorate %type__Globals 0 Offset 0 + OpMemberDecorate %type__Globals 1 Offset 4 + OpMemberDecorate %type__Globals 2 Offset 16 + %uint = OpTypeInt 32 0 + %c_0 = OpSpecConstant %uint 0 + %c_1 = OpSpecConstant %uint 1 + %c_2 = OpSpecConstant %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 +%type__Globals = OpTypeStruct %uint %uint %uint +%spec_const_global = OpSpecConstantComposite %type__Globals %c_0 %c_1 %c_2 +%extract = OpSpecConstantOp %uint CompositeExtract %spec_const_global 1 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %main = OpFunction %void None %14 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<opt::EliminateDeadMembersPass>(text, true); +} + +TEST_F(EliminateDeadMemberTest, UpdateSpecConstOpInsert) { + // Test that an insert in an OpSpecConstantOp is correctly updated. + const std::string text = R"( +; CHECK: OpName +; CHECK-NEXT: OpMemberName %type__Globals 0 "y" +; CHECK-NOT: OpMemberName +; CHECK: OpDecorate [[spec_const:%\w+]] SpecId 1 +; CHECK: OpMemberDecorate %type__Globals 0 Offset 4 +; CHECK: %type__Globals = OpTypeStruct %uint +; CHECK: [[struct:%\w+]] = OpSpecConstantComposite %type__Globals [[spec_const]] +; CHECK: OpSpecConstantOp %type__Globals CompositeInsert %uint_3 [[struct]] 0 + OpCapability Shader + OpCapability Addresses + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource HLSL 600 + OpName %type__Globals "type.$Globals" + OpMemberName %type__Globals 0 "x" + OpMemberName %type__Globals 1 "y" + OpMemberName %type__Globals 2 "z" + OpName %main "main" + OpDecorate %c_0 SpecId 0 + OpDecorate %c_1 SpecId 1 + OpDecorate %c_2 SpecId 2 + OpMemberDecorate %type__Globals 0 Offset 0 + OpMemberDecorate %type__Globals 1 Offset 4 + OpMemberDecorate %type__Globals 2 Offset 16 + %uint = OpTypeInt 32 0 + %c_0 = OpSpecConstant %uint 0 + %c_1 = OpSpecConstant %uint 1 + %c_2 = OpSpecConstant %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 +%type__Globals = OpTypeStruct %uint %uint %uint +%spec_const_global = OpSpecConstantComposite %type__Globals %c_0 %c_1 %c_2 +%insert = OpSpecConstantOp %type__Globals CompositeInsert %uint_3 %spec_const_global 1 +%extract = OpSpecConstantOp %uint CompositeExtract %insert 1 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %main = OpFunction %void None %14 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<opt::EliminateDeadMembersPass>(text, true); +} + } // namespace
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index db01924..beb26f2 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp
@@ -3297,7 +3297,16 @@ "%2 = OpIMul %int %3 %int_1\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 3) + 2, 3), + // Test case 42: Don't fold comparisons of 64-bit types + // (https://github.com/KhronosGroup/SPIRV-Tools/issues/3343). + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSLessThan %bool %long_0 %long_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) )); INSTANTIATE_TEST_SUITE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTest,
diff --git a/test/opt/inline_opaque_test.cpp b/test/opt/inline_opaque_test.cpp index d10913a..b8d2dfa 100644 --- a/test/opt/inline_opaque_test.cpp +++ b/test/opt/inline_opaque_test.cpp
@@ -102,12 +102,12 @@ OpStore %32 %31 %33 = OpLoad %S_t %s0 OpStore %param %33 -%41 = OpAccessChain %_ptr_Function_18 %param %int_2 -%42 = OpLoad %18 %41 -%43 = OpAccessChain %_ptr_Function_v2float %param %int_0 -%44 = OpLoad %v2float %43 -%45 = OpImageSampleImplicitLod %v4float %42 %44 -OpStore %outColor %45 +%42 = OpAccessChain %_ptr_Function_18 %param %int_2 +%43 = OpLoad %18 %42 +%44 = OpAccessChain %_ptr_Function_v2float %param %int_0 +%45 = OpLoad %v2float %44 +%46 = OpImageSampleImplicitLod %v4float %43 %45 +OpStore %outColor %46 OpReturn OpFunctionEnd )"; @@ -191,10 +191,10 @@ %34 = OpVariable %_ptr_Function_20 Function %35 = OpVariable %_ptr_Function_20 Function %25 = OpVariable %_ptr_Function_20 Function -%36 = OpLoad %20 %sampler16 -OpStore %34 %36 -%37 = OpLoad %20 %34 -OpStore %35 %37 +%37 = OpLoad %20 %sampler16 +OpStore %34 %37 +%38 = OpLoad %20 %34 +OpStore %35 %38 %26 = OpLoad %20 %35 OpStore %25 %26 %27 = OpLoad %20 %25 @@ -301,12 +301,12 @@ OpStore %33 %32 %34 = OpLoad %S_t %s0 OpStore %param %34 -%44 = OpAccessChain %_ptr_Function_19 %param %int_2 -%45 = OpLoad %19 %44 -%46 = OpAccessChain %_ptr_Function_v2float %param %int_0 -%47 = OpLoad %v2float %46 -%48 = OpImageSampleImplicitLod %v4float %45 %47 -OpStore %outColor %48 +%45 = OpAccessChain %_ptr_Function_19 %param %int_2 +%46 = OpLoad %19 %45 +%47 = OpAccessChain %_ptr_Function_v2float %param %int_0 +%48 = OpLoad %v2float %47 +%49 = OpImageSampleImplicitLod %v4float %46 %48 +OpStore %outColor %49 OpReturn OpFunctionEnd )";
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp index f44c04a..fc2197c 100644 --- a/test/opt/inline_test.cpp +++ b/test/opt/inline_test.cpp
@@ -115,12 +115,12 @@ "%param = OpVariable %_ptr_Function_v4float Function", "%22 = OpLoad %v4float %BaseColor", "OpStore %param %22", - "%33 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%34 = OpLoad %float %33", - "%35 = OpAccessChain %_ptr_Function_float %param %uint_1", - "%36 = OpLoad %float %35", - "%37 = OpFAdd %float %34 %36", - "OpStore %32 %37", + "%34 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%35 = OpLoad %float %34", + "%36 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%37 = OpLoad %float %36", + "%38 = OpFAdd %float %35 %37", + "OpStore %32 %38", "%23 = OpLoad %float %32", "%24 = OpCompositeConstruct %v4float %23 %23 %23 %23", "OpStore %color %24", @@ -248,7 +248,7 @@ // clang-format off "%main = OpFunction %void None %15", "%28 = OpLabel", - "%57 = OpVariable %_ptr_Function_float Function", + "%58 = OpVariable %_ptr_Function_float Function", "%46 = OpVariable %_ptr_Function_float Function", "%47 = OpVariable %_ptr_Function_float Function", "%48 = OpVariable %_ptr_Function_float Function", @@ -256,21 +256,21 @@ "%param_1 = OpVariable %_ptr_Function_v4float Function", "%29 = OpLoad %v4float %BaseColor", "OpStore %param_1 %29", - "%49 = OpAccessChain %_ptr_Function_float %param_1 %uint_0", - "%50 = OpLoad %float %49", - "%51 = OpAccessChain %_ptr_Function_float %param_1 %uint_1", - "%52 = OpLoad %float %51", - "%53 = OpFAdd %float %50 %52", - "OpStore %46 %53", - "%54 = OpAccessChain %_ptr_Function_float %param_1 %uint_2", - "%55 = OpLoad %float %54", - "OpStore %47 %55", - "%58 = OpLoad %float %46", - "%59 = OpLoad %float %47", - "%60 = OpFMul %float %58 %59", - "OpStore %57 %60", - "%56 = OpLoad %float %57", - "OpStore %48 %56", + "%50 = OpAccessChain %_ptr_Function_float %param_1 %uint_0", + "%51 = OpLoad %float %50", + "%52 = OpAccessChain %_ptr_Function_float %param_1 %uint_1", + "%53 = OpLoad %float %52", + "%54 = OpFAdd %float %51 %53", + "OpStore %46 %54", + "%55 = OpAccessChain %_ptr_Function_float %param_1 %uint_2", + "%56 = OpLoad %float %55", + "OpStore %47 %56", + "%60 = OpLoad %float %46", + "%61 = OpLoad %float %47", + "%62 = OpFMul %float %60 %61", + "OpStore %58 %62", + "%57 = OpLoad %float %58", + "OpStore %48 %57", "%30 = OpLoad %float %48", "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", "OpStore %color %31", @@ -390,13 +390,13 @@ "OpStore %b %24", "%25 = OpLoad %v4float %b", "OpStore %param %25", - "%39 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%40 = OpLoad %float %39", - "%41 = OpAccessChain %_ptr_Function_float %param %uint_1", - "%42 = OpLoad %float %41", - "%43 = OpFAdd %float %40 %42", - "%44 = OpAccessChain %_ptr_Function_float %param %uint_2", - "OpStore %44 %43", + "%40 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%41 = OpLoad %float %40", + "%42 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%43 = OpLoad %float %42", + "%44 = OpFAdd %float %41 %43", + "%45 = OpAccessChain %_ptr_Function_float %param %uint_2", + "OpStore %45 %44", "%27 = OpLoad %v4float %param", "OpStore %b %27", "%28 = OpAccessChain %_ptr_Function_float %b %uint_2", @@ -521,21 +521,21 @@ "%param = OpVariable %_ptr_Function_v4float Function", "%24 = OpLoad %v4float %BaseColor", "OpStore %param %24", - "%40 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%41 = OpLoad %float %40", - "OpStore %38 %41", - "%42 = OpLoad %float %38", - "%43 = OpFOrdLessThan %bool %42 %float_0", - "OpSelectionMerge %44 None", - "OpBranchConditional %43 %45 %44", + "%41 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%42 = OpLoad %float %41", + "OpStore %38 %42", + "%43 = OpLoad %float %38", + "%44 = OpFOrdLessThan %bool %43 %float_0", + "OpSelectionMerge %48 None", + "OpBranchConditional %44 %45 %48", "%45 = OpLabel", "%46 = OpLoad %float %38", "%47 = OpFNegate %float %46", "OpStore %38 %47", - "OpBranch %44", - "%44 = OpLabel", - "%48 = OpLoad %float %38", - "OpStore %39 %48", + "OpBranch %48", + "%48 = OpLabel", + "%49 = OpLoad %float %38", + "OpStore %39 %49", "%25 = OpLoad %float %39", "%26 = OpCompositeConstruct %v4float %25 %25 %25 %25", "OpStore %color %26", @@ -675,8 +675,8 @@ // clang-format off "%main = OpFunction %void None %12", "%27 = OpLabel", - "%62 = OpVariable %_ptr_Function_float Function", "%63 = OpVariable %_ptr_Function_float Function", + "%64 = OpVariable %_ptr_Function_float Function", "%52 = OpVariable %_ptr_Function_float Function", "%53 = OpVariable %_ptr_Function_float Function", "%color = OpVariable %_ptr_Function_v4float Function", @@ -687,20 +687,20 @@ "%29 = OpAccessChain %_ptr_Function_float %color %uint_0", "%30 = OpLoad %float %29", "OpStore %param %30", - "%54 = OpLoad %float %param", - "OpStore %52 %54", - "%55 = OpLoad %float %52", - "%56 = OpFOrdLessThan %bool %55 %float_0", - "OpSelectionMerge %57 None", - "OpBranchConditional %56 %58 %57", + "%55 = OpLoad %float %param", + "OpStore %52 %55", + "%56 = OpLoad %float %52", + "%57 = OpFOrdLessThan %bool %56 %float_0", + "OpSelectionMerge %61 None", + "OpBranchConditional %57 %58 %61", "%58 = OpLabel", "%59 = OpLoad %float %52", "%60 = OpFNegate %float %59", "OpStore %52 %60", - "OpBranch %57", - "%57 = OpLabel", - "%61 = OpLoad %float %52", - "OpStore %53 %61", + "OpBranch %61", + "%61 = OpLabel", + "%62 = OpLoad %float %52", + "OpStore %53 %62", "%31 = OpLoad %float %53", "%32 = OpFOrdGreaterThan %bool %31 %float_2", "OpSelectionMerge %33 None", @@ -709,25 +709,25 @@ "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", "%36 = OpLoad %float %35", "OpStore %param_0 %36", - "%64 = OpLoad %float %param_0", - "OpStore %62 %64", - "%65 = OpLoad %float %62", - "%66 = OpFOrdLessThan %bool %65 %float_0", - "OpSelectionMerge %67 None", - "OpBranchConditional %66 %68 %67", - "%68 = OpLabel", - "%69 = OpLoad %float %62", - "%70 = OpFNegate %float %69", - "OpStore %62 %70", - "OpBranch %67", - "%67 = OpLabel", - "%71 = OpLoad %float %62", + "%66 = OpLoad %float %param_0", + "OpStore %63 %66", + "%67 = OpLoad %float %63", + "%68 = OpFOrdLessThan %bool %67 %float_0", + "OpSelectionMerge %72 None", + "OpBranchConditional %68 %69 %72", + "%69 = OpLabel", + "%70 = OpLoad %float %63", + "%71 = OpFNegate %float %70", "OpStore %63 %71", - "%37 = OpLoad %float %63", + "OpBranch %72", + "%72 = OpLabel", + "%73 = OpLoad %float %63", + "OpStore %64 %73", + "%37 = OpLoad %float %64", "%38 = OpFOrdGreaterThan %bool %37 %float_2", "OpBranch %33", "%33 = OpLabel", - "%39 = OpPhi %bool %32 %57 %38 %67", + "%39 = OpPhi %bool %32 %61 %38 %72", "OpSelectionMerge %40 None", "OpBranchConditional %39 %41 %40", "%41 = OpLabel", @@ -902,28 +902,28 @@ "OpStore %color1 %42", "%43 = OpLoad %v4float %BaseColor", "OpStore %param %43", - "%68 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%69 = OpLoad %float %68", - "OpStore %66 %69", - "%70 = OpLoad %float %66", - "%71 = OpFOrdLessThan %bool %70 %float_0", - "OpSelectionMerge %72 None", - "OpBranchConditional %71 %73 %72", + "%69 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%70 = OpLoad %float %69", + "OpStore %66 %70", + "%71 = OpLoad %float %66", + "%72 = OpFOrdLessThan %bool %71 %float_0", + "OpSelectionMerge %76 None", + "OpBranchConditional %72 %73 %76", "%73 = OpLabel", "%74 = OpLoad %float %66", "%75 = OpFNegate %float %74", "OpStore %66 %75", - "OpBranch %72", - "%72 = OpLabel", - "%76 = OpLoad %float %66", - "OpStore %67 %76", + "OpBranch %76", + "%76 = OpLabel", + "%77 = OpLoad %float %66", + "OpStore %67 %77", "%44 = OpLoad %float %67", "%45 = OpCompositeConstruct %v4float %44 %44 %44 %44", "OpStore %color2 %45", "%46 = OpLoad %25 %t2D", "%47 = OpLoad %27 %samp", - "%77 = OpSampledImage %29 %39 %40", - "%48 = OpImageSampleImplicitLod %v4float %77 %35", + "%78 = OpSampledImage %29 %39 %40", + "%48 = OpImageSampleImplicitLod %v4float %78 %35", "OpStore %color3 %48", "%49 = OpLoad %v4float %color1", "%50 = OpLoad %v4float %color2", @@ -1108,27 +1108,27 @@ "OpStore %color1 %43", "%46 = OpLoad %v4float %BaseColor", "OpStore %param %46", - "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%71 = OpLoad %float %70", - "OpStore %68 %71", - "%72 = OpLoad %float %68", - "%73 = OpFOrdLessThan %bool %72 %float_0", - "OpSelectionMerge %74 None", - "OpBranchConditional %73 %75 %74", + "%71 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%72 = OpLoad %float %71", + "OpStore %68 %72", + "%73 = OpLoad %float %68", + "%74 = OpFOrdLessThan %bool %73 %float_0", + "OpSelectionMerge %78 None", + "OpBranchConditional %74 %75 %78", "%75 = OpLabel", "%76 = OpLoad %float %68", "%77 = OpFNegate %float %76", "OpStore %68 %77", - "OpBranch %74", - "%74 = OpLabel", - "%78 = OpLoad %float %68", - "OpStore %69 %78", + "OpBranch %78", + "%78 = OpLabel", + "%79 = OpLoad %float %68", + "OpStore %69 %79", "%47 = OpLoad %float %69", "%48 = OpCompositeConstruct %v4float %47 %47 %47 %47", "OpStore %color2 %48", - "%79 = OpSampledImage %30 %40 %41", - "%80 = OpImage %26 %79", - "%49 = OpSampledImage %30 %80 %45", + "%80 = OpSampledImage %30 %40 %41", + "%81 = OpImage %26 %80", + "%49 = OpSampledImage %30 %81 %45", "%50 = OpImageSampleImplicitLod %v4float %49 %36", "OpStore %color3 %50", "%51 = OpLoad %v4float %color1", @@ -1314,28 +1314,28 @@ "OpStore %color1 %43", "%47 = OpLoad %v4float %BaseColor", "OpStore %param %47", - "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", - "%71 = OpLoad %float %70", - "OpStore %68 %71", - "%72 = OpLoad %float %68", - "%73 = OpFOrdLessThan %bool %72 %float_0", - "OpSelectionMerge %74 None", - "OpBranchConditional %73 %75 %74", + "%71 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%72 = OpLoad %float %71", + "OpStore %68 %72", + "%73 = OpLoad %float %68", + "%74 = OpFOrdLessThan %bool %73 %float_0", + "OpSelectionMerge %78 None", + "OpBranchConditional %74 %75 %78", "%75 = OpLabel", "%76 = OpLoad %float %68", "%77 = OpFNegate %float %76", "OpStore %68 %77", - "OpBranch %74", - "%74 = OpLabel", - "%78 = OpLoad %float %68", - "OpStore %69 %78", + "OpBranch %78", + "%78 = OpLabel", + "%79 = OpLoad %float %68", + "OpStore %69 %79", "%48 = OpLoad %float %69", "%49 = OpCompositeConstruct %v4float %48 %48 %48 %48", "OpStore %color2 %49", - "%79 = OpSampledImage %30 %40 %41", - "%80 = OpImage %26 %79", - "%81 = OpSampledImage %30 %80 %45", - "%50 = OpImageSampleImplicitLod %v4float %81 %36", + "%80 = OpSampledImage %30 %40 %41", + "%81 = OpImage %26 %80", + "%82 = OpSampledImage %30 %81 %45", + "%50 = OpImageSampleImplicitLod %v4float %82 %36", "OpStore %color3 %50", "%51 = OpLoad %v4float %color1", "%52 = OpLoad %v4float %color2", @@ -1355,292 +1355,6 @@ /* skip_nop = */ false, /* do_validate = */ true); } -TEST_F(InlineTest, EarlyReturnFunctionInlined) { - // #version 140 - // - // in vec4 BaseColor; - // - // float foo(vec4 bar) - // { - // if (bar.x < 0.0) - // return 0.0; - // return bar.x; - // } - // - // void main() - // { - // vec4 color = vec4(foo(BaseColor)); - // gl_FragColor = color; - // } - - const std::string predefs = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 -OpName %main "main" -OpName %foo_vf4_ "foo(vf4;" -OpName %bar "bar" -OpName %color "color" -OpName %BaseColor "BaseColor" -OpName %param "param" -OpName %gl_FragColor "gl_FragColor" -%void = OpTypeVoid -%10 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v4float = OpTypeVector %float 4 -%_ptr_Function_v4float = OpTypePointer Function %v4float -%14 = OpTypeFunction %float %_ptr_Function_v4float -%uint = OpTypeInt 32 0 -%uint_0 = OpConstant %uint 0 -%_ptr_Function_float = OpTypePointer Function %float -%float_0 = OpConstant %float 0 -%bool = OpTypeBool -%_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input -%_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output -)"; - - const std::string nonEntryFuncs = - R"(%foo_vf4_ = OpFunction %float None %14 -%bar = OpFunctionParameter %_ptr_Function_v4float -%27 = OpLabel -%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 -%29 = OpLoad %float %28 -%30 = OpFOrdLessThan %bool %29 %float_0 -OpSelectionMerge %31 None -OpBranchConditional %30 %32 %31 -%32 = OpLabel -OpReturnValue %float_0 -%31 = OpLabel -%33 = OpAccessChain %_ptr_Function_float %bar %uint_0 -%34 = OpLoad %float %33 -OpReturnValue %34 -OpFunctionEnd -)"; - - const std::string before = - R"(%main = OpFunction %void None %10 -%22 = OpLabel -%color = OpVariable %_ptr_Function_v4float Function -%param = OpVariable %_ptr_Function_v4float Function -%23 = OpLoad %v4float %BaseColor -OpStore %param %23 -%24 = OpFunctionCall %float %foo_vf4_ %param -%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 -OpStore %color %25 -%26 = OpLoad %v4float %color -OpStore %gl_FragColor %26 -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%false = OpConstantFalse %bool -%main = OpFunction %void None %10 -%22 = OpLabel -%35 = OpVariable %_ptr_Function_float Function -%color = OpVariable %_ptr_Function_v4float Function -%param = OpVariable %_ptr_Function_v4float Function -%23 = OpLoad %v4float %BaseColor -OpStore %param %23 -OpBranch %36 -%36 = OpLabel -OpLoopMerge %37 %38 None -OpBranch %39 -%39 = OpLabel -%40 = OpAccessChain %_ptr_Function_float %param %uint_0 -%41 = OpLoad %float %40 -%42 = OpFOrdLessThan %bool %41 %float_0 -OpSelectionMerge %43 None -OpBranchConditional %42 %44 %43 -%44 = OpLabel -OpStore %35 %float_0 -OpBranch %37 -%43 = OpLabel -%45 = OpAccessChain %_ptr_Function_float %param %uint_0 -%46 = OpLoad %float %45 -OpStore %35 %46 -OpBranch %37 -%38 = OpLabel -OpBranchConditional %false %36 %37 -%37 = OpLabel -%24 = OpLoad %float %35 -%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 -OpStore %color %25 -%26 = OpLoad %v4float %color -OpStore %gl_FragColor %26 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>(predefs + before + nonEntryFuncs, - predefs + after + nonEntryFuncs, - false, true); -} - -TEST_F(InlineTest, EarlyReturnNotAppearingLastInFunctionInlined) { - // Example from https://github.com/KhronosGroup/SPIRV-Tools/issues/755 - // - // Original example is derived from: - // - // #version 450 - // - // float foo() { - // if (true) { - // } - // } - // - // void main() { foo(); } - // - // But the order of basic blocks in foo is changed so that the return - // block is listed second-last. There is only one return in the callee - // but it does not appear last. - - const std::string predefs = - R"(OpCapability Shader -OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %main "main" -OpSource GLSL 450 -OpName %main "main" -OpName %foo_ "foo(" -%void = OpTypeVoid -%4 = OpTypeFunction %void -%bool = OpTypeBool -%true = OpConstantTrue %bool -)"; - - const std::string nonEntryFuncs = - R"(%foo_ = OpFunction %void None %4 -%7 = OpLabel -OpSelectionMerge %8 None -OpBranchConditional %true %9 %8 -%8 = OpLabel -OpReturn -%9 = OpLabel -OpBranch %8 -OpFunctionEnd -)"; - - const std::string before = - R"(%main = OpFunction %void None %4 -%10 = OpLabel -%11 = OpFunctionCall %void %foo_ -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%main = OpFunction %void None %4 -%10 = OpLabel -OpSelectionMerge %12 None -OpBranchConditional %true %13 %12 -%12 = OpLabel -OpBranch %14 -%13 = OpLabel -OpBranch %12 -%14 = OpLabel -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>(predefs + nonEntryFuncs + before, - predefs + nonEntryFuncs + after, - false, true); -} - -TEST_F(InlineTest, ForwardReferencesInPhiInlined) { - // The basic structure of the test case is like this: - // - // int foo() { - // int result = 1; - // if (true) { - // result = 1; - // } - // return result; - // } - // - // void main() { - // int x = foo(); - // } - // - // but with modifications: Using Phi instead of load/store, and the - // return block in foo appears before the "then" block. - - const std::string predefs = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %main "main" -OpSource GLSL 450 -OpName %main "main" -OpName %foo_ "foo(" -OpName %x "x" -%void = OpTypeVoid -%6 = OpTypeFunction %void -%int = OpTypeInt 32 1 -%8 = OpTypeFunction %int -%bool = OpTypeBool -%true = OpConstantTrue %bool -%int_0 = OpConstant %int 0 -%_ptr_Function_int = OpTypePointer Function %int -)"; - - const std::string nonEntryFuncs = - R"(%foo_ = OpFunction %int None %8 -%13 = OpLabel -%14 = OpCopyObject %int %int_0 -OpSelectionMerge %15 None -OpBranchConditional %true %16 %15 -%15 = OpLabel -%17 = OpPhi %int %14 %13 %18 %16 -OpReturnValue %17 -%16 = OpLabel -%18 = OpCopyObject %int %int_0 -OpBranch %15 -OpFunctionEnd -)"; - - const std::string before = - R"(%main = OpFunction %void None %6 -%19 = OpLabel -%x = OpVariable %_ptr_Function_int Function -%20 = OpFunctionCall %int %foo_ -OpStore %x %20 -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%main = OpFunction %void None %6 -%19 = OpLabel -%21 = OpVariable %_ptr_Function_int Function -%x = OpVariable %_ptr_Function_int Function -%22 = OpCopyObject %int %int_0 -OpSelectionMerge %23 None -OpBranchConditional %true %24 %23 -%23 = OpLabel -%26 = OpPhi %int %22 %19 %25 %24 -OpStore %21 %26 -OpBranch %27 -%24 = OpLabel -%25 = OpCopyObject %int %int_0 -OpBranch %23 -%27 = OpLabel -%20 = OpLoad %int %21 -OpStore %x %20 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>(predefs + nonEntryFuncs + before, - predefs + nonEntryFuncs + after, - false, true); -} - TEST_F(InlineTest, EarlyReturnInLoopIsNotInlined) { // #version 140 // @@ -1820,8 +1534,8 @@ OpBranch %10 %10 = OpLabel OpLoopMerge %12 %10 None -OpBranch %13 -%13 = OpLabel +OpBranch %14 +%14 = OpLabel OpBranchConditional %true %10 %12 %12 = OpLabel OpReturn @@ -1890,11 +1604,11 @@ OpBranch %18 %18 = OpLabel %19 = OpCopyObject %int %int_3 -%25 = OpCopyObject %int %int_1 +%26 = OpCopyObject %int %int_1 OpLoopMerge %22 %23 None -OpBranch %26 -%26 = OpLabel -%27 = OpCopyObject %int %int_2 +OpBranch %27 +%27 = OpLabel +%28 = OpCopyObject %int %int_2 %21 = OpCopyObject %int %int_4 OpBranchConditional %true %23 %22 %23 = OpLabel @@ -1983,11 +1697,11 @@ OpLoopMerge %16 %13 None OpBranch %17 %17 = OpLabel -%18 = OpCopyObject %bool %true -OpSelectionMerge %19 None -OpBranchConditional %true %19 %19 -%19 = OpLabel -%20 = OpPhi %bool %18 %17 +%19 = OpCopyObject %bool %true +OpSelectionMerge %20 None +OpBranchConditional %true %20 %20 +%20 = OpLabel +%21 = OpPhi %bool %19 %17 OpBranchConditional %true %13 %16 %16 = OpLabel OpReturn @@ -2060,11 +1774,11 @@ OpLoopMerge %22 %23 None OpBranch %25 %25 = OpLabel -%26 = OpCopyObject %int %int_1 -OpSelectionMerge %27 None -OpBranchConditional %true %27 %27 -%27 = OpLabel -%28 = OpCopyObject %int %int_2 +%27 = OpCopyObject %int %int_1 +OpSelectionMerge %28 None +OpBranchConditional %true %28 %28 +%28 = OpLabel +%29 = OpCopyObject %int %int_2 %21 = OpCopyObject %int %int_4 OpBranchConditional %true %23 %22 %23 = OpLabel @@ -2080,165 +1794,6 @@ false, true); } -TEST_F( - InlineTest, - SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMergeAndMultiReturns) { - // This is similar to SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMerge - // except that in addition to starting with a selection header, the - // callee also has multi returns. - // - // So now we have to accommodate: - // - The caller's OpLoopMerge (which must move to the first block) - // - The single-trip loop to wrap the multi returns, and - // - The callee's selection merge in its first block. - // Each of these must go into their own blocks. - - const std::string predefs = - R"(OpCapability Shader -OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %1 "main" -OpSource OpenCL_C 120 -%bool = OpTypeBool -%int = OpTypeInt 32 1 -%true = OpConstantTrue %bool -%false = OpConstantFalse %bool -%int_0 = OpConstant %int 0 -%int_1 = OpConstant %int 1 -%int_2 = OpConstant %int 2 -%int_3 = OpConstant %int 3 -%int_4 = OpConstant %int 4 -%void = OpTypeVoid -%12 = OpTypeFunction %void -)"; - - const std::string nonEntryFuncs = - R"(%13 = OpFunction %void None %12 -%14 = OpLabel -%15 = OpCopyObject %int %int_0 -OpReturn -%16 = OpLabel -%17 = OpCopyObject %int %int_1 -OpReturn -OpFunctionEnd -)"; - - const std::string before = - R"(%1 = OpFunction %void None %12 -%18 = OpLabel -OpBranch %19 -%19 = OpLabel -%20 = OpCopyObject %int %int_2 -%21 = OpFunctionCall %void %13 -%22 = OpCopyObject %int %int_3 -OpLoopMerge %23 %19 None -OpBranchConditional %true %19 %23 -%23 = OpLabel -%24 = OpCopyObject %int %int_4 -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%1 = OpFunction %void None %12 -%18 = OpLabel -OpBranch %19 -%19 = OpLabel -%20 = OpCopyObject %int %int_2 -%25 = OpCopyObject %int %int_0 -OpLoopMerge %23 %19 None -OpBranch %26 -%27 = OpLabel -%28 = OpCopyObject %int %int_1 -OpBranch %26 -%26 = OpLabel -%22 = OpCopyObject %int %int_3 -OpBranchConditional %true %19 %23 -%23 = OpLabel -%24 = OpCopyObject %int %int_4 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>(predefs + nonEntryFuncs + before, - predefs + nonEntryFuncs + after, - false, true); -} - -TEST_F(InlineTest, CalleeWithMultiReturnAndPhiRequiresEntryBlockRemapping) { - // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/790 - // - // The callee has multiple returns, and so must be wrapped with a single-trip - // loop. That code must remap the callee entry block ID to the introduced - // loop body's ID. Otherwise you can get a dominance error in a cloned OpPhi. - - const std::string predefs = - R"(OpCapability Shader -OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %1 "main" -OpSource OpenCL_C 120 -%int = OpTypeInt 32 1 -%int_0 = OpConstant %int 0 -%int_1 = OpConstant %int 1 -%int_2 = OpConstant %int 2 -%int_3 = OpConstant %int 3 -%int_4 = OpConstant %int 4 -%void = OpTypeVoid -%9 = OpTypeFunction %void -%bool = OpTypeBool -%false = OpConstantFalse %bool -)"; - - // This callee has multiple returns, and a Phi in the second block referencing - // a value generated in the entry block. - const std::string nonEntryFuncs = - R"(%12 = OpFunction %void None %9 -%13 = OpLabel -%14 = OpCopyObject %int %int_0 -OpBranch %15 -%15 = OpLabel -%16 = OpPhi %int %14 %13 -%17 = OpCopyObject %int %int_1 -OpReturn -%18 = OpLabel -%19 = OpCopyObject %int %int_2 -OpReturn -OpFunctionEnd -)"; - - const std::string before = - R"(%1 = OpFunction %void None %9 -%20 = OpLabel -%21 = OpCopyObject %int %int_3 -%22 = OpFunctionCall %void %12 -%23 = OpCopyObject %int %int_4 -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%1 = OpFunction %void None %9 -%20 = OpLabel -%21 = OpCopyObject %int %int_3 -%24 = OpCopyObject %int %int_0 -OpBranch %25 -%25 = OpLabel -%26 = OpPhi %int %24 %20 -%27 = OpCopyObject %int %int_1 -OpBranch %28 -%29 = OpLabel -%30 = OpCopyObject %int %int_2 -OpBranch %28 -%28 = OpLabel -%23 = OpCopyObject %int %int_4 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>(predefs + nonEntryFuncs + before, - predefs + nonEntryFuncs + after, - false, true); -} - TEST_F(InlineTest, NonInlinableCalleeWithSingleReturn) { // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/2018 // @@ -2324,138 +1879,6 @@ predefs + caller + callee, predefs + caller + callee, false, true); } -TEST_F(InlineTest, CalleeWithSingleReturnNeedsSingleTripLoopWrapper) { - // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/2018 - // - // The callee has a single return, but needs single-trip loop wrapper - // to be inlined because the return is in a selection structure. - - const std::string predefs = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %_GLF_color -OpExecutionMode %main OriginUpperLeft -OpSource ESSL 310 -OpName %main "main" -OpName %f_ "f(" -OpName %i "i" -OpName %_GLF_color "_GLF_color" -OpDecorate %_GLF_color Location 0 -%void = OpTypeVoid -%7 = OpTypeFunction %void -%float = OpTypeFloat 32 -%9 = OpTypeFunction %float -%float_1 = OpConstant %float 1 -%bool = OpTypeBool -%false = OpConstantFalse %bool -%true = OpConstantTrue %bool -%int = OpTypeInt 32 1 -%_ptr_Function_int = OpTypePointer Function %int -%int_0 = OpConstant %int 0 -%int_1 = OpConstant %int 1 -%v4float = OpTypeVector %float 4 -%_ptr_Output_v4float = OpTypePointer Output %v4float -%_GLF_color = OpVariable %_ptr_Output_v4float Output -%float_0 = OpConstant %float 0 -%21 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 -%22 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 -)"; - - const std::string new_predefs = - R"(%_ptr_Function_float = OpTypePointer Function %float -)"; - - const std::string main_before = - R"(%main = OpFunction %void None %7 -%23 = OpLabel -%i = OpVariable %_ptr_Function_int Function -OpStore %i %int_0 -OpBranch %24 -%24 = OpLabel -OpLoopMerge %25 %26 None -OpBranch %27 -%27 = OpLabel -%28 = OpLoad %int %i -%29 = OpSLessThan %bool %28 %int_1 -OpBranchConditional %29 %30 %25 -%30 = OpLabel -OpStore %_GLF_color %21 -%31 = OpFunctionCall %float %f_ -OpBranch %26 -%26 = OpLabel -%32 = OpLoad %int %i -%33 = OpIAdd %int %32 %int_1 -OpStore %i %33 -OpBranch %24 -%25 = OpLabel -OpStore %_GLF_color %22 -OpReturn -OpFunctionEnd -)"; - - const std::string main_after = - R"(%main = OpFunction %void None %7 -%23 = OpLabel -%38 = OpVariable %_ptr_Function_float Function -%i = OpVariable %_ptr_Function_int Function -OpStore %i %int_0 -OpBranch %24 -%24 = OpLabel -OpLoopMerge %25 %26 None -OpBranch %27 -%27 = OpLabel -%28 = OpLoad %int %i -%29 = OpSLessThan %bool %28 %int_1 -OpBranchConditional %29 %30 %25 -%30 = OpLabel -OpStore %_GLF_color %21 -OpBranch %39 -%39 = OpLabel -OpLoopMerge %40 %41 None -OpBranch %42 -%42 = OpLabel -OpSelectionMerge %43 None -OpBranchConditional %true %44 %43 -%44 = OpLabel -OpStore %38 %float_1 -OpBranch %40 -%43 = OpLabel -OpStore %38 %float_1 -OpBranch %40 -%41 = OpLabel -OpBranchConditional %false %39 %40 -%40 = OpLabel -%31 = OpLoad %float %38 -OpBranch %26 -%26 = OpLabel -%32 = OpLoad %int %i -%33 = OpIAdd %int %32 %int_1 -OpStore %i %33 -OpBranch %24 -%25 = OpLabel -OpStore %_GLF_color %22 -OpReturn -OpFunctionEnd -)"; - - const std::string callee = - R"(%f_ = OpFunction %float None %9 -%34 = OpLabel -OpSelectionMerge %35 None -OpBranchConditional %true %36 %35 -%36 = OpLabel -OpReturnValue %float_1 -%35 = OpLabel -OpReturnValue %float_1 -OpFunctionEnd -)"; - - SinglePassRunAndCheck<InlineExhaustivePass>( - predefs + main_before + callee, - predefs + new_predefs + main_after + callee, false, true); -} - TEST_F(InlineTest, Decorated1) { // Same test as Simple with the difference // that OpFAdd in the outlined function is @@ -2526,7 +1949,7 @@ )"; const std::string after = - R"(OpDecorate %37 RelaxedPrecision + R"(OpDecorate %38 RelaxedPrecision %void = OpTypeVoid %11 = OpTypeFunction %void %float = OpTypeFloat 32 @@ -2548,12 +1971,12 @@ %param = OpVariable %_ptr_Function_v4float Function %23 = OpLoad %v4float %BaseColor OpStore %param %23 -%33 = OpAccessChain %_ptr_Function_float %param %uint_0 -%34 = OpLoad %float %33 -%35 = OpAccessChain %_ptr_Function_float %param %uint_1 -%36 = OpLoad %float %35 -%37 = OpFAdd %float %34 %36 -OpStore %32 %37 +%34 = OpAccessChain %_ptr_Function_float %param %uint_0 +%35 = OpLoad %float %34 +%36 = OpAccessChain %_ptr_Function_float %param %uint_1 +%37 = OpLoad %float %36 +%38 = OpFAdd %float %35 %37 +OpStore %32 %38 %24 = OpLoad %float %32 %25 = OpCompositeConstruct %v4float %24 %24 %24 %24 OpStore %color %25 @@ -2672,12 +2095,12 @@ %param = OpVariable %_ptr_Function_v4float Function %22 = OpLoad %v4float %BaseColor OpStore %param %22 -%33 = OpAccessChain %_ptr_Function_float %param %uint_0 -%34 = OpLoad %float %33 -%35 = OpAccessChain %_ptr_Function_float %param %uint_1 -%36 = OpLoad %float %35 -%37 = OpFAdd %float %34 %36 -OpStore %32 %37 +%34 = OpAccessChain %_ptr_Function_float %param %uint_0 +%35 = OpLoad %float %34 +%36 = OpAccessChain %_ptr_Function_float %param %uint_1 +%37 = OpLoad %float %36 +%38 = OpFAdd %float %35 %37 +OpStore %32 %38 %23 = OpLoad %float %32 %24 = OpCompositeConstruct %v4float %23 %23 %23 %23 OpStore %color %24 @@ -3017,7 +2440,7 @@ %main = OpFunction %void None %3 %5 = OpLabel OpKill -%17 = OpLabel +%18 = OpLabel OpReturn OpFunctionEnd %kill_ = OpFunction %void None %3 @@ -3030,6 +2453,1305 @@ SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true); } +TEST_F(InlineTest, EarlyReturnFunctionInlined) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // if (bar.x < 0.0) + // return 0.0; + // return bar.x; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string foo = + R"(%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%27 = OpLabel +%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%29 = OpLoad %float %28 +%30 = OpFOrdLessThan %bool %29 %float_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +OpReturnValue %float_0 +%31 = OpLabel +%33 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%34 = OpLoad %float %33 +OpReturnValue %34 +OpFunctionEnd +)"; + + const std::string fooMergeReturn = + R"(%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%27 = OpLabel +%41 = OpVariable %_ptr_Function_bool Function %false +%36 = OpVariable %_ptr_Function_float Function +OpSelectionMerge %35 None +OpSwitch %uint_0 %38 +%38 = OpLabel +%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%29 = OpLoad %float %28 +%30 = OpFOrdLessThan %bool %29 %float_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +OpStore %41 %true +OpStore %36 %float_0 +OpBranch %35 +%31 = OpLabel +%33 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%34 = OpLoad %float %33 +OpStore %41 %true +OpStore %36 %34 +OpBranch %35 +%35 = OpLabel +%37 = OpLoad %float %36 +OpReturnValue %37 +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%22 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +%24 = OpFunctionCall %float %foo_vf4_ %param +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%false = OpConstantFalse %bool +%_ptr_Function_bool = OpTypePointer Function %bool +%true = OpConstantTrue %bool +%main = OpFunction %void None %10 +%22 = OpLabel +%43 = OpVariable %_ptr_Function_bool Function %false +%44 = OpVariable %_ptr_Function_float Function +%45 = OpVariable %_ptr_Function_float Function +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +OpStore %43 %false +OpSelectionMerge %55 None +OpSwitch %uint_0 %47 +%47 = OpLabel +%48 = OpAccessChain %_ptr_Function_float %param %uint_0 +%49 = OpLoad %float %48 +%50 = OpFOrdLessThan %bool %49 %float_0 +OpSelectionMerge %52 None +OpBranchConditional %50 %51 %52 +%51 = OpLabel +OpStore %43 %true +OpStore %44 %float_0 +OpBranch %55 +%52 = OpLabel +%53 = OpAccessChain %_ptr_Function_float %param %uint_0 +%54 = OpLoad %float %53 +OpStore %43 %true +OpStore %44 %54 +OpBranch %55 +%55 = OpLabel +%56 = OpLoad %float %44 +OpStore %45 %56 +%24 = OpLoad %float %45 +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + // The early return case must be handled by merge-return first. + AddPass<MergeReturnPass>(); + AddPass<InlineExhaustivePass>(); + RunAndCheck(predefs + before + foo, predefs + after + fooMergeReturn); +} + +TEST_F(InlineTest, EarlyReturnNotAppearingLastInFunctionInlined) { + // Example from https://github.com/KhronosGroup/SPIRV-Tools/issues/755 + // + // Original example is derived from: + // + // #version 450 + // + // float foo() { + // if (true) { + // } + // } + // + // void main() { foo(); } + // + // But the order of basic blocks in foo is changed so that the return + // block is listed second-last. There is only one return in the callee + // but it does not appear last. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +)"; + + const std::string foo = + R"(%foo_ = OpFunction %void None %4 +%7 = OpLabel +OpSelectionMerge %8 None +OpBranchConditional %true %9 %8 +%8 = OpLabel +OpReturn +%9 = OpLabel +OpBranch %8 +OpFunctionEnd +)"; + + const std::string fooMergeReturn = + R"(%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%_ptr_Function_bool = OpTypePointer Function %bool +%foo_ = OpFunction %void None %4 +%7 = OpLabel +%18 = OpVariable %_ptr_Function_bool Function %false +OpSelectionMerge %12 None +OpSwitch %uint_0 %13 +%13 = OpLabel +OpSelectionMerge %8 None +OpBranchConditional %true %9 %8 +%8 = OpLabel +OpStore %18 %true +OpBranch %12 +%9 = OpLabel +OpBranch %8 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %4 +%10 = OpLabel +%11 = OpFunctionCall %void %foo_ +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %4 +%10 = OpLabel +%19 = OpVariable %_ptr_Function_bool Function %false +OpStore %19 %false +OpSelectionMerge %24 None +OpSwitch %uint_0 %21 +%21 = OpLabel +OpSelectionMerge %22 None +OpBranchConditional %true %23 %22 +%22 = OpLabel +OpStore %19 %true +OpBranch %24 +%23 = OpLabel +OpBranch %22 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // The early return case must be handled by merge-return first. + AddPass<MergeReturnPass>(); + AddPass<InlineExhaustivePass>(); + RunAndCheck(predefs + foo + before, predefs + fooMergeReturn + after); +} + +TEST_F(InlineTest, CalleeWithSingleReturnNeedsSingleTripLoopWrapper) { + // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/2018 + // + // The callee has a single return, but needs single-trip loop wrapper + // to be inlined because the return is in a selection structure. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_GLF_color +OpExecutionMode %main OriginUpperLeft +OpSource ESSL 310 +OpName %main "main" +OpName %f_ "f(" +OpName %i "i" +OpName %_GLF_color "_GLF_color" +OpDecorate %_GLF_color Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%9 = OpTypeFunction %float +%float_1 = OpConstant %float 1 +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_GLF_color = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%21 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%22 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 +)"; + + const std::string new_predefs = + R"(%_ptr_Function_float = OpTypePointer Function %float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_bool = OpTypePointer Function %bool +)"; + + const std::string main_before = + R"(%main = OpFunction %void None %7 +%23 = OpLabel +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_1 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +OpStore %_GLF_color %21 +%31 = OpFunctionCall %float %f_ +OpBranch %26 +%26 = OpLabel +%32 = OpLoad %int %i +%33 = OpIAdd %int %32 %int_1 +OpStore %i %33 +OpBranch %24 +%25 = OpLabel +OpStore %_GLF_color %22 +OpReturn +OpFunctionEnd +)"; + + const std::string main_after = + R"(%main = OpFunction %void None %7 +%23 = OpLabel +%46 = OpVariable %_ptr_Function_bool Function %false +%47 = OpVariable %_ptr_Function_float Function +%48 = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_1 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +OpStore %_GLF_color %21 +OpStore %46 %false +OpSelectionMerge %53 None +OpSwitch %uint_0 %50 +%50 = OpLabel +OpSelectionMerge %52 None +OpBranchConditional %true %51 %52 +%51 = OpLabel +OpStore %46 %true +OpStore %47 %float_1 +OpBranch %53 +%52 = OpLabel +OpStore %46 %true +OpStore %47 %float_1 +OpBranch %53 +%53 = OpLabel +%54 = OpLoad %float %47 +OpStore %48 %54 +%31 = OpLoad %float %48 +OpBranch %26 +%26 = OpLabel +%32 = OpLoad %int %i +%33 = OpIAdd %int %32 %int_1 +OpStore %i %33 +OpBranch %24 +%25 = OpLabel +OpStore %_GLF_color %22 +OpReturn +OpFunctionEnd +)"; + + const std::string callee = + R"(%f_ = OpFunction %float None %9 +%34 = OpLabel +OpSelectionMerge %35 None +OpBranchConditional %true %36 %35 +%36 = OpLabel +OpReturnValue %float_1 +%35 = OpLabel +OpReturnValue %float_1 +OpFunctionEnd +)"; + + const std::string calleeMergeReturn = + R"(%f_ = OpFunction %float None %9 +%34 = OpLabel +%45 = OpVariable %_ptr_Function_bool Function %false +%39 = OpVariable %_ptr_Function_float Function +OpSelectionMerge %37 None +OpSwitch %uint_0 %41 +%41 = OpLabel +OpSelectionMerge %35 None +OpBranchConditional %true %36 %35 +%36 = OpLabel +OpStore %45 %true +OpStore %39 %float_1 +OpBranch %37 +%35 = OpLabel +OpStore %45 %true +OpStore %39 %float_1 +OpBranch %37 +%37 = OpLabel +%40 = OpLoad %float %39 +OpReturnValue %40 +OpFunctionEnd +)"; + + // The early return case must be handled by merge-return first. + AddPass<MergeReturnPass>(); + AddPass<InlineExhaustivePass>(); + RunAndCheck(predefs + main_before + callee, + predefs + new_predefs + main_after + calleeMergeReturn); +} + +TEST_F(InlineTest, ForwardReferencesInPhiInlined) { + // The basic structure of the test case is like this: + // + // int foo() { + // int result = 1; + // if (true) { + // result = 1; + // } + // return result; + // } + // + // void main() { + // int x = foo(); + // } + // + // but with modifications: Using Phi instead of load/store, and the + // return block in foo appears before the "then" block. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +OpName %x "x" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%8 = OpTypeFunction %int +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int_0 = OpConstant %int 0 +%_ptr_Function_int = OpTypePointer Function %int +)"; + + const std::string callee = + R"(%foo_ = OpFunction %int None %8 +%13 = OpLabel +%14 = OpCopyObject %int %int_0 +OpSelectionMerge %15 None +OpBranchConditional %true %16 %15 +%15 = OpLabel +%17 = OpPhi %int %14 %13 %18 %16 +OpReturnValue %17 +%16 = OpLabel +%18 = OpCopyObject %int %int_0 +OpBranch %15 +OpFunctionEnd +)"; + + const std::string calleeMergeReturn = + R"(%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%_ptr_Function_bool = OpTypePointer Function %bool +%foo_ = OpFunction %int None %8 +%13 = OpLabel +%29 = OpVariable %_ptr_Function_bool Function %false +%22 = OpVariable %_ptr_Function_int Function +OpSelectionMerge %21 None +OpSwitch %uint_0 %24 +%24 = OpLabel +%14 = OpCopyObject %int %int_0 +OpSelectionMerge %15 None +OpBranchConditional %true %16 %15 +%15 = OpLabel +%17 = OpPhi %int %14 %24 %18 %16 +OpStore %29 %true +OpStore %22 %17 +OpBranch %21 +%16 = OpLabel +%18 = OpCopyObject %int %int_0 +OpBranch %15 +%21 = OpLabel +%23 = OpLoad %int %22 +OpReturnValue %23 +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%19 = OpLabel +%x = OpVariable %_ptr_Function_int Function +%20 = OpFunctionCall %int %foo_ +OpStore %x %20 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%19 = OpLabel +%30 = OpVariable %_ptr_Function_bool Function %false +%31 = OpVariable %_ptr_Function_int Function +%32 = OpVariable %_ptr_Function_int Function +%x = OpVariable %_ptr_Function_int Function +OpStore %30 %false +OpSelectionMerge %40 None +OpSwitch %uint_0 %34 +%34 = OpLabel +%35 = OpCopyObject %int %int_0 +OpSelectionMerge %36 None +OpBranchConditional %true %38 %36 +%36 = OpLabel +%37 = OpPhi %int %35 %34 %39 %38 +OpStore %30 %true +OpStore %31 %37 +OpBranch %40 +%38 = OpLabel +%39 = OpCopyObject %int %int_0 +OpBranch %36 +%40 = OpLabel +%41 = OpLoad %int %31 +OpStore %32 %41 +%20 = OpLoad %int %32 +OpStore %x %20 +OpReturn +OpFunctionEnd +)"; + + AddPass<MergeReturnPass>(); + AddPass<InlineExhaustivePass>(); + RunAndCheck(predefs + callee + before, predefs + calleeMergeReturn + after); +} + +TEST_F(InlineTest, DebugSimple) { + // Check that it correctly generates DebugInlinedAt and maps it to DebugScope + // for the inlined function foo(). + const std::string text = R"( +; CHECK: [[main_name:%\d+]] = OpString "main" +; CHECK: [[foo_name:%\d+]] = OpString "foo" +; CHECK: [[dbg_main:%\d+]] = OpExtInst %void {{%\d+}} DebugFunction [[main_name]] {{%\d+}} {{%\d+}} 4 1 {{%\d+}} [[main_name]] FlagIsProtected|FlagIsPrivate 4 [[main:%\d+]] +; CHECK: [[dbg_foo:%\d+]] = OpExtInst %void {{%\d+}} DebugFunction [[foo_name]] {{%\d+}} {{%\d+}} 1 1 {{%\d+}} [[foo_name]] FlagIsProtected|FlagIsPrivate 1 [[foo:%\d+]] +; CHECK: [[foo_bb:%\d+]] = OpExtInst %void {{%\d+}} DebugLexicalBlock {{%\d+}} 1 14 [[dbg_foo]] +; CHECK: [[inlined_at:%\d+]] = OpExtInst %void {{%\d+}} DebugInlinedAt 4 [[dbg_main]] +; CHECK: [[main]] = OpFunction %void None +; CHECK: {{%\d+}} = OpExtInst %void {{%\d+}} DebugScope [[foo_bb]] [[inlined_at]] +; CHECK: [[foo]] = OpFunction %v4float None + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %3 %4 + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + OpSource HLSL 600 %5 + %6 = OpString "float" + %main_name = OpString "main" + %foo_name = OpString "foo" + OpDecorate %3 Location 0 + OpDecorate %4 Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %v4float = OpTypeVector %float 4 + %14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %18 = OpTypeFunction %void + %19 = OpTypeFunction %v4float + %3 = OpVariable %_ptr_Input_v4float Input + %4 = OpVariable %_ptr_Output_v4float Output + %20 = OpExtInst %void %1 DebugSource %5 + %21 = OpExtInst %void %1 DebugCompilationUnit 1 4 %20 HLSL + %22 = OpExtInst %void %1 DebugTypeBasic %6 %uint_32 Float + %23 = OpExtInst %void %1 DebugTypeVector %22 4 + %24 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 %23 + %25 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %23 + %dbg_main = OpExtInst %void %1 DebugFunction %main_name %24 %20 4 1 %21 %main_name FlagIsProtected|FlagIsPrivate 4 %main + %dbg_foo = OpExtInst %void %1 DebugFunction %foo_name %25 %20 1 1 %21 %foo_name FlagIsProtected|FlagIsPrivate 1 %foo + %29 = OpExtInst %void %1 DebugLexicalBlock %20 1 14 %dbg_foo + %main = OpFunction %void None %18 + %30 = OpLabel + %31 = OpExtInst %void %1 DebugScope %dbg_main + %32 = OpFunctionCall %v4float %foo + %33 = OpLoad %v4float %3 + %34 = OpFAdd %v4float %32 %33 + OpStore %4 %34 + OpReturn + OpFunctionEnd + %foo = OpFunction %v4float None %19 + %35 = OpExtInst %void %1 DebugScope %dbg_foo + %36 = OpLabel + %37 = OpExtInst %void %1 DebugScope %29 + OpReturnValue %14 + OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugNested) { + // When function main() calls function zoo() and function zoo() calls + // function bar() and function bar() calls function foo(), check that + // the inline pass correctly generates DebugInlinedAt instructions + // for the nested function calls. + const std::string text = R"( +; CHECK: [[v4f1:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +; CHECK: [[v4f2:%\d+]] = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +; CHECK: [[v4f3:%\d+]] = OpConstantComposite %v4float %float_3 %float_3 %float_3 %float_3 +; CHECK: [[color:%\d+]] = OpVariable %_ptr_Input_v4float Input +; CHECK: [[dbg_main:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction {{%\d+}} {{%\d+}} {{%\d+}} 10 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 10 [[main:%\d+]] +; CHECK: [[dbg_foo:%\d+]] = OpExtInst %void [[ext]] DebugFunction {{%\d+}} {{%\d+}} {{%\d+}} 1 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 1 [[foo:%\d+]] +; CHECK: [[dbg_bar:%\d+]] = OpExtInst %void [[ext]] DebugFunction {{%\d+}} {{%\d+}} {{%\d+}} 4 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 4 [[bar:%\d+]] +; CHECK: [[dbg_zoo:%\d+]] = OpExtInst %void [[ext]] DebugFunction {{%\d+}} {{%\d+}} {{%\d+}} 7 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 7 [[zoo:%\d+]] +; CHECK: [[inlined_to_main:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 600 [[dbg_main]] +; CHECK: [[inlined_to_zoo:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 700 [[dbg_zoo]] [[inlined_to_main]] +; CHECK: [[inlined_to_bar:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 300 [[dbg_bar]] [[inlined_to_zoo]] +; CHECK: [[main]] = OpFunction %void None +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_foo]] [[inlined_to_bar]] +; CHECK-NEXT: OpLine {{%\d+}} 100 0 +; CHECK-NEXT: OpStore {{%\d+}} [[v4f1]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_bar]] [[inlined_to_zoo]] +; CHECK-NEXT: OpLine {{%\d+}} 300 0 +; CHECK-NEXT: [[foo_ret:%\d+]] = OpLoad %v4float +; CHECK-NEXT: OpLine {{%\d+}} 400 0 +; CHECK-NEXT: {{%\d+}} = OpFAdd %v4float [[foo_ret]] [[v4f2]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_zoo]] [[inlined_to_main]] +; CHECK-NEXT: OpLine {{%\d+}} 700 0 +; CHECK-NEXT: [[bar_ret:%\d+]] = OpLoad %v4float +; CHECK-NEXT: {{%\d+}} = OpFAdd %v4float [[bar_ret]] [[v4f3]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_main]] +; CHECK-NEXT: OpLine {{%\d+}} 600 0 +; CHECK-NEXT: [[zoo_ret:%\d+]] = OpLoad %v4float +; CHECK-NEXT: [[color_val:%\d+]] = OpLoad %v4float [[color]] +; CHECK-NEXT: {{%\d+}} = OpFAdd %v4float [[zoo_ret]] [[color_val]] +; CHECK: [[foo]] = OpFunction %v4float None +; CHECK: [[bar]] = OpFunction %v4float None +; CHECK: [[zoo]] = OpFunction %v4float None + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %3 %4 + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + OpSource HLSL 600 %5 + %6 = OpString "float" + %7 = OpString "main" + %8 = OpString "foo" + %9 = OpString "bar" + %10 = OpString "zoo" + OpDecorate %3 Location 0 + OpDecorate %4 Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %float_3 = OpConstant %float 3 + %v4float = OpTypeVector %float 4 + %18 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 + %19 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 + %20 = OpConstantComposite %v4float %float_3 %float_3 %float_3 %float_3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %24 = OpTypeFunction %void + %25 = OpTypeFunction %v4float + %3 = OpVariable %_ptr_Input_v4float Input + %4 = OpVariable %_ptr_Output_v4float Output + %26 = OpExtInst %void %1 DebugSource %5 + %27 = OpExtInst %void %1 DebugCompilationUnit 1 4 %26 HLSL + %28 = OpExtInst %void %1 DebugTypeBasic %6 %uint_32 Float + %29 = OpExtInst %void %1 DebugTypeVector %28 4 + %30 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %29 %29 + %31 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %29 + %32 = OpExtInst %void %1 DebugFunction %7 %30 %26 10 1 %27 %7 FlagIsProtected|FlagIsPrivate 10 %main + %33 = OpExtInst %void %1 DebugFunction %8 %31 %26 1 1 %27 %8 FlagIsProtected|FlagIsPrivate 1 %foo + %35 = OpExtInst %void %1 DebugFunction %9 %31 %26 4 1 %27 %9 FlagIsProtected|FlagIsPrivate 4 %bar + %37 = OpExtInst %void %1 DebugFunction %10 %31 %26 7 1 %27 %10 FlagIsProtected|FlagIsPrivate 7 %zoo + %main = OpFunction %void None %24 + %39 = OpLabel + %40 = OpExtInst %void %1 DebugScope %32 + OpLine %5 600 0 + %41 = OpFunctionCall %v4float %zoo + %42 = OpLoad %v4float %3 + %43 = OpFAdd %v4float %41 %42 + OpStore %4 %43 + OpReturn + OpFunctionEnd + %foo = OpFunction %v4float None %25 + %44 = OpExtInst %void %1 DebugScope %33 + %45 = OpLabel + OpLine %5 100 0 + OpReturnValue %18 + OpFunctionEnd + OpLine %5 200 0 + %bar = OpFunction %v4float None %25 + %46 = OpExtInst %void %1 DebugScope %35 + %47 = OpLabel + OpLine %5 300 0 + %48 = OpFunctionCall %v4float %foo + OpLine %5 400 0 + %49 = OpFAdd %v4float %48 %19 + OpLine %5 500 0 + OpReturnValue %49 + OpFunctionEnd + %zoo = OpFunction %v4float None %25 + %50 = OpExtInst %void %1 DebugScope %37 + %51 = OpLabel + OpLine %5 700 0 + %52 = OpFunctionCall %v4float %bar + %53 = OpFAdd %v4float %52 %20 + OpReturnValue %53 + OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugSimpleHLSLPixelShader) { + const std::string text = R"( +; CHECK: [[dbg_main:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction {{%\d+}} {{%\d+}} {{%\d+}} 1 1 {{%\d+}} {{%\d+}} FlagIsProtected|FlagIsPrivate 1 %src_main +; CHECK: [[lex_blk:%\d+]] = OpExtInst %void [[ext]] DebugLexicalBlock {{%\d+}} 1 47 [[dbg_main]] +; CHECK: %main = OpFunction %void None +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_main]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare {{%\d+}} %param_var_color +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[lex_blk]] +; CHECK: OpLine {{%\d+}} 2 10 +; CHECK: {{%\d+}} = OpLoad %v4float %param_var_color +; CHECK: OpLine {{%\d+}} 2 3 +; CHECK: OpFunctionEnd +; CHECK: %src_main = OpFunction %v4float None + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET + OpExecutionMode %main OriginUpperLeft + %5 = OpString "ps.hlsl" + OpSource HLSL 600 %5 + %14 = OpString "#line 1 \"ps.hlsl\" +float4 main(float4 color : COLOR) : SV_TARGET { + return color; +} +" + %17 = OpString "float" + %21 = OpString "src.main" + %24 = OpString "color" + OpName %in_var_COLOR "in.var.COLOR" + OpName %out_var_SV_TARGET "out.var.SV_TARGET" + OpName %main "main" + OpName %param_var_color "param.var.color" + OpName %src_main "src.main" + OpName %color "color" + OpName %bb_entry "bb.entry" + OpDecorate %in_var_COLOR Location 0 + OpDecorate %out_var_SV_TARGET Location 0 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %27 = OpTypeFunction %void +%_ptr_Function_v4float = OpTypePointer Function %v4float + %33 = OpTypeFunction %v4float %_ptr_Function_v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output + %13 = OpExtInst %void %1 DebugExpression + %15 = OpExtInst %void %1 DebugSource %5 %14 + %16 = OpExtInst %void %1 DebugCompilationUnit 1 4 %15 HLSL + %18 = OpExtInst %void %1 DebugTypeBasic %17 %uint_32 Float + %19 = OpExtInst %void %1 DebugTypeVector %18 4 + %20 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %19 %19 + %22 = OpExtInst %void %1 DebugFunction %21 %20 %15 1 1 %16 %21 FlagIsProtected|FlagIsPrivate 1 %src_main + %25 = OpExtInst %void %1 DebugLocalVariable %24 %19 %15 1 20 %22 FlagIsLocal 0 + %26 = OpExtInst %void %1 DebugLexicalBlock %15 1 47 %22 + %main = OpFunction %void None %27 + %28 = OpLabel +%param_var_color = OpVariable %_ptr_Function_v4float Function + %31 = OpLoad %v4float %in_var_COLOR + OpStore %param_var_color %31 + %32 = OpFunctionCall %v4float %src_main %param_var_color + OpStore %out_var_SV_TARGET %32 + OpReturn + OpFunctionEnd + OpLine %5 1 1 + %src_main = OpFunction %v4float None %33 + %34 = OpExtInst %void %1 DebugScope %22 + %color = OpFunctionParameter %_ptr_Function_v4float + %36 = OpExtInst %void %1 DebugDeclare %25 %color %13 + %bb_entry = OpLabel + %38 = OpExtInst %void %1 DebugScope %26 + OpLine %5 2 10 + %39 = OpLoad %v4float %color + OpLine %5 2 3 + OpReturnValue %39 + OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugDeclareForCalleeFunctionParam) { + // Check that InlinePass correctly generates DebugDeclare instructions + // for callee function's parameters and maps them to corresponding + // local variables of caller function. + const std::string text = R"( +; CHECK: [[add:%\d+]] = OpString "add" +; CHECK: [[a:%\d+]] = OpString "a" +; CHECK: [[b:%\d+]] = OpString "b" +; CHECK: [[dbg_add:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[add]] +; CHECK: [[dbg_a:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[a]] +; CHECK: [[dbg_b:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[b]] +; CHECK: [[inlinedat:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 5 +; CHECK: OpStore [[param_a:%\d+]] +; CHECK: OpStore [[param_b:%\d+]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_add]] [[inlinedat]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare [[dbg_a]] [[param_a]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare [[dbg_b]] [[param_b]] + +OpCapability Shader +%ext = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +%file_name = OpString "ps.hlsl" +OpSource HLSL 600 %file_name +%float_name = OpString "float" +%main_name = OpString "main" +%add_name = OpString "add" +%a_name = OpString "a" +%b_name = OpString "b" +OpDecorate %in_var_COLOR Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%v4float = OpTypeVector %float 4 +%v4f1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%v4f2 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%add_fn_type = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%void = OpTypeVoid +%void_fn_type = OpTypeFunction %void +%v4f_fn_type = OpTypeFunction %v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%null_expr = OpExtInst %void %ext DebugExpression +%src = OpExtInst %void %ext DebugSource %file_name +%cu = OpExtInst %void %ext DebugCompilationUnit 1 4 %src HLSL +%dbg_f = OpExtInst %void %ext DebugTypeBasic %float_name %uint_32 Float +%dbg_v4f = OpExtInst %void %ext DebugTypeVector %dbg_f 4 +%main_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f +%add_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f %dbg_v4f +%dbg_main = OpExtInst %void %ext DebugFunction %main_name %main_ty %src 5 1 %cu %main_name FlagIsProtected|FlagIsPrivate 10 %main +%dbg_add = OpExtInst %void %ext DebugFunction %add_name %add_ty %src 1 1 %cu %add_name FlagIsProtected|FlagIsPrivate 1 %add +%dbg_a = OpExtInst %void %ext DebugLocalVariable %a_name %dbg_v4f %src 1 13 %dbg_add FlagIsLocal 0 +%dbg_b = OpExtInst %void %ext DebugLocalVariable %b_name %dbg_v4f %src 1 20 %dbg_add FlagIsLocal 1 +%add_lb = OpExtInst %void %ext DebugLexicalBlock %src 1 23 %dbg_add +%main = OpFunction %void None %void_fn_type +%main_bb = OpLabel +%param_a = OpVariable %_ptr_Function_v4float Function +%param_b = OpVariable %_ptr_Function_v4float Function +%scope0 = OpExtInst %void %ext DebugScope %dbg_main +OpStore %param_a %v4f1 +OpStore %param_b %v4f2 +%result = OpFunctionCall %v4float %add %param_a %param_b +OpStore %out_var_SV_TARGET %result +OpReturn +OpFunctionEnd +%add = OpFunction %v4float None %add_fn_type +%scope1 = OpExtInst %void %ext DebugScope %dbg_add +%a = OpFunctionParameter %_ptr_Function_v4float +%b = OpFunctionParameter %_ptr_Function_v4float +%decl0 = OpExtInst %void %ext DebugDeclare %dbg_a %a %null_expr +%decl1 = OpExtInst %void %ext DebugDeclare %dbg_b %b %null_expr +%add_bb = OpLabel +%scope2 = OpExtInst %void %ext DebugScope %add_lb +%a_val = OpLoad %v4float %a +%b_val = OpLoad %v4float %b +%res = OpFAdd %v4float %a_val %b_val +OpReturnValue %res +OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugDeclareForCalleeLocalVar) { + // Check that InlinePass correctly generates DebugDeclare instructions + // for callee function's local variables and maps them to corresponding + // local variables of caller function. + const std::string text = R"( +; CHECK: [[add:%\d+]] = OpString "add" +; CHECK: [[foo:%\d+]] = OpString "foo" +; CHECK: [[dbg_add:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[add]] +; CHECK: [[dbg_foo:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[foo]] {{%\d+}} {{%\d+}} 2 2 [[dbg_add]] +; CHECK: [[inlinedat:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 5 + +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_add]] [[inlinedat]] +; CHECK: [[new_foo:%\d+]] = OpVariable %_ptr_Function_v4float Function + +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_add]] [[inlinedat]] +; CHECK: [[a_val:%\d+]] = OpLoad %v4float +; CHECK: [[b_val:%\d+]] = OpLoad %v4float +; CHECK: [[res:%\d+]] = OpFAdd %v4float [[a_val]] [[b_val]] +; CHECK: OpStore [[new_foo]] [[res]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare [[dbg_foo]] [[new_foo]] + +OpCapability Shader +%ext = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +%file_name = OpString "ps.hlsl" +OpSource HLSL 600 %file_name +%float_name = OpString "float" +%main_name = OpString "main" +%add_name = OpString "add" +%foo_name = OpString "foo" +OpDecorate %in_var_COLOR Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%v4float = OpTypeVector %float 4 +%v4f1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%v4f2 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%add_fn_type = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%void = OpTypeVoid +%void_fn_type = OpTypeFunction %void +%v4f_fn_type = OpTypeFunction %v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%null_expr = OpExtInst %void %ext DebugExpression +%src = OpExtInst %void %ext DebugSource %file_name +%cu = OpExtInst %void %ext DebugCompilationUnit 1 4 %src HLSL +%dbg_f = OpExtInst %void %ext DebugTypeBasic %float_name %uint_32 Float +%dbg_v4f = OpExtInst %void %ext DebugTypeVector %dbg_f 4 +%main_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f +%add_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f %dbg_v4f +%dbg_main = OpExtInst %void %ext DebugFunction %main_name %main_ty %src 5 1 %cu %main_name FlagIsProtected|FlagIsPrivate 10 %main +%dbg_add = OpExtInst %void %ext DebugFunction %add_name %add_ty %src 1 1 %cu %add_name FlagIsProtected|FlagIsPrivate 1 %add +%dbg_foo = OpExtInst %void %ext DebugLocalVariable %foo_name %dbg_v4f %src 2 2 %dbg_add FlagIsLocal +%main = OpFunction %void None %void_fn_type +%main_bb = OpLabel +%param_a = OpVariable %_ptr_Function_v4float Function +%param_b = OpVariable %_ptr_Function_v4float Function +%scope0 = OpExtInst %void %ext DebugScope %dbg_main +OpStore %param_a %v4f1 +OpStore %param_b %v4f2 +%result = OpFunctionCall %v4float %add %param_a %param_b +OpStore %out_var_SV_TARGET %result +OpReturn +OpFunctionEnd +%add = OpFunction %v4float None %add_fn_type +%scope1 = OpExtInst %void %ext DebugScope %dbg_add +%a = OpFunctionParameter %_ptr_Function_v4float +%b = OpFunctionParameter %_ptr_Function_v4float +%add_bb = OpLabel +%foo = OpVariable %_ptr_Function_v4float Function +%a_val = OpLoad %v4float %a +%b_val = OpLoad %v4float %b +%res = OpFAdd %v4float %a_val %b_val +OpStore %foo %res +%decl = OpExtInst %void %ext DebugDeclare %dbg_foo %foo %null_expr +%foo_val = OpLoad %v4float %foo +OpReturnValue %foo_val +OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugDeclareMultiple) { + // Check that InlinePass correctly generates DebugDeclare instructions + // for callee function's parameters and maps them to corresponding + // local variables of caller function. + const std::string text = R"( +; CHECK: [[add:%\d+]] = OpString "add" +; CHECK: [[a:%\d+]] = OpString "a" +; CHECK: [[b:%\d+]] = OpString "b" +; CHECK: [[dbg_add:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[add]] +; CHECK: [[dbg_a:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[a]] +; CHECK: [[dbg_b:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[b]] +; CHECK: OpFunction +; CHECK-NOT: OpFunctionEnd +; CHECK: OpStore [[param_a:%\d+]] +; CHECK: OpStore [[param_b:%\d+]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_add]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare [[dbg_a]] [[param_a]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugDeclare [[dbg_b]] [[param_b]] +; CHECK: [[a_val:%\d+]] = OpLoad %v4float [[param_a]] +; CHECK: OpStore [[foo:%\d+]] [[a_val]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugValue [[dbg_a]] [[foo]] + +OpCapability Shader +%ext = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +%file_name = OpString "ps.hlsl" +OpSource HLSL 600 %file_name +%float_name = OpString "float" +%main_name = OpString "main" +%add_name = OpString "add" +%a_name = OpString "a" +%b_name = OpString "b" +OpDecorate %in_var_COLOR Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%v4float = OpTypeVector %float 4 +%v4f1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%v4f2 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%add_fn_type = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%void = OpTypeVoid +%void_fn_type = OpTypeFunction %void +%v4f_fn_type = OpTypeFunction %v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%null_expr = OpExtInst %void %ext DebugExpression +%src = OpExtInst %void %ext DebugSource %file_name +%cu = OpExtInst %void %ext DebugCompilationUnit 1 4 %src HLSL +%dbg_f = OpExtInst %void %ext DebugTypeBasic %float_name %uint_32 Float +%dbg_v4f = OpExtInst %void %ext DebugTypeVector %dbg_f 4 +%main_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f +%add_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f %dbg_v4f +%dbg_main = OpExtInst %void %ext DebugFunction %main_name %main_ty %src 5 1 %cu %main_name FlagIsProtected|FlagIsPrivate 10 %main +%dbg_add = OpExtInst %void %ext DebugFunction %add_name %add_ty %src 1 1 %cu %add_name FlagIsProtected|FlagIsPrivate 1 %add +%dbg_a = OpExtInst %void %ext DebugLocalVariable %a_name %dbg_v4f %src 1 13 %dbg_add FlagIsLocal 0 +%dbg_b = OpExtInst %void %ext DebugLocalVariable %b_name %dbg_v4f %src 1 20 %dbg_add FlagIsLocal 1 +%main = OpFunction %void None %void_fn_type +%main_bb = OpLabel +%param_a = OpVariable %_ptr_Function_v4float Function +%param_b = OpVariable %_ptr_Function_v4float Function +%scope0 = OpExtInst %void %ext DebugScope %dbg_main +OpStore %param_a %v4f1 +OpStore %param_b %v4f2 +%result = OpFunctionCall %v4float %add %param_a %param_b +OpStore %out_var_SV_TARGET %result +OpReturn +OpFunctionEnd +%add = OpFunction %v4float None %add_fn_type +%scope1 = OpExtInst %void %ext DebugScope %dbg_add +%a = OpFunctionParameter %_ptr_Function_v4float +%b = OpFunctionParameter %_ptr_Function_v4float +%decl0 = OpExtInst %void %ext DebugDeclare %dbg_a %a %null_expr +%add_bb = OpLabel +%decl1 = OpExtInst %void %ext DebugDeclare %dbg_b %b %null_expr +%foo = OpVariable %_ptr_Function_v4float Function +%a_val = OpLoad %v4float %a +OpStore %foo %a_val +%dbg_val = OpExtInst %void %ext DebugValue %dbg_a %foo %null_expr +%b_val = OpLoad %v4float %b +%res = OpFAdd %v4float %a_val %b_val +OpReturnValue %res +OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, DebugValueForFunctionCallReturn) { + // Check that InlinePass correctly generates DebugValue instruction + // for function call's return value and maps it to a corresponding + // value in the caller function. + const std::string text = R"( +; CHECK: [[main:%\d+]] = OpString "main" +; CHECK: [[add:%\d+]] = OpString "add" +; CHECK: [[result:%\d+]] = OpString "result" +; CHECK: [[dbg_main:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[main]] +; CHECK: [[dbg_add:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[add]] +; CHECK: [[dbg_result:%\d+]] = OpExtInst %void [[ext]] DebugLocalVariable [[result]] {{%\d+}} {{%\d+}} 6 2 [[dbg_main]] +; CHECK: [[inlinedat:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 5 +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_add]] [[inlinedat]] +; CHECK: [[a_val:%\d+]] = OpLoad %v4float +; CHECK: [[b_val:%\d+]] = OpLoad %v4float +; CHECK: [[res:%\d+]] = OpFAdd %v4float [[a_val]] [[b_val]] +; CHECK: OpStore [[new_result:%\d+]] [[res]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_main]] +; CHECK: [[result_val:%\d+]] = OpLoad %v4float [[new_result]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugValue [[dbg_result]] [[result_val]] + +OpCapability Shader +%ext = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +%file_name = OpString "ps.hlsl" +OpSource HLSL 600 %file_name +%float_name = OpString "float" +%main_name = OpString "main" +%add_name = OpString "add" +%result_name = OpString "result" +OpDecorate %in_var_COLOR Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%v4float = OpTypeVector %float 4 +%v4f1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%v4f2 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%add_fn_type = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%void = OpTypeVoid +%void_fn_type = OpTypeFunction %void +%v4f_fn_type = OpTypeFunction %v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%null_expr = OpExtInst %void %ext DebugExpression +%src = OpExtInst %void %ext DebugSource %file_name +%cu = OpExtInst %void %ext DebugCompilationUnit 1 4 %src HLSL +%dbg_f = OpExtInst %void %ext DebugTypeBasic %float_name %uint_32 Float +%dbg_v4f = OpExtInst %void %ext DebugTypeVector %dbg_f 4 +%main_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f +%add_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f %dbg_v4f +%dbg_main = OpExtInst %void %ext DebugFunction %main_name %main_ty %src 5 1 %cu %main_name FlagIsProtected|FlagIsPrivate 10 %main +%dbg_add = OpExtInst %void %ext DebugFunction %add_name %add_ty %src 1 1 %cu %add_name FlagIsProtected|FlagIsPrivate 1 %add +%dbg_result = OpExtInst %void %ext DebugLocalVariable %result_name %dbg_v4f %src 6 2 %dbg_main FlagIsLocal +%main = OpFunction %void None %void_fn_type +%main_bb = OpLabel +%param_a = OpVariable %_ptr_Function_v4float Function +%param_b = OpVariable %_ptr_Function_v4float Function +%scope0 = OpExtInst %void %ext DebugScope %dbg_main +OpStore %param_a %v4f1 +OpStore %param_b %v4f2 +%result = OpFunctionCall %v4float %add %param_a %param_b +%value = OpExtInst %void %ext DebugValue %dbg_result %result %null_expr +OpStore %out_var_SV_TARGET %result +OpReturn +OpFunctionEnd +%add = OpFunction %v4float None %add_fn_type +%scope1 = OpExtInst %void %ext DebugScope %dbg_add +%a = OpFunctionParameter %_ptr_Function_v4float +%b = OpFunctionParameter %_ptr_Function_v4float +%add_bb = OpLabel +%a_val = OpLoad %v4float %a +%b_val = OpLoad %v4float %b +%res = OpFAdd %v4float %a_val %b_val +OpReturnValue %res +OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + +TEST_F(InlineTest, NestedWithAnExistingDebugInlinedAt) { + // When a DebugScope instruction in a callee function already has a + // DebugInlinedAt information, we have to create a recursive + // DebugInlinedAt chain. See inlined_to_zoo and inlined_to_bar in + // the following code. + const std::string text = R"( +; CHECK: [[main:%\d+]] = OpString "main" +; CHECK: [[foo:%\d+]] = OpString "foo" +; CHECK: [[bar:%\d+]] = OpString "bar" +; CHECK: [[zoo:%\d+]] = OpString "zoo" +; CHECK: [[v4f1:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +; CHECK: [[v4f2:%\d+]] = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +; CHECK: [[v4f3:%\d+]] = OpConstantComposite %v4float %float_3 %float_3 %float_3 %float_3 +; CHECK: [[dbg_main:%\d+]] = OpExtInst %void [[ext:%\d+]] DebugFunction [[main]] +; CHECK: [[dbg_foo:%\d+]] = OpExtInst %void [[ext]] DebugFunction [[foo]] +; CHECK: [[dbg_bar:%\d+]] = OpExtInst %void [[ext]] DebugFunction [[bar]] +; CHECK: [[dbg_zoo:%\d+]] = OpExtInst %void [[ext]] DebugFunction [[zoo]] +; CHECK: [[inlined_to_main:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 10 [[dbg_main]] +; CHECK: [[inlined_to_zoo:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 7 [[dbg_zoo]] [[inlined_to_main]] +; CHECK: [[inlined_to_main:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 10 [[dbg_main]] +; CHECK: [[inlined_to_bar:%\d+]] = OpExtInst %void [[ext]] DebugInlinedAt 4 [[dbg_bar]] [[inlined_to_zoo]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_foo]] [[inlined_to_bar]] +; CHECK: OpStore [[foo_ret:%\d+]] [[v4f1]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_bar]] [[inlined_to_zoo]] +; CHECK: [[foo_ret_val:%\d+]] = OpLoad %v4float [[foo_ret]] +; CHECK: [[bar_ret:%\d+]] = OpFAdd %v4float [[foo_ret_val]] [[v4f2]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_zoo]] [[inlined_to_main]] +; CHECK: [[zoo_result:%\d+]] = OpFAdd %v4float [[bar_ret]] [[v4f3]] +; CHECK: OpStore [[zoo_ret:%\d+]] [[zoo_result]] +; CHECK: {{%\d+}} = OpExtInst %void [[ext]] DebugScope [[dbg_main]] +; CHECK: [[zoo_ret_val:%\d+]] = OpLoad %v4float [[zoo_ret]] +; CHECK: {{%\d+}} = OpFAdd %v4float [[zoo_ret_val]] {{%\d+}} + +OpCapability Shader +%ext = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +%file_name = OpString "ps.hlsl" +OpSource HLSL 600 %file_name +%float_name = OpString "float" +%main_name = OpString "main" +%foo_name = OpString "foo" +%bar_name = OpString "bar" +%zoo_name = OpString "zoo" +OpDecorate %in_var_COLOR Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%float_3 = OpConstant %float 3 +%v4float = OpTypeVector %float 4 +%v4f1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%v4f2 = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2 +%v4f3 = OpConstantComposite %v4float %float_3 %float_3 %float_3 %float_3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%void = OpTypeVoid +%void_fn_type = OpTypeFunction %void +%v4f_fn_type = OpTypeFunction %v4float +%in_var_COLOR = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%src = OpExtInst %void %ext DebugSource %file_name +%cu = OpExtInst %void %ext DebugCompilationUnit 1 4 %src HLSL +%dbg_f = OpExtInst %void %ext DebugTypeBasic %float_name %uint_32 Float +%dbg_v4f = OpExtInst %void %ext DebugTypeVector %dbg_f 4 +%main_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f %dbg_v4f +%foo_ty = OpExtInst %void %ext DebugTypeFunction FlagIsProtected|FlagIsPrivate %dbg_v4f +%dbg_main = OpExtInst %void %ext DebugFunction %main_name %main_ty %src 10 1 %cu %main_name FlagIsProtected|FlagIsPrivate 10 %main +%dbg_foo = OpExtInst %void %ext DebugFunction %foo_name %foo_ty %src 1 1 %cu %foo_name FlagIsProtected|FlagIsPrivate 1 %foo +%dbg_bar = OpExtInst %void %ext DebugFunction %bar_name %foo_ty %src 4 1 %cu %bar_name FlagIsProtected|FlagIsPrivate 4 %bar +%dbg_zoo = OpExtInst %void %ext DebugFunction %zoo_name %foo_ty %src 7 1 %cu %zoo_name FlagIsProtected|FlagIsPrivate 7 %zoo +%inlined_to_zoo = OpExtInst %void %ext DebugInlinedAt 7 %dbg_zoo +%main = OpFunction %void None %void_fn_type +%main_bb = OpLabel +%scope0 = OpExtInst %void %ext DebugScope %dbg_main +%zoo_val = OpFunctionCall %v4float %zoo +%color = OpLoad %v4float %in_var_COLOR +%result = OpFAdd %v4float %zoo_val %color +OpStore %out_var_SV_TARGET %result +OpReturn +OpFunctionEnd +%foo = OpFunction %v4float None %v4f_fn_type +%scope1 = OpExtInst %void %ext DebugScope %dbg_foo +%foo_bb = OpLabel +OpReturnValue %v4f1 +OpFunctionEnd +%zoo = OpFunction %v4float None %v4f_fn_type +%scope3 = OpExtInst %void %ext DebugScope %dbg_zoo +%zoo_bb = OpLabel +%scope2 = OpExtInst %void %ext DebugScope %dbg_bar %inlined_to_zoo +%foo_val = OpFunctionCall %v4float %foo +%bar_val = OpFAdd %v4float %foo_val %v4f2 +%scope4 = OpExtInst %void %ext DebugScope %dbg_zoo +%zoo_ret = OpFAdd %v4float %bar_val %v4f3 +OpReturnValue %zoo_ret +OpFunctionEnd +%bar = OpFunction %v4float None %v4f_fn_type +%scope5 = OpExtInst %void %ext DebugScope %dbg_bar +%bar_bb = OpLabel +%foo_val0 = OpFunctionCall %v4float %foo +%bar_ret = OpFAdd %v4float %foo_val0 %v4f2 +OpReturnValue %bar_ret +OpFunctionEnd +)"; + + SinglePassRunAndMatch<InlineExhaustivePass>(text, true); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Empty modules
diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp index 1995c5b..c5b92ef 100644 --- a/test/opt/instruction_test.cpp +++ b/test/opt/instruction_test.cpp
@@ -29,8 +29,8 @@ namespace opt { namespace { -using spvtest::MakeInstruction; using ::testing::Eq; +using spvtest::MakeInstruction; using DescriptorTypeTest = PassTest<::testing::Test>; using OpaqueTypeTest = PassTest<::testing::Test>; using GetBaseTest = PassTest<::testing::Test>; @@ -74,6 +74,18 @@ EXPECT_EQ("abcde", operand.AsString()); } +TEST(InstructionTest, OperandAsLiteralUint64_32bits) { + Operand::OperandData words{0x1234}; + Operand operand(SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, std::move(words)); + EXPECT_EQ(uint64_t(0x1234), operand.AsLiteralUint64()); +} + +TEST(InstructionTest, OperandAsLiteralUint64_64bits) { + Operand::OperandData words{0x1234, 0x89ab}; + Operand operand(SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, std::move(words)); + EXPECT_EQ((uint64_t(0x89ab) << 32 | 0x1234), operand.AsLiteralUint64()); +} + // The words for an OpTypeInt for 32-bit signed integer resulting in Id 44. uint32_t kSampleInstructionWords[] = {(4 << 16) | uint32_t(SpvOpTypeInt), 44, 32, 1}; @@ -327,6 +339,7 @@ %3 = OpVariable %8 UniformConstant %2 = OpFunction %4 None %5 %9 = OpLabel + %10 = OpCopyObject %8 %3 OpReturn OpFunctionEnd )"; @@ -341,7 +354,10 @@ EXPECT_FALSE(type->IsVulkanUniformBuffer()); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_FALSE(variable->IsReadOnlyVariable()); + EXPECT_FALSE(variable->IsReadOnlyPointer()); + + Instruction* object_copy = context->get_def_use_mgr()->GetDef(10); + EXPECT_FALSE(object_copy->IsReadOnlyPointer()); } TEST_F(DescriptorTypeTest, SampledImage) { @@ -363,6 +379,7 @@ %3 = OpVariable %8 UniformConstant %2 = OpFunction %4 None %5 %9 = OpLabel + %10 = OpCopyObject %8 %3 OpReturn OpFunctionEnd )"; @@ -377,7 +394,10 @@ EXPECT_FALSE(type->IsVulkanUniformBuffer()); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_TRUE(variable->IsReadOnlyVariable()); + EXPECT_TRUE(variable->IsReadOnlyPointer()); + + Instruction* object_copy = context->get_def_use_mgr()->GetDef(10); + EXPECT_TRUE(object_copy->IsReadOnlyPointer()); } TEST_F(DescriptorTypeTest, StorageTexelBuffer) { @@ -399,6 +419,7 @@ %3 = OpVariable %8 UniformConstant %2 = OpFunction %4 None %5 %9 = OpLabel + %10 = OpCopyObject %8 %3 OpReturn OpFunctionEnd )"; @@ -413,7 +434,10 @@ EXPECT_FALSE(type->IsVulkanUniformBuffer()); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_FALSE(variable->IsReadOnlyVariable()); + EXPECT_FALSE(variable->IsReadOnlyPointer()); + + Instruction* object_copy = context->get_def_use_mgr()->GetDef(10); + EXPECT_FALSE(object_copy->IsReadOnlyPointer()); } TEST_F(DescriptorTypeTest, StorageBuffer) { @@ -438,6 +462,7 @@ %3 = OpVariable %10 Uniform %2 = OpFunction %4 None %5 %11 = OpLabel + %12 = OpCopyObject %8 %3 OpReturn OpFunctionEnd )"; @@ -452,7 +477,10 @@ EXPECT_FALSE(type->IsVulkanUniformBuffer()); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_FALSE(variable->IsReadOnlyVariable()); + EXPECT_FALSE(variable->IsReadOnlyPointer()); + + Instruction* object_copy = context->get_def_use_mgr()->GetDef(12); + EXPECT_FALSE(object_copy->IsReadOnlyPointer()); } TEST_F(DescriptorTypeTest, UniformBuffer) { @@ -477,6 +505,7 @@ %3 = OpVariable %10 Uniform %2 = OpFunction %4 None %5 %11 = OpLabel + %12 = OpCopyObject %10 %3 OpReturn OpFunctionEnd )"; @@ -491,7 +520,10 @@ EXPECT_TRUE(type->IsVulkanUniformBuffer()); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_TRUE(variable->IsReadOnlyVariable()); + EXPECT_TRUE(variable->IsReadOnlyPointer()); + + Instruction* object_copy = context->get_def_use_mgr()->GetDef(12); + EXPECT_TRUE(object_copy->IsReadOnlyPointer()); } TEST_F(DescriptorTypeTest, NonWritableIsReadOnly) { @@ -517,6 +549,7 @@ %3 = OpVariable %10 Uniform %2 = OpFunction %4 None %5 %11 = OpLabel + %12 = OpCopyObject %8 %3 OpReturn OpFunctionEnd )"; @@ -524,7 +557,107 @@ std::unique_ptr<IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* variable = context->get_def_use_mgr()->GetDef(3); - EXPECT_TRUE(variable->IsReadOnlyVariable()); + EXPECT_TRUE(variable->IsReadOnlyPointer()); + + // This demonstrates that the check for whether a pointer is read-only is not + // precise: copying a NonWritable-decorated variable can yield a pointer that + // the check does not regard as read-only. + Instruction* object_copy = context->get_def_use_mgr()->GetDef(12); + EXPECT_FALSE(object_copy->IsReadOnlyPointer()); +} + +TEST_F(DescriptorTypeTest, AccessChainIntoReadOnlyStructIsReadOnly) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 320 + OpMemberDecorate %3 0 Offset 0 + OpMemberDecorate %3 1 Offset 4 + OpDecorate %3 Block + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %3 = OpTypeStruct %6 %8 + %9 = OpTypePointer PushConstant %3 + %10 = OpVariable %9 PushConstant + %11 = OpConstant %6 0 + %12 = OpTypePointer PushConstant %6 + %13 = OpConstant %6 1 + %14 = OpTypePointer PushConstant %8 + %2 = OpFunction %4 None %5 + %15 = OpLabel + %16 = OpVariable %7 Function + %17 = OpAccessChain %12 %10 %11 + %18 = OpAccessChain %14 %10 %13 + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + Instruction* push_constant_struct_variable = + context->get_def_use_mgr()->GetDef(10); + EXPECT_TRUE(push_constant_struct_variable->IsReadOnlyPointer()); + + Instruction* push_constant_struct_field_0 = + context->get_def_use_mgr()->GetDef(17); + EXPECT_TRUE(push_constant_struct_field_0->IsReadOnlyPointer()); + + Instruction* push_constant_struct_field_1 = + context->get_def_use_mgr()->GetDef(18); + EXPECT_TRUE(push_constant_struct_field_1->IsReadOnlyPointer()); +} + +TEST_F(DescriptorTypeTest, ReadOnlyPointerParameter) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 320 + OpMemberDecorate %3 0 Offset 0 + OpMemberDecorate %3 1 Offset 4 + OpDecorate %3 Block + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %3 = OpTypeStruct %6 %8 + %9 = OpTypePointer PushConstant %3 + %10 = OpVariable %9 PushConstant + %11 = OpConstant %6 0 + %12 = OpTypePointer PushConstant %6 + %13 = OpConstant %6 1 + %14 = OpTypePointer PushConstant %8 + %15 = OpTypeFunction %4 %9 + %2 = OpFunction %4 None %5 + %16 = OpLabel + %17 = OpVariable %7 Function + %18 = OpAccessChain %12 %10 %11 + %19 = OpAccessChain %14 %10 %13 + OpReturn + OpFunctionEnd + %20 = OpFunction %4 None %15 + %21 = OpFunctionParameter %9 + %22 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + Instruction* push_constant_struct_parameter = + context->get_def_use_mgr()->GetDef(21); + EXPECT_TRUE(push_constant_struct_parameter->IsReadOnlyPointer()); } TEST_F(OpaqueTypeTest, BaseOpaqueTypesShader) {
diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp index d5710fc..e72561c 100644 --- a/test/opt/ir_context_test.cpp +++ b/test/opt/ir_context_test.cpp
@@ -19,6 +19,7 @@ #include <string> #include <utility> +#include "OpenCLDebugInfo100.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "source/opt/pass.h" @@ -372,6 +373,123 @@ EXPECT_TRUE(context->annotations().empty()); } +TEST_F(IRContextTest, KillFunctionFromDebugFunction) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + %3 = OpString "ps.hlsl" + %4 = OpString "foo" + OpSource HLSL 600 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %7 = OpExtInst %void %1 DebugSource %3 + %8 = OpExtInst %void %1 DebugCompilationUnit 1 4 %7 HLSL + %9 = OpExtInst %void %1 DebugTypeFunction FlagIsProtected|FlagIsPrivate %void + %10 = OpExtInst %void %1 DebugFunction %4 %9 %7 1 1 %8 %4 FlagIsProtected|FlagIsPrivate 1 %11 + %2 = OpFunction %void None %6 + %12 = OpLabel + OpReturn + OpFunctionEnd + %11 = OpFunction %void None %6 + %13 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Delete the second variable. + context->KillDef(11); + + // Get DebugInfoNone id. + uint32_t debug_info_none_id = 0; + for (auto it = context->ext_inst_debuginfo_begin(); + it != context->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugInfoNone) { + debug_info_none_id = it->result_id(); + } + } + EXPECT_NE(0, debug_info_none_id); + + // Check the Function operand of DebugFunction is DebugInfoNone. + const uint32_t kDebugFunctionOperandFunctionIndex = 13; + bool checked = false; + for (auto it = context->ext_inst_debuginfo_begin(); + it != context->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugFunction) { + EXPECT_FALSE(checked); + EXPECT_EQ(it->GetOperand(kDebugFunctionOperandFunctionIndex).words[0], + debug_info_none_id); + checked = true; + } + } + EXPECT_TRUE(checked); +} + +TEST_F(IRContextTest, KillVariableFromDebugGlobalVariable) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + %3 = OpString "ps.hlsl" + %4 = OpString "foo" + %5 = OpString "int" + OpSource HLSL 600 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 +%_ptr_Private_uint = OpTypePointer Private %uint + %void = OpTypeVoid + %10 = OpTypeFunction %void + %11 = OpVariable %_ptr_Private_uint Private + %12 = OpExtInst %void %1 DebugSource %3 + %13 = OpExtInst %void %1 DebugCompilationUnit 1 4 %12 HLSL + %14 = OpExtInst %void %1 DebugTypeBasic %5 %uint_32 Signed + %15 = OpExtInst %void %1 DebugGlobalVariable %4 %14 %12 1 12 %13 %4 %11 FlagIsDefinition + %2 = OpFunction %void None %10 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Delete the second variable. + context->KillDef(11); + + // Get DebugInfoNone id. + uint32_t debug_info_none_id = 0; + for (auto it = context->ext_inst_debuginfo_begin(); + it != context->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugInfoNone) { + debug_info_none_id = it->result_id(); + } + } + EXPECT_NE(0, debug_info_none_id); + + // Check the Function operand of DebugFunction is DebugInfoNone. + const uint32_t kDebugGlobalVariableOperandVariableIndex = 11; + bool checked = false; + for (auto it = context->ext_inst_debuginfo_begin(); + it != context->ext_inst_debuginfo_end(); ++it) { + if (it->GetOpenCL100DebugOpcode() == + OpenCLDebugInfo100DebugGlobalVariable) { + EXPECT_FALSE(checked); + EXPECT_EQ( + it->GetOperand(kDebugGlobalVariableOperandVariableIndex).words[0], + debug_info_none_id); + checked = true; + } + } + EXPECT_TRUE(checked); +} + TEST_F(IRContextTest, BasicVisitFromEntryPoint) { // Make sure we visit the entry point, and the function it calls. // Do not visit Dead or Exported.
diff --git a/test/opt/ir_loader_test.cpp b/test/opt/ir_loader_test.cpp index 50e3a08..2104492 100644 --- a/test/opt/ir_loader_test.cpp +++ b/test/opt/ir_loader_test.cpp
@@ -948,6 +948,42 @@ EXPECT_EQ(text, disassembled_text); } +TEST(IrBuilder, DebugInfoForTerminationInsts) { + // Check that DebugScope instructions for termination instructions are + // preserved. + DoRoundTripCheck(R"(OpCapability Shader +%1 = OpExtInstImport "OpenCL.DebugInfo.100" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%3 = OpString "simple_vs.hlsl" +OpSource HLSL 600 %3 +OpName %main "main" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%6 = OpExtInst %void %1 DebugSource %3 +%7 = OpExtInst %void %1 DebugCompilationUnit 2 4 %6 HLSL +%main = OpFunction %void None %5 +%8 = OpLabel +%20 = OpExtInst %void %1 DebugScope %7 +OpBranch %10 +%21 = OpExtInst %void %1 DebugNoScope +%10 = OpLabel +%22 = OpExtInst %void %1 DebugScope %7 +OpKill +%23 = OpExtInst %void %1 DebugNoScope +%14 = OpLabel +%24 = OpExtInst %void %1 DebugScope %7 +OpUnreachable +%25 = OpExtInst %void %1 DebugNoScope +%17 = OpLabel +%26 = OpExtInst %void %1 DebugScope %7 +OpReturn +%27 = OpExtInst %void %1 DebugNoScope +OpFunctionEnd +)"); +} + TEST(IrBuilder, LocalGlobalVariables) { // #version 310 es //
diff --git a/test/opt/local_ssa_elim_test.cpp b/test/opt/local_ssa_elim_test.cpp index 7afbb4c..d29a554 100644 --- a/test/opt/local_ssa_elim_test.cpp +++ b/test/opt/local_ssa_elim_test.cpp
@@ -1998,6 +1998,32 @@ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); } +TEST_F(LocalSSAElimTest, OpConstantNull) { + const std::string text = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %4 "A" +OpSource OpenCL_C 200000 +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 0 +%11 = OpTypePointer CrossWorkgroup %6 +%16 = OpConstantNull %11 +%20 = OpConstant %6 269484031 +%4 = OpFunction %2 None %3 +%17 = OpLabel +%18 = OpLoad %6 %16 Aligned 536870912 +%19 = OpBitwiseXor %6 %18 %20 +OpStore %16 %19 Aligned 536870912 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunToBinary<SSARewritePass>(text, false); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // No optimization in the presence of
diff --git a/test/opt/loop_optimizations/loop_descriptions.cpp b/test/opt/loop_optimizations/loop_descriptions.cpp index 91dbdc6..4d2f989 100644 --- a/test/opt/loop_optimizations/loop_descriptions.cpp +++ b/test/opt/loop_optimizations/loop_descriptions.cpp
@@ -379,6 +379,43 @@ EXPECT_EQ(loop.GetLatchBlock()->id(), 30u); } +TEST_F(PassClassTest, UnreachableMerge) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %1 = OpFunction %void None %3 + %4 = OpLabel + OpBranch %5 + %5 = OpLabel + OpLoopMerge %6 %7 None + OpBranch %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %7 + %7 = OpLabel + OpBranch %5 + %6 = OpLabel + OpUnreachable + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 1); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 1u); +} + } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp index d16b65c..e55a48f 100644 --- a/test/opt/pass_merge_return_test.cpp +++ b/test/opt/pass_merge_return_test.cpp
@@ -1970,6 +1970,113 @@ EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); } +TEST_F(MergeReturnPassTest, SingleReturnInMiddle) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: OpReturn +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpName %main "main" + OpName %foo_ "foo(" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %foo_ = OpFunction %void None %4 + %7 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %true %9 %8 + %8 = OpLabel + OpReturn + %9 = OpLabel + OpBranch %8 + OpFunctionEnd + %main = OpFunction %void None %4 + %10 = OpLabel + %11 = OpFunctionCall %void %foo_ + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<MergeReturnPass>(before, false); +} + +TEST_F(MergeReturnPassTest, PhiWithTooManyEntries) { + // Check that the OpPhi node has the correct number of entries. This is + // checked by doing validation with the match. + const std::string before = + R"( +; CHECK: OpLoopMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi %int {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + %void = OpTypeVoid + %4 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %6 = OpTypeFunction %int + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %false = OpConstantFalse %bool + %2 = OpFunction %void None %4 + %10 = OpLabel + %11 = OpFunctionCall %int %12 + OpReturn + OpFunctionEnd + %12 = OpFunction %int None %6 + %13 = OpLabel + OpBranch %14 + %14 = OpLabel + %15 = OpPhi %int %int_1 %13 %16 %17 + OpLoopMerge %18 %17 None + OpBranch %19 + %19 = OpLabel + %20 = OpUndef %bool + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + OpSelectionMerge %25 None + OpBranchConditional %20 %22 %25 + %25 = OpLabel + OpReturnValue %int_1 + %23 = OpLabel + OpBranch %21 + %22 = OpLabel + OpSelectionMerge %26 None + OpBranchConditional %20 %27 %26 + %27 = OpLabel + OpBranch %28 + %28 = OpLabel + OpLoopMerge %29 %30 None + OpBranch %31 + %31 = OpLabel + OpReturnValue %int_1 + %30 = OpLabel + OpBranch %28 + %29 = OpLabel + OpUnreachable + %26 = OpLabel + OpBranch %17 + %17 = OpLabel + %16 = OpPhi %int %15 %26 + OpBranchConditional %false %14 %18 + %18 = OpLabel + OpReturnValue %16 + OpFunctionEnd +)"; + + SinglePassRunAndMatch<MergeReturnPass>(before, true); +} + } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/opt/struct_cfg_analysis_test.cpp b/test/opt/struct_cfg_analysis_test.cpp index 0451a8b..2d1deec 100644 --- a/test/opt/struct_cfg_analysis_test.cpp +++ b/test/opt/struct_cfg_analysis_test.cpp
@@ -1369,6 +1369,35 @@ auto c = analysis.FindFuncsCalledFromContinue(); EXPECT_THAT(c, UnorderedElementsAre(14u, 16u, 21u)); } + +TEST_F(StructCFGAnalysisTest, SingleBlockLoop) { + const std::string text = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %void = OpTypeVoid + %bool = OpTypeBool + %undef = OpUndef %bool + %void_fn = OpTypeFunction %void + %main = OpFunction %void None %void_fn + %2 = OpLabel + OpBranch %3 + %3 = OpLabel + OpLoopMerge %4 %3 None + OpBranchConditional %undef %3 %4 + %4 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + EXPECT_TRUE(analysis.IsInContinueConstruct(3)); +} } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/opt/value_table_test.cpp b/test/opt/value_table_test.cpp index a0942cc..76e7f73 100644 --- a/test/opt/value_table_test.cpp +++ b/test/opt/value_table_test.cpp
@@ -684,6 +684,50 @@ vtable.GetValueNumber(inst); } +TEST_F(ValueTableTest, RedundantSampledImageLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %gl_FragColor + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %tex0 "tex0" + OpName %gl_FragColor "gl_FragColor" + OpDecorate %tex0 Location 0 + OpDecorate %tex0 DescriptorSet 0 + OpDecorate %tex0 Binding 0 + OpDecorate %gl_FragColor Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %9 = OpTypeImage %float 2D 0 0 0 1 Unknown + %10 = OpTypeSampledImage %9 +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %tex0 = OpVariable %_ptr_UniformConstant_10 UniformConstant +%_ptr_Output_v4float = OpTypePointer Output %v4float + %13 = OpConstantNull %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output + %14 = OpUndef %v4float + %main = OpFunction %void None %6 + %15 = OpLabel + %16 = OpLoad %10 %tex0 + %17 = OpImageSampleProjImplicitLod %v4float %16 %13 + %18 = OpImageSampleProjImplicitLod %v4float %16 %13 + %19 = OpFAdd %v4float %18 %17 + OpStore %gl_FragColor %19 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* load1 = context->get_def_use_mgr()->GetDef(17); + Instruction* load2 = context->get_def_use_mgr()->GetDef(18); + EXPECT_EQ(vtable.GetValueNumber(load1), vtable.GetValueNumber(load2)); +} + } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp index d50af28..33e52f0 100644 --- a/test/opt/wrap_opkill_test.cpp +++ b/test/opt/wrap_opkill_test.cpp
@@ -513,6 +513,139 @@ EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } +TEST_F(WrapOpKillTest, SetParentBlock) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%undef = OpUndef %bool +%void_fn = OpTypeFunction %void +%main = OpFunction %void None %void_fn +%entry = OpLabel +OpBranch %loop +%loop = OpLabel +OpLoopMerge %merge %continue None +OpBranchConditional %undef %merge %continue +%continue = OpLabel +%call = OpFunctionCall %void %kill_func +OpBranch %loop +%merge = OpLabel +OpReturn +OpFunctionEnd +%kill_func = OpFunction %void None %void_fn +%kill_entry = OpLabel +OpKill +OpFunctionEnd +)"; + + auto result = SinglePassRunToBinary<WrapOpKill>(text, true); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); + result = SinglePassRunToBinary<WrapOpKill>(text, true); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +TEST_F(WrapOpKillTest, KillInSingleBlockLoop) { + const std::string text = R"( +; CHECK: OpFunction %void +; CHECK: OpFunction %void +; CHECK-NOT: OpKill +; CHECK: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NOT: OpKill +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %void = OpTypeVoid + %bool = OpTypeBool + %undef = OpUndef %bool + %void_fn = OpTypeFunction %void + %main = OpFunction %void None %void_fn +%main_entry = OpLabel + OpBranch %loop + %loop = OpLabel + %call = OpFunctionCall %void %sub + OpLoopMerge %exit %loop None + OpBranchConditional %undef %loop %exit + %exit = OpLabel + OpReturn + OpFunctionEnd + %sub = OpFunction %void None %void_fn + %sub_entry = OpLabel + OpSelectionMerge %ret None + OpBranchConditional %undef %kill %ret + %kill = OpLabel + OpKill + %ret = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<WrapOpKill>(text, true); +} + +TEST_F(WrapOpKillTest, DebugInfoSimple) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill:%\w+]] +; CHECK: [[orig_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: {{%\d+}} = OpExtInst %void [[ext:%\d+]] DebugScope +; CHECK-NEXT: OpLine [[file:%\d+]] 100 200 +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: {{%\d+}} = OpExtInst %void [[ext]] DebugNoScope +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "OpenCL.DebugInfo.100" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + %2 = OpString "File name" + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %3 = OpExtInst %void %1 DebugSource %2 + %4 = OpExtInst %void %1 DebugCompilationUnit 0 0 %3 GLSL + %main = OpFunction %void None %5 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %10 %11 None + OpBranch %12 + %12 = OpLabel + OpBranchConditional %true %13 %10 + %13 = OpLabel + OpBranch %11 + %11 = OpLabel + %14 = OpFunctionCall %void %kill_ + OpBranch %9 + %10 = OpLabel + OpReturn + OpFunctionEnd + %kill_ = OpFunction %void None %5 + %15 = OpLabel + %16 = OpExtInst %void %1 DebugScope %4 + OpLine %2 100 200 + OpKill + OpFunctionEnd + )"; + + SinglePassRunAndMatch<WrapOpKill>(text, true); +} + } // namespace } // namespace opt } // namespace spvtools
diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt index b19bba4..652f0ab 100644 --- a/test/reduce/CMakeLists.txt +++ b/test/reduce/CMakeLists.txt
@@ -24,7 +24,8 @@ remove_block_test.cpp remove_function_test.cpp remove_selection_test.cpp - remove_unreferenced_instruction_test.cpp + remove_unused_instruction_test.cpp + remove_unused_struct_member_test.cpp structured_loop_to_selection_test.cpp validation_during_reduction_test.cpp conditional_branch_to_simple_conditional_branch_test.cpp
diff --git a/test/reduce/reducer_test.cpp b/test/reduce/reducer_test.cpp index 59f2803..0de5af1 100644 --- a/test/reduce/reducer_test.cpp +++ b/test/reduce/reducer_test.cpp
@@ -16,7 +16,7 @@ #include "source/opt/build_module.h" #include "source/reduce/operand_to_const_reduction_opportunity_finder.h" -#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h" #include "test/reduce/reduce_test_util.h" namespace spvtools { @@ -157,35 +157,17 @@ OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %4 "main" %60 + OpEntryPoint Fragment %4 "main" OpExecutionMode %4 OriginUpperLeft - OpMemberDecorate %16 0 Offset 0 - OpDecorate %16 Block - OpDecorate %18 DescriptorSet 0 - OpDecorate %18 Binding 2 - OpMemberDecorate %25 0 Offset 0 - OpDecorate %25 Block - OpDecorate %27 DescriptorSet 0 - OpDecorate %27 Binding 1 - OpDecorate %60 Location 0 %2 = OpTypeVoid %3 = OpTypeFunction %2 %6 = OpTypeInt 32 1 %9 = OpConstant %6 0 - %16 = OpTypeStruct %6 - %17 = OpTypePointer Uniform %16 - %18 = OpVariable %17 Uniform %22 = OpTypeBool %100 = OpConstantTrue %22 %24 = OpTypeFloat 32 - %25 = OpTypeStruct %24 - %26 = OpTypePointer Uniform %25 - %27 = OpVariable %26 Uniform %31 = OpConstant %24 2 %56 = OpConstant %6 1 - %58 = OpTypeVector %24 4 - %59 = OpTypePointer Output %58 - %60 = OpVariable %59 Output %72 = OpUndef %24 %74 = OpUndef %6 %4 = OpFunction %2 None %3 @@ -218,8 +200,7 @@ return ping_pong_interesting.IsInteresting(binary); }); reducer.AddReductionPass( - MakeUnique<RemoveUnreferencedInstructionReductionOpportunityFinder>( - false)); + MakeUnique<RemoveUnusedInstructionReductionOpportunityFinder>(false)); reducer.AddReductionPass( MakeUnique<OperandToConstReductionOpportunityFinder>());
diff --git a/test/reduce/remove_unreferenced_instruction_test.cpp b/test/reduce/remove_unused_instruction_test.cpp similarity index 62% rename from test/reduce/remove_unreferenced_instruction_test.cpp rename to test/reduce/remove_unused_instruction_test.cpp index 3caf88c..68bc601 100644 --- a/test/reduce/remove_unreferenced_instruction_test.cpp +++ b/test/reduce/remove_unused_instruction_test.cpp
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "source/reduce/remove_unreferenced_instruction_reduction_opportunity_finder.h" +#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h" #include "source/opt/build_module.h" #include "source/reduce/reduction_opportunity.h" @@ -25,11 +25,11 @@ const spv_target_env kEnv = SPV_ENV_UNIVERSAL_1_3; -TEST(RemoveUnreferencedInstructionReductionPassTest, RemoveStores) { +TEST(RemoveUnusedInstructionReductionPassTest, RemoveStores) { // A module with some unused instructions, including some unused OpStore // instructions. - RemoveUnreferencedInstructionReductionOpportunityFinder finder(true); + RemoveUnusedInstructionReductionOpportunityFinder finder(true); const std::string original = R"( OpCapability Shader @@ -223,11 +223,11 @@ ASSERT_EQ(0, ops.size()); } -TEST(RemoveUnreferencedInstructionReductionPassTest, Referenced) { +TEST(RemoveUnusedInstructionReductionPassTest, Referenced) { // A module with some unused global variables, constants, and types. Some will // not be removed initially because of the OpDecorate instructions. - RemoveUnreferencedInstructionReductionOpportunityFinder finder(true); + RemoveUnusedInstructionReductionOpportunityFinder finder(true); const std::string shader = R"( OpCapability Shader @@ -375,6 +375,189 @@ ASSERT_EQ(0, ops.size()); } +TEST(RemoveUnusedResourceVariableTest, RemoveUnusedResourceVariables) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 Block + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 1 + OpMemberDecorate %16 0 Offset 0 + OpMemberDecorate %16 1 Offset 4 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 0 + OpMemberDecorate %19 0 Offset 0 + OpDecorate %19 BufferBlock + OpDecorate %21 DescriptorSet 1 + OpDecorate %21 Binding 0 + OpMemberDecorate %22 0 Offset 0 + OpDecorate %22 Block + OpDecorate %29 DescriptorSet 1 + OpDecorate %29 Binding 1 + OpDecorate %32 DescriptorSet 1 + OpDecorate %32 Binding 2 + OpDecorate %32 NonReadable + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpTypeStruct %6 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %13 = OpTypePointer Uniform %6 + %16 = OpTypeStruct %6 %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypeStruct %6 + %20 = OpTypePointer Uniform %19 + %21 = OpVariable %20 Uniform + %22 = OpTypeStruct %6 + %23 = OpTypePointer PushConstant %22 + %24 = OpVariable %23 PushConstant + %25 = OpTypeFloat 32 + %26 = OpTypeImage %25 2D 0 0 0 1 Unknown + %27 = OpTypeSampledImage %26 + %28 = OpTypePointer UniformConstant %27 + %29 = OpVariable %28 UniformConstant + %30 = OpTypeImage %25 2D 0 0 0 2 Unknown + %31 = OpTypePointer UniformConstant %30 + %32 = OpVariable %31 UniformConstant + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + + auto ops = RemoveUnusedInstructionReductionOpportunityFinder(true) + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(7, ops.size()); + + for (auto& op : ops) { + ASSERT_TRUE(op->PreconditionHolds()); + op->TryToApply(); + } + + std::string expected_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 Block + OpMemberDecorate %16 0 Offset 0 + OpMemberDecorate %16 1 Offset 4 + OpDecorate %16 Block + OpMemberDecorate %19 0 Offset 0 + OpDecorate %19 BufferBlock + OpMemberDecorate %22 0 Offset 0 + OpDecorate %22 Block + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpTypeStruct %6 + %10 = OpTypePointer Uniform %9 + %16 = OpTypeStruct %6 %6 + %17 = OpTypePointer Uniform %16 + %19 = OpTypeStruct %6 + %20 = OpTypePointer Uniform %19 + %22 = OpTypeStruct %6 + %23 = OpTypePointer PushConstant %22 + %25 = OpTypeFloat 32 + %26 = OpTypeImage %25 2D 0 0 0 1 Unknown + %27 = OpTypeSampledImage %26 + %28 = OpTypePointer UniformConstant %27 + %30 = OpTypeImage %25 2D 0 0 0 2 Unknown + %31 = OpTypePointer UniformConstant %30 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, expected_1, context.get()); + + ops = RemoveUnusedInstructionReductionOpportunityFinder(true) + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(6, ops.size()); + + for (auto& op : ops) { + ASSERT_TRUE(op->PreconditionHolds()); + op->TryToApply(); + } + + std::string expected_2 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 Block + OpMemberDecorate %16 0 Offset 0 + OpMemberDecorate %16 1 Offset 4 + OpDecorate %16 Block + OpMemberDecorate %19 0 Offset 0 + OpDecorate %19 BufferBlock + OpMemberDecorate %22 0 Offset 0 + OpDecorate %22 Block + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpTypeStruct %6 + %16 = OpTypeStruct %6 %6 + %19 = OpTypeStruct %6 + %22 = OpTypeStruct %6 + %25 = OpTypeFloat 32 + %26 = OpTypeImage %25 2D 0 0 0 1 Unknown + %27 = OpTypeSampledImage %26 + %30 = OpTypeImage %25 2D 0 0 0 2 Unknown + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, expected_2, context.get()); + + ops = RemoveUnusedInstructionReductionOpportunityFinder(true) + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(6, ops.size()); + + for (auto& op : ops) { + ASSERT_TRUE(op->PreconditionHolds()); + op->TryToApply(); + } + + std::string expected_3 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" + OpExecutionMode %4 LocalSize 1 1 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %25 = OpTypeFloat 32 + %26 = OpTypeImage %25 2D 0 0 0 1 Unknown + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, expected_3, context.get()); +} + } // namespace } // namespace reduce } // namespace spvtools
diff --git a/test/reduce/remove_unused_struct_member_test.cpp b/test/reduce/remove_unused_struct_member_test.cpp new file mode 100644 index 0000000..402ef2d --- /dev/null +++ b/test/reduce/remove_unused_struct_member_test.cpp
@@ -0,0 +1,238 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/remove_unused_struct_member_reduction_opportunity_finder.h" + +#include "source/opt/build_module.h" +#include "source/reduce/reduction_opportunity.h" +#include "test/reduce/reduce_test_util.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(RemoveUnusedStructMemberTest, RemoveOneMember) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeStruct %6 %6 + %8 = OpTypePointer Function %7 + %50 = OpConstant %6 0 + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstantComposite %7 %10 %11 + %13 = OpConstant %6 4 + %14 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %9 = OpVariable %8 Function + OpStore %9 %12 + %15 = OpAccessChain %14 %9 %10 + %22 = OpInBoundsAccessChain %14 %9 %10 + %20 = OpLoad %7 %9 + %21 = OpCompositeExtract %6 %20 1 + %23 = OpCompositeInsert %7 %10 %20 1 + OpStore %15 %13 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + + auto ops = RemoveUnusedStructMemberReductionOpportunityFinder() + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeStruct %6 + %8 = OpTypePointer Function %7 + %50 = OpConstant %6 0 + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstantComposite %7 %11 + %13 = OpConstant %6 4 + %14 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %9 = OpVariable %8 Function + OpStore %9 %12 + %15 = OpAccessChain %14 %9 %50 + %22 = OpInBoundsAccessChain %14 %9 %50 + %20 = OpLoad %7 %9 + %21 = OpCompositeExtract %6 %20 0 + %23 = OpCompositeInsert %7 %10 %20 0 + OpStore %15 %13 + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, expected, context.get()); +} + +TEST(RemoveUnusedStructMemberTest, RemoveUniformBufferMember) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %10 0 Offset 0 + OpMemberDecorate %10 1 Offset 4 + OpDecorate %10 Block + OpDecorate %12 DescriptorSet 0 + OpDecorate %12 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %9 = OpTypeInt 32 1 + %10 = OpTypeStruct %9 %6 + %11 = OpTypePointer Uniform %10 + %12 = OpVariable %11 Uniform + %13 = OpConstant %9 1 + %20 = OpConstant %9 0 + %14 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %15 = OpAccessChain %14 %12 %13 + %16 = OpLoad %6 %15 + OpStore %8 %16 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + + auto ops = RemoveUnusedStructMemberReductionOpportunityFinder() + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %10 0 Offset 4 + OpDecorate %10 Block + OpDecorate %12 DescriptorSet 0 + OpDecorate %12 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %9 = OpTypeInt 32 1 + %10 = OpTypeStruct %6 + %11 = OpTypePointer Uniform %10 + %12 = OpVariable %11 Uniform + %13 = OpConstant %9 1 + %20 = OpConstant %9 0 + %14 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %15 = OpAccessChain %14 %12 %20 + %16 = OpLoad %6 %15 + OpStore %8 %16 + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, expected, context.get()); +} + +TEST(RemoveUnusedStructMemberTest, DoNotRemoveNamedMemberRemoveOneMember) { + // This illustrates that naming a member is enough to prevent its removal. + // Removal of names is done by a different pass. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpMemberName %7 0 "someName" + OpMemberName %7 1 "someOtherName" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeStruct %6 %6 + %8 = OpTypePointer Function %7 + %50 = OpConstant %6 0 + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstantComposite %7 %10 %11 + %13 = OpConstant %6 4 + %14 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %9 = OpVariable %8 Function + OpStore %9 %12 + %15 = OpAccessChain %14 %9 %10 + %22 = OpInBoundsAccessChain %14 %9 %10 + %20 = OpLoad %7 %9 + %21 = OpCompositeExtract %6 %20 1 + %23 = OpCompositeInsert %7 %10 %20 1 + OpStore %15 %13 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + + auto ops = RemoveUnusedStructMemberReductionOpportunityFinder() + .GetAvailableOpportunities(context.get()); + ASSERT_EQ(0, ops.size()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools
diff --git a/test/tools/opt/flags.py b/test/tools/opt/flags.py index 2f6c0a7..f7dc64c 100644 --- a/test/tools/opt/flags.py +++ b/test/tools/opt/flags.py
@@ -73,7 +73,10 @@ '--remove-duplicates', '--replace-invalid-opcode', '--ssa-rewrite', '--scalar-replacement', '--scalar-replacement=42', '--strength-reduction', '--strip-debug', '--strip-reflect', '--vector-dce', '--workaround-1209', - '--unify-const' + '--unify-const', '--legalize-vector-shuffle', + '--split-invalid-unreachable', '--generate-webgpu-initializers', + '--decompose-initialized-variables', '--graphics-robust-access', + '--wrap-opkill', '--amd-ext-to-khr' ] expected_passes = [ 'wrap-opkill', @@ -120,7 +123,14 @@ 'strip-reflect', 'vector-dce', 'workaround-1209', - 'unify-const' + 'unify-const', + 'legalize-vector-shuffle', + 'split-invalid-unreachable', + 'generate-webgpu-initializers', + 'decompose-initialized-variables', + 'graphics-robust-access', + 'wrap-opkill', + 'amd-ext-to-khr' ] shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') output = placeholder.TempFileName('output.spv') @@ -153,9 +163,18 @@ 'eliminate-dead-code-aggressive', 'ccp', 'eliminate-dead-code-aggressive', + 'loop-unroll', + 'eliminate-dead-branches', 'redundancy-elimination', 'combine-access-chains', 'simplify-instructions', + 'scalar-replacement=100', + 'convert-local-access-chains', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'ssa-rewrite', + 'eliminate-dead-code-aggressive', 'vector-dce', 'eliminate-dead-inserts', 'eliminate-dead-branches',
diff --git a/test/val/val_barriers_test.cpp b/test/val/val_barriers_test.cpp index fa2b153..3643883 100644 --- a/test/val/val_barriers_test.cpp +++ b/test/val/val_barriers_test.cpp
@@ -70,6 +70,7 @@ %subgroup = OpConstant %u32 3 %invocation = OpConstant %u32 4 %queuefamily = OpConstant %u32 5 +%shadercall = OpConstant %u32 6 %none = OpConstant %u32 0 %acquire = OpConstant %u32 2 @@ -1586,6 +1587,79 @@ "CooperativeMatrixNV capability is present")); } +TEST_F(ValidateBarriers, OpMemoryBarrierShaderCallRayGenSuccess) { + const std::string body = + "OpMemoryBarrier %shadercall %release_uniform_workgroup"; + + CompileSuccessfully(GenerateShaderCodeImpl(body, + // capabilities_and_extensions + R"( + OpCapability VulkanMemoryModelKHR + OpCapability RayTracingProvisionalKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpExtension "SPV_KHR_ray_tracing" + )", + // definitions + "", + // execution_model + "RayGenerationKHR", + // memory_model + "OpMemoryModel Logical VulkanKHR"), + SPV_ENV_VULKAN_1_1); + + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierShaderCallComputeFailure) { + const std::string body = + "OpMemoryBarrier %shadercall %release_uniform_workgroup"; + + CompileSuccessfully(GenerateShaderCodeImpl(body, + // capabilities_and_extensions + R"( + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + )", + // definitions + "", + // execution_model + "GLCompute", + // memory_model + "OpMemoryModel Logical VulkanKHR"), + SPV_ENV_VULKAN_1_1); + + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ShaderCallKHR Memory Scope requires a ray tracing execution model")); +} + +TEST_F(ValidateBarriers, OpControlBarrierShaderCallRayGenFailure) { + const std::string body = "OpControlBarrier %shadercall %shadercall %none"; + + CompileSuccessfully(GenerateShaderCodeImpl(body, + // capabilities_and_extensions + R"( + OpCapability VulkanMemoryModelKHR + OpCapability RayTracingProvisionalKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpExtension "SPV_KHR_ray_tracing" + )", + // definitions + "", + // execution_model + "RayGenerationKHR", + // memory_model + "OpMemoryModel Logical VulkanKHR"), + SPV_ENV_VULKAN_1_1); + + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("in Vulkan environment Execution Scope is limited to " + "Workgroup and Subgroup")); +} + } // namespace } // namespace val } // namespace spvtools
diff --git a/test/val/val_capability_test.cpp b/test/val/val_capability_test.cpp index 098fa2f..8580818 100644 --- a/test/val/val_capability_test.cpp +++ b/test/val/val_capability_test.cpp
@@ -1229,14 +1229,18 @@ "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), std::make_pair(std::string(kOpenCLMemoryModel) + + // Block applies to struct type. "OpEntryPoint Kernel %func \"compute\" \n" - "OpDecorate %intt Block\n" - "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + "OpDecorate %block Block\n" + "%intt = OpTypeInt 32 0\n" + "%block = OpTypeStruct %intt\n" + std::string(kVoidFVoid), ShaderDependencies()), std::make_pair(std::string(kOpenCLMemoryModel) + + // BufferBlock applies to struct type. "OpEntryPoint Kernel %func \"compute\" \n" - "OpDecorate %intt BufferBlock\n" - "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + "OpDecorate %block BufferBlock\n" + "%intt = OpTypeInt 32 0\n" + "%block = OpTypeStruct %intt\n" + std::string(kVoidFVoid), ShaderDependencies()), std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n"
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp index 256e115..204f468 100644 --- a/test/val/val_decoration_test.cpp +++ b/test/val/val_decoration_test.cpp
@@ -700,6 +700,61 @@ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); } +TEST_F(ValidateDecorations, BlockDecoratingArrayBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %Output = OpTypeArray %float %int_3 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Block decoration on a non-struct type")); +} + +TEST_F(ValidateDecorations, BlockDecoratingIntBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %Output = OpTypeInt 32 1 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Block decoration on a non-struct type")); +} + TEST_F(ValidateDecorations, BlockMissingOffsetBad) { std::string spirv = R"( OpCapability Shader @@ -6995,6 +7050,53 @@ "contains an array with stride 4, but with an element size of 16")); } +TEST_F(ValidateDecorations, FunctionsWithOpGroupDecorate) { + std::string spirv = R"( + OpCapability Addresses + OpCapability Linkage + OpCapability Kernel + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical32 OpenCL + OpName %foo "foo" + OpName %entry "entry" + OpName %bar "bar" + OpName %entry_0 "entry" + OpName %k "k" + OpName %entry_1 "entry" + OpName %b "b" + OpDecorate %28 FuncParamAttr Zext + %28 = OpDecorationGroup + OpDecorate %k LinkageAttributes "k" Export + OpDecorate %foo LinkageAttributes "foo" Export + OpDecorate %bar LinkageAttributes "bar" Export + OpDecorate %b Alignment 1 + OpGroupDecorate %28 %foo %bar + %uchar = OpTypeInt 8 0 + %bool = OpTypeBool + %3 = OpTypeFunction %bool + %void = OpTypeVoid + %10 = OpTypeFunction %void + %_ptr_Function_uchar = OpTypePointer Function %uchar + %true = OpConstantTrue %bool + %foo = OpFunction %bool DontInline %3 + %entry = OpLabel + OpReturnValue %true + OpFunctionEnd + %bar = OpFunction %bool DontInline %3 + %entry_0 = OpLabel + OpReturnValue %true + OpFunctionEnd + %k = OpFunction %void DontInline %10 + %entry_1 = OpLabel + %b = OpVariable %_ptr_Function_uchar Function + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + } // namespace } // namespace val } // namespace spvtools
diff --git a/test/val/val_ext_inst_test.cpp b/test/val/val_ext_inst_test.cpp index aa73989..d8d0010 100644 --- a/test/val/val_ext_inst_test.cpp +++ b/test/val/val_ext_inst_test.cpp
@@ -889,6 +889,31 @@ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateOpenCL100DebugInfo, DebugTypeCompositeSizeDebugInfoNone) { + const std::string src = R"( +%src = OpString "simple.hlsl" +%code = OpString "OpaqueType foo; +main() {} +" +%ty_name = OpString "struct VS_OUTPUT" +)"; + + const std::string dbg_inst_header = R"( +%dbg_none = OpExtInst %void %DbgExt DebugInfoNone +%dbg_src = OpExtInst %void %DbgExt DebugSource %src %code +%comp_unit = OpExtInst %void %DbgExt DebugCompilationUnit 2 4 %dbg_src HLSL +%opaque = OpExtInst %void %DbgExt DebugTypeComposite %ty_name Class %dbg_src 1 1 %comp_unit %ty_name %dbg_none FlagIsPublic +)"; + + const std::string extension = R"( +%DbgExt = OpExtInstImport "OpenCL.DebugInfo.100" +)"; + + CompileSuccessfully(GenerateShaderCodeForDebugInfo(src, "", dbg_inst_header, + "", extension, "Vertex")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + TEST_F(ValidateOpenCL100DebugInfo, DebugTypeCompositeForwardReference) { const std::string src = R"( %src = OpString "simple.hlsl"
diff --git a/test/val/val_image_test.cpp b/test/val/val_image_test.cpp index 570dd16..1a6e79c 100644 --- a/test/val/val_image_test.cpp +++ b/test/val/val_image_test.cpp
@@ -5004,6 +5004,70 @@ "opcodes and OpImageFetch")); } +TEST_F(ValidateImage, GatherBiasAMDSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 Bias %f32_1 +)"; + + const std::string extra = R"( +OpCapability ImageGatherBiasLodAMD +OpExtension "SPV_AMD_texture_gather_bias_lod" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, GatherLodAMDSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 Lod %f32_1 +)"; + + const std::string extra = R"( +OpCapability ImageGatherBiasLodAMD +OpExtension "SPV_AMD_texture_gather_bias_lod" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SparseGatherBiasAMDSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_f32vec4 %simg %f32vec4_0000 %u32_1 Bias %f32_1 +)"; + + const std::string extra = R"( +OpCapability ImageGatherBiasLodAMD +OpExtension "SPV_AMD_texture_gather_bias_lod" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SparseGatherLodAMDSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_f32vec4 %simg %f32vec4_0000 %u32_1 Lod %f32_1 +)"; + + const std::string extra = R"( +OpCapability ImageGatherBiasLodAMD +OpExtension "SPV_AMD_texture_gather_bias_lod" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + // No negative tests for ZeroExtend since we don't truly know the // texel format.
diff --git a/tools/fuzz/fuzz.cpp b/tools/fuzz/fuzz.cpp index 718d038..2c9807d 100644 --- a/tools/fuzz/fuzz.cpp +++ b/tools/fuzz/fuzz.cpp
@@ -145,6 +145,13 @@ --version Display fuzzer version information. +Supported validator options are as follows. See `spirv-val --help` for details. + --before-hlsl-legalization + --relax-block-layout + --relax-logical-pointer + --relax-struct-store + --scalar-block-layout + --skip-block-layout )", program, program, program, program); } @@ -166,7 +173,8 @@ std::vector<std::string>* interestingness_test, std::string* shrink_transformations_file, std::string* shrink_temp_file_prefix, - spvtools::FuzzerOptions* fuzzer_options) { + spvtools::FuzzerOptions* fuzzer_options, + spvtools::ValidatorOptions* validator_options) { uint32_t positional_arg_index = 0; bool only_positional_arguments_remain = false; bool force_render_red = false; @@ -227,6 +235,18 @@ sizeof("--shrinker-temp-file-prefix=") - 1)) { const auto split_flag = spvtools::utils::SplitFlagArgs(cur_arg); *shrink_temp_file_prefix = std::string(split_flag.second); + } else if (0 == strcmp(cur_arg, "--before-hlsl-legalization")) { + validator_options->SetBeforeHlslLegalization(true); + } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { + validator_options->SetRelaxLogicalPointer(true); + } else if (0 == strcmp(cur_arg, "--relax-block-layout")) { + validator_options->SetRelaxBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--scalar-block-layout")) { + validator_options->SetScalarBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--skip-block-layout")) { + validator_options->SetSkipBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { + validator_options->SetRelaxStructStore(true); } else if (0 == strcmp(cur_arg, "--")) { only_positional_arguments_remain = true; } else { @@ -357,6 +377,7 @@ bool Replay(const spv_target_env& target_env, spv_const_fuzzer_options fuzzer_options, + spv_validator_options validator_options, const std::vector<uint32_t>& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, const std::string& replay_transformations_file, @@ -368,8 +389,8 @@ &transformation_sequence)) { return false; } - spvtools::fuzz::Replayer replayer(target_env, - fuzzer_options->replay_validation_enabled); + spvtools::fuzz::Replayer replayer( + target_env, fuzzer_options->replay_validation_enabled, validator_options); replayer.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); auto replay_result_status = replayer.Run(binary_in, initial_facts, transformation_sequence, @@ -380,6 +401,7 @@ bool Shrink(const spv_target_env& target_env, spv_const_fuzzer_options fuzzer_options, + spv_validator_options validator_options, const std::vector<uint32_t>& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, const std::string& shrink_transformations_file, @@ -393,9 +415,9 @@ &transformation_sequence)) { return false; } - spvtools::fuzz::Shrinker shrinker(target_env, - fuzzer_options->shrinker_step_limit, - fuzzer_options->replay_validation_enabled); + spvtools::fuzz::Shrinker shrinker( + target_env, fuzzer_options->shrinker_step_limit, + fuzzer_options->replay_validation_enabled, validator_options); shrinker.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); assert(!interestingness_command.empty() && @@ -434,6 +456,7 @@ bool Fuzz(const spv_target_env& target_env, spv_const_fuzzer_options fuzzer_options, + spv_validator_options validator_options, const std::vector<uint32_t>& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, const std::string& donors, std::vector<uint32_t>* binary_out, @@ -469,7 +492,7 @@ fuzzer_options->has_random_seed ? fuzzer_options->random_seed : static_cast<uint32_t>(std::random_device()()), - fuzzer_options->fuzzer_pass_validation_enabled); + fuzzer_options->fuzzer_pass_validation_enabled, validator_options); fuzzer.SetMessageConsumer(message_consumer); auto fuzz_result_status = fuzzer.Run(binary_in, initial_facts, donor_suppliers, binary_out, @@ -513,11 +536,13 @@ std::string shrink_temp_file_prefix = "temp_"; spvtools::FuzzerOptions fuzzer_options; + spvtools::ValidatorOptions validator_options; - FuzzStatus status = ParseFlags( - argc, argv, &in_binary_file, &out_binary_file, &donors_file, - &replay_transformations_file, &interestingness_test, - &shrink_transformations_file, &shrink_temp_file_prefix, &fuzzer_options); + FuzzStatus status = + ParseFlags(argc, argv, &in_binary_file, &out_binary_file, &donors_file, + &replay_transformations_file, &interestingness_test, + &shrink_transformations_file, &shrink_temp_file_prefix, + &fuzzer_options, &validator_options); if (status.action == FuzzActions::STOP) { return status.code; @@ -555,20 +580,22 @@ switch (status.action) { case FuzzActions::FORCE_RENDER_RED: - if (!spvtools::fuzz::ForceRenderRed(target_env, binary_in, initial_facts, + if (!spvtools::fuzz::ForceRenderRed(target_env, validator_options, + binary_in, initial_facts, &binary_out)) { return 1; } break; case FuzzActions::FUZZ: - if (!Fuzz(target_env, fuzzer_options, binary_in, initial_facts, - donors_file, &binary_out, &transformations_applied)) { + if (!Fuzz(target_env, fuzzer_options, validator_options, binary_in, + initial_facts, donors_file, &binary_out, + &transformations_applied)) { return 1; } break; case FuzzActions::REPLAY: - if (!Replay(target_env, fuzzer_options, binary_in, initial_facts, - replay_transformations_file, &binary_out, + if (!Replay(target_env, fuzzer_options, validator_options, binary_in, + initial_facts, replay_transformations_file, &binary_out, &transformations_applied)) { return 1; } @@ -579,9 +606,9 @@ << std::endl; return 1; } - if (!Shrink(target_env, fuzzer_options, binary_in, initial_facts, - shrink_transformations_file, shrink_temp_file_prefix, - interestingness_test, &binary_out, + if (!Shrink(target_env, fuzzer_options, validator_options, binary_in, + initial_facts, shrink_transformations_file, + shrink_temp_file_prefix, interestingness_test, &binary_out, &transformations_applied)) { return 1; }
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 658bd5b..66f9228 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp
@@ -114,6 +114,10 @@ and VK_AMD_shader_trinary_minmax with equivalent code using core instructions and capabilities.)"); printf(R"( + --before-hlsl-legalization + Forwards this option to the validator. See the validator help + for details.)"); + printf(R"( --ccp Apply the conditional constant propagation transform. This will propagate constant values throughout the program, and simplify @@ -403,14 +407,21 @@ Looks for instructions in the same function that compute the same value, and deletes the redundant ones.)"); printf(R"( + --relax-block-layout + Forwards this option to the validator. See the validator help + for details.)"); + printf(R"( --relax-float-ops Decorate all float operations with RelaxedPrecision if not already so decorated. This does not decorate types or variables.)"); printf(R"( + --relax-logical-pointer + Forwards this option to the validator. See the validator help + for details.)"); + printf(R"( --relax-struct-store - Allow store from one struct type to a different type with - compatible layout and members. This option is forwarded to the - validator.)"); + Forwards this option to the validator. See the validator help + for details.)"); printf(R"( --remove-duplicates Removes duplicate types, decorations, capabilities and extension @@ -425,6 +436,10 @@ Replace loads and stores to function local variables with operations on SSA IDs.)"); printf(R"( + --scalar-block-layout + Forwards this option to the validator. See the validator help + for details.)"); + printf(R"( --scalar-replacement[=<n>] Replace aggregate function scope variables that are only accessed via their elements with new function variables representing each @@ -444,6 +459,10 @@ Will simplify all instructions in the function as much as possible.)"); printf(R"( + --skip-block-layout + Forwards this option to the validator. See the validator help + for details.)"); + printf(R"( --split-invalid-unreachable Attempts to legalize for WebGPU cases where an unreachable merge-block is also a continue-target by splitting it into two @@ -822,6 +841,18 @@ optimizer->RegisterWebGPUToVulkanPasses(); } else if (0 == strcmp(cur_arg, "--validate-after-all")) { optimizer->SetValidateAfterAll(true); + } else if (0 == strcmp(cur_arg, "--before-hlsl-legalization")) { + validator_options->SetBeforeHlslLegalization(true); + } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { + validator_options->SetRelaxLogicalPointer(true); + } else if (0 == strcmp(cur_arg, "--relax-block-layout")) { + validator_options->SetRelaxBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--scalar-block-layout")) { + validator_options->SetScalarBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--skip-block-layout")) { + validator_options->SetSkipBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { + validator_options->SetRelaxStructStore(true); } else { // Some passes used to accept the form '--pass arg', canonicalize them // to '--pass=arg'.
diff --git a/tools/sva/yarn.lock b/tools/sva/yarn.lock index be19e7c..11ba12f 100644 --- a/tools/sva/yarn.lock +++ b/tools/sva/yarn.lock
@@ -47,9 +47,9 @@ integrity sha512-tiNTrP1MP0QrChmD2DdupCr6HWSFeKVw5d/dHTu4Y7rkAkRhU/Dt7dphAfIUyxtHpl/eBVip5uTNSpQJHylpAw== acorn@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.0.0.tgz#26b8d1cd9a9b700350b71c0905546f64d1284e7a" - integrity sha512-PaF/MduxijYYt7unVGRuds1vBC9bFxbNf+VWqhOClfdgy7RlVkQqt610ig1/yxTgsDIfW1cWDel5EBbOy3jdtQ== + version "7.1.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" + integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== ajv@6.5.3: version "6.5.3"
diff --git a/utils/vscode/src/parser/parser.go b/utils/vscode/src/parser/parser.go index 9f5691b..1775b0f 100644 --- a/utils/vscode/src/parser/parser.go +++ b/utils/vscode/src/parser/parser.go
@@ -310,7 +310,7 @@ tok := &Token{Type: Operator, Range: Range{Start: l.pos, End: l.pos}} for l.e == nil { switch l.next() { - case '=': + case '=', '|': tok.Range.End = l.pos l.toks = append(l.toks, tok) return @@ -374,7 +374,7 @@ case r == '"': l.restore(s) l.string() - case r == '=': + case r == '=', r == '|': l.restore(s) l.operator() case r == ';': @@ -556,21 +556,34 @@ s := tok.Text(p.lines) for _, e := range k.Enumerants { if e.Enumerant == s { - n := 1 + count := 1 for _, param := range e.Parameters { - p, c := p.operand(param.Name, param.Kind, i+n, false) + p, c := p.operand(param.Name, param.Kind, i+count, false) if p != nil { op.Tokens = append(op.Tokens, p.Tokens...) op.Parameters = append(op.Parameters, p) } - n += c + count += c } - return op, n + + // Handle bitfield '|' chains + if p.tok(i+count).Text(p.lines) == "|" { + count++ // '|' + p, c := p.operand(n, k, i+count, false) + if p != nil { + op.Tokens = append(op.Tokens, p.Tokens...) + op.Parameters = append(op.Parameters, p) + } + count += c + } + + return op, count } } if !optional { p.err(p.tok(i), "invalid operand value '%s'", s) } + return nil, 0 case schema.OperandCategoryID:
diff --git a/utils/vscode/src/schema/schema.go b/utils/vscode/src/schema/schema.go index 66bd7dc..0d57cb1 100755 --- a/utils/vscode/src/schema/schema.go +++ b/utils/vscode/src/schema/schema.go
@@ -97,6 +97,7 @@ OperandCategoryComposite = "Composite" ) +// OpcodeMap is a map of opcode name to Opcode type. type OpcodeMap map[string]*Opcode var ( @@ -467,11 +468,41 @@ "OpGroupNonUniformPartitionNV": OpGroupNonUniformPartitionNV, "OpWritePackedPrimitiveIndices4x8NV": OpWritePackedPrimitiveIndices4x8NV, "OpReportIntersectionNV": OpReportIntersectionNV, + "OpReportIntersectionKHR": OpReportIntersectionKHR, "OpIgnoreIntersectionNV": OpIgnoreIntersectionNV, + "OpIgnoreIntersectionKHR": OpIgnoreIntersectionKHR, "OpTerminateRayNV": OpTerminateRayNV, + "OpTerminateRayKHR": OpTerminateRayKHR, "OpTraceNV": OpTraceNV, + "OpTraceRayKHR": OpTraceRayKHR, "OpTypeAccelerationStructureNV": OpTypeAccelerationStructureNV, + "OpTypeAccelerationStructureKHR": OpTypeAccelerationStructureKHR, + "OpTypeRayQueryProvisionalKHR": OpTypeRayQueryProvisionalKHR, + "OpRayQueryInitializeKHR": OpRayQueryInitializeKHR, + "OpRayQueryTerminateKHR": OpRayQueryTerminateKHR, + "OpRayQueryGenerateIntersectionKHR": OpRayQueryGenerateIntersectionKHR, + "OpRayQueryConfirmIntersectionKHR": OpRayQueryConfirmIntersectionKHR, + "OpRayQueryProceedKHR": OpRayQueryProceedKHR, + "OpRayQueryGetIntersectionTypeKHR": OpRayQueryGetIntersectionTypeKHR, + "OpRayQueryGetRayTMinKHR": OpRayQueryGetRayTMinKHR, + "OpRayQueryGetRayFlagsKHR": OpRayQueryGetRayFlagsKHR, + "OpRayQueryGetIntersectionTKHR": OpRayQueryGetIntersectionTKHR, + "OpRayQueryGetIntersectionInstanceCustomIndexKHR": OpRayQueryGetIntersectionInstanceCustomIndexKHR, + "OpRayQueryGetIntersectionInstanceIdKHR": OpRayQueryGetIntersectionInstanceIdKHR, + "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR": OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + "OpRayQueryGetIntersectionGeometryIndexKHR": OpRayQueryGetIntersectionGeometryIndexKHR, + "OpRayQueryGetIntersectionPrimitiveIndexKHR": OpRayQueryGetIntersectionPrimitiveIndexKHR, + "OpRayQueryGetIntersectionBarycentricsKHR": OpRayQueryGetIntersectionBarycentricsKHR, + "OpRayQueryGetIntersectionFrontFaceKHR": OpRayQueryGetIntersectionFrontFaceKHR, + "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR": OpRayQueryGetIntersectionCandidateAABBOpaqueKHR, + "OpRayQueryGetIntersectionObjectRayDirectionKHR": OpRayQueryGetIntersectionObjectRayDirectionKHR, + "OpRayQueryGetIntersectionObjectRayOriginKHR": OpRayQueryGetIntersectionObjectRayOriginKHR, + "OpRayQueryGetWorldRayDirectionKHR": OpRayQueryGetWorldRayDirectionKHR, + "OpRayQueryGetWorldRayOriginKHR": OpRayQueryGetWorldRayOriginKHR, + "OpRayQueryGetIntersectionObjectToWorldKHR": OpRayQueryGetIntersectionObjectToWorldKHR, + "OpRayQueryGetIntersectionWorldToObjectKHR": OpRayQueryGetIntersectionWorldToObjectKHR, "OpExecuteCallableNV": OpExecuteCallableNV, + "OpExecuteCallableKHR": OpExecuteCallableKHR, "OpTypeCooperativeMatrixNV": OpTypeCooperativeMatrixNV, "OpCooperativeMatrixLoadNV": OpCooperativeMatrixLoadNV, "OpCooperativeMatrixStoreNV": OpCooperativeMatrixStoreNV, @@ -10573,6 +10604,33 @@ }, }, } + OpReportIntersectionKHR = &Opcode { + Opname: "OpReportIntersectionKHR", + Class: "Reserved", + Opcode: 5334, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Hit'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'HitKind'", + Quantifier: "", + }, + }, + } OpIgnoreIntersectionNV = &Opcode { Opname: "OpIgnoreIntersectionNV", Class: "Reserved", @@ -10580,6 +10638,13 @@ Operands: []Operand { }, } + OpIgnoreIntersectionKHR = &Opcode { + Opname: "OpIgnoreIntersectionKHR", + Class: "Reserved", + Opcode: 5335, + Operands: []Operand { + }, + } OpTerminateRayNV = &Opcode { Opname: "OpTerminateRayNV", Class: "Reserved", @@ -10587,6 +10652,13 @@ Operands: []Operand { }, } + OpTerminateRayKHR = &Opcode { + Opname: "OpTerminateRayKHR", + Class: "Reserved", + Opcode: 5336, + Operands: []Operand { + }, + } OpTraceNV = &Opcode { Opname: "OpTraceNV", Class: "Reserved", @@ -10649,6 +10721,68 @@ }, }, } + OpTraceRayKHR = &Opcode { + Opname: "OpTraceRayKHR", + Class: "Reserved", + Opcode: 5337, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'Accel'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Ray Flags'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Cull Mask'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'SBT Offset'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'SBT Stride'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Miss Index'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Ray Origin'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Ray Tmin'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Ray Direction'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Ray Tmax'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'PayloadId'", + Quantifier: "", + }, + }, + } OpTypeAccelerationStructureNV = &Opcode { Opname: "OpTypeAccelerationStructureNV", Class: "Reserved", @@ -10661,6 +10795,601 @@ }, }, } + OpTypeAccelerationStructureKHR = &Opcode { + Opname: "OpTypeAccelerationStructureKHR", + Class: "Reserved", + Opcode: 5341, + Operands: []Operand { + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + }, + } + OpTypeRayQueryProvisionalKHR = &Opcode { + Opname: "OpTypeRayQueryProvisionalKHR", + Class: "Reserved", + Opcode: 4472, + Operands: []Operand { + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + }, + } + OpRayQueryInitializeKHR = &Opcode { + Opname: "OpRayQueryInitializeKHR", + Class: "Reserved", + Opcode: 4473, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Accel'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayFlags'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'CullMask'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayOrigin'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayTMin'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayDirection'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayTMax'", + Quantifier: "", + }, + }, + } + OpRayQueryTerminateKHR = &Opcode { + Opname: "OpRayQueryTerminateKHR", + Class: "Reserved", + Opcode: 4474, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGenerateIntersectionKHR = &Opcode { + Opname: "OpRayQueryGenerateIntersectionKHR", + Class: "Reserved", + Opcode: 4475, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'HitT'", + Quantifier: "", + }, + }, + } + OpRayQueryConfirmIntersectionKHR = &Opcode { + Opname: "OpRayQueryConfirmIntersectionKHR", + Class: "Reserved", + Opcode: 4476, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryProceedKHR = &Opcode { + Opname: "OpRayQueryProceedKHR", + Class: "Reserved", + Opcode: 4477, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionTypeKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionTypeKHR", + Class: "Reserved", + Opcode: 4479, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetRayTMinKHR = &Opcode { + Opname: "OpRayQueryGetRayTMinKHR", + Class: "Reserved", + Opcode: 6016, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetRayFlagsKHR = &Opcode { + Opname: "OpRayQueryGetRayFlagsKHR", + Class: "Reserved", + Opcode: 6017, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionTKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionTKHR", + Class: "Reserved", + Opcode: 6018, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionInstanceCustomIndexKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionInstanceCustomIndexKHR", + Class: "Reserved", + Opcode: 6019, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionInstanceIdKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionInstanceIdKHR", + Class: "Reserved", + Opcode: 6020, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR", + Class: "Reserved", + Opcode: 6021, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionGeometryIndexKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionGeometryIndexKHR", + Class: "Reserved", + Opcode: 6022, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionPrimitiveIndexKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionPrimitiveIndexKHR", + Class: "Reserved", + Opcode: 6023, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionBarycentricsKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionBarycentricsKHR", + Class: "Reserved", + Opcode: 6024, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionFrontFaceKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionFrontFaceKHR", + Class: "Reserved", + Opcode: 6025, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionCandidateAABBOpaqueKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR", + Class: "Reserved", + Opcode: 6026, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionObjectRayDirectionKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionObjectRayDirectionKHR", + Class: "Reserved", + Opcode: 6027, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionObjectRayOriginKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionObjectRayOriginKHR", + Class: "Reserved", + Opcode: 6028, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetWorldRayDirectionKHR = &Opcode { + Opname: "OpRayQueryGetWorldRayDirectionKHR", + Class: "Reserved", + Opcode: 6029, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetWorldRayOriginKHR = &Opcode { + Opname: "OpRayQueryGetWorldRayOriginKHR", + Class: "Reserved", + Opcode: 6030, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionObjectToWorldKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionObjectToWorldKHR", + Class: "Reserved", + Opcode: 6031, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } + OpRayQueryGetIntersectionWorldToObjectKHR = &Opcode { + Opname: "OpRayQueryGetIntersectionWorldToObjectKHR", + Class: "Reserved", + Opcode: 6032, + Operands: []Operand { + Operand { + Kind: OperandKindIdResultType, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdResult, + Name: "", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'RayQuery'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Intersection'", + Quantifier: "", + }, + }, + } OpExecuteCallableNV = &Opcode { Opname: "OpExecuteCallableNV", Class: "Reserved", @@ -10678,6 +11407,23 @@ }, }, } + OpExecuteCallableKHR = &Opcode { + Opname: "OpExecuteCallableKHR", + Class: "Reserved", + Opcode: 5344, + Operands: []Operand { + Operand { + Kind: OperandKindIdRef, + Name: "'SBT Index'", + Quantifier: "", + }, + Operand { + Kind: OperandKindIdRef, + Name: "'Callable DataId'", + Quantifier: "", + }, + }, + } OpTypeCooperativeMatrixNV = &Opcode { Opname: "OpTypeCooperativeMatrixNV", Class: "Reserved", @@ -19430,6 +20176,90 @@ }, Bases: []*OperandKind {}, } + OperandKindRayFlags = &OperandKind { + Kind: "RayFlags", + Category: "BitEnum", + Enumerants: []Enumerant { + Enumerant{ + Enumerant: "NoneKHR", + Value: 0x0000, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "OpaqueKHR", + Value: 0x0001, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "NoOpaqueKHR", + Value: 0x0002, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "TerminateOnFirstHitKHR", + Value: 0x0004, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "SkipClosestHitShaderKHR", + Value: 0x0008, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "CullBackFacingTrianglesKHR", + Value: 0x0010, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "CullFrontFacingTrianglesKHR", + Value: 0x0020, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "CullOpaqueKHR", + Value: 0x0040, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "CullNoOpaqueKHR", + Value: 0x0080, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "SkipTrianglesKHR", + Value: 0x0100, + Capabilities: []string{"RayTraversalPrimitiveCullingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "SkipAABBsKHR", + Value: 0x0200, + Capabilities: []string{"RayTraversalPrimitiveCullingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + }, + Bases: []*OperandKind {}, + } OperandKindSourceLanguage = &OperandKind { Kind: "SourceLanguage", Category: "ValueEnum", @@ -19549,42 +20379,84 @@ Enumerant{ Enumerant: "RayGenerationNV", Value: 5313, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayGenerationKHR", + Value: 5313, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "IntersectionNV", Value: 5314, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "IntersectionKHR", + Value: 5314, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "AnyHitNV", Value: 5315, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "AnyHitKHR", + Value: 5315, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "ClosestHitNV", Value: 5316, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "ClosestHitKHR", + Value: 5316, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "MissNV", Value: 5317, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "MissKHR", + Value: 5317, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "CallableNV", Value: 5318, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "CallableKHR", + Value: 5318, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, @@ -20172,42 +21044,84 @@ Enumerant{ Enumerant: "CallableDataNV", Value: 5328, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "CallableDataKHR", + Value: 5328, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "IncomingCallableDataNV", Value: 5329, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "IncomingCallableDataKHR", + Value: 5329, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "RayPayloadNV", Value: 5338, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayPayloadKHR", + Value: 5338, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "HitAttributeNV", Value: 5339, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "HitAttributeKHR", + Value: 5339, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "IncomingRayPayloadNV", Value: 5342, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "IncomingRayPayloadKHR", + Value: 5342, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "ShaderRecordBufferNV", Value: 5343, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "ShaderRecordBufferKHR", + Value: 5343, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, @@ -21593,7 +22507,7 @@ Enumerant{ Enumerant: "PrimitiveId", Value: 7, - Capabilities: []string{"Geometry","Tessellation","RayTracingNV",}, + Capabilities: []string{"Geometry","Tessellation","RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "", }, @@ -22139,98 +23053,203 @@ Enumerant{ Enumerant: "LaunchIdNV", Value: 5319, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "LaunchIdKHR", + Value: 5319, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "LaunchSizeNV", Value: 5320, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "LaunchSizeKHR", + Value: 5320, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "WorldRayOriginNV", Value: 5321, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "WorldRayOriginKHR", + Value: 5321, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "WorldRayDirectionNV", Value: 5322, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "WorldRayDirectionKHR", + Value: 5322, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "ObjectRayOriginNV", Value: 5323, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "ObjectRayOriginKHR", + Value: 5323, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "ObjectRayDirectionNV", Value: 5324, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "ObjectRayDirectionKHR", + Value: 5324, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "RayTminNV", Value: 5325, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayTminKHR", + Value: 5325, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "RayTmaxNV", Value: 5326, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayTmaxKHR", + Value: 5326, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "InstanceCustomIndexNV", Value: 5327, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "InstanceCustomIndexKHR", + Value: 5327, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "ObjectToWorldNV", Value: 5330, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "ObjectToWorldKHR", + Value: 5330, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "WorldToObjectNV", Value: 5331, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "WorldToObjectKHR", + Value: 5331, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "HitTNV", Value: 5332, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "HitTKHR", + Value: 5332, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "HitKindNV", Value: 5333, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "HitKindKHR", + Value: 5333, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, Enumerant{ Enumerant: "IncomingRayFlagsNV", Value: 5351, - Capabilities: []string{"RayTracingNV",}, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "IncomingRayFlagsKHR", + Value: 5351, + Capabilities: []string{"RayTracingNV","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayGeometryIndexKHR", + Value: 5352, + Capabilities: []string{"RayTracingProvisionalKHR",}, Parameters: []Parameter{}, Version: "None", }, @@ -22318,6 +23337,13 @@ Parameters: []Parameter{}, Version: "1.5", }, + Enumerant{ + Enumerant: "ShaderCallKHR", + Value: 6, + Capabilities: []string{"RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, }, Bases: []*OperandKind {}, } @@ -23054,6 +24080,20 @@ Version: "1.4", }, Enumerant{ + Enumerant: "RayQueryProvisionalKHR", + Value: 4471, + Capabilities: []string{"Shader",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ + Enumerant: "RayTraversalPrimitiveCullingProvisionalKHR", + Value: 4478, + Capabilities: []string{"RayQueryProvisionalKHR","RayTracingProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ Enumerant: "Float16ImageAMD", Value: 5008, Capabilities: []string{"Shader",}, @@ -23425,6 +24465,13 @@ Version: "None", }, Enumerant{ + Enumerant: "RayTracingProvisionalKHR", + Value: 5353, + Capabilities: []string{"Shader",}, + Parameters: []Parameter{}, + Version: "None", + }, + Enumerant{ Enumerant: "CooperativeMatrixNV", Value: 5357, Capabilities: []string{"Shader",}, @@ -23525,6 +24572,76 @@ }, Bases: []*OperandKind {}, } + OperandKindRayQueryIntersection = &OperandKind { + Kind: "RayQueryIntersection", + Category: "ValueEnum", + Enumerants: []Enumerant { + Enumerant{ + Enumerant: "RayQueryCandidateIntersectionKHR", + Value: 0, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "RayQueryCommittedIntersectionKHR", + Value: 1, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + }, + Bases: []*OperandKind {}, + } + OperandKindRayQueryCommittedIntersectionType = &OperandKind { + Kind: "RayQueryCommittedIntersectionType", + Category: "ValueEnum", + Enumerants: []Enumerant { + Enumerant{ + Enumerant: "RayQueryCommittedIntersectionNoneKHR", + Value: 0, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "RayQueryCommittedIntersectionTriangleKHR", + Value: 1, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "RayQueryCommittedIntersectionGeneratedKHR", + Value: 2, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + }, + Bases: []*OperandKind {}, + } + OperandKindRayQueryCandidateIntersectionType = &OperandKind { + Kind: "RayQueryCandidateIntersectionType", + Category: "ValueEnum", + Enumerants: []Enumerant { + Enumerant{ + Enumerant: "RayQueryCandidateIntersectionTriangleKHR", + Value: 0, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + Enumerant{ + Enumerant: "RayQueryCandidateIntersectionAABBKHR", + Value: 1, + Capabilities: []string{"RayQueryProvisionalKHR",}, + Parameters: []Parameter{}, + Version: "", + }, + }, + Bases: []*OperandKind {}, + } OperandKindIdResultType = &OperandKind { Kind: "IdResultType", Category: "Id",