Update SPIR-V Tools to dd4b663e1
Changes:
dd4b663e1 Prepare release v2024.2 (#5651)
be6fb2a54 build(deps): bump the github-actions group across 1 directory with 4 updates (#5650)
dadb3012d Add SPIRV_TOOLS_EXPORT to public C++ API (#5591)
53c073606 A fix to support of SPV_QCOM_image_processing2 (#5646)
67a3ed670 opt: add GroupNonUniformPartitionedNV capability to trim pass (#5648)
2904985ae spirv-val: Add Vulkan check for Rect Dim in OpTypeImage (#5644)
02470f606 Validate duplicate decorations and execution modes (#5641)
6761288d3 Validator: Support SPV_NV_raw_access_chains (#5568)
3983d15a1 Fix rebuilding types with circular references (#5623). (#5637)
ade1f7cfd Add AliasedPointer decoration (#5635)
24f2cdad8 build(deps): bump the github-actions group with 1 update (#5634)
58ab8baf7 docs: explain LunarG is the source of truth for releases (#5627)
7fe5f75e5 Roll external/re2/ 6598a8ecd..917047f36 (3 commits) (#5626)
87721a100 Roll external/spirv-headers/ 7d500c4d7..4f7b471f1 (1 commit) (#5625)
fe7bae090 Minor fix to cmakelists to avoid rerunning command each build (#5620)
fc4286556 build(deps): bump the github-actions group with 2 updates (#5621)
67451ebf6 Roll external/spirv-headers/ 04db24d69..7d500c4d7 (1 commit) (#5619)
dda7731e9 build(deps): bump the github-actions group with 2 updates (#5618)
3fafcc20e Roll external/spirv-headers/ 8b246ff75..04db24d69 (1 commit) (#5617)
3a0471c3b build(deps): bump the github-actions group with 1 update (#5615)
6c3f632a2 roll deps (#5614)
c6615779e Roll external/googletest/ b479e7a3c..c231e6f5b (1 commit) (#5613)
f20663ca7 add support for vulkan-shader-profiler external passes (#5512)
f74f4e74c Roll external/re2/ ed9fc269e..d00d1e937 (2 commits) (#5589)
e39cabca2 build(deps): bump the github-actions group with 2 updates (#5610)
efb0fce2d Use bazel 7 and bzlmod (#5601)
02c79e908 kokoro: Update bazel to 7.0.2 for Linux builds (#5609)
f869d391a [OPT] Fix handling of analyses rebuild (#5608)
d15a7aa25 kokoro: Update bazel to 7.0.2 for Mac builds (#5606)
04896c462 Prepare release v2024.1 (#5605)
7c363050d Add operand types for SPV_NV_raw_access_chains (#5602)
5bc7c2876 build(deps): bump the github-actions group with 2 updates (#5598)
75ad1345d Remove redundant function declarations from source/operand.h (#5584)
9bd44d028 Suppot for SPV_QCOM_image_processing2 (#5582)
0b027bafa Support operand kind for SPV_INTEL_maximum_registers (#5580)
fbc7a14b3 Fix access chain struct checks (#5592)
99a3ad32f build(deps): bump the github-actions group with 1 update (#5594)
c3a9ffd74 build(deps): bump the github-actions group with 1 update (#5593)
1b643eac5 spirv-val: Make Constant evaluation consistent (#5587)
dc6676445 Roll external/googletest/ 6eb225cb8..5df0241ea (2 commits) (#5583)
7da2c941f Update WORKSPACE (#5588)
16af142c1 build(deps): bump the github-actions group with 1 update (#5586)
b0a5c4ac1 SPV_NV_shader_atomic_fp16_vector (#5581)
55cb3989e build(deps): bump the github-actions group with 1 update (#5578)
11afeb4bb roll deps (#5576)
7604147c2 [OPT] Add removed unused interface var pass to legalization passes (#5579)
f9184c650 spirv-val: Revert Validate PhysicalStorageBuffer Stage Interface (#5575)
20ad38c18 spirv-val: Multiple interface var with same SC (#5528)
e08c012b1 [OPT] Identify arrays with unknown length in copy prop arrays (#5570)
56a51dd94 Roll external/spirv-headers/ e77d03080..d3c2a6fa9 (1 commit) (#5574)
0c986f596 update image enum tests to remove Kernel capability (#5562)
b7413609c [OPT] Use new instruction folder for for all opcodes in spec consti folding (#5569)
784b064f9 spirv-val: Validate PhysicalStorageBuffer Stage Interface (#5539)
a8959dc65 Fold 64-bit int operations (#5561)
80926d97a roll deps (#5566)
9a7b1af90 build(deps): bump the github-actions group with 1 update (#5564)
1a2cbabd8 Roll external/googletest/ 48729681a..64be1c79f (1 commit) (#5563)
7657cb1c6 build(deps): bump the github-actions group with 3 updates (#5560)
032c15aaf [NFC] Refactor code to fold instruction in fold tests. (#5558)
9938f5bc2 Roll external/googletest/ 456574145..48729681a (1 commit) (#5559)
ab59dc608 opt: prevent meld to merge block with MaximalReconvergence (#5557)
6c11c2bd4 Roll external/re2/ 283636ffb..ab7c5918b (2 commits) (#5555)
a8afbe941 roll deps (#5550)
8d3ee2e8f spirv-opt: Fix OpCompositeExtract relaxation with struct operands (#5536)
61c51d4ba spirv-val: Add Mesh Primitive Built-In validaiton (#5529)
5d3c8b73f opt: Add OpEntryPoint to DescriptorScalarReplacement pass (#5553)
de65e8174 [NFC] Remove unused code (#5554)
ad11927e6 opt: add SPV_EXT_mesh_shader to opt allowlist (#5551)
27ffe976e build(deps): bump the github-actions group with 2 updates (#5549)
e5fcb7fac Roll external/re2/ 264e71e88..826ad10e5 (1 commit) (#5538)
80bc99c3d Skip entire test/ folder if SPIRV_SKIP_TESTS is set. (#5548)
0a6f0d189 opt: Add TrimCapabilities pass to spirv-opt tool (#5545)
b951948ea SPV_KHR_quad_control (#5547)
69197ba90 Add modify-maximal-reconvergence to spirv-opt help (#5546)
0045b01ff opt: Add VulkanMemoryModelDeviceScope to trim (#5544)
ef2f43236 Add support for SPV_KHR_float_controls2 (#5543)
de3d5acc0 Add tooling support for SPV_KHR_maximal_reconvergence (#5542)
14000ad47 Use python3 explicitly. (#5540)
359012927 workflow: add vulkan-sdk tags as release tags (#5518)
3e6bdd0f9 build(deps): bump the github-actions group with 3 updates (#5537)
ed6835aff Roll external/re2/ c042630ed..32c181e0a (1 commit) (#5532)
c96fe8b94 spirv-val: Re-enable OpControlBarrier VU (#5527)
5dbdc7b60 build(deps): bump the github-actions group with 4 updates (#5531)
155728b2e Add preserver-interface option to spirv-opt (#5524)
01ee1bf31 Roll external/googletest/ b10fad38c..76bb2afb8 (1 commit) (#5485)
36be541ee Remove unnecessary debug code (#5523)
c7affa170 opt: add Int16 and Float16 to capability trim pass (#5519)
0a9f3d1f2 Revert "Fix(cmake): CMake doesn't find system installed SPIRV-Headers (#5422)" (#5517)
7d2429594 Fix(cmake): CMake doesn't find system installed SPIRV-Headers (#5422)
f0cc85efd Prepare release v2023.6 (#5510)
e03c8f5c8 Fix broken build (#5505)
d75b3cfbb Zero initialize local variables (#5501)
6b4f0c9d0 instrument: Fix handling of gl_InvocationID (#5493)
b5d60826e printf: Remove stage specific info (#5495)
e7a52b70f build(deps): bump the github-actions group with 1 update (#5498)
2da75e152 Do not crash when tryingto fold unsupported spec constant (#5496)
0d8784553 Remove uses of std::system(nullptr) (#5494)
f4a73dd7a std::system requires include of <cstdlib> (#5486)
ffe645023 Add iOS build to CI (#5490)
afaf8fda2 Fix iOS / Android CMake builds (#5482)
7d2a618bf build(deps): bump the github-actions group with 1 update (#5484)
2a238ed24 Roll external/spirv-headers/ 38f39dae5..cca08c63c (2 commits) (#5480)
246e6d4c6 spirv-val: Loosen restriction on base type of DebugTypePointer and DebugTypeQualifier (#5479)
0df791f97 Fix nullptr argument in MarkInsertChain (#5465)
fb91e6f0e Flush stdout before changing mode back to text (#5477)
560eea6d7 build(deps): bump the github-actions group with 1 update (#5478)
c8510a5e8 Fix python warning seen on Fedora 39 (#5474)
8ee3ae524 Add comment to --inst-debug-printf option (#5466)
f43c464d5 opt: add PhysicalStorageBufferAddresses to trim (#5476)
c91e9d09b opt: add StorageImageReadWithoutFormat to cap trim (#5475)
d88742fbd fix(build): git describe all tagged versions (#5447)
6b1e609ef Support missing git in update_build_version.py (#5473)
fbf047cc8 Roll external/re2/ 24d460a9d..974f44c8d (4 commits) (#5470)
9e7a1f2dd Fix array size calculation (#5463)
eacc969b7 build(deps): bump the github-actions group with 2 updates (#5457)
7210d247c Roll external/googletest/ 518387203..5b7fd63d6 (1 commit) (#5454)
a08f648c8 Remove references to __FILE__ (#5462)
c87755bb9 spirv-val: Add WorkgroupMemoryExplicitLayoutKHR check for Block (#5461)
4f014aff9 Roll external/re2/ 601d9ea3e..a0b3bc60c (1 commit) (#5453)
33bac5144 Roll external/googletest/ 116b7e552..518387203 (1 commit) (#5450)
01e851be9 Roll external/re2/ 928a015e6..601d9ea3e (1 commit) (#5448)
1928c76cd Roll external/googletest/ 2dd1c1319..829c19901 (1 commit) (#5444)
73876defc opt: support 64-bit OpAccessChain index in FixStorageClass (#5446)
5084f58e5 build(deps): bump the github-actions group with 4 updates (#5445)
a9c61d137 update_build_version.py produce deterministic header. (#5426)
5bb595091 Add ComputeDerivativeGroup*NV capabilities to trim capabilities pass. (#5430)
3985f0da0 Roll external/spirv-headers/ e867c0663..4183b260f (1 commit) (#5439)
661f429b1 Roll external/re2/ b673de358..ece4cecab (2 commits) (#5437)
360d469b9 Prepare release v2023.5.rc1 (#5423)
74005dfa6 Roll external/re2/ 35bb195de..b673de358 (2 commits) (#5433)
933db564c roll deps (#5432)
ce995b319 Hash pin workflows and config dependabot (#5412)
df2f2a031 build(deps): bump get-func-name from 2.0.0 to 2.0.2 in /tools/sva (#5418)
866e60def Roll external/spirv-headers/ 79743b899..e867c0663 (2 commits) (#5427)
023a8c79e opt: add Float64 capability to trim pass (#5428)
4fab7435b Roll external/googletest/ e47544ad3..beb552fb4 (2 commits) (#5424)
847715d6c instrument: Ensure linking works even of nothing is changed (#5419)
dc9900967 Update BUILD.gn to include header for new pass (#5421)
1bc0e6f59 Add a new legalization pass to dedupe invocation interlock instructions (#5409)
48c97c131 roll deps (#5415)
27673a054 Remove reviewer from autoroller (#5414)
ee7598d49 instrument: Use Import linkage for instrumentation functions (#5355)
a40483d31 roll deps (#5408)
a996591b1 Update SPIRV-Headers, add cache control operand kinds (#5406)
fc54e178e Change autoroll pr review id (#5404)
2d6996f73 Check for git repository before git commands (#5403)
361638cfd Make sure that fragment shader interlock instructions are not removed by DCE (#5400)
5e6054c1c Roll external/re2/ e0077036c..a807e8a3a (6 commits) (#5401)
47b63a4d7 val: re-add ImageMSArray validation (#5394)
d660bb55b Add SPV_KHR_physical_storage_buffer to allowlists (#5402)
440f018cc Fix `AddMemberDecoration` variable names. (#5399)
4e0b94ed7 opt: add ImageMSArray capability to trim pass. (#5395)
d474a0708 Add SPV_EXT_fragment_shader_interlock to allow lists (#5393)
1f07f483e opt: add raytracing/rayquery to trim pass (#5397)
158bc7bd6 Roll external/re2/ 523f9b097..e0077036c (2 commits) (#5391)
1121c2319 opt: add Int64 capability to trim pass (#5398)
3cc7e1c4c NFC: rename tests using capability as prefix (#5396)
4c16c35b1 opt: add FragmentShader*InterlockEXT to capability trim pass (#5390)
9b923f7cc QNX has support for ANSI ESC codes, default terminal is QANSI. (#5387)
51367c40f Enable OpenSSF Scorecard and Badge (#5377)
d09c753a4 Roll external/re2/ 73031bbc0..523f9b097 (1 commit) (#5389)
b6893ccdf Roll external/googletest/ 460ae9826..8a6feabf0 (1 commit) (#5388)
1b3c4cb68 roll deps (#5386)
abd548b81 roll deps (#5384)
2601f644e Roll external/googletest/ 9fce54804..61332bd7e (2 commits) (#5383)
714966003 opt: Add SwitchDescriptorSetPass (#5375)
6520d83ef linker: Add --use-highest-version option (#5376)
bfc94f63a roll deps (#5382)
b12fc2904 Roll external/googletest/ 7e33b6a1c..987e22561 (5 commits) (#5381)
89ca3aa57 SPV_QCOM_image_processing support (#5223)
c55888661 Fix failing action when PR is already open. (#5380)
0f17d05c4 opt: add bitmask support for capability trimming (#5372)
fddcc8ced Roll external/re2/ 9dc7ae7b5..6148386f0 (3 commits) (#5379)
7ddc65c72 Support 2 Intel extensions (#5357)
43b888649 roll deps (#5374)
d6300ee92 Fix -Wunreachable-code-loop-increment warning (#5373)
8714d7fad enable StorageUniform16 (#5371)
8e3da01b4 Move token version/cap/ext checks from parsing to validation (#5370)
4788ff157 opt: add StorageUniformBufferBlock16 to trim pass (#5367)
ebda56e35 opt: add StoragePushConstant16 to trim pass (#5366)
3af4244ae Roll external/googletest/ 46db91ef6..89b25572d (1 commit) (#5365)
60e684fe7 opt: fix StorageInputOutput16 trimming. (#5359)
13892fe86 Roll external/googletest/ 6f6ab4212..e7fd109b5 (2 commits) (#5356)
727f4346d docs: update references to `main` branch (#5363)
e553b884c Prepare release for v2023.4.rc2 (#5362)
4a9881fe9 Use absolute path to depot_tools (#5360)
09b76c23e Update SPIRV-Headers; test some coop matrix enums (#5361)
1d14d84f2 opt: fix missing CreateTrimCapabilitiesPass definition (#5353)
47fff21d5 instrument: Reduce number of inst_bindless_stream_write_6 calls (#5327)
02cd71d41 roll deps (#5352)
e68fe9be4 Add SPV_EXT_shader_atomic_float_add to allow lists (#5348)
c6d0b0480 build: fix missing files in BUILD.gn (#5351)
b5f600c08 Roll external/googletest/ 01e18376e..40412d851 (1 commit) (#5347)
a0f1c8727 opt: Fix incorrect half float conversion (#5349)
35d8b05de opt: add capability trimming pass (not default). (#5278)
ec90d2872 roll deps (#5345)
d52c39c37 Do not crash when folding 16-bit OpFDiv (#5338)
17d9669d5 enumset: add iterator based constructor/insert (#5344)
daad2295c Roll external/googletest/ cc366710b..d66ce5851 (2 commits) (#5337)
45f7e55af Bump word-wrap from 1.2.3 to 1.2.4 in /tools/sva (#5343)
bf03d4092 opt: change Get* functions to return const& (#5331)
876ccc6cd Add /bigobj to test_opt for VS 2017 (#5336)
c50bc49f5 Fix link flags for Clang-based MinGW cross compile (#5342)
2813da268 kokoro: rename glslang (#5339)
883417544 Set cmake_policy CMP0128 (#5341)
6c7e1acc5 NFC: fix missing algorithm include in enumset file (#5334)
61221e7d6 Add python3 requirement for the script (#5326)
4b6bd5a66 Prepare release v2023.4 (#5330)
9e0b780ff Create SECURITY.md (#5325)
7dd5f95d2 [spirv-opt] Handle OpFunction in GetPtr (#5316)
6add9ccf0 Add support for LiteralFloat type (#5323)
85a448213 NFC: makes the FeatureManager immutable for users (#5329)
29431859f NFC: replace EnumSet::ForEach with range-based-for (#5322)
d6b9389f6 Roll external/spirv-headers/ d0006a393..f1ba373ef (2 commits) (#5320)
5b4fb072e enumset: fix bug in the new iterator class (#5321)
9ab811a12 NFC: fix missing comments on functions (#5318)
9266197c3 instrument: Cast gl_VertexIndex and InstanceIndex to uint (#5319)
ee50fa7d8 Roll external/googletest/ 4a1a299b2..cc366710b (1 commit) (#5317)
3424b16c1 enumset: STL-ize container (#5311)
7ff331af6 source: Give better message if using new Source Language (#5314)
abcd228d9 Update README to say Android NDK r25c is required (#5312)
0530a532f Validate GroupNonUniform instructions (#5296)
4594ffce9 Roll external/re2/ a57a1d646..e66463312 (1 commit) (#5313)
4be7d0e3c Use android ndk r25 (#5309)
e751c7e7d Treat spir-v.xml as utf-8 (#5306)
0f3bea06e NFC: rewrite EnumSet to handle larger enums. (#5289)
870fd1e17 spirv-val: Label SPV_KHR_cooperative_matrix VUID (#5301)
a1e8fff14 Roll external/re2/ 2d39b703d..1c1ffbe3c (1 commit) (#5304)
58459c2b1 roll deps (#5300)
d3b0a522c Roll external/googletest/ 687c58994..251e72039 (1 commit) (#5299)
ea5af2fb5 roll deps (#5297)
f83f50d23 Roll external/googletest/ ec4fed932..8e32de89c (2 commits) (#5294)
7520bfa6b build: remove last references of c++11 (#5295)
e090ce9c4 Update CMakeLists.txt (#5293)
bfb40a240 fix ndk build standard to c++17 (#5290)
310a67020 Validate layouts for PhysicalStorageBuffer pointers (#5291)
c640b1934 Update CMakeLists.txt (#5288)
cfb99efd7 Roll external/googletest/ af39146b4..ec4fed932 (1 commit) (#5287)
04cdb2d34 SPV_KHR_cooperative_matrix (#5286)
16098b3c1 Have effcee add abseil subdirectory (#5281)
daee1e7d3 instrument: Combine descriptor length and init state checking (#5274)
a68ef7b2c cmake: Remove unused SPIRV-Headers variables (#5284)
b12c0fe6f Roll external/googletest/ fb11778f4..af39146b4 (1 commit) (#5285)
54691dcd7 Migrate `exec_tools` back to `tools`. (#5280)
a6b57f2f0 Roll external/googletest/ 9b12f749f..fb11778f4 (4 commits) (#5279)
a63ac9f73 cmake: Use modern Python3 CMake support (#5277)
951980e5a Enable vector constant folding (#4913) (#5272)
a720a6926 Roll external/googletest/ 18fa6a4db..9b12f749f (1 commit) (#5276)
6b9fc7933 Fold negation of integer vectors (#5269)
285f6cefa roll deps (#5273)
d33bea584 instrument: Fix buffer address length calculations (#5257)
b4f352e54 Expose preserve_interface in Optimizer::Register*Passes. (#5268)
40dde04ca Roll external/googletest/ 65cfeca1a..e9078161e (1 commit) (#5267)
6d0e3cf6a Roll external/googletest/ 334704df2..65cfeca1a (1 commit) (#5265)
9c66587d1 spirv-diff: Update test expectations (#5264)
ae1843b67 spirv-diff: Leave undefined ids unpaired. (#5262)
93c13345e spirv-diff: Properly match SPV_KHR_ray_query types. (#5259)
9da026922 roll deps (#5263)
1d7dec3c5 Use windows 2019 to workaround bazel issue (#5261)
59b4febd8 Allow OpTypeBool in UniformConstant (#5237)
5ed21eb1e Add folding rule for OpTranspose (#5241)
ec244c859 Increase tested Android API level (#5253)
c7e436921 roll deps (#5243)
182fd9ebc Allow physical storage buffer pointer in IO (#5251)
226c3bbe6 Fix broken link in README (#5250)
9ed2ac257 Fix pairing of function parameters. (#5225)
8841d560c Add c++ version to .bazelrc (#5247)
cf62673e4 Error for invalid location type (#5249)
673d8bfcb Checkout abseil in the smoketest (#5248)
06bbd7f53 Update deps in sva (#5246)
23cb9b96c spirv-val: Remove VUID from 1.3.251 spec (#5244)
1021ec302 Add Abseil dep to the README (#5242)
3e82fa067 Revert "Disable RE2 autoroll (#5234)" (#5239)
e0936b646 Roll external/spirv-headers/ bdbfd019b..69155b22b (1 commit) (#5238)
af27ece75 Check if const is zero before getting components. (#5217)
235800182 Add Abseil as a dep and update RE2 (#5236)
f29e11dcb diff: Don't give up entry point matching too early. (#5224)
82b1a87b2 Add SPV_NV_bindless_texture to spirv optimizations (#5231)
60c546f3f Roll external/googletest/ bc860af08..bb2941fcc (1 commit) (#5220)
dcfea36ab Have the macos bazel build us git-sync-deps (#5235)
44c9da6fe Remove const zero image operands (#5232)
e357a36cc Disable RE2 autoroll (#5234)
e7c6084fd Prepare release 2023.3 (#5222)
17a26b45f Improve an error message in the assembler (#5219)
7c39951f6 spirv-val: Label Interface Location/Component VUIDs (#5221)
Commands:
./third_party/update-spirvtools.sh
Bug: b/123642959
Change-Id: I4e497c6d70e17beb10e859266dbfa905d1f99136
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/73373
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Tested-by: Romaric Jodin <rjodin@chromium.org>
Presubmit-Ready: Romaric Jodin <rjodin@chromium.org>
Tested-by: Ben Clayton <bclayton@google.com>
diff --git a/third_party/SPIRV-Tools/.bazelrc b/third_party/SPIRV-Tools/.bazelrc
new file mode 100644
index 0000000..79ad594
--- /dev/null
+++ b/third_party/SPIRV-Tools/.bazelrc
@@ -0,0 +1,7 @@
+# Enable Bzlmod for every Bazel command
+common --enable_bzlmod
+
+build --enable_platform_specific_config
+build:linux --cxxopt=-std=c++17
+build:macos --cxxopt=-std=c++17
+build:windows --cxxopt=/std:c++17
diff --git a/third_party/SPIRV-Tools/.bazelversion b/third_party/SPIRV-Tools/.bazelversion
index 0062ac9..a8907c0 100644
--- a/third_party/SPIRV-Tools/.bazelversion
+++ b/third_party/SPIRV-Tools/.bazelversion
@@ -1 +1 @@
-5.0.0
+7.0.2
diff --git a/third_party/SPIRV-Tools/.github/dependabot.yml b/third_party/SPIRV-Tools/.github/dependabot.yml
new file mode 100644
index 0000000..dca857a
--- /dev/null
+++ b/third_party/SPIRV-Tools/.github/dependabot.yml
@@ -0,0 +1,25 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+version: 2
+updates:
+ - package-ecosystem: github-actions
+ directory: /
+ schedule:
+ interval: daily
+ groups:
+ github-actions:
+ patterns:
+ - "*"
+ open-pull-requests-limit: 3
diff --git a/third_party/SPIRV-Tools/.github/workflows/autoroll.yml b/third_party/SPIRV-Tools/.github/workflows/autoroll.yml
index 4520309..ed33622 100644
--- a/third_party/SPIRV-Tools/.github/workflows/autoroll.yml
+++ b/third_party/SPIRV-Tools/.github/workflows/autoroll.yml
@@ -16,14 +16,14 @@
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
# Checkout the depot tools they are needed by roll_deps.sh
- name: Checkout depot tools
run: git clone https://chromium.googlesource.com/chromium/tools/depot_tools.git
- name: Update PATH
- run: echo "./depot_tools" >> $GITHUB_PATH
+ run: echo "$(pwd)/depot_tools" >> $GITHUB_PATH
- name: Download dependencies
run: python3 utils/git-sync-deps
@@ -47,6 +47,10 @@
if: steps.update_dependencies.outputs.changed == 'true'
run: |
git push --force --set-upstream origin roll_deps
- gh pr create --label 'kokoro:run' --base main -f -r s-perron
+ # Create a PR. If it aready exists, the command fails, so ignore the return code.
+ gh pr create --base main -f || true
+ # Add the 'kokoro:run' label so that the kokoro tests will be run.
+ gh pr edit --add-label 'kokoro:run'
+ gh pr merge --auto --squash
env:
GITHUB_TOKEN: ${{ github.token }}
diff --git a/third_party/SPIRV-Tools/.github/workflows/bazel.yml b/third_party/SPIRV-Tools/.github/workflows/bazel.yml
index 88700c4..43c99d6 100644
--- a/third_party/SPIRV-Tools/.github/workflows/bazel.yml
+++ b/third_party/SPIRV-Tools/.github/workflows/bazel.yml
@@ -13,30 +13,22 @@
timeout-minutes: 120
strategy:
matrix:
- os: [ubuntu-latest, macos-latest, windows-latest]
+ os: [ubuntu-latest, macos-latest, windows-2019]
runs-on: ${{matrix.os}}
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
with:
fetch-depth: '0'
- name: Download dependencies
run: python3 utils/git-sync-deps
- name: Mount Bazel cache
- uses: actions/cache@v3
+ uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
with:
path: ~/.bazel/cache
key: bazel-cache-${{ runner.os }}
- - name: Build All (Windows)
- if: ${{matrix.os == 'windows-latest' }}
- run: bazel --output_user_root=~/.bazel/cache build --cxxopt=/std:c++17 //...
- - name: Test All (Windows)
- if: ${{matrix.os == 'windows-latest' }}
- run: bazel --output_user_root=~/.bazel/cache test --cxxopt=/std:c++17 //...
- - name: Build All (Linux, MacOS)
- if: ${{ matrix.os != 'windows-latest' }}
- run: bazel --output_user_root=~/.bazel/cache build --cxxopt=-std=c++17 //...
- - name: Test All (Linux, MacOS)
- if: ${{ matrix.os != 'windows-latest' }}
- run: bazel --output_user_root=~/.bazel/cache test --cxxopt=-std=c++17 //...
+ - name: Build All
+ run: bazel --output_user_root=~/.bazel/cache build //...
+ - name: Test All
+ run: bazel --output_user_root=~/.bazel/cache test //...
diff --git a/third_party/SPIRV-Tools/.github/workflows/ios.yml b/third_party/SPIRV-Tools/.github/workflows/ios.yml
new file mode 100644
index 0000000..feb64a7
--- /dev/null
+++ b/third_party/SPIRV-Tools/.github/workflows/ios.yml
@@ -0,0 +1,30 @@
+name: iOS
+permissions:
+ contents: read
+
+on: [push, pull_request, workflow_dispatch]
+
+jobs:
+ build:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ macos-12, macos-13 ]
+ steps:
+ - uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
+ - uses: lukka/get-cmake@4931ab1fc1604964c055eb330edb3f6b26ba0cfa # v3.29.2
+ - name: Download dependencies
+ run: python3 utils/git-sync-deps
+ # NOTE: The MacOS SDK ships universal binaries. CI should reflect this.
+ - name: Configure Universal Binary for iOS
+ run: |
+ cmake -S . -B build \
+ -D CMAKE_BUILD_TYPE=Debug \
+ -D CMAKE_SYSTEM_NAME=iOS \
+ "-D CMAKE_OSX_ARCHITECTURES=arm64;x86_64" \
+ -G Ninja
+ env:
+ # Linker warnings as errors
+ LDFLAGS: -Wl,-fatal_warnings
+ - run: cmake --build build
+ - run: cmake --install build --prefix /tmp
diff --git a/third_party/SPIRV-Tools/.github/workflows/release.yml b/third_party/SPIRV-Tools/.github/workflows/release.yml
index ada9431..583c8f1 100644
--- a/third_party/SPIRV-Tools/.github/workflows/release.yml
+++ b/third_party/SPIRV-Tools/.github/workflows/release.yml
@@ -6,13 +6,14 @@
push:
tags:
- 'v[0-9]+.[0-9]+'
+ - 'vulkan-sdk-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
- '!v[0-9]+.[0-9]+.rc*'
jobs:
prepare-release-job:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
- name: Prepare CHANGELOG for version
run: |
python utils/generate_changelog.py CHANGES "${{ github.ref_name }}" VERSION_CHANGELOG
diff --git a/third_party/SPIRV-Tools/.github/workflows/scorecard.yml b/third_party/SPIRV-Tools/.github/workflows/scorecard.yml
new file mode 100644
index 0000000..adcfa76
--- /dev/null
+++ b/third_party/SPIRV-Tools/.github/workflows/scorecard.yml
@@ -0,0 +1,53 @@
+name: Scorecard supply-chain security
+on:
+ # For Branch-Protection check. Only the default branch is supported. See
+ # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
+ branch_protection_rule:
+ # To guarantee Maintained check is occasionally updated. See
+ # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
+ schedule:
+ - cron: '36 17 * * 5'
+ push:
+ branches: [ "main" ]
+
+# Declare default permissions as read only.
+permissions: read-all
+
+jobs:
+ analysis:
+ name: Scorecard analysis
+ runs-on: ubuntu-latest
+ permissions:
+ security-events: write # to upload the results to code-scanning dashboard
+ id-token: write # to publish results and get a badge
+
+ steps:
+ - name: "Checkout code"
+ uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
+ with:
+ persist-credentials: false
+
+ - name: "Run analysis"
+ uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1
+ with:
+ results_file: results.sarif
+ results_format: sarif
+ # To enable Branch-Protection uncomment the `repo_token` line below
+ # To create the Fine-grained PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-fine-grained-pat-optional.
+ # repo_token: ${{ secrets.SCORECARD_TOKEN }}
+ publish_results: true # allows the repo to include the Scorecard badge
+
+ # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
+ # format to the repository Actions tab.
+ - name: "Upload artifact"
+ uses: actions/upload-artifact@1746f4ab65b179e0ea60a494b83293b640dd5bba # v4.3.2
+ with:
+ name: SARIF file
+ path: results.sarif
+ retention-days: 5
+
+ # Upload the results to GitHub's code scanning dashboard.
+ - name: "Upload to code-scanning"
+ uses: github/codeql-action/upload-sarif@c7f9125735019aa87cfc361530512d50ea439c71 # v3.25.1
+ with:
+ sarif_file: results.sarif
diff --git a/third_party/SPIRV-Tools/.github/workflows/wasm.yml b/third_party/SPIRV-Tools/.github/workflows/wasm.yml
index 62c9af3..6807b3d 100644
--- a/third_party/SPIRV-Tools/.github/workflows/wasm.yml
+++ b/third_party/SPIRV-Tools/.github/workflows/wasm.yml
@@ -9,7 +9,7 @@
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@1d96c772d19495a3b5c517cd2bc0cb401ea0529f # v4.1.3
with:
fetch-depth: '0'
- name: Build web
diff --git a/third_party/SPIRV-Tools/.gitignore b/third_party/SPIRV-Tools/.gitignore
index ec709ba..e85cea9 100644
--- a/third_party/SPIRV-Tools/.gitignore
+++ b/third_party/SPIRV-Tools/.gitignore
@@ -4,6 +4,7 @@
compile_commands.json
/build*/
/buildtools/
+/external/abseil_cpp/
/external/googletest
/external/SPIRV-Headers
/external/spirv-headers
@@ -22,6 +23,7 @@
bazel-spirv-tools
bazel-SPIRV-Tools
bazel-testlogs
+MODULE.bazel.lock
# Vim
[._]*.s[a-w][a-z]
diff --git a/third_party/SPIRV-Tools/BUILD.bazel b/third_party/SPIRV-Tools/BUILD.bazel
index ae7f35c..48a688e 100644
--- a/third_party/SPIRV-Tools/BUILD.bazel
+++ b/third_party/SPIRV-Tools/BUILD.bazel
@@ -58,6 +58,8 @@
generate_vendor_tables(extension = "nonsemantic.clspvreflection")
+generate_vendor_tables(extension = "nonsemantic.vkspreflection")
+
generate_vendor_tables(
extension = "opencl.debuginfo.100",
operand_kind_prefix = "CLDEBUG100_",
@@ -94,7 +96,7 @@
outs = ["generators.inc"],
cmd = "$(location :generate_registry_tables) --xml=$(location @spirv_headers//:spirv_xml_registry) --generator-output=$(location generators.inc)",
cmd_bat = "$(location :generate_registry_tables) --xml=$(location @spirv_headers//:spirv_xml_registry) --generator-output=$(location generators.inc)",
- exec_tools = [":generate_registry_tables"],
+ tools = [":generate_registry_tables"],
)
py_binary(
@@ -108,10 +110,8 @@
outs = ["build-version.inc"],
cmd = "SOURCE_DATE_EPOCH=0 $(location :update_build_version) $(location CHANGES) $(location build-version.inc)",
cmd_bat = "set SOURCE_DATE_EPOCH=0 && $(location :update_build_version) $(location CHANGES) $(location build-version.inc)",
- # This is explicitly tools and not exec_tools because we run it locally (on the host platform) instead of
- # (potentially remotely) on the execution platform.
- tools = [":update_build_version"],
local = True,
+ tools = [":update_build_version"],
)
# Libraries
@@ -146,15 +146,16 @@
":gen_extinst_lang_headers_OpenCLDebugInfo100",
":gen_glsl_tables_unified1",
":gen_opencl_tables_unified1",
- ":generators_inc",
":gen_vendor_tables_debuginfo",
":gen_vendor_tables_nonsemantic_clspvreflection",
+ ":gen_vendor_tables_nonsemantic_vkspreflection",
":gen_vendor_tables_nonsemantic_shader_debuginfo_100",
":gen_vendor_tables_opencl_debuginfo_100",
":gen_vendor_tables_spv_amd_gcn_shader",
":gen_vendor_tables_spv_amd_shader_ballot",
":gen_vendor_tables_spv_amd_shader_explicit_vertex_parameter",
":gen_vendor_tables_spv_amd_shader_trinary_minmax",
+ ":generators_inc",
],
hdrs = [
"include/spirv-tools/libspirv.h",
@@ -307,17 +308,17 @@
cc_binary(
name = "spirv-objdump",
srcs = [
- "tools/objdump/objdump.cpp",
"tools/objdump/extract_source.cpp",
"tools/objdump/extract_source.h",
+ "tools/objdump/objdump.cpp",
],
copts = COMMON_COPTS,
visibility = ["//visibility:public"],
deps = [
- ":tools_io",
- ":tools_util",
":spirv_tools_internal",
":spirv_tools_opt_internal",
+ ":tools_io",
+ ":tools_util",
"@spirv_headers//:spirv_cpp_headers",
],
)
@@ -428,7 +429,7 @@
copts = TEST_COPTS,
deps = [
":spirv_tools_internal",
- "@com_google_googletest//:gtest",
+ "@googletest//:gtest",
],
)
@@ -439,23 +440,25 @@
name = "base_{testcase}_test".format(testcase = f[len("test/"):-len("_test.cpp")]),
size = "small",
srcs = [f],
- copts = TEST_COPTS + ['-DTESTING'],
+ copts = TEST_COPTS + ["-DTESTING"],
linkstatic = 1,
target_compatible_with = {
"test/timer_test.cpp": incompatible_with(["@bazel_tools//src/conditions:windows"]),
}.get(f, []),
deps = [
+ "tools_util",
":spirv_tools_internal",
":test_lib",
- "tools_util",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
- ["test/*_test.cpp", "test/tools/*_test.cpp"],
+ [
+ "test/*_test.cpp",
+ "test/tools/*_test.cpp",
+ ],
exclude = [
"test/cpp_interface_test.cpp",
- "test/log_test.cpp",
"test/pch_test.cpp",
],
)]
@@ -467,8 +470,8 @@
linkstatic = 1,
deps = [
":spirv_tools_opt_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
"@spirv_headers//:spirv_cpp11_headers",
],
)
@@ -481,21 +484,8 @@
linkstatic = 1,
deps = [
":spirv_tools_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_test(
- name = "base_log_test",
- size = "small",
- srcs = ["test/log_test.cpp"],
- copts = TEST_COPTS,
- linkstatic = 1,
- deps = [
- ":spirv_tools_opt_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
)
@@ -521,8 +511,8 @@
linkstatic = 1,
deps = [
":link_test_lib",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
["test/link/*_test.cpp"],
@@ -538,8 +528,8 @@
":spirv_tools",
":spirv_tools_lint_internal",
":spirv_tools_opt_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
["test/lint/*_test.cpp"],
@@ -563,7 +553,7 @@
":spirv_tools_internal",
":spirv_tools_opt_internal",
"@com_google_effcee//:effcee",
- "@com_google_googletest//:gtest",
+ "@googletest//:gtest",
],
)
@@ -579,8 +569,8 @@
":spirv_tools_opt_internal",
":test_lib",
"@com_google_effcee//:effcee",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(["test/opt/*_test.cpp"])]
@@ -593,8 +583,8 @@
deps = [
":opt_test_lib",
":spirv_tools_opt_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
["test/opt/dominator_tree/*.cpp"],
@@ -612,8 +602,8 @@
":spirv_tools",
":spirv_tools_opt_internal",
"@com_google_effcee//:effcee",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
["test/opt/loop_optimizations/*.cpp"],
@@ -634,7 +624,7 @@
":spirv_tools_reduce",
":test_lib",
":tools_io",
- "@com_google_googletest//:gtest",
+ "@googletest//:gtest",
],
)
@@ -649,7 +639,7 @@
":spirv_tools_internal",
":spirv_tools_opt_internal",
":spirv_tools_reduce",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest_main",
],
) for f in glob(["test/reduce/*_test.cpp"])]
@@ -661,8 +651,8 @@
linkstatic = 1,
deps = [
":spirv_tools_internal",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(["test/util/*_test.cpp"])]
@@ -693,8 +683,8 @@
":spirv_tools_internal",
":test_lib",
":val_test_lib",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
) for f in glob(
["test/val/val_*_test.cpp"],
@@ -715,8 +705,8 @@
":spirv_tools_internal",
":test_lib",
":val_test_lib",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
)
@@ -732,7 +722,7 @@
deps = [
":test_lib",
":val_test_lib",
- "@com_google_googletest//:gtest",
- "@com_google_googletest//:gtest_main",
+ "@googletest//:gtest",
+ "@googletest//:gtest_main",
],
)
diff --git a/third_party/SPIRV-Tools/BUILD.gn b/third_party/SPIRV-Tools/BUILD.gn
index e8622a1..4848fdd 100644
--- a/third_party/SPIRV-Tools/BUILD.gn
+++ b/third_party/SPIRV-Tools/BUILD.gn
@@ -331,6 +331,10 @@
"...nil...",
],
[
+ "nonsemantic.vkspreflection",
+ "...nil...",
+ ],
+ [
"nonsemantic.shader.debuginfo.100",
"SHDEBUG100_",
],
@@ -696,6 +700,8 @@
"source/opt/interface_var_sroa.h",
"source/opt/interp_fixup_pass.cpp",
"source/opt/interp_fixup_pass.h",
+ "source/opt/invocation_interlock_placement_pass.cpp",
+ "source/opt/invocation_interlock_placement_pass.h",
"source/opt/ir_builder.h",
"source/opt/ir_context.cpp",
"source/opt/ir_context.h",
@@ -738,6 +744,8 @@
"source/opt/mem_pass.h",
"source/opt/merge_return_pass.cpp",
"source/opt/merge_return_pass.h",
+ "source/opt/modify_maximal_reconvergence.cpp",
+ "source/opt/modify_maximal_reconvergence.h",
"source/opt/module.cpp",
"source/opt/module.h",
"source/opt/null_pass.h",
@@ -792,7 +800,11 @@
"source/opt/strip_nonsemantic_info_pass.h",
"source/opt/struct_cfg_analysis.cpp",
"source/opt/struct_cfg_analysis.h",
+ "source/opt/switch_descriptorset_pass.cpp",
+ "source/opt/switch_descriptorset_pass.h",
"source/opt/tree_iterator.h",
+ "source/opt/trim_capabilities_pass.cpp",
+ "source/opt/trim_capabilities_pass.h",
"source/opt/type_manager.cpp",
"source/opt/type_manager.h",
"source/opt/types.cpp",
diff --git a/third_party/SPIRV-Tools/CHANGES b/third_party/SPIRV-Tools/CHANGES
index dbe31a0..102703a 100644
--- a/third_party/SPIRV-Tools/CHANGES
+++ b/third_party/SPIRV-Tools/CHANGES
@@ -1,5 +1,150 @@
Revision history for SPIRV-Tools
+v2024.2 2024-04-22
+ - General
+ - Add SPIRV_TOOLS_EXPORT to public C++ API (#5591)
+ - Use bazel 7 and bzlmod (#5601)
+ - Optimizer
+ - opt: add GroupNonUniformPartitionedNV capability to trim pass (#5648)
+ - Fix rebuilding types with circular references. (#5637)
+ - Add AliasedPointer decoration (#5635)
+ - add support for vulkan-shader-profiler external passes (#5512)
+ - Validator
+ - A fix to support of SPV_QCOM_image_processing2 (#5646)
+ - spirv-val: Add Vulkan check for Rect Dim in OpTypeImage (#5644)
+ - Validate duplicate decorations and execution modes (#5641)
+ - Validator: Support SPV_NV_raw_access_chains (#5568)
+
+v2024.1 2024-03-06
+ - General
+ - Add tooling support for SPV_KHR_maximal_reconvergence (#5542)
+ - Add support for SPV_KHR_float_controls2 (#5543)
+ - SPV_KHR_quad_control (#5547)
+ - Fold 64-bit int operations (#5561)
+ - update image enum tests to remove Kernel capability (#5562)
+ - Support operand kind for SPV_INTEL_maximum_registers (#5580)
+ - SPV_NV_shader_atomic_fp16_vector (#5581)
+ - Support for SPV_QCOM_image_processing2 (#5582)
+ - Fix access chain struct checks (#5592)
+ - Optimizer
+ - opt: add Int16 and Float16 to capability trim pass (#5519)
+ - Add preserver-interface option to spirv-opt (#5524)
+ - spirv-opt: Fix OpCompositeExtract relaxation with struct operands (#5536)
+ - opt: Add VulkanMemoryModelDeviceScope to trim (#5544)
+ - opt: Add TrimCapabilities pass to spirv-opt tool (#5545)
+ - Add modify-maximal-reconvergence to spirv-opt help (#5546)
+ - opt: add SPV_EXT_mesh_shader to opt allowlist (#5551)
+ - opt: Add OpEntryPoint to DescriptorScalarReplacement pass (#5553)
+ - opt: prevent meld to merge block with MaximalReconvergence (#5557)
+ - [OPT] Use new instruction folder for for all opcodes in spec consti folding (#5569)
+ - [OPT] Identify arrays with unknown length in copy prop arrays (#5570)
+ - [OPT] Add removed unused interface var pass to legalization passes (#5579)
+ - Validator
+ - spirv-val: Re-enable OpControlBarrier VU (#5527)
+ - spirv-val: Add Mesh Primitive Built-In validaiton (#5529)
+ - spirv-val: Validate PhysicalStorageBuffer Stage Interface (#5539)
+ - spirv-val: Multiple interface var with same SC (#5528)
+ - spirv-val: Revert Validate PhysicalStorageBuffer Stage Interface (#5575)
+ - spirv-val: Make Constant evaluation consistent (#5587)
+
+v2023.6 2023-12-18
+ - General
+ - update_build_version.py produce deterministic header. (#5426)
+ - Support missing git in update_build_version.py (#5473)
+ - Optimizer
+ - Add ComputeDerivativeGroup*NV capabilities to trim capabilities pass. (#5430)
+ - Do not crash when tryingto fold unsupported spec constant (#5496)
+ - instrument: Fix handling of gl_InvocationID (#5493)
+ - Fix nullptr argument in MarkInsertChain (#5465)
+ - opt: support 64-bit OpAccessChain index in FixStorageClass (#5446)
+ - opt: add StorageImageReadWithoutFormat to cap trim (#5475)
+ - opt: add PhysicalStorageBufferAddresses to trim (#5476)
+ - Fix array size calculation (#5463
+ - Validator
+ - spirv-val: Loosen restriction on base type of DebugTypePointer and DebugTypeQualifier (#5479)
+ - spirv-val: Add WorkgroupMemoryExplicitLayoutKHR check for Block (#5461)
+
+v2023.5 2023-10-15
+ - General
+ - Support 2 Intel extensions (#5357)
+ - SPV_QCOM_image_processing support (#5223)
+ - Optimizer
+ - opt: fix StorageInputOutput16 trimming. (#5359)
+ - opt: add StoragePushConstant16 to trim pass (#5366)
+ - opt: enable StorageUniform16 (#5371)
+ - opt: add bitmask support for capability trimming (#5372)
+ - opt: Add SwitchDescriptorSetPass (#5375)
+ - opt: add FragmentShader*InterlockEXT to capability trim pass (#5390)
+ - opt: add Int64 capability to trim pass (#5398)
+ - opt: add Float64 capability to trim pass (#5428)
+ - opt: add raytracing/rayquery to trim pass (#5397)
+ - opt: add ImageMSArray capability to trim pass. (#5395)
+ - Add SPV_KHR_physical_storage_buffer to allowlists (#5402)
+ - Add SPV_EXT_fragment_shader_interlock to allow lists (#5393)
+ - Make sure that fragment shader interlock instructions are not removed by DCE (#5400)
+ - instrument: Use Import linkage for instrumentation functions (#5355)
+ - Add a new legalization pass to dedupe invocation interlock instructions (#5409)
+ - instrument: Ensure linking works even of nothing is changed (#5419)
+ - Validator
+ - Move token version/cap/ext checks from parsing to validation (#5370)
+ - val: re-add ImageMSArray validation (#5394)
+ - Linker
+ - linker: Add --use-highest-version option
+
+v2023.4 2023-07-17
+ - General
+ - Set cmake_policy CMP0128 (#5341)
+ - Add python3 requirement for the script (#5326)
+ - Add support for LiteralFloat type (#5323)
+ - SPV_KHR_cooperative_matrix (#5286)
+ - Allow OpTypeBool in UniformConstant (#5237)
+ - Allow physical storage buffer pointer in IO (#5251)
+ - Remove const zero image operands (#5232)
+ - Optimizer
+ - Enable vector constant folding (#4913) (#5272)
+ - Fold negation of integer vectors (#5269)
+ - Add folding rule for OpTranspose (#5241)
+ - Add SPV_NV_bindless_texture to spirv optimizations (#5231)
+ - Fix incorrect half float conversion (#5349)
+ - Add SPV_EXT_shader_atomic_float_add to allow lists (#5348)
+ - Instrument
+ - instrument: Cast gl_VertexIndex and InstanceIndex to uint (#5319)
+ - instrument: Fix buffer address length calculations (#5257)
+ - instrument: Reduce number of inst_bindless_stream_write_6 calls (#5327)
+ - Validator
+ - Validate GroupNonUniform instructions (#5296)
+ - spirv-val: Label SPV_KHR_cooperative_matrix VUID (#5301)
+ - Validate layouts for PhysicalStorageBuffer pointers (#5291)
+ - spirv-val: Remove VUID from 1.3.251 spec (#5244)
+ - Diff
+ - spirv-diff: Update test expectations (#5264)
+ - spirv-diff: Leave undefined ids unpaired. (#5262)
+ - spirv-diff: Properly match SPV_KHR_ray_query types. (#5259)
+ - diff: Don't give up entry point matching too early. (#5224)
+
+v2023.3 2023-05-15
+ - General
+ - Update spirv_headers to include SPV_KHR_ray_tracing_position_fetch (#5205)
+ - spirv-tools: Add support for QNX (#5211)
+ - build: set std=c++17 for BUILD.gn (#5162)
+ - Optimizer
+ - Run ADCE when the printf extension is used. (#5215)
+ - Don't convert struct members to half (#5201)
+ - Apply scalar replacement on vars with Pointer decorations (#5208)
+ - opt: Fix null deref in OpMatrixTimesVector and OpVectorTimesMatrix (#5199)
+ - instrument: Add set and binding to bindless error records (#5204)
+ - instrument: Change descriptor state storage format (#5178)
+ - Fix LICMPass (#5087)
+ - Add Vulkan memory model to allow lists (#5173)
+ - Do not remove control barrier after spv1.3 (#5174)
+ - Validator
+ - spirv-val: Label Interface Location/Component VUIDs (#5221)
+ - Add support for SPV_EXT_shader_tile_image (#5188)
+ - Fix vector OpConstantComposite type validation (#5191)
+ - spirv-val: Label new Vulkan VUID 07951 (#5154)
+ - Fuzz
+ - Do not define GOOGLE_PROTOBUF_INTERNAL_DONATE_STEAL_INLINE if it is already defined. (#5200)
+
v2023.2 2023-03-10
- General
- build: move from c++11 to c++17 (#4983)
diff --git a/third_party/SPIRV-Tools/CMakeLists.txt b/third_party/SPIRV-Tools/CMakeLists.txt
index 71cdc00..0ba173f 100644
--- a/third_party/SPIRV-Tools/CMakeLists.txt
+++ b/third_party/SPIRV-Tools/CMakeLists.txt
@@ -16,6 +16,16 @@
project(spirv-tools)
+# Avoid a bug in CMake 3.22.1. By default it will set -std=c++11 for
+# targets in test/*, when those tests need -std=c++17.
+# https://github.com/KhronosGroup/SPIRV-Tools/issues/5340
+# The bug is fixed in CMake 3.22.2
+if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.22.1")
+ if (${CMAKE_VERSION} VERSION_LESS "3.22.2")
+ cmake_policy(SET CMP0128 NEW)
+ endif()
+endif()
+
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
enable_testing()
@@ -53,6 +63,8 @@
add_definitions(-DSPIRV_IOS)
elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "tvOS")
add_definitions(-DSPIRV_TVOS)
+elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "visionOS")
+ add_definitions(-DSPIRV_VISIONOS)
elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
add_definitions(-DSPIRV_ANDROID)
set(SPIRV_TIMER_ENABLED ${SPIRV_ALLOW_TIMERS})
@@ -231,7 +243,7 @@
# For MinGW cross compile, statically link to the C++ runtime.
# But it still depends on MSVCRT.dll.
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
- if (${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
+ if (NOT MSVC)
set_target_properties(${TARGET} PROPERTIES
LINK_FLAGS -static -static-libgcc -static-libstdc++)
endif()
@@ -250,7 +262,7 @@
endif()
# Tests require Python3
-find_host_package(PythonInterp 3 REQUIRED)
+find_host_package(Python3 REQUIRED)
# Check for symbol exports on Linux.
# At the moment, this check will fail on the OSX build machines for the Android NDK.
@@ -259,7 +271,7 @@
macro(spvtools_check_symbol_exports TARGET)
if (NOT "${SPIRV_SKIP_TESTS}")
add_test(NAME spirv-tools-symbol-exports-${TARGET}
- COMMAND ${PYTHON_EXECUTABLE}
+ COMMAND Python3::Interpreter
${spirv-tools_SOURCE_DIR}/utils/check_symbol_exports.py "$<TARGET_FILE:${TARGET}>")
endif()
endmacro()
@@ -292,15 +304,23 @@
endmacro()
endif()
-# Defaults to OFF if the user didn't set it.
-option(SPIRV_SKIP_EXECUTABLES
- "Skip building the executable and tests along with the library"
- ${SPIRV_SKIP_EXECUTABLES})
-option(SPIRV_SKIP_TESTS
- "Skip building tests along with the library" ${SPIRV_SKIP_TESTS})
-if ("${SPIRV_SKIP_EXECUTABLES}")
+# Currently iOS and Android are very similar.
+# They both have their own packaging (APP/APK).
+# Which makes regular executables/testing problematic.
+#
+# Currently the only deliverables for these platforms are
+# libraries (either STATIC or SHARED).
+#
+# Furthermore testing is equally problematic.
+if (IOS OR ANDROID)
+ set(SPIRV_SKIP_EXECUTABLES ON)
+endif()
+
+option(SPIRV_SKIP_EXECUTABLES "Skip building the executable and tests along with the library")
+if (SPIRV_SKIP_EXECUTABLES)
set(SPIRV_SKIP_TESTS ON)
endif()
+option(SPIRV_SKIP_TESTS "Skip building tests along with the library")
# Defaults to ON. The checks can be time consuming.
# Turn off if they take too long.
@@ -358,7 +378,7 @@
if (NOT "${SPIRV_SKIP_TESTS}")
add_test(NAME spirv-tools-copyrights
- COMMAND ${PYTHON_EXECUTABLE} utils/check_copyright.py
+ COMMAND Python3::Interpreter utils/check_copyright.py
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
@@ -367,7 +387,8 @@
# Build pkg-config file
# Use a first-class target so it's regenerated when relevant files are updated.
-add_custom_target(spirv-tools-pkg-config ALL
+add_custom_command(
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools.pc
COMMAND ${CMAKE_COMMAND}
-DCHANGES_FILE=${CMAKE_CURRENT_SOURCE_DIR}/CHANGES
-DTEMPLATE_FILE=${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools.pc.in
@@ -377,8 +398,9 @@
-DCMAKE_INSTALL_INCLUDEDIR=${CMAKE_INSTALL_INCLUDEDIR}
-DSPIRV_LIBRARIES=${SPIRV_LIBRARIES}
-P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake
- DEPENDS "CHANGES" "cmake/SPIRV-Tools.pc.in" "cmake/write_pkg_config.cmake")
-add_custom_target(spirv-tools-shared-pkg-config ALL
+ DEPENDS "CHANGES" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools.pc.in" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake")
+add_custom_command(
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools-shared.pc
COMMAND ${CMAKE_COMMAND}
-DCHANGES_FILE=${CMAKE_CURRENT_SOURCE_DIR}/CHANGES
-DTEMPLATE_FILE=${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools-shared.pc.in
@@ -388,7 +410,10 @@
-DCMAKE_INSTALL_INCLUDEDIR=${CMAKE_INSTALL_INCLUDEDIR}
-DSPIRV_SHARED_LIBRARIES=${SPIRV_SHARED_LIBRARIES}
-P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake
- DEPENDS "CHANGES" "cmake/SPIRV-Tools-shared.pc.in" "cmake/write_pkg_config.cmake")
+ DEPENDS "CHANGES" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools-shared.pc.in" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake")
+add_custom_target(spirv-tools-pkg-config
+ ALL
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools-shared.pc ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools.pc)
# Install pkg-config file
if (ENABLE_SPIRV_TOOLS_INSTALL)
diff --git a/third_party/SPIRV-Tools/CONTRIBUTING.md b/third_party/SPIRV-Tools/CONTRIBUTING.md
index 893998e..11fb4e2 100644
--- a/third_party/SPIRV-Tools/CONTRIBUTING.md
+++ b/third_party/SPIRV-Tools/CONTRIBUTING.md
@@ -3,7 +3,7 @@
## For users: Reporting bugs and requesting features
We organize known future work in GitHub projects. See
-[Tracking SPIRV-Tools work with GitHub projects](https://github.com/KhronosGroup/SPIRV-Tools/blob/master/docs/projects.md)
+[Tracking SPIRV-Tools work with GitHub projects](https://github.com/KhronosGroup/SPIRV-Tools/blob/main/docs/projects.md)
for more.
To report a new bug or request a new feature, please file a GitHub issue. Please
@@ -46,7 +46,7 @@
approved it, but you must do it before we can put your code into our codebase.
See
-[README.md](https://github.com/KhronosGroup/SPIRV-Tools/blob/master/README.md)
+[README.md](https://github.com/KhronosGroup/SPIRV-Tools/blob/main/README.md)
for instruction on how to get, build, and test the source. Once you have made
your changes:
@@ -59,7 +59,7 @@
* If your patch completely fixes bug 1234, the commit message should say
`Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/1234` When you do
this, the issue will be closed automatically when the commit goes into
- master. Also, this helps us update the [CHANGES](CHANGES) file.
+ main. Also, this helps us update the [CHANGES](CHANGES) file.
* Watch the continuous builds to make sure they pass.
* Request a code review.
@@ -107,7 +107,7 @@
## For maintainers: Merging a PR
-We intend to maintain a linear history on the GitHub master branch, and the
+We intend to maintain a linear history on the GitHub main branch, and the
build and its tests should pass at each commit in that history. A linear
always-working history is easier to understand and to bisect in case we want to
find which commit introduced a bug. The
diff --git a/third_party/SPIRV-Tools/DEPS b/third_party/SPIRV-Tools/DEPS
index 9b6039e..8413d1b 100644
--- a/third_party/SPIRV-Tools/DEPS
+++ b/third_party/SPIRV-Tools/DEPS
@@ -3,18 +3,23 @@
vars = {
'github': 'https://github.com',
- 'effcee_revision': '66edefd2bb641de8a2f46b476de21f227fc03a28',
+ 'abseil_revision': '79ca5d7aad63973c83a4962a66ab07cd623131ea',
- 'googletest_revision': 'bc860af08783b8113005ca7697da5f5d49a8056f',
+ 'effcee_revision': '19b4aa87af25cb4ee779a071409732f34bfc305c',
+
+ 'googletest_revision': '5a37b517ad4ab6738556f0284c256cae1466c5b4',
# Use protobufs before they gained the dependency on abseil
'protobuf_revision': 'v21.12',
- 're2_revision': 'c9cba76063cf4235c1a15dd14a24a4ef8d623761',
- 'spirv_headers_revision': '268a061764ee69f09a477a695bf6a11ffe311b8d',
+ 're2_revision': '917047f3606d3ba9e2de0d383c3cd80c94ed732c',
+ 'spirv_headers_revision': '4f7b471f1a66b6d06462cd4ba57628cc0cd087d7',
}
deps = {
+ 'external/abseil_cpp':
+ Var('github') + '/abseil/abseil-cpp.git@' + Var('abseil_revision'),
+
'external/effcee':
Var('github') + '/google/effcee.git@' + Var('effcee_revision'),
diff --git a/third_party/SPIRV-Tools/MODULE.bazel b/third_party/SPIRV-Tools/MODULE.bazel
new file mode 100644
index 0000000..c36fe45
--- /dev/null
+++ b/third_party/SPIRV-Tools/MODULE.bazel
@@ -0,0 +1,7 @@
+bazel_dep(name = "bazel_skylib", version = "1.5.0")
+
+bazel_dep(name = "googletest", dev_dependency = True)
+local_path_override(
+ module_name = "googletest",
+ path = "external/googletest",
+)
diff --git a/third_party/SPIRV-Tools/README.md b/third_party/SPIRV-Tools/README.md
index 92e4d3c..7db5bd4 100644
--- a/third_party/SPIRV-Tools/README.md
+++ b/third_party/SPIRV-Tools/README.md
@@ -1,4 +1,5 @@
# SPIR-V Tools
+[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/KhronosGroup/SPIRV-Tools/badge)](https://securityscorecards.dev/viewer/?uri=github.com/KhronosGroup/SPIRV-Tools)
NEWS 2023-01-11: Development occurs on the `main` branch.
@@ -23,6 +24,13 @@
## Downloads
+The official releases for SPIRV-Tools can be found on LunarG's
+[SDK download page](https://vulkan.lunarg.com/sdk/home).
+
+For convenience, here are also links to the latest builds (HEAD).
+Those are untested automated builds. Those are not official releases, nor
+are guaranteed to work. Official releases builds are in the Vulkan SDK.
+
<img alt="Linux" src="kokoro/img/linux.png" width="20px" height="20px" hspace="2px"/>[![Linux Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_linux_clang_release.svg)](https://storage.googleapis.com/spirv-tools/badges/build_link_linux_clang_release.html)
<img alt="MacOS" src="kokoro/img/macos.png" width="20px" height="20px" hspace="2px"/>[![MacOS Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_macos_clang_release.svg)](https://storage.googleapis.com/spirv-tools/badges/build_link_macos_clang_release.html)
<img alt="Windows" src="kokoro/img/windows.png" width="20px" height="20px" hspace="2px"/>[![Windows Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_windows_release.svg)](https://storage.googleapis.com/spirv-tools/badges/build_link_windows_vs2019_release.html)
@@ -48,17 +56,14 @@
## Releases
-Some versions of SPIRV-Tools are tagged as stable releases (see
-[tags](https://github.com/KhronosGroup/SPIRV-Tools/tags) on github).
-These versions undergo extra testing.
-Releases are not directly related to releases (or versions) of
-[SPIRV-Headers][spirv-headers].
-Releases of SPIRV-Tools are tested against the version of SPIRV-Headers listed
-in the [DEPS](DEPS) file.
-The release generally uses the most recent compatible version of SPIRV-Headers
-available at the time of release.
-No version of SPIRV-Headers other than the one listed in the DEPS file is
-guaranteed to work with the SPIRV-Tools release.
+The official releases for SPIRV-Tools can be found on LunarG's
+[SDK download page](https://vulkan.lunarg.com/sdk/home).
+
+You can find either the prebuilt, and QA tested binaries, or download the
+SDK Config, which lists the commits to use to build the release from scratch.
+
+GitHub releases are deprecated, and we will not publish new releases until
+further notice.
## Supported features
@@ -292,16 +297,18 @@
git clone https://github.com/google/googletest.git spirv-tools/external/googletest
git clone https://github.com/google/effcee.git spirv-tools/external/effcee
git clone https://github.com/google/re2.git spirv-tools/external/re2
+ git clone https://github.com/abseil/abseil-cpp.git spirv-tools/external/abseil_cpp
#### Dependency on Effcee
Some tests depend on the [Effcee][effcee] library for stateful matching.
-Effcee itself depends on [RE2][re2].
+Effcee itself depends on [RE2][re2], and RE2 depends on [Abseil][abseil-cpp].
* If SPIRV-Tools is configured as part of a larger project that already uses
Effcee, then that project should include Effcee before SPIRV-Tools.
-* Otherwise, SPIRV-Tools expects Effcee sources to appear in `external/effcee`
- and RE2 sources to appear in `external/re2`.
+* Otherwise, SPIRV-Tools expects Effcee sources to appear in `external/effcee`,
+ RE2 sources to appear in `external/re2`, and Abseil sources to appear in
+ `external/abseil_cpp`.
### Source code organization
@@ -313,6 +320,9 @@
* `external/re2`: Location of [RE2][re2] sources, if the `re2` library is not already
configured by an enclosing project.
(The Effcee project already requires RE2.)
+* `external/abseil_cpp`: Location of [Abseil][abseil-cpp] sources, if Abseil is
+ not already configured by an enclosing project.
+ (The RE2 project already requires Abseil.)
* `include/`: API clients should add this directory to the include search path
* `external/spirv-headers`: Intended location for
[SPIR-V headers][spirv-headers], not provided
@@ -381,15 +391,8 @@
### Build using Bazel
You can also use [Bazel](https://bazel.build/) to build the project.
-On linux:
```sh
-cd <spirv-dir>
-bazel build --cxxopt=-std=c++17 :all
-```
-
-On windows:
-```sh
-bazel build --cxxopt=/std:c++17 :all
+bazel build :all
```
### Build a node.js package using Emscripten
@@ -427,7 +430,7 @@
- [Python 3](http://www.python.org/): for utility scripts and running the test
suite.
- [Bazel](https://bazel.build/) (optional): if building the source with Bazel,
-you need to install Bazel Version 5.0.0 on your machine. Other versions may
+you need to install Bazel Version 7.0.2 on your machine. Other versions may
also work, but are not verified.
- [Emscripten SDK](https://emscripten.org) (optional): if building the
WebAssembly module.
@@ -480,12 +483,12 @@
### Android ndk-build
SPIR-V Tools supports building static libraries `libSPIRV-Tools.a` and
-`libSPIRV-Tools-opt.a` for Android:
+`libSPIRV-Tools-opt.a` for Android. Using the Android NDK r25c or later:
```
cd <spirv-dir>
-export ANDROID_NDK=/path/to/your/ndk
+export ANDROID_NDK=/path/to/your/ndk # NDK r25c or later
mkdir build && cd build
mkdir libs
@@ -798,6 +801,7 @@
[googletest-issue-610]: https://github.com/google/googletest/issues/610
[effcee]: https://github.com/google/effcee
[re2]: https://github.com/google/re2
+[abseil-cpp]: https://github.com/abseil/abseil-cpp
[CMake]: https://cmake.org/
[cpp-style-guide]: https://google.github.io/styleguide/cppguide.html
[clang-sanitizers]: http://clang.llvm.org/docs/UsersManual.html#controlling-code-generation
diff --git a/third_party/SPIRV-Tools/SECURITY.md b/third_party/SPIRV-Tools/SECURITY.md
new file mode 100644
index 0000000..99c5f44
--- /dev/null
+++ b/third_party/SPIRV-Tools/SECURITY.md
@@ -0,0 +1,13 @@
+# Security Policy
+
+## Supported Versions
+
+Security updates are applied only to the latest release.
+
+## Reporting a Vulnerability
+
+If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released.
+
+Please disclose it at [security advisory](https://github.com/KhronosGroup/SPIRV-Tools/security/advisories/new).
+
+This project is maintained by a team of volunteers on a reasonable-effort basis. As such, please give us at least 90 days to work on a fix before public exposure.
diff --git a/third_party/SPIRV-Tools/WORKSPACE b/third_party/SPIRV-Tools/WORKSPACE
index 5abfc98..6e78059 100644
--- a/third_party/SPIRV-Tools/WORKSPACE
+++ b/third_party/SPIRV-Tools/WORKSPACE
@@ -4,11 +4,6 @@
)
local_repository(
- name = "com_google_googletest",
- path = "external/googletest",
-)
-
-local_repository(
name = "com_googlesource_code_re2",
path = "external/re2",
)
@@ -17,3 +12,8 @@
name = "com_google_effcee",
path = "external/effcee",
)
+
+local_repository(
+ name = "abseil-cpp",
+ path = "external/abseil_cpp",
+)
diff --git a/third_party/SPIRV-Tools/build_defs.bzl b/third_party/SPIRV-Tools/build_defs.bzl
index 4d6f15c..76bf3e7 100644
--- a/third_party/SPIRV-Tools/build_defs.bzl
+++ b/third_party/SPIRV-Tools/build_defs.bzl
@@ -88,7 +88,7 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_grammar_tables"],
+ tools = [":generate_grammar_tables"],
visibility = ["//visibility:private"],
)
@@ -123,7 +123,7 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_grammar_tables"],
+ tools = [":generate_grammar_tables"],
visibility = ["//visibility:private"],
)
@@ -151,7 +151,7 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_grammar_tables"],
+ tools = [":generate_grammar_tables"],
visibility = ["//visibility:private"],
)
@@ -179,7 +179,7 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_grammar_tables"],
+ tools = [":generate_grammar_tables"],
visibility = ["//visibility:private"],
)
@@ -207,7 +207,7 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_grammar_tables"],
+ tools = [":generate_grammar_tables"],
visibility = ["//visibility:private"],
)
@@ -229,6 +229,6 @@
outs = outs.values(),
cmd = cmd,
cmd_bat = cmd,
- exec_tools = [":generate_language_headers"],
+ tools = [":generate_language_headers"],
visibility = ["//visibility:private"],
)
diff --git a/third_party/SPIRV-Tools/docs/downloads.md b/third_party/SPIRV-Tools/docs/downloads.md
index 168937a..0454b9e 100644
--- a/third_party/SPIRV-Tools/docs/downloads.md
+++ b/third_party/SPIRV-Tools/docs/downloads.md
@@ -1,8 +1,24 @@
# Downloads
-## Latest builds
+## Vulkan SDK
-Download the latest builds of the [master](https://github.com/KhronosGroup/SPIRV-Tools/tree/master) branch.
+The official releases for SPIRV-Tools can be found on LunarG's
+[SDK download page](https://vulkan.lunarg.com/sdk/home).
+The Vulkan SDK is updated approximately every six weeks.
+
+## Android NDK
+
+SPIRV-Tools host executables, and library sources are published as
+part of the [Android NDK](https://developer.android.com/ndk/downloads).
+
+## Automated builds
+
+For convenience, here are also links to the latest builds (HEAD).
+Those are untested automated builds. Those are not official releases, nor
+are guaranteed to work. Official releases builds are in the Android NDK or
+Vulkan SDK.
+
+Download the latest builds of the [main](https://github.com/KhronosGroup/SPIRV-Tools/tree/main) branch.
### Release build
| Windows | Linux | MacOS |
@@ -15,14 +31,3 @@
| --- | --- | --- |
| [MSVC 2017](https://storage.googleapis.com/spirv-tools/badges/build_link_windows_vs2017_debug.html) | [clang](https://storage.googleapis.com/spirv-tools/badges/build_link_linux_clang_debug.html) | [clang](https://storage.googleapis.com/spirv-tools/badges/build_link_macos_clang_debug.html) |
| | [gcc](https://storage.googleapis.com/spirv-tools/badges/build_link_linux_gcc_debug.html) | |
-
-
-## Vulkan SDK
-
-SPIRV-Tools is published as part of the [LunarG Vulkan SDK](https://www.lunarg.com/vulkan-sdk/).
-The Vulkan SDK is updated approximately every six weeks.
-
-## Android NDK
-
-SPIRV-Tools host executables, and library sources are published as
-part of the [Android NDK](https://developer.android.com/ndk/downloads).
diff --git a/third_party/SPIRV-Tools/docs/projects.md b/third_party/SPIRV-Tools/docs/projects.md
index 8f7f0bc..cc88cb3 100644
--- a/third_party/SPIRV-Tools/docs/projects.md
+++ b/third_party/SPIRV-Tools/docs/projects.md
@@ -34,7 +34,7 @@
ones.
* They determine if the work for a card has been completed.
* Normally they are the person (or persons) who can approve and merge a pull
- request into the `master` branch.
+ request into the `main` branch.
Our projects organize cards into the following columns:
* `Ideas`: Work which could be done, captured either as Cards or Notes.
@@ -51,7 +51,7 @@
claimed by someone.
* `Done`: Issues which have been resolved, by completing their work.
* The changes have been applied to the repository, typically by being pushed
- into the `master` branch.
+ into the `main` branch.
* Other kinds of work could update repository settings, for example.
* `Rejected ideas`: Work which has been considered, but which we don't want
implemented.
diff --git a/third_party/SPIRV-Tools/external/CMakeLists.txt b/third_party/SPIRV-Tools/external/CMakeLists.txt
index 6ee37d9..5d8a3da 100644
--- a/third_party/SPIRV-Tools/external/CMakeLists.txt
+++ b/third_party/SPIRV-Tools/external/CMakeLists.txt
@@ -41,8 +41,6 @@
# Do this so enclosing projects can use SPIRV-Headers_SOURCE_DIR to find
# headers to include.
if (NOT DEFINED SPIRV-Headers_SOURCE_DIR)
- set(SPIRV_HEADERS_SKIP_INSTALL ON)
- set(SPIRV_HEADERS_SKIP_EXAMPLES ON)
add_subdirectory(${SPIRV_HEADER_DIR})
endif()
else()
@@ -93,10 +91,22 @@
# Find Effcee and RE2, for testing.
+ # RE2 depends on Abseil. We set absl_SOURCE_DIR if it is not already set, so
+ # that effcee can find abseil.
+ if(NOT TARGET absl::base)
+ if (NOT absl_SOURCE_DIR)
+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/abseil_cpp)
+ set(absl_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/abseil_cpp" CACHE STRING "Abseil source dir" )
+ endif()
+ endif()
+ endif()
+
# First find RE2, since Effcee depends on it.
# If already configured, then use that. Otherwise, prefer to find it under 're2'
# in this directory.
if (NOT TARGET re2)
+
+
# If we are configuring RE2, then turn off its testing. It takes a long time and
# does not add much value for us. If an enclosing project configured RE2, then it
# has already chosen whether to enable RE2 testing.
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/instrument.hpp b/third_party/SPIRV-Tools/include/spirv-tools/instrument.hpp
index 448cf8a..0a6e630 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/instrument.hpp
+++ b/third_party/SPIRV-Tools/include/spirv-tools/instrument.hpp
@@ -73,213 +73,14 @@
// which generated the validation error.
static const int kInstCommonOutInstructionIdx = 2;
-// This is the stage which generated the validation error. This word is used
-// to determine the contents of the next two words in the record.
-// 0:Vert, 1:TessCtrl, 2:TessEval, 3:Geom, 4:Frag, 5:Compute
-static const int kInstCommonOutStageIdx = 3;
-static const int kInstCommonOutCnt = 4;
-
-// Stage-specific Stream Record Offsets
-//
-// Each stage will contain different values in the next set of words of the
-// record used to identify which instantiation of the shader generated the
-// validation error.
-//
-// Vertex Shader Output Record Offsets
-static const int kInstVertOutVertexIndex = kInstCommonOutCnt;
-static const int kInstVertOutInstanceIndex = kInstCommonOutCnt + 1;
-static const int kInstVertOutUnused = kInstCommonOutCnt + 2;
-
-// Frag Shader Output Record Offsets
-static const int kInstFragOutFragCoordX = kInstCommonOutCnt;
-static const int kInstFragOutFragCoordY = kInstCommonOutCnt + 1;
-static const int kInstFragOutUnused = kInstCommonOutCnt + 2;
-
-// Compute Shader Output Record Offsets
-static const int kInstCompOutGlobalInvocationIdX = kInstCommonOutCnt;
-static const int kInstCompOutGlobalInvocationIdY = kInstCommonOutCnt + 1;
-static const int kInstCompOutGlobalInvocationIdZ = kInstCommonOutCnt + 2;
-
-// Tessellation Control Shader Output Record Offsets
-static const int kInstTessCtlOutInvocationId = kInstCommonOutCnt;
-static const int kInstTessCtlOutPrimitiveId = kInstCommonOutCnt + 1;
-static const int kInstTessCtlOutUnused = kInstCommonOutCnt + 2;
-
-// Tessellation Eval Shader Output Record Offsets
-static const int kInstTessEvalOutPrimitiveId = kInstCommonOutCnt;
-static const int kInstTessEvalOutTessCoordU = kInstCommonOutCnt + 1;
-static const int kInstTessEvalOutTessCoordV = kInstCommonOutCnt + 2;
-
-// Geometry Shader Output Record Offsets
-static const int kInstGeomOutPrimitiveId = kInstCommonOutCnt;
-static const int kInstGeomOutInvocationId = kInstCommonOutCnt + 1;
-static const int kInstGeomOutUnused = kInstCommonOutCnt + 2;
-
-// Ray Tracing Shader Output Record Offsets
-static const int kInstRayTracingOutLaunchIdX = kInstCommonOutCnt;
-static const int kInstRayTracingOutLaunchIdY = kInstCommonOutCnt + 1;
-static const int kInstRayTracingOutLaunchIdZ = kInstCommonOutCnt + 2;
-
-// Mesh Shader Output Record Offsets
-static const int kInstMeshOutGlobalInvocationIdX = kInstCommonOutCnt;
-static const int kInstMeshOutGlobalInvocationIdY = kInstCommonOutCnt + 1;
-static const int kInstMeshOutGlobalInvocationIdZ = kInstCommonOutCnt + 2;
-
-// Task Shader Output Record Offsets
-static const int kInstTaskOutGlobalInvocationIdX = kInstCommonOutCnt;
-static const int kInstTaskOutGlobalInvocationIdY = kInstCommonOutCnt + 1;
-static const int kInstTaskOutGlobalInvocationIdZ = kInstCommonOutCnt + 2;
-
-// Size of Common and Stage-specific Members
-static const int kInstStageOutCnt = kInstCommonOutCnt + 3;
-
-// Validation Error Code Offset
-//
-// This identifies the validation error. It also helps to identify
-// how many words follow in the record and their meaning.
-static const int kInstValidationOutError = kInstStageOutCnt;
-
-// Validation-specific Output Record Offsets
-//
-// Each different validation will generate a potentially different
-// number of words at the end of the record giving more specifics
-// about the validation error.
-//
-// A bindless bounds error will output the index and the bound.
-static const int kInstBindlessBoundsOutDescSet = kInstStageOutCnt + 1;
-static const int kInstBindlessBoundsOutDescBinding = kInstStageOutCnt + 2;
-static const int kInstBindlessBoundsOutDescIndex = kInstStageOutCnt + 3;
-static const int kInstBindlessBoundsOutDescBound = kInstStageOutCnt + 4;
-static const int kInstBindlessBoundsOutUnused = kInstStageOutCnt + 5;
-static const int kInstBindlessBoundsOutCnt = kInstStageOutCnt + 6;
-
-// A descriptor uninitialized error will output the index.
-static const int kInstBindlessUninitOutDescSet = kInstStageOutCnt + 1;
-static const int kInstBindlessUninitOutBinding = kInstStageOutCnt + 2;
-static const int kInstBindlessUninitOutDescIndex = kInstStageOutCnt + 3;
-static const int kInstBindlessUninitOutUnused = kInstStageOutCnt + 4;
-static const int kInstBindlessUninitOutUnused2 = kInstStageOutCnt + 5;
-static const int kInstBindlessUninitOutCnt = kInstStageOutCnt + 6;
-
-// A buffer out-of-bounds error will output the descriptor
-// index, the buffer offset and the buffer size
-static const int kInstBindlessBuffOOBOutDescSet = kInstStageOutCnt + 1;
-static const int kInstBindlessBuffOOBOutDescBinding = kInstStageOutCnt + 2;
-static const int kInstBindlessBuffOOBOutDescIndex = kInstStageOutCnt + 3;
-static const int kInstBindlessBuffOOBOutBuffOff = kInstStageOutCnt + 4;
-static const int kInstBindlessBuffOOBOutBuffSize = kInstStageOutCnt + 5;
-static const int kInstBindlessBuffOOBOutCnt = kInstStageOutCnt + 6;
-
-// A buffer address unalloc error will output the 64-bit pointer in
-// two 32-bit pieces, lower bits first.
-static const int kInstBuffAddrUnallocOutDescPtrLo = kInstStageOutCnt + 1;
-static const int kInstBuffAddrUnallocOutDescPtrHi = kInstStageOutCnt + 2;
-static const int kInstBuffAddrUnallocOutCnt = kInstStageOutCnt + 3;
-
-// Maximum Output Record Member Count
-static const int kInstMaxOutCnt = kInstStageOutCnt + 6;
-
-// Validation Error Codes
-//
-// These are the possible validation error codes.
-static const int kInstErrorBindlessBounds = 0;
-static const int kInstErrorBindlessUninit = 1;
-static const int kInstErrorBuffAddrUnallocRef = 2;
-// Deleted: static const int kInstErrorBindlessBuffOOB = 3;
-// This comment will will remain for 2 releases to allow
-// for the transition of all builds. Buffer OOB is
-// generating the following four differentiated codes instead:
-static const int kInstErrorBuffOOBUniform = 4;
-static const int kInstErrorBuffOOBStorage = 5;
-static const int kInstErrorBuffOOBUniformTexel = 6;
-static const int kInstErrorBuffOOBStorageTexel = 7;
-static const int kInstErrorMax = kInstErrorBuffOOBStorageTexel;
-
-// Direct Input Buffer Offsets
-//
-// The following values provide member offsets into the input buffers
-// consumed by InstrumentPass::GenDebugDirectRead(). This method is utilized
-// by InstBindlessCheckPass.
-//
-// The only object in an input buffer is a runtime array of unsigned
-// integers. Each validation will have its own formatting of this array.
-static const int kDebugInputDataOffset = 0;
-
// Debug Buffer Bindings
//
// These are the bindings for the different buffers which are
// read or written by the instrumentation passes.
//
-// This is the output buffer written by InstBindlessCheckPass,
-// InstBuffAddrCheckPass, and possibly other future validations.
-static const int kDebugOutputBindingStream = 0;
-
-// The binding for the input buffer read by InstBindlessCheckPass.
-static const int kDebugInputBindingBindless = 1;
-
-// The binding for the input buffer read by InstBuffAddrCheckPass.
-static const int kDebugInputBindingBuffAddr = 2;
-
// This is the output buffer written by InstDebugPrintfPass.
static const int kDebugOutputPrintfStream = 3;
-// clang-format off
-// Bindless Validation Input Buffer Format
-//
-// An input buffer for bindless validation has this structure:
-// GLSL:
-// layout(buffer_reference, std430, buffer_reference_align = 8) buffer DescriptorSetData {
-// uint num_bindings;
-// uint data[];
-// };
-//
-// layout(set = 7, binding = 1, std430) buffer inst_bindless_InputBuffer
-// {
-// DescriptorSetData desc_sets[32];
-// } inst_bindless_input_buffer;
-//
-//
-// To look up the length of a binding:
-// uint length = inst_bindless_input_buffer[set].data[binding];
-// Scalar bindings have a length of 1.
-//
-// To look up the initialization state of a descriptor in a binding:
-// uint num_bindings = inst_bindless_input_buffer[set].num_bindings;
-// uint binding_state_start = inst_bindless_input_buffer[set].data[num_bindings + binding];
-// uint init_state = inst_bindless_input_buffer[set].data[binding_state_start + index];
-//
-// For scalar bindings, use 0 for the index.
-// clang-format on
-//
-// The size of the inst_bindless_input_buffer array, regardless of how many
-// descriptor sets the device supports.
-static const int kDebugInputBindlessMaxDescSets = 32;
-
-// Buffer Device Address Input Buffer Format
-//
-// An input buffer for buffer device address validation consists of a single
-// array of unsigned 64-bit integers we will call Data[]. This array is
-// formatted as follows:
-//
-// At offset kDebugInputBuffAddrPtrOffset is a list of sorted valid buffer
-// addresses. The list is terminated with the address 0xffffffffffffffff.
-// If 0x0 is not a valid buffer address, this address is inserted at the
-// start of the list.
-//
-static const int kDebugInputBuffAddrPtrOffset = 1;
-//
-// At offset kDebugInputBuffAddrLengthOffset in Data[] is a single uint64 which
-// gives an offset to the start of the buffer length data. More
-// specifically, for a buffer whose pointer is located at input buffer offset
-// i, the length is located at:
-//
-// Data[ i - kDebugInputBuffAddrPtrOffset
-// + Data[ kDebugInputBuffAddrLengthOffset ] ]
-//
-// The length associated with the 0xffffffffffffffff address is zero. If
-// not a valid buffer, the length associated with the 0x0 address is zero.
-static const int kDebugInputBuffAddrLengthOffset = 0;
-
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_INSTRUMENT_HPP_
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/libspirv.h b/third_party/SPIRV-Tools/include/spirv-tools/libspirv.h
index 542b745..83b1a8e 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/libspirv.h
+++ b/third_party/SPIRV-Tools/include/spirv-tools/libspirv.h
@@ -33,15 +33,19 @@
#else
#define SPIRV_TOOLS_EXPORT __declspec(dllimport)
#endif
+#define SPIRV_TOOLS_LOCAL
#else
#if defined(SPIRV_TOOLS_IMPLEMENTATION)
#define SPIRV_TOOLS_EXPORT __attribute__((visibility("default")))
+#define SPIRV_TOOLS_LOCAL __attribute__((visibility("hidden")))
#else
#define SPIRV_TOOLS_EXPORT
+#define SPIRV_TOOLS_LOCAL
#endif
#endif
#else
#define SPIRV_TOOLS_EXPORT
+#define SPIRV_TOOLS_LOCAL
#endif
// Helpers
@@ -143,6 +147,7 @@
// may be larger than 32, which would require such a typed literal value to
// occupy multiple SPIR-V words.
SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
+ SPV_OPERAND_TYPE_LITERAL_FLOAT, // Always 32-bit float.
// Set 3: The literal string operand type.
SPV_OPERAND_TYPE_LITERAL_STRING,
@@ -285,6 +290,28 @@
// An optional packed vector format
SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
+ // Concrete operand types for cooperative matrix.
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
+ // An optional cooperative matrix operands
+ SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
+
+ // Enum type from SPV_INTEL_global_variable_fpga_decorations
+ SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER,
+ // Enum type from SPV_INTEL_global_variable_host_access
+ SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER,
+ // Enum type from SPV_INTEL_cache_controls
+ SPV_OPERAND_TYPE_LOAD_CACHE_CONTROL,
+ // Enum type from SPV_INTEL_cache_controls
+ SPV_OPERAND_TYPE_STORE_CACHE_CONTROL,
+ // Enum type from SPV_INTEL_maximum_registers
+ SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS,
+ // Enum type from SPV_NV_raw_access_chains
+ SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS,
+ // Optional enum type from SPV_NV_raw_access_chains
+ SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS,
+
// This is a sentinel value, and does not represent an operand type.
// It should come last.
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
@@ -310,6 +337,7 @@
SPV_EXT_INST_TYPE_OPENCL_DEBUGINFO_100,
SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION,
SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100,
+ SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION,
// Multiple distinct extended instruction set types could return this
// value, if they are prefixed with NonSemantic. and are otherwise
@@ -949,9 +977,16 @@
spv_optimizer_t* optimizer, const char* flag);
// Registers passes specified by length number of flags in an optimizer object.
+// Passes may remove interface variables that are unused.
SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlags(
spv_optimizer_t* optimizer, const char** flags, const size_t flag_count);
+// Registers passes specified by length number of flags in an optimizer object.
+// Passes will not remove interface variables.
+SPIRV_TOOLS_EXPORT bool
+spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(
+ spv_optimizer_t* optimizer, const char** flags, const size_t flag_count);
+
// Optimizes the SPIR-V code of size |word_count| pointed to by |binary| and
// returns an optimized spv_binary in |optimized_binary|.
//
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/libspirv.hpp b/third_party/SPIRV-Tools/include/spirv-tools/libspirv.hpp
index ee6c846..59ff82b 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/libspirv.hpp
+++ b/third_party/SPIRV-Tools/include/spirv-tools/libspirv.hpp
@@ -37,7 +37,7 @@
std::function<spv_result_t(const spv_parsed_instruction_t& instruction)>;
// C++ RAII wrapper around the C context object spv_context.
-class Context {
+class SPIRV_TOOLS_EXPORT Context {
public:
// Constructs a context targeting the given environment |env|.
//
@@ -73,7 +73,7 @@
};
// A RAII wrapper around a validator options object.
-class ValidatorOptions {
+class SPIRV_TOOLS_EXPORT ValidatorOptions {
public:
ValidatorOptions() : options_(spvValidatorOptionsCreate()) {}
~ValidatorOptions() { spvValidatorOptionsDestroy(options_); }
@@ -163,7 +163,7 @@
};
// A C++ wrapper around an optimization options object.
-class OptimizerOptions {
+class SPIRV_TOOLS_EXPORT OptimizerOptions {
public:
OptimizerOptions() : options_(spvOptimizerOptionsCreate()) {}
~OptimizerOptions() { spvOptimizerOptionsDestroy(options_); }
@@ -205,7 +205,7 @@
};
// A C++ wrapper around a reducer options object.
-class ReducerOptions {
+class SPIRV_TOOLS_EXPORT ReducerOptions {
public:
ReducerOptions() : options_(spvReducerOptionsCreate()) {}
~ReducerOptions() { spvReducerOptionsDestroy(options_); }
@@ -236,7 +236,7 @@
};
// A C++ wrapper around a fuzzer options object.
-class FuzzerOptions {
+class SPIRV_TOOLS_EXPORT FuzzerOptions {
public:
FuzzerOptions() : options_(spvFuzzerOptionsCreate()) {}
~FuzzerOptions() { spvFuzzerOptionsDestroy(options_); }
@@ -283,7 +283,7 @@
// provides methods for assembling, disassembling, and validating.
//
// Instances of this class provide basic thread-safety guarantee.
-class SpirvTools {
+class SPIRV_TOOLS_EXPORT SpirvTools {
public:
enum {
// Default assembling option used by assemble():
@@ -388,7 +388,8 @@
bool IsValid() const;
private:
- struct Impl; // Opaque struct for holding the data fields used by this class.
+ struct SPIRV_TOOLS_LOCAL
+ Impl; // Opaque struct for holding the data fields used by this class.
std::unique_ptr<Impl> impl_; // Unique pointer to implementation data.
};
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/linker.hpp b/third_party/SPIRV-Tools/include/spirv-tools/linker.hpp
index d2f3e72..6ba6e96 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/linker.hpp
+++ b/third_party/SPIRV-Tools/include/spirv-tools/linker.hpp
@@ -24,13 +24,8 @@
namespace spvtools {
-class LinkerOptions {
+class SPIRV_TOOLS_EXPORT LinkerOptions {
public:
- LinkerOptions()
- : create_library_(false),
- verify_ids_(false),
- allow_partial_linkage_(false) {}
-
// Returns whether a library or an executable should be produced by the
// linking phase.
//
@@ -63,10 +58,16 @@
allow_partial_linkage_ = allow_partial_linkage;
}
+ bool GetUseHighestVersion() const { return use_highest_version_; }
+ void SetUseHighestVersion(bool use_highest_vers) {
+ use_highest_version_ = use_highest_vers;
+ }
+
private:
- bool create_library_;
- bool verify_ids_;
- bool allow_partial_linkage_;
+ bool create_library_{false};
+ bool verify_ids_{false};
+ bool allow_partial_linkage_{false};
+ bool use_highest_version_{false};
};
// Links one or more SPIR-V modules into a new SPIR-V module. That is, combine
@@ -83,14 +84,15 @@
// * Some entry points were defined multiple times;
// * Some imported symbols did not have an exported counterpart;
// * Possibly other reasons.
-spv_result_t Link(const Context& context,
- const std::vector<std::vector<uint32_t>>& binaries,
- std::vector<uint32_t>* linked_binary,
- const LinkerOptions& options = LinkerOptions());
-spv_result_t Link(const Context& context, const uint32_t* const* binaries,
- const size_t* binary_sizes, size_t num_binaries,
- std::vector<uint32_t>* linked_binary,
- const LinkerOptions& options = LinkerOptions());
+SPIRV_TOOLS_EXPORT spv_result_t
+Link(const Context& context, const std::vector<std::vector<uint32_t>>& binaries,
+ std::vector<uint32_t>* linked_binary,
+ const LinkerOptions& options = LinkerOptions());
+SPIRV_TOOLS_EXPORT spv_result_t
+Link(const Context& context, const uint32_t* const* binaries,
+ const size_t* binary_sizes, size_t num_binaries,
+ std::vector<uint32_t>* linked_binary,
+ const LinkerOptions& options = LinkerOptions());
} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/linter.hpp b/third_party/SPIRV-Tools/include/spirv-tools/linter.hpp
index 52ed5a4..ccbcf0c 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/linter.hpp
+++ b/third_party/SPIRV-Tools/include/spirv-tools/linter.hpp
@@ -24,7 +24,7 @@
// provides a method for linting.
//
// Instances of this class provides basic thread-safety guarantee.
-class Linter {
+class SPIRV_TOOLS_EXPORT Linter {
public:
explicit Linter(spv_target_env env);
@@ -40,7 +40,7 @@
bool Run(const uint32_t* binary, size_t binary_size);
private:
- struct Impl;
+ struct SPIRV_TOOLS_LOCAL Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/include/spirv-tools/optimizer.hpp b/third_party/SPIRV-Tools/include/spirv-tools/optimizer.hpp
index 8bdd4e8..a3119d9 100644
--- a/third_party/SPIRV-Tools/include/spirv-tools/optimizer.hpp
+++ b/third_party/SPIRV-Tools/include/spirv-tools/optimizer.hpp
@@ -37,14 +37,14 @@
// provides methods for registering optimization passes and optimizing.
//
// Instances of this class provides basic thread-safety guarantee.
-class Optimizer {
+class SPIRV_TOOLS_EXPORT Optimizer {
public:
// The token for an optimization pass. It is returned via one of the
// Create*Pass() standalone functions at the end of this header file and
// consumed by the RegisterPass() method. Tokens are one-time objects that
// only support move; copying is not allowed.
struct PassToken {
- struct Impl; // Opaque struct for holding internal data.
+ struct SPIRV_TOOLS_LOCAL Impl; // Opaque struct for holding internal data.
PassToken(std::unique_ptr<Impl>);
@@ -97,12 +97,20 @@
// Registers passes that attempt to improve performance of generated code.
// This sequence of passes is subject to constant review and will change
// from time to time.
+ //
+ // If |preserve_interface| is true, all non-io variables in the entry point
+ // interface are considered live and are not eliminated.
Optimizer& RegisterPerformancePasses();
+ Optimizer& RegisterPerformancePasses(bool preserve_interface);
// Registers passes that attempt to improve the size of generated code.
// This sequence of passes is subject to constant review and will change
// from time to time.
+ //
+ // If |preserve_interface| is true, all non-io variables in the entry point
+ // interface are considered live and are not eliminated.
Optimizer& RegisterSizePasses();
+ Optimizer& RegisterSizePasses(bool preserve_interface);
// Registers passes that attempt to legalize the generated code.
//
@@ -112,7 +120,11 @@
//
// This sequence of passes is subject to constant review and will change
// from time to time.
+ //
+ // If |preserve_interface| is true, all non-io variables in the entry point
+ // interface are considered live and are not eliminated.
Optimizer& RegisterLegalizationPasses();
+ Optimizer& RegisterLegalizationPasses(bool preserve_interface);
// Register passes specified in the list of |flags|. Each flag must be a
// string of a form accepted by Optimizer::FlagHasValidForm().
@@ -121,8 +133,13 @@
// error message is emitted to the MessageConsumer object (use
// Optimizer::SetMessageConsumer to define a message consumer, if needed).
//
+ // If |preserve_interface| is true, all non-io variables in the entry point
+ // interface are considered live and are not eliminated.
+ //
// If all the passes are registered successfully, it returns true.
bool RegisterPassesFromFlags(const std::vector<std::string>& flags);
+ bool RegisterPassesFromFlags(const std::vector<std::string>& flags,
+ bool preserve_interface);
// Registers the optimization pass associated with |flag|. This only accepts
// |flag| values of the form "--pass_name[=pass_args]". If no such pass
@@ -139,7 +156,11 @@
//
// --legalize-hlsl: Registers all passes that legalize SPIR-V generated by an
// HLSL front-end.
+ //
+ // If |preserve_interface| is true, all non-io variables in the entry point
+ // interface are considered live and are not eliminated.
bool RegisterPassFromFlag(const std::string& flag);
+ bool RegisterPassFromFlag(const std::string& flag, bool preserve_interface);
// Validates that |flag| has a valid format. Strings accepted:
//
@@ -218,7 +239,7 @@
Optimizer& SetValidateAfterAll(bool validate);
private:
- struct Impl; // Opaque struct for holding internal data.
+ struct SPIRV_TOOLS_LOCAL Impl; // Opaque struct for holding internal data.
std::unique_ptr<Impl> impl_; // Unique pointer to internal data.
};
@@ -748,19 +769,9 @@
// potentially de-optimizing the instrument code, for example, inlining
// the debug record output function throughout the module.
//
-// The instrumentation will read and write buffers in debug
-// descriptor set |desc_set|. It will write |shader_id| in each output record
+// The instrumentation will write |shader_id| in each output record
// to identify the shader module which generated the record.
-// |desc_length_enable| controls instrumentation of runtime descriptor array
-// references, |desc_init_enable| controls instrumentation of descriptor
-// initialization checking, and |buff_oob_enable| controls instrumentation
-// of storage and uniform buffer bounds checking, all of which require input
-// buffer support. |texbuff_oob_enable| controls instrumentation of texel
-// buffers, which does not require input buffer support.
-Optimizer::PassToken CreateInstBindlessCheckPass(
- uint32_t desc_set, uint32_t shader_id, bool desc_length_enable = false,
- bool desc_init_enable = false, bool buff_oob_enable = false,
- bool texbuff_oob_enable = false);
+Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t shader_id);
// Create a pass to instrument physical buffer address checking
// This pass instruments all physical buffer address references to check that
@@ -781,8 +792,7 @@
// The instrumentation will read and write buffers in debug
// descriptor set |desc_set|. It will write |shader_id| in each output record
// to identify the shader module which generated the record.
-Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set,
- uint32_t shader_id);
+Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t shader_id);
// Create a pass to instrument OpDebugPrintf instructions.
// This pass replaces all OpDebugPrintf instructions with instructions to write
@@ -971,6 +981,32 @@
// object, currently the pass would remove accesschain pointer argument passed
// to the function
Optimizer::PassToken CreateFixFuncCallArgumentsPass();
+
+// Creates a trim-capabilities pass.
+// This pass removes unused capabilities for a given module, and if possible,
+// associated extensions.
+// See `trim_capabilities.h` for the list of supported capabilities.
+//
+// If the module contains unsupported capabilities, this pass will ignore them.
+// This should be fine in most cases, but could yield to incorrect results if
+// the unknown capability interacts with one of the trimmed capabilities.
+Optimizer::PassToken CreateTrimCapabilitiesPass();
+
+// Creates a switch-descriptorset pass.
+// This pass changes any DescriptorSet decorations with the value |ds_from| to
+// use the new value |ds_to|.
+Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t ds_from,
+ uint32_t ds_to);
+
+// Creates an invocation interlock placement pass.
+// This pass ensures that an entry point will have at most one
+// OpBeginInterlockInvocationEXT and one OpEndInterlockInvocationEXT, in that
+// order.
+Optimizer::PassToken CreateInvocationInterlockPlacementPass();
+
+// Creates a pass to add/remove maximal reconvergence execution mode.
+// This pass either adds or removes maximal reconvergence from all entry points.
+Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add);
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/third_party/SPIRV-Tools/kokoro/macos-clang-release-bazel/build.sh b/third_party/SPIRV-Tools/kokoro/macos-clang-release-bazel/build.sh
index 2465d9c..4bb889a 100644
--- a/third_party/SPIRV-Tools/kokoro/macos-clang-release-bazel/build.sh
+++ b/third_party/SPIRV-Tools/kokoro/macos-clang-release-bazel/build.sh
@@ -30,20 +30,16 @@
git config --global --add safe.directory $SRC
cd $SRC
-git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers
-git clone https://github.com/google/googletest external/googletest
-cd external && cd googletest && git reset --hard 1fb1bb23bb8418dc73a5a9a82bbed31dc610fec7 && cd .. && cd ..
-git clone --depth=1 https://github.com/google/effcee external/effcee
-git clone --depth=1 https://github.com/google/re2 external/re2
+/usr/bin/python3 utils/git-sync-deps --treeless
-# Get bazel 5.0.0
-gsutil cp gs://bazel/5.0.0/release/bazel-5.0.0-darwin-x86_64 .
-chmod +x bazel-5.0.0-darwin-x86_64
+# Get bazel 7.0.2
+gsutil cp gs://bazel/7.0.2/release/bazel-7.0.2-darwin-x86_64 .
+chmod +x bazel-7.0.2-darwin-x86_64
echo $(date): Build everything...
-./bazel-5.0.0-darwin-x86_64 build --cxxopt=-std=c++17 :all
+./bazel-7.0.2-darwin-x86_64 build --cxxopt=-std=c++17 :all
echo $(date): Build completed.
echo $(date): Starting bazel test...
-./bazel-5.0.0-darwin-x86_64 test --cxxopt=-std=c++17 :all
+./bazel-7.0.2-darwin-x86_64 test --cxxopt=-std=c++17 :all
echo $(date): Bazel test completed.
diff --git a/third_party/SPIRV-Tools/kokoro/scripts/linux/build-docker.sh b/third_party/SPIRV-Tools/kokoro/scripts/linux/build-docker.sh
index f2a06e0..e47037d 100755
--- a/third_party/SPIRV-Tools/kokoro/scripts/linux/build-docker.sh
+++ b/third_party/SPIRV-Tools/kokoro/scripts/linux/build-docker.sh
@@ -131,6 +131,7 @@
git clone https://github.com/KhronosGroup/SPIRV-Headers.git spirv-headers
git clone https://github.com/google/re2
git clone https://github.com/google/effcee
+ git clone https://github.com/abseil/abseil-cpp abseil_cpp
cd $SHADERC_DIR
mkdir build
@@ -141,7 +142,7 @@
cmake -GNinja -DRE2_BUILD_TESTING=OFF -DCMAKE_BUILD_TYPE="Release" ..
echo $(date): Build glslang...
- ninja glslangValidator
+ ninja glslang-standalone
echo $(date): Build everything...
ninja
@@ -155,7 +156,7 @@
echo $(date): ctest completed.
elif [ $TOOL = "cmake-android-ndk" ]; then
using cmake-3.17.2
- using ndk-r21d
+ using ndk-r25c
using ninja-1.10.0
clean_dir "$ROOT_DIR/build"
@@ -163,7 +164,7 @@
echo $(date): Starting build...
cmake -DCMAKE_BUILD_TYPE=Release \
- -DANDROID_NATIVE_API_LEVEL=android-16 \
+ -DANDROID_NATIVE_API_LEVEL=android-24 \
-DANDROID_ABI="armeabi-v7a with NEON" \
-DSPIRV_SKIP_TESTS=ON \
-DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK_HOME/build/cmake/android.toolchain.cmake" \
@@ -175,7 +176,7 @@
ninja
echo $(date): Build completed.
elif [ $TOOL = "android-ndk-build" ]; then
- using ndk-r21d
+ using ndk-r25c
clean_dir "$ROOT_DIR/build"
cd "$ROOT_DIR/build"
@@ -190,7 +191,7 @@
echo $(date): ndk-build completed.
elif [ $TOOL = "bazel" ]; then
- using bazel-5.0.0
+ using bazel-7.0.2
echo $(date): Build everything...
bazel build --cxxopt=-std=c++17 :all
diff --git a/third_party/SPIRV-Tools/kokoro/scripts/windows/build.bat b/third_party/SPIRV-Tools/kokoro/scripts/windows/build.bat
index bb14da3..fe15f2d 100644
--- a/third_party/SPIRV-Tools/kokoro/scripts/windows/build.bat
+++ b/third_party/SPIRV-Tools/kokoro/scripts/windows/build.bat
@@ -30,6 +30,9 @@
if %VS_VERSION% == 2017 (
call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
echo "Using VS 2017..."
+
+ :: RE2 does not support VS2017, we we must disable tests.
+ set BUILD_TESTS=NO
) else if %VS_VERSION% == 2019 (
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
echo "Using VS 2019..."
@@ -56,6 +59,10 @@
:: Build spirv-fuzz
set CMAKE_FLAGS=%CMAKE_FLAGS% -DSPIRV_BUILD_FUZZER=ON
+if "%BUILD_TESTS%" == "NO" (
+ set CMAKE_FLAGS=-DSPIRV_SKIP_TESTS=ON %CMAKE_FLAGS%
+)
+
cmake %CMAKE_FLAGS% ..
if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL%
@@ -71,10 +78,12 @@
:: ################################################
:: Run the tests
:: ################################################
-echo "Running Tests... %DATE% %TIME%"
-ctest -C %BUILD_TYPE% --output-on-failure --timeout 300
-if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
-echo "Tests Completed %DATE% %TIME%"
+if "%BUILD_TESTS%" NEQ "NO" (
+ echo "Running Tests... %DATE% %TIME%"
+ ctest -C %BUILD_TYPE% --output-on-failure --timeout 300
+ if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
+ echo "Tests Completed %DATE% %TIME%"
+)
:: ################################################
:: Install and package.
diff --git a/third_party/SPIRV-Tools/source/CMakeLists.txt b/third_party/SPIRV-Tools/source/CMakeLists.txt
index acfa0c1..d0454c6c 100644
--- a/third_party/SPIRV-Tools/source/CMakeLists.txt
+++ b/third_party/SPIRV-Tools/source/CMakeLists.txt
@@ -31,7 +31,7 @@
set(GRAMMAR_INSTS_INC_FILE "${spirv-tools_BINARY_DIR}/core.insts-${CONFIG_VERSION}.inc")
set(GRAMMAR_KINDS_INC_FILE "${spirv-tools_BINARY_DIR}/operand.kinds-${CONFIG_VERSION}.inc")
add_custom_command(OUTPUT ${GRAMMAR_INSTS_INC_FILE} ${GRAMMAR_KINDS_INC_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${GRAMMAR_PROCESSING_SCRIPT}
--spirv-core-grammar=${GRAMMAR_JSON_FILE}
--extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE}
--extinst-cldebuginfo100-grammar=${CLDEBUGINFO100_GRAMMAR_JSON_FILE}
@@ -53,7 +53,7 @@
set(GRAMMAR_ENUM_STRING_MAPPING_INC_FILE "${spirv-tools_BINARY_DIR}/enum_string_mapping.inc")
add_custom_command(OUTPUT ${GRAMMAR_EXTENSION_ENUM_INC_FILE}
${GRAMMAR_ENUM_STRING_MAPPING_INC_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${GRAMMAR_PROCESSING_SCRIPT}
--spirv-core-grammar=${GRAMMAR_JSON_FILE}
--extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE}
--extinst-cldebuginfo100-grammar=${CLDEBUGINFO100_GRAMMAR_JSON_FILE}
@@ -75,7 +75,7 @@
set(OPENCL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.opencl.std.100.grammar.json")
set(VIMSYNTAX_FILE "${spirv-tools_BINARY_DIR}/spvasm.vim")
add_custom_command(OUTPUT ${VIMSYNTAX_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${VIMSYNTAX_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${VIMSYNTAX_PROCESSING_SCRIPT}
--spirv-core-grammar=${GRAMMAR_JSON_FILE}
--extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE}
--extinst-glsl-grammar=${GLSL_GRAMMAR_JSON_FILE}
@@ -91,7 +91,7 @@
set(GLSL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.glsl.std.450.grammar.json")
set(GRAMMAR_INC_FILE "${spirv-tools_BINARY_DIR}/glsl.std.450.insts.inc")
add_custom_command(OUTPUT ${GRAMMAR_INC_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${GRAMMAR_PROCESSING_SCRIPT}
--extinst-glsl-grammar=${GLSL_GRAMMAR_JSON_FILE}
--glsl-insts-output=${GRAMMAR_INC_FILE}
--output-language=c++
@@ -105,7 +105,7 @@
set(OPENCL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.opencl.std.100.grammar.json")
set(GRAMMAR_INC_FILE "${spirv-tools_BINARY_DIR}/opencl.std.insts.inc")
add_custom_command(OUTPUT ${GRAMMAR_INC_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${GRAMMAR_PROCESSING_SCRIPT}
--extinst-opencl-grammar=${OPENCL_GRAMMAR_JSON_FILE}
--opencl-insts-output=${GRAMMAR_INC_FILE}
DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${CORE_GRAMMAR_JSON_FILE} ${OPENCL_GRAMMAR_JSON_FILE}
@@ -120,7 +120,7 @@
set(GRAMMAR_FILE "${spirv-tools_SOURCE_DIR}/source/extinst.${VENDOR_TABLE}.grammar.json")
endif()
add_custom_command(OUTPUT ${INSTS_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${GRAMMAR_PROCESSING_SCRIPT}
--extinst-vendor-grammar=${GRAMMAR_FILE}
--vendor-insts-output=${INSTS_FILE}
--vendor-operand-kind-prefix=${OPERAND_KIND_PREFIX}
@@ -134,7 +134,7 @@
macro(spvtools_extinst_lang_headers NAME GRAMMAR_FILE)
set(OUT_H ${spirv-tools_BINARY_DIR}/${NAME}.h)
add_custom_command(OUTPUT ${OUT_H}
- COMMAND ${PYTHON_EXECUTABLE} ${LANG_HEADER_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${LANG_HEADER_PROCESSING_SCRIPT}
--extinst-grammar=${GRAMMAR_FILE}
--extinst-output-path=${OUT_H}
DEPENDS ${LANG_HEADER_PROCESSING_SCRIPT} ${GRAMMAR_FILE}
@@ -156,6 +156,7 @@
spvtools_vendor_tables("opencl.debuginfo.100" "cldi100" "CLDEBUG100_")
spvtools_vendor_tables("nonsemantic.shader.debuginfo.100" "shdi100" "SHDEBUG100_")
spvtools_vendor_tables("nonsemantic.clspvreflection" "clspvreflection" "")
+spvtools_vendor_tables("nonsemantic.vkspreflection" "vkspreflection" "")
spvtools_extinst_lang_headers("DebugInfo" ${DEBUGINFO_GRAMMAR_JSON_FILE})
spvtools_extinst_lang_headers("OpenCLDebugInfo100" ${CLDEBUGINFO100_GRAMMAR_JSON_FILE})
spvtools_extinst_lang_headers("NonSemanticShaderDebugInfo100" ${VKDEBUGINFO100_GRAMMAR_JSON_FILE})
@@ -168,7 +169,7 @@
set(GENERATOR_INC_FILE ${spirv-tools_BINARY_DIR}/generators.inc)
set(SPIRV_XML_REGISTRY_FILE ${SPIRV_HEADER_INCLUDE_DIR}/spirv/spir-v.xml)
add_custom_command(OUTPUT ${GENERATOR_INC_FILE}
- COMMAND ${PYTHON_EXECUTABLE} ${XML_REGISTRY_PROCESSING_SCRIPT}
+ COMMAND Python3::Interpreter ${XML_REGISTRY_PROCESSING_SCRIPT}
--xml=${SPIRV_XML_REGISTRY_FILE}
--generator-output=${GENERATOR_INC_FILE}
DEPENDS ${XML_REGISTRY_PROCESSING_SCRIPT} ${SPIRV_XML_REGISTRY_FILE}
@@ -198,7 +199,7 @@
set(SPIRV_TOOLS_CHANGES_FILE
${spirv-tools_SOURCE_DIR}/CHANGES)
add_custom_command(OUTPUT ${SPIRV_TOOLS_BUILD_VERSION_INC}
- COMMAND ${PYTHON_EXECUTABLE}
+ COMMAND Python3::Interpreter
${SPIRV_TOOLS_BUILD_VERSION_INC_GENERATOR}
${SPIRV_TOOLS_CHANGES_FILE} ${SPIRV_TOOLS_BUILD_VERSION_INC}
DEPENDS ${SPIRV_TOOLS_BUILD_VERSION_INC_GENERATOR}
@@ -418,12 +419,6 @@
endif()
endif()
-if (ANDROID)
- foreach(target ${SPIRV_TOOLS_TARGETS})
- target_link_libraries(${target} PRIVATE android log)
- endforeach()
-endif()
-
if(ENABLE_SPIRV_TOOLS_INSTALL)
install(TARGETS ${SPIRV_TOOLS_TARGETS} EXPORT ${SPIRV_TOOLS}Targets)
export(EXPORT ${SPIRV_TOOLS}Targets FILE ${SPIRV_TOOLS}Target.cmake)
diff --git a/third_party/SPIRV-Tools/source/assembly_grammar.cpp b/third_party/SPIRV-Tools/source/assembly_grammar.cpp
index 6df823e..0092d01 100644
--- a/third_party/SPIRV-Tools/source/assembly_grammar.cpp
+++ b/third_party/SPIRV-Tools/source/assembly_grammar.cpp
@@ -21,6 +21,7 @@
#include "source/ext_inst.h"
#include "source/opcode.h"
#include "source/operand.h"
+#include "source/spirv_target_env.h"
#include "source/table.h"
namespace spvtools {
@@ -154,11 +155,12 @@
CASE(InBoundsAccessChain),
CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain),
- CASE(CooperativeMatrixLengthNV)
+ CASE(CooperativeMatrixLengthNV),
+ CASE(CooperativeMatrixLengthKHR)
};
// The 60 is determined by counting the opcodes listed in the spec.
-static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete");
#undef CASE
// clang-format on
@@ -175,15 +177,18 @@
CapabilitySet AssemblyGrammar::filterCapsAgainstTargetEnv(
const spv::Capability* cap_array, uint32_t count) const {
CapabilitySet cap_set;
+ const auto version = spvVersionForTargetEnv(target_env_);
for (uint32_t i = 0; i < count; ++i) {
- spv_operand_desc cap_desc = {};
+ spv_operand_desc entry = {};
if (SPV_SUCCESS == lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
static_cast<uint32_t>(cap_array[i]),
- &cap_desc)) {
- // spvOperandTableValueLookup() filters capabilities internally
- // according to the current target environment by itself. So we
- // should be safe to add this capability if the lookup succeeds.
- cap_set.Add(cap_array[i]);
+ &entry)) {
+ // This token is visible in this environment if it's in an appropriate
+ // core version, or it is enabled by a capability or an extension.
+ if ((version >= entry->minVersion && version <= entry->lastVersion) ||
+ entry->numExtensions > 0u || entry->numCapabilities > 0u) {
+ cap_set.insert(cap_array[i]);
+ }
}
}
return cap_set;
diff --git a/third_party/SPIRV-Tools/source/binary.cpp b/third_party/SPIRV-Tools/source/binary.cpp
index beb56be..cf1f0b7 100644
--- a/third_party/SPIRV-Tools/source/binary.cpp
+++ b/third_party/SPIRV-Tools/source/binary.cpp
@@ -546,6 +546,13 @@
parsed_operand.number_bit_width = 32;
break;
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT:
+ // These are regular single-word literal float operands.
+ parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_FLOAT;
+ parsed_operand.number_kind = SPV_NUMBER_FLOATING;
+ parsed_operand.number_bit_width = 32;
+ break;
+
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
@@ -626,7 +633,6 @@
} break;
case SPV_OPERAND_TYPE_CAPABILITY:
- case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
case SPV_OPERAND_TYPE_MEMORY_MODEL:
@@ -664,7 +670,8 @@
case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
- case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT: {
+ case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
+ case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS: {
// A single word that is a plain enum value.
// Map an optional operand type to its corresponding concrete type.
@@ -683,22 +690,44 @@
spvPushOperandTypes(entry->operandTypes, expected_operands);
} break;
+ case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: {
+ spv_operand_desc entry;
+ if (grammar_.lookupOperand(type, word, &entry)) {
+ return diagnostic()
+ << "Invalid " << spvOperandTypeStr(parsed_operand.type)
+ << " operand: " << word
+ << ", if you are creating a new source language please use "
+ "value 0 "
+ "(Unknown) and when ready, add your source language to "
+ "SPRIV-Headers";
+ }
+ // Prepare to accept operands to this operand, if needed.
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
+ } break;
+
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
- case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: {
+ case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
// This operand is a mask.
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
- else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
+ parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS)
+ parsed_operand.type = SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS;
// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
diff --git a/third_party/SPIRV-Tools/source/diff/diff.cpp b/third_party/SPIRV-Tools/source/diff/diff.cpp
index 6daed32..6269af5 100644
--- a/third_party/SPIRV-Tools/source/diff/diff.cpp
+++ b/third_party/SPIRV-Tools/source/diff/diff.cpp
@@ -101,9 +101,12 @@
return from < id_map_.size() && id_map_[from] != 0;
}
- // Map any ids in src and dst that have not been mapped to new ids in dst and
- // src respectively.
- void MapUnmatchedIds(IdMap& other_way);
+ bool IsMapped(const opt::Instruction* from_inst) const {
+ assert(from_inst != nullptr);
+ assert(!from_inst->HasResultId());
+
+ return inst_map_.find(from_inst) != inst_map_.end();
+ }
// Some instructions don't have result ids. Those are mapped by pointer.
void MapInsts(const opt::Instruction* from_inst,
@@ -117,6 +120,12 @@
uint32_t IdBound() const { return static_cast<uint32_t>(id_map_.size()); }
+ // Generate a fresh id in this mapping's domain.
+ uint32_t MakeFreshId() {
+ id_map_.push_back(0);
+ return static_cast<uint32_t>(id_map_.size()) - 1;
+ }
+
private:
// Given an id, returns the corresponding id in the other module, or 0 if not
// matched yet.
@@ -150,10 +159,16 @@
bool IsSrcMapped(uint32_t src) { return src_to_dst_.IsMapped(src); }
bool IsDstMapped(uint32_t dst) { return dst_to_src_.IsMapped(dst); }
+ bool IsDstMapped(const opt::Instruction* dst_inst) {
+ return dst_to_src_.IsMapped(dst_inst);
+ }
// Map any ids in src and dst that have not been mapped to new ids in dst and
- // src respectively.
- void MapUnmatchedIds();
+ // src respectively. Use src_insn_defined and dst_insn_defined to ignore ids
+ // that are simply never defined. (Since we assume the inputs are valid
+ // SPIR-V, this implies they are also never used.)
+ void MapUnmatchedIds(std::function<bool(uint32_t)> src_insn_defined,
+ std::function<bool(uint32_t)> dst_insn_defined);
// Some instructions don't have result ids. Those are mapped by pointer.
void MapInsts(const opt::Instruction* src_inst,
@@ -203,6 +218,11 @@
void MapIdToInstruction(uint32_t id, const opt::Instruction* inst);
+ // Return true if id is mapped to any instruction, false otherwise.
+ bool IsDefined(uint32_t id) {
+ return id < inst_map_.size() && inst_map_[id] != nullptr;
+ }
+
void MapIdsToInstruction(
opt::IteratorRange<opt::Module::const_inst_iterator> section);
void MapIdsToInfos(
@@ -338,6 +358,59 @@
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group);
+ // Bucket `src_ids` and `dst_ids` by the key ids returned by `get_group`, and
+ // then call `match_group` on pairs of buckets whose key ids are matched with
+ // each other.
+ //
+ // For example, suppose we want to pair up groups of instructions with the
+ // same type. Naturally, the source instructions refer to their types by their
+ // ids in the source, and the destination instructions use destination type
+ // ids, so simply comparing source and destination type ids as integers, as
+ // `GroupIdsAndMatch` would do, is meaningless. But if a prior call to
+ // `MatchTypeIds` has established type matches between the two modules, then
+ // we can consult those to pair source and destination buckets whose types are
+ // equivalent.
+ //
+ // Suppose our input groups are as follows:
+ //
+ // - src_ids: { 1 -> 100, 2 -> 300, 3 -> 100, 4 -> 200 }
+ // - dst_ids: { 5 -> 10, 6 -> 20, 7 -> 10, 8 -> 300 }
+ //
+ // Here, `X -> Y` means that the instruction with SPIR-V id `X` is a member of
+ // the group, and `Y` is the id of its type. If we use
+ // `Differ::GroupIdsHelperGetTypeId` for `get_group`, then
+ // `get_group(X) == Y`.
+ //
+ // These instructions are bucketed by type as follows:
+ //
+ // - source: [1, 3] -> 100
+ // [4] -> 200
+ // [2] -> 300
+ //
+ // - destination: [5, 7] -> 10
+ // [6] -> 20
+ // [8] -> 300
+ //
+ // Now suppose that we have previously matched up src type 100 with dst type
+ // 10, and src type 200 with dst type 20, but no other types are matched.
+ //
+ // Then `match_group` is called twice:
+ // - Once with ([1,3], [5, 7]), corresponding to 100/10
+ // - Once with ([4],[6]), corresponding to 200/20
+ //
+ // The source type 300 isn't matched with anything, so the fact that there's a
+ // destination type 300 is irrelevant, and thus 2 and 8 are never passed to
+ // `match_group`.
+ //
+ // This function isn't specific to types; it simply buckets by the ids
+ // returned from `get_group`, and consults existing matches to pair up the
+ // resulting buckets.
+ void GroupIdsAndMatchByMappedId(
+ const IdGroup& src_ids, const IdGroup& dst_ids,
+ uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
+ std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+ match_group);
+
// Helper functions that determine if two instructions match
bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
bool DoesOperandMatch(const opt::Operand& src_operand,
@@ -504,36 +577,27 @@
FunctionMap dst_funcs_;
};
-void IdMap::MapUnmatchedIds(IdMap& other_way) {
- const uint32_t src_id_bound = static_cast<uint32_t>(id_map_.size());
- const uint32_t dst_id_bound = static_cast<uint32_t>(other_way.id_map_.size());
-
- uint32_t next_src_id = src_id_bound;
- uint32_t next_dst_id = dst_id_bound;
+void SrcDstIdMap::MapUnmatchedIds(
+ std::function<bool(uint32_t)> src_insn_defined,
+ std::function<bool(uint32_t)> dst_insn_defined) {
+ const uint32_t src_id_bound = static_cast<uint32_t>(src_to_dst_.IdBound());
+ const uint32_t dst_id_bound = static_cast<uint32_t>(dst_to_src_.IdBound());
for (uint32_t src_id = 1; src_id < src_id_bound; ++src_id) {
- if (!IsMapped(src_id)) {
- MapIds(src_id, next_dst_id);
-
- other_way.id_map_.push_back(0);
- other_way.MapIds(next_dst_id++, src_id);
+ if (!src_to_dst_.IsMapped(src_id) && src_insn_defined(src_id)) {
+ uint32_t fresh_dst_id = dst_to_src_.MakeFreshId();
+ MapIds(src_id, fresh_dst_id);
}
}
for (uint32_t dst_id = 1; dst_id < dst_id_bound; ++dst_id) {
- if (!other_way.IsMapped(dst_id)) {
- id_map_.push_back(0);
- MapIds(next_src_id, dst_id);
-
- other_way.MapIds(dst_id, next_src_id++);
+ if (!dst_to_src_.IsMapped(dst_id) && dst_insn_defined(dst_id)) {
+ uint32_t fresh_src_id = src_to_dst_.MakeFreshId();
+ MapIds(fresh_src_id, dst_id);
}
}
}
-void SrcDstIdMap::MapUnmatchedIds() {
- src_to_dst_.MapUnmatchedIds(dst_to_src_);
-}
-
void IdInstructions::MapIdToInstruction(uint32_t id,
const opt::Instruction* inst) {
assert(id != 0);
@@ -889,6 +953,37 @@
}
}
+void Differ::GroupIdsAndMatchByMappedId(
+ const IdGroup& src_ids, const IdGroup& dst_ids,
+ uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
+ std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+ match_group) {
+ // Group the ids based on a key (get_group)
+ std::map<uint32_t, IdGroup> src_groups;
+ std::map<uint32_t, IdGroup> dst_groups;
+
+ GroupIds<uint32_t>(src_ids, true, &src_groups, get_group);
+ GroupIds<uint32_t>(dst_ids, false, &dst_groups, get_group);
+
+ // Iterate over pairs of groups whose keys map to each other.
+ for (const auto& iter : src_groups) {
+ const uint32_t& src_key = iter.first;
+ const IdGroup& src_group = iter.second;
+
+ if (src_key == 0) {
+ continue;
+ }
+
+ if (id_map_.IsSrcMapped(src_key)) {
+ const uint32_t& dst_key = id_map_.MappedDstId(src_key);
+ const IdGroup& dst_group = dst_groups[dst_key];
+
+ // Let the caller match the groups as appropriate.
+ match_group(src_group, dst_group);
+ }
+ }
+}
+
bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
assert(dst_id != 0);
return id_map_.MappedDstId(src_id) == dst_id;
@@ -1419,7 +1514,6 @@
GroupIdsAndMatch<std::string>(
src, dst, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {
-
// Match only if there's a unique forward declaration with this debug
// name.
if (src_group.size() == 1 && dst_group.size() == 1) {
@@ -1574,6 +1668,8 @@
id_map_.MapIds(match_result.src_id, match_result.dst_id);
+ MatchFunctionParamIds(src_funcs_[match_result.src_id],
+ dst_funcs_[match_result.dst_id]);
MatchIdsInFunctionBodies(src_func_insts.at(match_result.src_id),
dst_func_insts.at(match_result.dst_id),
match_result.src_match, match_result.dst_match, 0);
@@ -1598,7 +1694,6 @@
GroupIdsAndMatch<std::string>(
src_params, dst_params, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {
-
// There shouldn't be two parameters with the same name, so the ids
// should match. There is nothing restricting the SPIR-V however to have
// two parameters with the same name, so be resilient against that.
@@ -1609,17 +1704,17 @@
// Then match the parameters by their type. If there are multiple of them,
// match them by their order.
- GroupIdsAndMatch<uint32_t>(
- src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
+ GroupIdsAndMatchByMappedId(
+ src_params, dst_params, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
-
const size_t shared_param_count =
std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
for (size_t param_index = 0; param_index < shared_param_count;
++param_index) {
- id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
+ id_map_.MapIds(src_group_by_type_id[param_index],
+ dst_group_by_type_id[param_index]);
}
});
}
@@ -1943,6 +2038,10 @@
// Always unsigned integers.
*number_bit_width = 32;
return SPV_NUMBER_UNSIGNED_INT;
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT:
+ // Always float.
+ *number_bit_width = 32;
+ return SPV_NUMBER_FLOATING;
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
switch (inst.opcode()) {
@@ -2064,9 +2163,10 @@
}
// Otherwise match them by name.
- bool matched = false;
for (const opt::Instruction* src_inst : src_insts) {
for (const opt::Instruction* dst_inst : dst_insts) {
+ if (id_map_.IsDstMapped(dst_inst)) continue;
+
const opt::Operand& src_name = src_inst->GetOperand(2);
const opt::Operand& dst_name = dst_inst->GetOperand(2);
@@ -2075,13 +2175,9 @@
uint32_t dst_id = dst_inst->GetSingleWordOperand(1);
id_map_.MapIds(src_id, dst_id);
id_map_.MapInsts(src_inst, dst_inst);
- matched = true;
break;
}
}
- if (matched) {
- break;
- }
}
}
}
@@ -2126,7 +2222,6 @@
spv::StorageClass::Max, &Differ::GroupIdsHelperGetTypePointerStorageClass,
[this](const IdGroup& src_group_by_storage_class,
const IdGroup& dst_group_by_storage_class) {
-
// Group them further by the type they are pointing to and loop over
// them.
GroupIdsAndMatch<spv::Op>(
@@ -2134,7 +2229,6 @@
spv::Op::Max, &Differ::GroupIdsHelperGetTypePointerTypeOp,
[this](const IdGroup& src_group_by_type_op,
const IdGroup& dst_group_by_type_op) {
-
// Group them even further by debug info, if possible and match by
// debug name.
MatchTypeForwardPointersByName(src_group_by_type_op,
@@ -2199,7 +2293,9 @@
case spv::Op::OpTypeVoid:
case spv::Op::OpTypeBool:
case spv::Op::OpTypeSampler:
- // void, bool and sampler are unique, match them.
+ case spv::Op::OpTypeAccelerationStructureNV:
+ case spv::Op::OpTypeRayQueryKHR:
+ // the above types have no operands and are unique, match them.
return true;
case spv::Op::OpTypeInt:
case spv::Op::OpTypeFloat:
@@ -2378,7 +2474,6 @@
GroupIdsAndMatch<std::string>(
src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {
-
// If there is a single function with this name in src and dst, it's a
// definite match.
if (src_group.size() == 1 && dst_group.size() == 1) {
@@ -2392,7 +2487,6 @@
&Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
-
if (src_group_by_type_id.size() == 1 &&
dst_group_by_type_id.size() == 1) {
id_map_.MapIds(src_group_by_type_id[0],
@@ -2437,7 +2531,6 @@
src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
-
BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
src_func_insts_, dst_func_insts_);
});
@@ -2647,7 +2740,9 @@
}
spv_result_t Differ::Output() {
- id_map_.MapUnmatchedIds();
+ id_map_.MapUnmatchedIds(
+ [this](uint32_t src_id) { return src_id_to_.IsDefined(src_id); },
+ [this](uint32_t dst_id) { return dst_id_to_.IsDefined(dst_id); });
src_id_to_.inst_map_.resize(id_map_.SrcToDstMap().IdBound(), nullptr);
dst_id_to_.inst_map_.resize(id_map_.DstToSrcMap().IdBound(), nullptr);
diff --git a/third_party/SPIRV-Tools/source/disassemble.cpp b/third_party/SPIRV-Tools/source/disassemble.cpp
index f862efd..f8f6f44 100644
--- a/third_party/SPIRV-Tools/source/disassemble.cpp
+++ b/third_party/SPIRV-Tools/source/disassemble.cpp
@@ -357,7 +357,8 @@
stream_ << opcode_desc->name;
} break;
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
- case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
+ case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT: {
SetRed();
EmitNumericLiteral(&stream_, inst, operand);
ResetColor();
@@ -424,6 +425,7 @@
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
EmitMaskOperand(operand.type, word);
break;
default:
diff --git a/third_party/SPIRV-Tools/source/enum_set.h b/third_party/SPIRV-Tools/source/enum_set.h
index 28ee5fe..a375138 100644
--- a/third_party/SPIRV-Tools/source/enum_set.h
+++ b/third_party/SPIRV-Tools/source/enum_set.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2016 Google Inc.
+// Copyright (c) 2023 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,195 +12,456 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <functional>
+#include <initializer_list>
+#include <limits>
+#include <type_traits>
+#include <vector>
+
#ifndef SOURCE_ENUM_SET_H_
#define SOURCE_ENUM_SET_H_
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <set>
-#include <utility>
-
#include "source/latest_version_spirv_header.h"
-#include "source/util/make_unique.h"
namespace spvtools {
-// A set of values of a 32-bit enum type.
-// It is fast and compact for the common case, where enum values
-// are at most 63. But it can represent enums with larger values,
-// as may appear in extensions.
-template <typename EnumType>
+// This container is optimized to store and retrieve unsigned enum values.
+// The base model for this implementation is an open-addressing hashtable with
+// linear probing. For small enums (max index < 64), all operations are O(1).
+//
+// - Enums are stored in buckets (64 contiguous values max per bucket)
+// - Buckets ranges don't overlap, but don't have to be contiguous.
+// - Enums are packed into 64-bits buckets, using 1 bit per enum value.
+//
+// Example:
+// - MyEnum { A = 0, B = 1, C = 64, D = 65 }
+// - 2 buckets are required:
+// - bucket 0, storing values in the range [ 0; 64[
+// - bucket 1, storing values in the range [64; 128[
+//
+// - Buckets are stored in a sorted vector (sorted by bucket range).
+// - Retrieval is done by computing the theoretical bucket index using the enum
+// value, and
+// doing a linear scan from this position.
+// - Insertion is done by retrieving the bucket and either:
+// - inserting a new bucket in the sorted vector when no buckets has a
+// compatible range.
+// - setting the corresponding bit in the bucket.
+// This means insertion in the middle/beginning can cause a memmove when no
+// bucket is available. In our case, this happens at most 23 times for the
+// largest enum we have (Opcodes).
+template <typename T>
class EnumSet {
private:
- // The ForEach method will call the functor on enum values in
- // enum value order (lowest to highest). To make that easier, use
- // an ordered set for the overflow values.
- using OverflowSetType = std::set<uint32_t>;
+ using BucketType = uint64_t;
+ using ElementType = std::underlying_type_t<T>;
+ static_assert(std::is_enum_v<T>, "EnumSets only works with enums.");
+ static_assert(std::is_signed_v<ElementType> == false,
+ "EnumSet doesn't supports signed enums.");
+
+ // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values.
+ // The first value a bucket can hold must be aligned on `kBucketSize`.
+ struct Bucket {
+ // bit mask to store `kBucketSize` enums.
+ BucketType data;
+ // 1st enum this bucket can represent.
+ T start;
+
+ friend bool operator==(const Bucket& lhs, const Bucket& rhs) {
+ return lhs.start == rhs.start && lhs.data == rhs.data;
+ }
+ };
+
+ // How many distinct values can a bucket hold? 1 bit per value.
+ static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL;
public:
- // Construct an empty set.
- EnumSet() {}
- // Construct an set with just the given enum value.
- explicit EnumSet(EnumType c) { Add(c); }
- // Construct an set from an initializer list of enum values.
- EnumSet(std::initializer_list<EnumType> cs) {
- for (auto c : cs) Add(c);
- }
- EnumSet(uint32_t count, const EnumType* ptr) {
- for (uint32_t i = 0; i < count; ++i) Add(ptr[i]);
- }
- // Copy constructor.
- EnumSet(const EnumSet& other) { *this = other; }
- // Move constructor. The moved-from set is emptied.
- EnumSet(EnumSet&& other) {
- mask_ = other.mask_;
- overflow_ = std::move(other.overflow_);
- other.mask_ = 0;
- other.overflow_.reset(nullptr);
- }
- // Assignment operator.
- EnumSet& operator=(const EnumSet& other) {
- if (&other != this) {
- mask_ = other.mask_;
- overflow_.reset(other.overflow_ ? new OverflowSetType(*other.overflow_)
- : nullptr);
+ class Iterator {
+ public:
+ typedef Iterator self_type;
+ typedef T value_type;
+ typedef T& reference;
+ typedef T* pointer;
+ typedef std::forward_iterator_tag iterator_category;
+ typedef size_t difference_type;
+
+ Iterator(const Iterator& other)
+ : set_(other.set_),
+ bucketIndex_(other.bucketIndex_),
+ bucketOffset_(other.bucketOffset_) {}
+
+ Iterator& operator++() {
+ do {
+ if (bucketIndex_ >= set_->buckets_.size()) {
+ bucketIndex_ = set_->buckets_.size();
+ bucketOffset_ = 0;
+ break;
+ }
+
+ if (bucketOffset_ + 1 == kBucketSize) {
+ bucketOffset_ = 0;
+ ++bucketIndex_;
+ } else {
+ ++bucketOffset_;
+ }
+
+ } while (bucketIndex_ < set_->buckets_.size() &&
+ !set_->HasEnumAt(bucketIndex_, bucketOffset_));
+ return *this;
}
+
+ Iterator operator++(int) {
+ Iterator old = *this;
+ operator++();
+ return old;
+ }
+
+ T operator*() const {
+ assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) &&
+ "operator*() called on an invalid iterator.");
+ return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_);
+ }
+
+ bool operator!=(const Iterator& other) const {
+ return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ ||
+ bucketIndex_ != other.bucketIndex_;
+ }
+
+ bool operator==(const Iterator& other) const {
+ return !(operator!=(other));
+ }
+
+ Iterator& operator=(const Iterator& other) {
+ set_ = other.set_;
+ bucketIndex_ = other.bucketIndex_;
+ bucketOffset_ = other.bucketOffset_;
+ return *this;
+ }
+
+ private:
+ Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset)
+ : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {}
+
+ private:
+ const EnumSet* set_ = nullptr;
+ // Index of the bucket in the vector.
+ size_t bucketIndex_ = 0;
+ // Offset in bits in the current bucket.
+ ElementType bucketOffset_ = 0;
+
+ friend class EnumSet;
+ };
+
+ // Required to allow the use of std::inserter.
+ using value_type = T;
+ using const_iterator = Iterator;
+ using iterator = Iterator;
+
+ public:
+ iterator cbegin() const noexcept {
+ auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0);
+ if (buckets_.size() == 0) {
+ return it;
+ }
+
+ // The iterator has the logic to find the next valid bit. If the value 0
+ // is not stored, use it to find the next valid bit.
+ if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) {
+ ++it;
+ }
+
+ return it;
+ }
+
+ iterator begin() const noexcept { return cbegin(); }
+
+ iterator cend() const noexcept {
+ return iterator(this, buckets_.size(), /* bucketOffset= */ 0);
+ }
+
+ iterator end() const noexcept { return cend(); }
+
+ // Creates an empty set.
+ EnumSet() : buckets_(0), size_(0) {}
+
+ // Creates a set and store `value` in it.
+ EnumSet(T value) : EnumSet() { insert(value); }
+
+ // Creates a set and stores each `values` in it.
+ EnumSet(std::initializer_list<T> values) : EnumSet() {
+ for (auto item : values) {
+ insert(item);
+ }
+ }
+
+ // Creates a set, and insert `count` enum values pointed by `array` in it.
+ EnumSet(ElementType count, const T* array) : EnumSet() {
+ for (ElementType i = 0; i < count; i++) {
+ insert(array[i]);
+ }
+ }
+
+ // Creates a set initialized with the content of the range [begin; end[.
+ template <class InputIt>
+ EnumSet(InputIt begin, InputIt end) : EnumSet() {
+ for (; begin != end; ++begin) {
+ insert(*begin);
+ }
+ }
+
+ // Copies the EnumSet `other` into a new EnumSet.
+ EnumSet(const EnumSet& other)
+ : buckets_(other.buckets_), size_(other.size_) {}
+
+ // Moves the EnumSet `other` into a new EnumSet.
+ EnumSet(EnumSet&& other)
+ : buckets_(std::move(other.buckets_)), size_(other.size_) {}
+
+ // Deep-copies the EnumSet `other` into this EnumSet.
+ EnumSet& operator=(const EnumSet& other) {
+ buckets_ = other.buckets_;
+ size_ = other.size_;
return *this;
}
- friend bool operator==(const EnumSet& a, const EnumSet& b) {
- if (a.mask_ != b.mask_) {
- return false;
+ // Matches std::unordered_set::insert behavior.
+ std::pair<iterator, bool> insert(const T& value) {
+ const size_t index = FindBucketForValue(value);
+ const ElementType offset = ComputeBucketOffset(value);
+
+ if (index >= buckets_.size() ||
+ buckets_[index].start != ComputeBucketStart(value)) {
+ size_ += 1;
+ InsertBucketFor(index, value);
+ return std::make_pair(Iterator(this, index, offset), true);
}
- if (a.overflow_ == nullptr && b.overflow_ == nullptr) {
+ auto& bucket = buckets_[index];
+ const auto mask = ComputeMaskForValue(value);
+ if (bucket.data & mask) {
+ return std::make_pair(Iterator(this, index, offset), false);
+ }
+
+ size_ += 1;
+ bucket.data |= ComputeMaskForValue(value);
+ return std::make_pair(Iterator(this, index, offset), true);
+ }
+
+ // Inserts `value` in the set if possible.
+ // Similar to `std::unordered_set::insert`, except the hint is ignored.
+ // Returns an iterator to the inserted element, or the element preventing
+ // insertion.
+ iterator insert(const_iterator, const T& value) {
+ return insert(value).first;
+ }
+
+ // Inserts `value` in the set if possible.
+ // Similar to `std::unordered_set::insert`, except the hint is ignored.
+ // Returns an iterator to the inserted element, or the element preventing
+ // insertion.
+ iterator insert(const_iterator, T&& value) { return insert(value).first; }
+
+ // Inserts all the values in the range [`first`; `last[.
+ // Similar to `std::unordered_set::insert`.
+ template <class InputIt>
+ void insert(InputIt first, InputIt last) {
+ for (auto it = first; it != last; ++it) {
+ insert(*it);
+ }
+ }
+
+ // Removes the value `value` into the set.
+ // Similar to `std::unordered_set::erase`.
+ // Returns the number of erased elements.
+ size_t erase(const T& value) {
+ const size_t index = FindBucketForValue(value);
+ if (index >= buckets_.size() ||
+ buckets_[index].start != ComputeBucketStart(value)) {
+ return 0;
+ }
+
+ auto& bucket = buckets_[index];
+ const auto mask = ComputeMaskForValue(value);
+ if (!(bucket.data & mask)) {
+ return 0;
+ }
+
+ size_ -= 1;
+ bucket.data &= ~mask;
+ if (bucket.data == 0) {
+ buckets_.erase(buckets_.cbegin() + index);
+ }
+ return 1;
+ }
+
+ // Returns true if `value` is present in the set.
+ bool contains(T value) const {
+ const size_t index = FindBucketForValue(value);
+ if (index >= buckets_.size() ||
+ buckets_[index].start != ComputeBucketStart(value)) {
+ return false;
+ }
+ auto& bucket = buckets_[index];
+ return bucket.data & ComputeMaskForValue(value);
+ }
+
+ // Returns the 1 if `value` is present in the set, `0` otherwise.
+ inline size_t count(T value) const { return contains(value) ? 1 : 0; }
+
+ // Returns true if the set is holds no values.
+ inline bool empty() const { return size_ == 0; }
+
+ // Returns the number of enums stored in this set.
+ size_t size() const { return size_; }
+
+ // Returns true if this set contains at least one value contained in `in_set`.
+ // Note: If `in_set` is empty, this function returns true.
+ bool HasAnyOf(const EnumSet<T>& in_set) const {
+ if (in_set.empty()) {
return true;
}
- if (a.overflow_ == nullptr || b.overflow_ == nullptr) {
- return false;
- }
+ auto lhs = buckets_.cbegin();
+ auto rhs = in_set.buckets_.cbegin();
- return *a.overflow_ == *b.overflow_;
- }
+ while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) {
+ if (lhs->start == rhs->start) {
+ if (lhs->data & rhs->data) {
+ // At least 1 bit is shared. Early return.
+ return true;
+ }
- friend bool operator!=(const EnumSet& a, const EnumSet& b) {
- return !(a == b);
- }
+ lhs++;
+ rhs++;
+ continue;
+ }
- // Adds the given enum value to the set. This has no effect if the
- // enum value is already in the set.
- void Add(EnumType c) { AddWord(ToWord(c)); }
+ // LHS bucket is smaller than the current RHS bucket. Catching up on RHS.
+ if (lhs->start < rhs->start) {
+ lhs++;
+ continue;
+ }
- // Removes the given enum value from the set. This has no effect if the
- // enum value is not in the set.
- void Remove(EnumType c) { RemoveWord(ToWord(c)); }
-
- // Returns true if this enum value is in the set.
- bool Contains(EnumType c) const { return ContainsWord(ToWord(c)); }
-
- // Applies f to each enum in the set, in order from smallest enum
- // value to largest.
- void ForEach(std::function<void(EnumType)> f) const {
- for (uint32_t i = 0; i < 64; ++i) {
- if (mask_ & AsMask(i)) f(static_cast<EnumType>(i));
- }
- if (overflow_) {
- for (uint32_t c : *overflow_) f(static_cast<EnumType>(c));
- }
- }
-
- // Returns true if the set is empty.
- bool IsEmpty() const {
- if (mask_) return false;
- if (overflow_ && !overflow_->empty()) return false;
- return true;
- }
-
- // Returns true if the set contains ANY of the elements of |in_set|,
- // or if |in_set| is empty.
- bool HasAnyOf(const EnumSet<EnumType>& in_set) const {
- if (in_set.IsEmpty()) return true;
-
- if (mask_ & in_set.mask_) return true;
-
- if (!overflow_ || !in_set.overflow_) return false;
-
- for (uint32_t item : *in_set.overflow_) {
- if (overflow_->find(item) != overflow_->end()) return true;
+ // Otherwise, RHS needs to catch up on LHS.
+ rhs++;
}
return false;
}
private:
- // Adds the given enum value (as a 32-bit word) to the set. This has no
- // effect if the enum value is already in the set.
- void AddWord(uint32_t word) {
- if (auto new_bits = AsMask(word)) {
- mask_ |= new_bits;
- } else {
- Overflow().insert(word);
+ // Returns the index of the last bucket in which `value` could be stored.
+ static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) {
+ return static_cast<size_t>(value) / kBucketSize;
+ }
+
+ // Returns the smallest enum value that could be contained in the same bucket
+ // as `value`.
+ static constexpr inline T ComputeBucketStart(T value) {
+ return static_cast<T>(kBucketSize *
+ ComputeLargestPossibleBucketIndexFor(value));
+ }
+
+ // Returns the index of the bit that corresponds to `value` in the bucket.
+ static constexpr inline ElementType ComputeBucketOffset(T value) {
+ return static_cast<ElementType>(value) % kBucketSize;
+ }
+
+ // Returns the bitmask used to represent the enum `value` in its bucket.
+ static constexpr inline BucketType ComputeMaskForValue(T value) {
+ return 1ULL << ComputeBucketOffset(value);
+ }
+
+ // Returns the `enum` stored in `bucket` at `offset`.
+ // `offset` is the bit-offset in the bucket storage.
+ static constexpr inline T GetValueFromBucket(const Bucket& bucket,
+ BucketType offset) {
+ return static_cast<T>(static_cast<ElementType>(bucket.start) + offset);
+ }
+
+ // For a given enum `value`, finds the bucket index that could contain this
+ // value. If no such bucket is found, the index at which the new bucket should
+ // be inserted is returned.
+ size_t FindBucketForValue(T value) const {
+ // Set is empty, insert at 0.
+ if (buckets_.size() == 0) {
+ return 0;
}
- }
- // Removes the given enum value (as a 32-bit word) from the set. This has no
- // effect if the enum value is not in the set.
- void RemoveWord(uint32_t word) {
- if (auto new_bits = AsMask(word)) {
- mask_ &= ~new_bits;
- } else {
- auto itr = Overflow().find(word);
- if (itr != Overflow().end()) Overflow().erase(itr);
+ const T wanted_start = ComputeBucketStart(value);
+ assert(buckets_.size() > 0 &&
+ "Size must not be 0 here. Has the code above changed?");
+ size_t index = std::min(buckets_.size() - 1,
+ ComputeLargestPossibleBucketIndexFor(value));
+
+ // This loops behaves like std::upper_bound with a reverse iterator.
+ // Buckets are sorted. 3 main cases:
+ // - The bucket matches
+ // => returns the bucket index.
+ // - The found bucket is larger
+ // => scans left until it finds the correct bucket, or insertion point.
+ // - The found bucket is smaller
+ // => We are at the end, so we return past-end index for insertion.
+ for (; buckets_[index].start >= wanted_start; index--) {
+ if (index == 0) {
+ return 0;
+ }
}
+
+ return index + 1;
}
- // Returns true if the enum represented as a 32-bit word is in the set.
- bool ContainsWord(uint32_t word) const {
- // We shouldn't call Overflow() since this is a const method.
- if (auto bits = AsMask(word)) {
- return (mask_ & bits) != 0;
- } else if (auto overflow = overflow_.get()) {
- return overflow->find(word) != overflow->end();
+ // Creates a new bucket to store `value` and inserts it at `index`.
+ // If the `index` is past the end, the bucket is inserted at the end of the
+ // vector.
+ void InsertBucketFor(size_t index, T value) {
+ const T bucket_start = ComputeBucketStart(value);
+ Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start};
+ auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket));
+#if defined(NDEBUG)
+ (void)it; // Silencing unused variable warning.
+#else
+ assert(std::next(it) == buckets_.end() ||
+ std::next(it)->start > bucket_start);
+ assert(it == buckets_.begin() || std::prev(it)->start < bucket_start);
+#endif
+ }
+
+ // Returns true if the bucket at `bucketIndex/ stores the enum at
+ // `bucketOffset`, false otherwise.
+ bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const {
+ assert(bucketIndex < buckets_.size());
+ assert(bucketOffset < kBucketSize);
+ return buckets_[bucketIndex].data & (1ULL << bucketOffset);
+ }
+
+ // Returns true if `lhs` and `rhs` hold the exact same values.
+ friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) {
+ if (lhs.size_ != rhs.size_) {
+ return false;
}
- // The word is large, but the set doesn't have large members, so
- // it doesn't have an overflow set.
- return false;
- }
- // Returns the enum value as a uint32_t.
- uint32_t ToWord(EnumType value) const {
- static_assert(sizeof(EnumType) <= sizeof(uint32_t),
- "EnumType must statically castable to uint32_t");
- return static_cast<uint32_t>(value);
- }
-
- // Determines whether the given enum value can be represented
- // as a bit in a uint64_t mask. If so, then returns that mask bit.
- // Otherwise, returns 0.
- uint64_t AsMask(uint32_t word) const {
- if (word > 63) return 0;
- return uint64_t(1) << word;
- }
-
- // Ensures that overflow_set_ references a set. A new empty set is
- // allocated if one doesn't exist yet. Returns overflow_set_.
- OverflowSetType& Overflow() {
- if (overflow_.get() == nullptr) {
- overflow_ = MakeUnique<OverflowSetType>();
+ if (lhs.buckets_.size() != rhs.buckets_.size()) {
+ return false;
}
- return *overflow_;
+ return lhs.buckets_ == rhs.buckets_;
}
- // Enums with values up to 63 are stored as bits in this mask.
- uint64_t mask_ = 0;
- // Enums with values larger than 63 are stored in this set.
- // This set should normally be empty or very small.
- std::unique_ptr<OverflowSetType> overflow_ = {};
+ // Returns true if `lhs` and `rhs` hold at least 1 different value.
+ friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) {
+ return !(lhs == rhs);
+ }
+
+ // Storage for the buckets.
+ std::vector<Bucket> buckets_;
+ // How many enums is this set storing.
+ size_t size_ = 0;
};
-// A set of spv::Capability, optimized for small capability values.
+// A set of spv::Capability.
using CapabilitySet = EnumSet<spv::Capability>;
} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/ext_inst.cpp b/third_party/SPIRV-Tools/source/ext_inst.cpp
index 4e27954..9a5ba84 100644
--- a/third_party/SPIRV-Tools/source/ext_inst.cpp
+++ b/third_party/SPIRV-Tools/source/ext_inst.cpp
@@ -30,6 +30,7 @@
#include "glsl.std.450.insts.inc"
#include "nonsemantic.clspvreflection.insts.inc"
#include "nonsemantic.shader.debuginfo.100.insts.inc"
+#include "nonsemantic.vkspreflection.insts.inc"
#include "opencl.debuginfo.100.insts.inc"
#include "opencl.std.insts.inc"
@@ -62,6 +63,9 @@
{SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION,
ARRAY_SIZE(nonsemantic_clspvreflection_entries),
nonsemantic_clspvreflection_entries},
+ {SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION,
+ ARRAY_SIZE(nonsemantic_vkspreflection_entries),
+ nonsemantic_vkspreflection_entries},
};
static const spv_ext_inst_table_t kTable_1_0 = {ARRAY_SIZE(kGroups_1_0),
@@ -138,6 +142,9 @@
if (!strncmp("NonSemantic.ClspvReflection.", name, 28)) {
return SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION;
}
+ if (!strncmp("NonSemantic.VkspReflection.", name, 27)) {
+ return SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION;
+ }
// ensure to add any known non-semantic extended instruction sets
// above this point, and update spvExtInstIsNonSemantic()
if (!strncmp("NonSemantic.", name, 12)) {
@@ -149,7 +156,8 @@
bool spvExtInstIsNonSemantic(const spv_ext_inst_type_t type) {
if (type == SPV_EXT_INST_TYPE_NONSEMANTIC_UNKNOWN ||
type == SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100 ||
- type == SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION) {
+ type == SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION ||
+ type == SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION) {
return true;
}
return false;
diff --git a/third_party/SPIRV-Tools/source/extensions.cpp b/third_party/SPIRV-Tools/source/extensions.cpp
index ebf6bec..ac987fc 100644
--- a/third_party/SPIRV-Tools/source/extensions.cpp
+++ b/third_party/SPIRV-Tools/source/extensions.cpp
@@ -40,8 +40,9 @@
std::string ExtensionSetToString(const ExtensionSet& extensions) {
std::stringstream ss;
- extensions.ForEach(
- [&ss](Extension ext) { ss << ExtensionToString(ext) << " "; });
+ for (auto extension : extensions) {
+ ss << ExtensionToString(extension) << " ";
+ }
return ss.str();
}
diff --git a/third_party/SPIRV-Tools/source/extensions.h b/third_party/SPIRV-Tools/source/extensions.h
index 8023444..cda4924 100644
--- a/third_party/SPIRV-Tools/source/extensions.h
+++ b/third_party/SPIRV-Tools/source/extensions.h
@@ -15,6 +15,7 @@
#ifndef SOURCE_EXTENSIONS_H_
#define SOURCE_EXTENSIONS_H_
+#include <cstdint>
#include <string>
#include "source/enum_set.h"
@@ -23,7 +24,7 @@
namespace spvtools {
// The known SPIR-V extensions.
-enum Extension {
+enum Extension : uint32_t {
#include "extension_enum.inc"
};
diff --git a/third_party/SPIRV-Tools/source/fuzz/transformation_add_no_contraction_decoration.cpp b/third_party/SPIRV-Tools/source/fuzz/transformation_add_no_contraction_decoration.cpp
index 07a31e5..87393e9 100644
--- a/third_party/SPIRV-Tools/source/fuzz/transformation_add_no_contraction_decoration.cpp
+++ b/third_party/SPIRV-Tools/source/fuzz/transformation_add_no_contraction_decoration.cpp
@@ -36,6 +36,11 @@
if (!instr) {
return false;
}
+ // |instr| must not be decorated with NoContraction.
+ if (ir_context->get_decoration_mgr()->HasDecoration(
+ message_.result_id(), spv::Decoration::NoContraction)) {
+ return false;
+ }
// The instruction must be arithmetic.
return IsArithmetic(instr->opcode());
}
diff --git a/third_party/SPIRV-Tools/source/fuzz/transformation_add_relaxed_decoration.cpp b/third_party/SPIRV-Tools/source/fuzz/transformation_add_relaxed_decoration.cpp
index 6cd4ecb..601546c 100644
--- a/third_party/SPIRV-Tools/source/fuzz/transformation_add_relaxed_decoration.cpp
+++ b/third_party/SPIRV-Tools/source/fuzz/transformation_add_relaxed_decoration.cpp
@@ -36,6 +36,11 @@
if (!instr) {
return false;
}
+ // |instr| must not be decorated with RelaxedPrecision.
+ if (ir_context->get_decoration_mgr()->HasDecoration(
+ message_.result_id(), spv::Decoration::RelaxedPrecision)) {
+ return false;
+ }
opt::BasicBlock* cur_block = ir_context->get_instr_block(instr);
// The instruction must have a block.
if (cur_block == nullptr) {
@@ -46,6 +51,7 @@
cur_block->id()))) {
return false;
}
+
// The instruction must be numeric.
return IsNumeric(instr->opcode());
}
diff --git a/third_party/SPIRV-Tools/source/link/linker.cpp b/third_party/SPIRV-Tools/source/link/linker.cpp
index e50391a..58930e4 100644
--- a/third_party/SPIRV-Tools/source/link/linker.cpp
+++ b/third_party/SPIRV-Tools/source/link/linker.cpp
@@ -91,7 +91,8 @@
// should be non-null. |max_id_bound| should be strictly greater than 0.
spv_result_t GenerateHeader(const MessageConsumer& consumer,
const std::vector<opt::Module*>& modules,
- uint32_t max_id_bound, opt::ModuleHeader* header);
+ uint32_t max_id_bound, opt::ModuleHeader* header,
+ const LinkerOptions& options);
// Merge all the modules from |in_modules| into a single module owned by
// |linked_context|.
@@ -202,7 +203,8 @@
spv_result_t GenerateHeader(const MessageConsumer& consumer,
const std::vector<opt::Module*>& modules,
- uint32_t max_id_bound, opt::ModuleHeader* header) {
+ uint32_t max_id_bound, opt::ModuleHeader* header,
+ const LinkerOptions& options) {
spv_position_t position = {};
if (modules.empty())
@@ -212,10 +214,12 @@
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
<< "|max_id_bound| of GenerateHeader should not be null.";
- const uint32_t linked_version = modules.front()->version();
+ uint32_t linked_version = modules.front()->version();
for (std::size_t i = 1; i < modules.size(); ++i) {
const uint32_t module_version = modules[i]->version();
- if (module_version != linked_version)
+ if (options.GetUseHighestVersion()) {
+ linked_version = std::max(linked_version, module_version);
+ } else if (module_version != linked_version) {
return DiagnosticStream({0, 0, 1}, consumer, "", SPV_ERROR_INTERNAL)
<< "Conflicting SPIR-V versions: "
<< SPV_SPIRV_VERSION_MAJOR_PART(linked_version) << "."
@@ -224,6 +228,7 @@
<< SPV_SPIRV_VERSION_MAJOR_PART(module_version) << "."
<< SPV_SPIRV_VERSION_MINOR_PART(module_version)
<< " (input module " << (i + 1) << ").";
+ }
}
header->magic_number = spv::MagicNumber;
@@ -753,7 +758,7 @@
// Phase 2: Generate the header
opt::ModuleHeader header;
- res = GenerateHeader(consumer, modules, max_id_bound, &header);
+ res = GenerateHeader(consumer, modules, max_id_bound, &header, options);
if (res != SPV_SUCCESS) return res;
IRContext linked_context(c_context->target_env, consumer);
linked_context.module()->SetHeader(header);
diff --git a/third_party/SPIRV-Tools/source/opcode.cpp b/third_party/SPIRV-Tools/source/opcode.cpp
index d26024a..787dbb3 100644
--- a/third_party/SPIRV-Tools/source/opcode.cpp
+++ b/third_party/SPIRV-Tools/source/opcode.cpp
@@ -274,6 +274,7 @@
case spv::Op::OpTypeArray:
case spv::Op::OpTypeStruct:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
return false;
@@ -294,6 +295,7 @@
case spv::Op::OpPtrAccessChain:
case spv::Op::OpLoad:
case spv::Op::OpConstantNull:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
@@ -308,6 +310,7 @@
case spv::Op::OpFunctionParameter:
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
@@ -340,6 +343,7 @@
case spv::Op::OpTypeNamedBarrier:
case spv::Op::OpTypeAccelerationStructureNV:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
// spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR:
@@ -532,6 +536,8 @@
case spv::Op::OpGroupNonUniformQuadBroadcast:
case spv::Op::OpGroupNonUniformQuadSwap:
case spv::Op::OpGroupNonUniformRotateKHR:
+ case spv::Op::OpGroupNonUniformQuadAllKHR:
+ case spv::Op::OpGroupNonUniformQuadAnyKHR:
return true;
default:
return false;
@@ -750,6 +756,7 @@
case spv::Op::OpInBoundsAccessChain:
case spv::Op::OpPtrAccessChain:
case spv::Op::OpInBoundsPtrAccessChain:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
diff --git a/third_party/SPIRV-Tools/source/operand.cpp b/third_party/SPIRV-Tools/source/operand.cpp
index 31a6c59..7848846 100644
--- a/third_party/SPIRV-Tools/source/operand.cpp
+++ b/third_party/SPIRV-Tools/source/operand.cpp
@@ -26,7 +26,6 @@
#include "source/macro.h"
#include "source/opcode.h"
#include "source/spirv_constant.h"
-#include "source/spirv_target_env.h"
// For now, assume unified1 contains up to SPIR-V 1.3 and no later
// SPIR-V version.
@@ -48,7 +47,7 @@
return SPV_SUCCESS;
}
-spv_result_t spvOperandTableNameLookup(spv_target_env env,
+spv_result_t spvOperandTableNameLookup(spv_target_env,
const spv_operand_table table,
const spv_operand_type_t type,
const char* name,
@@ -57,31 +56,18 @@
if (!table) return SPV_ERROR_INVALID_TABLE;
if (!name || !pEntry) return SPV_ERROR_INVALID_POINTER;
- const auto version = spvVersionForTargetEnv(env);
for (uint64_t typeIndex = 0; typeIndex < table->count; ++typeIndex) {
const auto& group = table->types[typeIndex];
if (type != group.type) continue;
for (uint64_t index = 0; index < group.count; ++index) {
const auto& entry = group.entries[index];
// We consider the current operand as available as long as
- // 1. The target environment satisfies the minimal requirement of the
- // operand; or
- // 2. There is at least one extension enabling this operand; or
- // 3. There is at least one capability enabling this operand.
- //
- // Note that the second rule assumes the extension enabling this operand
- // is indeed requested in the SPIR-V code; checking that should be
- // validator's work.
+ // it is in the grammar. It might not be *valid* to use,
+ // but that should be checked by the validator, not by parsing.
if (nameLength == strlen(entry.name) &&
!strncmp(entry.name, name, nameLength)) {
- if ((version >= entry.minVersion && version <= entry.lastVersion) ||
- entry.numExtensions > 0u || entry.numCapabilities > 0u) {
- *pEntry = &entry;
- return SPV_SUCCESS;
- } else {
- // if there is no extension/capability then the version is wrong
- return SPV_ERROR_WRONG_VERSION;
- }
+ *pEntry = &entry;
+ return SPV_SUCCESS;
}
}
}
@@ -89,7 +75,7 @@
return SPV_ERROR_INVALID_LOOKUP;
}
-spv_result_t spvOperandTableValueLookup(spv_target_env env,
+spv_result_t spvOperandTableValueLookup(spv_target_env,
const spv_operand_table table,
const spv_operand_type_t type,
const uint32_t value,
@@ -110,33 +96,15 @@
const auto beg = group.entries;
const auto end = group.entries + group.count;
- // We need to loop here because there can exist multiple symbols for the
- // same operand value, and they can be introduced in different target
- // environments, which means they can have different minimal version
- // requirements. For example, SubgroupEqMaskKHR can exist in any SPIR-V
- // version as long as the SPV_KHR_shader_ballot extension is there; but
- // starting from SPIR-V 1.3, SubgroupEqMask, which has the same numeric
- // value as SubgroupEqMaskKHR, is available in core SPIR-V without extension
- // requirements.
// Assumes the underlying table is already sorted ascendingly according to
// opcode value.
- const auto version = spvVersionForTargetEnv(env);
- for (auto it = std::lower_bound(beg, end, needle, comp);
- it != end && it->value == value; ++it) {
- // We consider the current operand as available as long as
- // 1. The target environment satisfies the minimal requirement of the
- // operand; or
- // 2. There is at least one extension enabling this operand; or
- // 3. There is at least one capability enabling this operand.
- //
- // Note that the second rule assumes the extension enabling this operand
- // is indeed requested in the SPIR-V code; checking that should be
- // validator's work.
- if ((version >= it->minVersion && version <= it->lastVersion) ||
- it->numExtensions > 0u || it->numCapabilities > 0u) {
- *pEntry = it;
- return SPV_SUCCESS;
- }
+ auto it = std::lower_bound(beg, end, needle, comp);
+ if (it != end && it->value == value) {
+ // The current operand is considered available as long as
+ // it is in the grammar. It might not be *valid* to use,
+ // but that should be checked by the validator, not by parsing.
+ *pEntry = it;
+ return SPV_SUCCESS;
}
}
@@ -155,6 +123,7 @@
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER:
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT:
return "literal number";
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
return "possibly multi-word literal integer";
@@ -236,6 +205,26 @@
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
return "packed vector format";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
+ return "cooperative matrix operands";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+ return "cooperative matrix layout";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
+ return "cooperative matrix use";
+ case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
+ return "initialization mode qualifier";
+ case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
+ return "host access qualifier";
+ case SPV_OPERAND_TYPE_LOAD_CACHE_CONTROL:
+ return "load cache control";
+ case SPV_OPERAND_TYPE_STORE_CACHE_CONTROL:
+ return "store cache control";
+ case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS:
+ return "named maximum number of registers";
+ case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
+ case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
+ return "raw access chain operands";
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
return "image";
@@ -325,6 +314,7 @@
}
switch (type) {
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT:
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER:
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
@@ -369,6 +359,13 @@
case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
+ case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
+ case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_LOAD_CACHE_CONTROL:
+ case SPV_OPERAND_TYPE_STORE_CACHE_CONTROL:
+ case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS:
return true;
default:
break;
@@ -387,6 +384,8 @@
case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+ case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
return true;
default:
break;
@@ -405,7 +404,9 @@
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
+ case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
return true;
default:
break;
diff --git a/third_party/SPIRV-Tools/source/operand.h b/third_party/SPIRV-Tools/source/operand.h
index a3010d9..f74c933 100644
--- a/third_party/SPIRV-Tools/source/operand.h
+++ b/third_party/SPIRV-Tools/source/operand.h
@@ -57,12 +57,6 @@
// Gets the name string of the non-variable operand type.
const char* spvOperandTypeStr(spv_operand_type_t type);
-// Returns true if the given type is concrete.
-bool spvOperandIsConcrete(spv_operand_type_t type);
-
-// Returns true if the given type is concrete and also a mask.
-bool spvOperandIsConcreteMask(spv_operand_type_t type);
-
// Returns true if an operand of the given type is optional.
bool spvOperandIsOptional(spv_operand_type_t type);
diff --git a/third_party/SPIRV-Tools/source/opt/CMakeLists.txt b/third_party/SPIRV-Tools/source/opt/CMakeLists.txt
index eea3c47..4e7d92d 100644
--- a/third_party/SPIRV-Tools/source/opt/CMakeLists.txt
+++ b/third_party/SPIRV-Tools/source/opt/CMakeLists.txt
@@ -71,6 +71,7 @@
instruction_list.h
instrument_pass.h
interface_var_sroa.h
+ invocation_interlock_placement_pass.h
interp_fixup_pass.h
ir_builder.h
ir_context.h
@@ -93,6 +94,7 @@
loop_unswitch_pass.h
mem_pass.h
merge_return_pass.h
+ modify_maximal_reconvergence.h
module.h
null_pass.h
passes.h
@@ -121,7 +123,9 @@
strip_debug_info_pass.h
strip_nonsemantic_info_pass.h
struct_cfg_analysis.h
+ switch_descriptorset_pass.h
tree_iterator.h
+ trim_capabilities_pass.h
type_manager.h
types.h
unify_const_pass.h
@@ -189,6 +193,7 @@
instruction_list.cpp
instrument_pass.cpp
interface_var_sroa.cpp
+ invocation_interlock_placement_pass.cpp
interp_fixup_pass.cpp
ir_context.cpp
ir_loader.cpp
@@ -210,6 +215,7 @@
loop_unswitch_pass.cpp
mem_pass.cpp
merge_return_pass.cpp
+ modify_maximal_reconvergence.cpp
module.cpp
optimizer.cpp
pass.cpp
@@ -236,6 +242,8 @@
strip_debug_info_pass.cpp
strip_nonsemantic_info_pass.cpp
struct_cfg_analysis.cpp
+ switch_descriptorset_pass.cpp
+ trim_capabilities_pass.cpp
type_manager.cpp
types.cpp
unify_const_pass.cpp
diff --git a/third_party/SPIRV-Tools/source/opt/aggressive_dead_code_elim_pass.cpp b/third_party/SPIRV-Tools/source/opt/aggressive_dead_code_elim_pass.cpp
index 1645638..4737da5 100644
--- a/third_party/SPIRV-Tools/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -438,6 +438,9 @@
const Instruction* inst) {
assert(inst->opcode() == spv::Op::OpFunctionCall);
std::vector<uint32_t> live_variables;
+ // NOTE: we should only be checking function call parameters here, not the
+ // function itself, however, `IsPtr` will trivially return false for
+ // OpFunction
inst->ForEachInId([this, &live_variables](const uint32_t* operand_id) {
if (!IsPtr(*operand_id)) return;
uint32_t var_id = GetVariableId(*operand_id);
@@ -938,6 +941,8 @@
void AggressiveDCEPass::InitExtensions() {
extensions_allowlist_.clear();
+
+ // clang-format off
extensions_allowlist_.insert({
"SPV_AMD_shader_explicit_vertex_parameter",
"SPV_AMD_shader_trinary_minmax",
@@ -980,11 +985,13 @@
"SPV_NV_shader_image_footprint",
"SPV_NV_shading_rate",
"SPV_NV_mesh_shader",
+ "SPV_EXT_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_KHR_ray_tracing",
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_physical_storage_buffer",
"SPV_KHR_terminate_invocation",
"SPV_KHR_shader_clock",
"SPV_KHR_vulkan_memory_model",
@@ -994,7 +1001,12 @@
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric",
+ "SPV_NV_bindless_texture",
+ "SPV_EXT_shader_atomic_float_add",
+ "SPV_EXT_fragment_shader_interlock",
+ "SPV_NV_compute_shader_derivatives"
});
+ // clang-format on
}
Instruction* AggressiveDCEPass::GetHeaderBranch(BasicBlock* blk) {
diff --git a/third_party/SPIRV-Tools/source/opt/block_merge_util.cpp b/third_party/SPIRV-Tools/source/opt/block_merge_util.cpp
index fe23e36..42f695f 100644
--- a/third_party/SPIRV-Tools/source/opt/block_merge_util.cpp
+++ b/third_party/SPIRV-Tools/source/opt/block_merge_util.cpp
@@ -98,6 +98,17 @@
return false;
}
+ // Note: This means that the instructions in a break block will execute as if
+ // they were still diverged according to the loop iteration. This restricts
+ // potential transformations an implementation may perform on the IR to match
+ // shader author expectations. Similarly, instructions in the loop construct
+ // cannot be moved into the continue construct unless it can be proven that
+ // invocations are always converged.
+ if (succ_is_merge && context->get_feature_mgr()->HasExtension(
+ kSPV_KHR_maximal_reconvergence)) {
+ return false;
+ }
+
if (pred_is_merge && IsContinue(context, lab_id)) {
// Cannot merge a continue target with a merge block.
return false;
diff --git a/third_party/SPIRV-Tools/source/opt/const_folding_rules.cpp b/third_party/SPIRV-Tools/source/opt/const_folding_rules.cpp
index 2610808..17900af 100644
--- a/third_party/SPIRV-Tools/source/opt/const_folding_rules.cpp
+++ b/third_party/SPIRV-Tools/source/opt/const_folding_rules.cpp
@@ -21,6 +21,59 @@
namespace {
constexpr uint32_t kExtractCompositeIdInIdx = 0;
+// Returns the value obtained by extracting the |number_of_bits| least
+// significant bits from |value|, and sign-extending it to 64-bits.
+uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
+ if (number_of_bits == 64) return value;
+
+ uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
+ uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
+ if (value & mask_for_sign_bit) {
+ // Set upper bits to 1
+ value |= ~mask_for_significant_bits;
+ } else {
+ // Clear the upper bits
+ value &= mask_for_significant_bits;
+ }
+ return value;
+}
+
+// Returns the value obtained by extracting the |number_of_bits| least
+// significant bits from |value|, and zero-extending it to 64-bits.
+uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
+ if (number_of_bits == 64) return value;
+
+ uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
+ uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
+ value &= mask_for_bits_to_keep;
+ return value;
+}
+
+// Returns a constant whose value is `value` and type is `type`. This constant
+// will be generated by `const_mgr`. The type must be a scalar integer type.
+const analysis::Constant* GenerateIntegerConstant(
+ const analysis::Integer* integer_type, uint64_t result,
+ analysis::ConstantManager* const_mgr) {
+ assert(integer_type != nullptr);
+
+ std::vector<uint32_t> words;
+ if (integer_type->width() == 64) {
+ // In the 64-bit case, two words are needed to represent the value.
+ words = {static_cast<uint32_t>(result),
+ static_cast<uint32_t>(result >> 32)};
+ } else {
+ // In all other cases, only a single word is needed.
+ assert(integer_type->width() <= 32);
+ if (integer_type->IsSigned()) {
+ result = SignExtendValue(result, integer_type->width());
+ } else {
+ result = ZeroExtendValue(result, integer_type->width());
+ }
+ words = {static_cast<uint32_t>(result)};
+ }
+ return const_mgr->GetConstant(integer_type, words);
+}
+
// Returns a constants with the value NaN of the given type. Only works for
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
const analysis::Constant* GetNan(const analysis::Type* type,
@@ -88,6 +141,22 @@
return nullptr;
}
+// Returns a constants with the value |-val| of the given type.
+const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
+ const analysis::Constant* val,
+ analysis::ConstantManager* const_mgr) {
+ const analysis::Integer* int_type = result_type->AsInteger();
+ assert(int_type != nullptr);
+
+ if (val->AsNullConstant()) {
+ return val;
+ }
+
+ uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
+ return const_mgr->GetIntConst(new_value, int_type->width(),
+ int_type->IsSigned());
+}
+
// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
return [](IRContext* context, Instruction* inst,
@@ -341,6 +410,69 @@
};
}
+// Returns to the constant that results from tranposing |matrix|. The result
+// will have type |result_type|, and |matrix| must exist in |context|. The
+// result constant will also exist in |context|.
+const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
+ analysis::Matrix* result_type,
+ IRContext* context) {
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ if (matrix->AsNullConstant() != nullptr) {
+ return const_mgr->GetNullCompositeConstant(result_type);
+ }
+
+ const auto& columns = matrix->AsMatrixConstant()->GetComponents();
+ uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
+
+ // Collect the ids of the elements in their new positions.
+ std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
+ for (const analysis::Constant* column : columns) {
+ if (column->AsNullConstant()) {
+ column = const_mgr->GetNullCompositeConstant(column->type());
+ }
+ const auto& column_components = column->AsVectorConstant()->GetComponents();
+
+ for (uint32_t row = 0; row < number_of_rows; ++row) {
+ result_elements[row].push_back(
+ const_mgr->GetDefiningInstruction(column_components[row])
+ ->result_id());
+ }
+ }
+
+ // Create the constant for each row in the result, and collect the ids.
+ std::vector<uint32_t> result_columns(number_of_rows);
+ for (uint32_t col = 0; col < number_of_rows; ++col) {
+ auto* element = const_mgr->GetConstant(result_type->element_type(),
+ result_elements[col]);
+ result_columns[col] =
+ const_mgr->GetDefiningInstruction(element)->result_id();
+ }
+
+ // Create the matrix constant from the row ids, and return it.
+ return const_mgr->GetConstant(result_type, result_columns);
+}
+
+const analysis::Constant* FoldTranspose(
+ IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == spv::Op::OpTranspose);
+
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
+ return nullptr;
+ }
+ }
+
+ const analysis::Constant* matrix = constants[0];
+ if (matrix == nullptr) {
+ return nullptr;
+ }
+
+ auto* result_type = type_mgr->GetType(inst->type_id());
+ return TransposeMatrix(matrix, result_type->AsMatrix(), context);
+}
+
ConstantFoldingRule FoldVectorTimesMatrix() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
@@ -376,13 +508,7 @@
assert(c1->type()->AsVector()->element_type() == element_type &&
c2->type()->AsMatrix()->element_type() == vector_type);
- // Get a float vector that is the result of vector-times-matrix.
- std::vector<const analysis::Constant*> c1_components =
- c1->GetVectorComponents(const_mgr);
- std::vector<const analysis::Constant*> c2_components =
- c2->AsMatrixConstant()->GetComponents();
uint32_t resultVectorSize = result_type->AsVector()->element_count();
-
std::vector<uint32_t> ids;
if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@@ -395,6 +521,12 @@
return const_mgr->GetConstant(vector_type, ids);
}
+ // Get a float vector that is the result of vector-times-matrix.
+ std::vector<const analysis::Constant*> c1_components =
+ c1->GetVectorComponents(const_mgr);
+ std::vector<const analysis::Constant*> c2_components =
+ c2->AsMatrixConstant()->GetComponents();
+
if (float_type->width() == 32) {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;
@@ -472,13 +604,7 @@
assert(c1->type()->AsMatrix()->element_type() == vector_type);
assert(c2->type()->AsVector()->element_type() == element_type);
- // Get a float vector that is the result of matrix-times-vector.
- std::vector<const analysis::Constant*> c1_components =
- c1->AsMatrixConstant()->GetComponents();
- std::vector<const analysis::Constant*> c2_components =
- c2->GetVectorComponents(const_mgr);
uint32_t resultVectorSize = result_type->AsVector()->element_count();
-
std::vector<uint32_t> ids;
if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@@ -491,6 +617,12 @@
return const_mgr->GetConstant(vector_type, ids);
}
+ // Get a float vector that is the result of matrix-times-vector.
+ std::vector<const analysis::Constant*> c1_components =
+ c1->AsMatrixConstant()->GetComponents();
+ std::vector<const analysis::Constant*> c2_components =
+ c2->GetVectorComponents(const_mgr);
+
if (float_type->width() == 32) {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;
@@ -587,13 +719,13 @@
const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b, analysis::ConstantManager*)>;
-// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
-// using |scalar_rule| and unary float point vectors ops by applying
+// Returns a |ConstantFoldingRule| that folds unary scalar ops
+// using |scalar_rule| and unary vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
// that is returned assumes that |constants| contains 1 entry. If they are
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
// whose element type is |Float| or |Integer|.
-ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
+ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
@@ -602,10 +734,6 @@
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
- if (!inst->IsFloatingPointFoldingAllowed()) {
- return nullptr;
- }
-
const analysis::Constant* arg =
(inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
@@ -640,6 +768,83 @@
};
}
+// Returns a |ConstantFoldingRule| that folds binary scalar ops
+// using |scalar_rule| and binary vectors ops by applying
+// |scalar_rule| to the elements of the vector. The folding rule assumes that op
+// has two inputs. For regular instruction, those are in operands 0 and 1. For
+// extended instruction, they are in operands 1 and 2. If an element in
+// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
+// or |Vector| whose element type is |Float| or |Integer|.
+ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
+ return [scalar_rule](IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants)
+ -> const analysis::Constant* {
+ assert(constants.size() == inst->NumInOperands());
+ assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
+ const analysis::Vector* vector_type = result_type->AsVector();
+
+ const analysis::Constant* arg1 =
+ (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
+ const analysis::Constant* arg2 =
+ (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];
+
+ if (arg1 == nullptr || arg2 == nullptr) {
+ return nullptr;
+ }
+
+ if (vector_type == nullptr) {
+ return scalar_rule(result_type, arg1, arg2, const_mgr);
+ }
+
+ std::vector<const analysis::Constant*> a_components;
+ std::vector<const analysis::Constant*> b_components;
+ std::vector<const analysis::Constant*> results_components;
+
+ a_components = arg1->GetVectorComponents(const_mgr);
+ b_components = arg2->GetVectorComponents(const_mgr);
+ assert(a_components.size() == b_components.size());
+
+ // Fold each component of the vector.
+ for (uint32_t i = 0; i < a_components.size(); ++i) {
+ results_components.push_back(scalar_rule(vector_type->element_type(),
+ a_components[i], b_components[i],
+ const_mgr));
+ if (results_components[i] == nullptr) {
+ return nullptr;
+ }
+ }
+
+ // Build the constant object and return it.
+ std::vector<uint32_t> ids;
+ for (const analysis::Constant* member : results_components) {
+ ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
+ }
+ return const_mgr->GetConstant(vector_type, ids);
+ };
+}
+
+// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
+// using |scalar_rule| and unary float point vectors ops by applying
+// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
+// that is returned assumes that |constants| contains 1 entry. If they are
+// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
+// whose element type is |Float| or |Integer|.
+ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
+ auto folding_rule = FoldUnaryOp(scalar_rule);
+ return [folding_rule](IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants)
+ -> const analysis::Constant* {
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return nullptr;
+ }
+
+ return folding_rule(context, inst, constants);
+ };
+}
+
// Returns the result of folding the constants in |constants| according the
// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
// per component.
@@ -872,6 +1077,11 @@
return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
}
+ uint32_t width = denominator->type()->AsFloat()->width();
+ if (width != 32 && width != 64) {
+ return nullptr;
+ }
+
const analysis::FloatConstant* denominator_float =
denominator->AsFloatConstant();
if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
@@ -1042,18 +1252,8 @@
};
}
-// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
-// from zero.
-UnaryScalarFoldingRule FoldFNegateOp() {
- return [](const analysis::Type* result_type, const analysis::Constant* a,
- analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
- assert(result_type != nullptr && a != nullptr);
- assert(result_type == a->type());
- return NegateFPConst(result_type, a, const_mgr);
- };
-}
-
-ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
+ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
+ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
return [cmp_opcode](IRContext* context, Instruction* inst,
@@ -1497,6 +1697,74 @@
return nullptr;
};
}
+
+enum Sign { Signed, Unsigned };
+
+// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
+// The `signedness` is used to determine if the operands should be interpreted
+// as signed or unsigned. If the operands are signed, the value will be sign
+// extended before the value is passed to `op`. Otherwise the values will be
+// zero extended.
+template <Sign signedness>
+BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
+ uint64_t)) {
+ return
+ [op](const analysis::Type* result_type, const analysis::Constant* a,
+ const analysis::Constant* b,
+ analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
+ assert(result_type != nullptr && a != nullptr && b != nullptr);
+ const analysis::Integer* integer_type = result_type->AsInteger();
+ assert(integer_type != nullptr);
+ assert(a->type()->kind() == analysis::Type::kInteger);
+ assert(b->type()->kind() == analysis::Type::kInteger);
+ assert(integer_type->width() == a->type()->AsInteger()->width());
+ assert(integer_type->width() == b->type()->AsInteger()->width());
+
+ // In SPIR-V, all operations support unsigned types, but the way they
+ // are interpreted depends on the opcode. This is why we use the
+ // template argument to determine how to interpret the operands.
+ uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
+ : a->GetZeroExtendedValue());
+ uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
+ : b->GetZeroExtendedValue());
+ uint64_t result = op(ia, ib);
+
+ const analysis::Constant* result_constant =
+ GenerateIntegerConstant(integer_type, result, const_mgr);
+ return result_constant;
+ };
+}
+
+// A scalar folding rule that folds OpSConvert.
+const analysis::Constant* FoldScalarSConvert(
+ const analysis::Type* result_type, const analysis::Constant* a,
+ analysis::ConstantManager* const_mgr) {
+ assert(result_type != nullptr);
+ assert(a != nullptr);
+ assert(const_mgr != nullptr);
+ const analysis::Integer* integer_type = result_type->AsInteger();
+ assert(integer_type && "The result type of an SConvert");
+ int64_t value = a->GetSignExtendedValue();
+ return GenerateIntegerConstant(integer_type, value, const_mgr);
+}
+
+// A scalar folding rule that folds OpUConvert.
+const analysis::Constant* FoldScalarUConvert(
+ const analysis::Type* result_type, const analysis::Constant* a,
+ analysis::ConstantManager* const_mgr) {
+ assert(result_type != nullptr);
+ assert(a != nullptr);
+ assert(const_mgr != nullptr);
+ const analysis::Integer* integer_type = result_type->AsInteger();
+ assert(integer_type && "The result type of an UConvert");
+ uint64_t value = a->GetZeroExtendedValue();
+
+ // If the operand was an unsigned value with less than 32-bit, it would have
+ // been sign extended earlier, and we need to clear those bits.
+ auto* operand_type = a->type()->AsInteger();
+ value = ZeroExtendValue(value, operand_type->width());
+ return GenerateIntegerConstant(integer_type, value, const_mgr);
+}
} // namespace
void ConstantFoldingRules::AddFoldingRules() {
@@ -1514,6 +1782,8 @@
rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
+ rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
+ rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));
rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
@@ -1566,10 +1836,52 @@
rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
+ rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
+ rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
+ rules_[spv::Op::OpIAdd].push_back(
+ FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
+ [](uint64_t a, uint64_t b) { return a + b; })));
+ rules_[spv::Op::OpISub].push_back(
+ FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
+ [](uint64_t a, uint64_t b) { return a - b; })));
+ rules_[spv::Op::OpIMul].push_back(
+ FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
+ [](uint64_t a, uint64_t b) { return a * b; })));
+ rules_[spv::Op::OpUDiv].push_back(
+ FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
+ [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
+ rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
+ return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
+ static_cast<int64_t>(b))
+ : 0);
+ })));
+ rules_[spv::Op::OpUMod].push_back(
+ FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
+ [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));
+
+ rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
+ return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
+ static_cast<int64_t>(b))
+ : 0);
+ })));
+
+ rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
+ if (b == 0) return static_cast<uint64_t>(0ull);
+
+ int64_t signed_a = static_cast<int64_t>(a);
+ int64_t signed_b = static_cast<int64_t>(b);
+ int64_t result = signed_a % signed_b;
+ if ((signed_b < 0) != (result < 0)) result += signed_b;
+ return static_cast<uint64_t>(result);
+ })));
+
// Add rules for GLSLstd450
FeatureManager* feature_manager = context_->get_feature_mgr();
uint32_t ext_inst_glslstd450_id =
diff --git a/third_party/SPIRV-Tools/source/opt/constants.cpp b/third_party/SPIRV-Tools/source/opt/constants.cpp
index 9b4c89a..6eebbb5 100644
--- a/third_party/SPIRV-Tools/source/opt/constants.cpp
+++ b/third_party/SPIRV-Tools/source/opt/constants.cpp
@@ -435,6 +435,8 @@
words_per_element = float_type->width() / 32;
else if (const auto* int_type = element_type->AsInteger())
words_per_element = int_type->width() / 32;
+ else if (element_type->AsBool() != nullptr)
+ words_per_element = 1;
if (words_per_element != 1 && words_per_element != 2) return nullptr;
@@ -487,6 +489,31 @@
return GetDefiningInstruction(c)->result_id();
}
+const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth,
+ bool isSigned) {
+ Type* int_type = context()->get_type_mgr()->GetIntType(bitWidth, isSigned);
+
+ if (isSigned) {
+ // Sign extend the value.
+ int32_t num_of_bit_to_ignore = 64 - bitWidth;
+ val = static_cast<int64_t>(val << num_of_bit_to_ignore) >>
+ num_of_bit_to_ignore;
+ } else if (bitWidth < 64) {
+ // Clear the upper bit that are not used.
+ uint64_t mask = ((1ull << bitWidth) - 1);
+ val &= mask;
+ }
+
+ if (bitWidth <= 32) {
+ return GetConstant(int_type, {static_cast<uint32_t>(val)});
+ }
+
+ // If the value is more than 32-bit, we need to split the operands into two
+ // 32-bit integers.
+ return GetConstant(
+ int_type, {static_cast<uint32_t>(val), static_cast<uint32_t>(val >> 32)});
+}
+
uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
Type* uint_type = context()->get_type_mgr()->GetUIntType();
const Constant* c = GetConstant(uint_type, {val});
diff --git a/third_party/SPIRV-Tools/source/opt/constants.h b/third_party/SPIRV-Tools/source/opt/constants.h
index 410304e..ae8dc62 100644
--- a/third_party/SPIRV-Tools/source/opt/constants.h
+++ b/third_party/SPIRV-Tools/source/opt/constants.h
@@ -659,6 +659,12 @@
// Returns the id of a 32-bit signed integer constant with value |val|.
uint32_t GetSIntConstId(int32_t val);
+ // Returns an integer constant with `bitWidth` and value |val|. If `isSigned`
+ // is true, the constant will be a signed integer. Otherwise it will be
+ // unsigned. Only the `bitWidth` lower order bits of |val| will be used. The
+ // rest will be ignored.
+ const Constant* GetIntConst(uint64_t val, int32_t bitWidth, bool isSigned);
+
// Returns the id of a 32-bit unsigned integer constant with value |val|.
uint32_t GetUIntConstId(uint32_t val);
diff --git a/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.cpp b/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.cpp
index 2c4a631..e243bed 100644
--- a/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.cpp
@@ -63,6 +63,10 @@
void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
+bool ConvertToHalfPass::CanRelaxOpOperands(Instruction* inst) {
+ return image_ops_.count(inst->opcode()) == 0;
+}
+
analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
analysis::Float float_ty(width);
return context()->get_type_mgr()->GetRegisteredType(&float_ty);
@@ -167,6 +171,19 @@
bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
bool modified = false;
+ // If this is a OpCompositeExtract instruction and has a struct operand, we
+ // should not relax this instruction. Doing so could cause a mismatch between
+ // the result type and the struct member type.
+ bool hasStructOperand = false;
+ if (inst->opcode() == spv::Op::OpCompositeExtract) {
+ inst->ForEachInId([&hasStructOperand, this](uint32_t* idp) {
+ Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
+ if (IsStruct(op_inst)) hasStructOperand = true;
+ });
+ if (hasStructOperand) {
+ return false;
+ }
+ }
// Convert all float32 based operands to float16 equivalent and change
// instruction type to float16 equivalent.
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
@@ -299,12 +316,19 @@
if (closure_ops_.count(inst->opcode()) == 0) return false;
// Can relax if all float operands are relaxed
bool relax = true;
- inst->ForEachInId([&relax, this](uint32_t* idp) {
+ bool hasStructOperand = false;
+ inst->ForEachInId([&relax, &hasStructOperand, this](uint32_t* idp) {
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
- if (IsStruct(op_inst)) relax = false;
+ if (IsStruct(op_inst)) hasStructOperand = true;
if (!IsFloat(op_inst, 32)) return;
if (!IsRelaxed(*idp)) relax = false;
});
+ // If the instruction has a struct operand, we should not relax it, even if
+ // all its uses are relaxed. Doing so could cause a mismatch between the
+ // result type and the struct member type.
+ if (hasStructOperand) {
+ return false;
+ }
if (relax) {
AddRelaxed(inst->result_id());
return true;
@@ -313,7 +337,8 @@
relax = true;
get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
- (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
+ (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id())) ||
+ !CanRelaxOpOperands(uinst)) {
relax = false;
return;
}
diff --git a/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.h b/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.h
index 24a478f..8e10c4f 100644
--- a/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.h
+++ b/third_party/SPIRV-Tools/source/opt/convert_to_half_pass.h
@@ -56,6 +56,9 @@
// Add |id| to the relaxed id set
void AddRelaxed(uint32_t id);
+ // Return true if the instruction's operands can be relaxed
+ bool CanRelaxOpOperands(Instruction* inst);
+
// Return type id for float with |width|
analysis::Type* FloatScalarType(uint32_t width);
@@ -133,13 +136,13 @@
// Set of 450 extension operations to be processed
std::unordered_set<uint32_t> target_ops_450_;
- // Set of sample operations
+ // Set of all sample operations, including dref and non-dref operations
std::unordered_set<spv::Op, hasher> image_ops_;
- // Set of dref sample operations
+ // Set of only dref sample operations
std::unordered_set<spv::Op, hasher> dref_image_ops_;
- // Set of dref sample operations
+ // Set of operations that can be marked as relaxed
std::unordered_set<spv::Op, hasher> closure_ops_;
// Set of ids of all relaxed instructions
diff --git a/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.cpp b/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.cpp
index 66a268f..c2bea8a 100644
--- a/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.cpp
+++ b/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.cpp
@@ -35,6 +35,32 @@
dbg_opcode == CommonDebugInfoDebugValue;
}
+// Returns the number of members in |type|. If |type| is not a composite type
+// or the number of components is not known at compile time, the return value
+// will be 0.
+uint32_t GetNumberOfMembers(const analysis::Type* type, IRContext* context) {
+ if (const analysis::Struct* struct_type = type->AsStruct()) {
+ return static_cast<uint32_t>(struct_type->element_types().size());
+ } else if (const analysis::Array* array_type = type->AsArray()) {
+ const analysis::Constant* length_const =
+ context->get_constant_mgr()->FindDeclaredConstant(
+ array_type->LengthId());
+
+ if (length_const == nullptr) {
+ // This can happen if the length is an OpSpecConstant.
+ return 0;
+ }
+ assert(length_const->type()->AsInteger());
+ return length_const->GetU32();
+ } else if (const analysis::Vector* vector_type = type->AsVector()) {
+ return vector_type->element_count();
+ } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
+ return matrix_type->element_count();
+ } else {
+ return 0;
+ }
+}
+
} // namespace
Pass::Status CopyPropagateArrays::Process() {
@@ -357,22 +383,9 @@
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
- uint32_t number_of_elements = 0;
- if (const analysis::Struct* struct_type = result_type->AsStruct()) {
- number_of_elements =
- static_cast<uint32_t>(struct_type->element_types().size());
- } else if (const analysis::Array* array_type = result_type->AsArray()) {
- const analysis::Constant* length_const =
- const_mgr->FindDeclaredConstant(array_type->LengthId());
- number_of_elements = length_const->GetU32();
- } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
- number_of_elements = vector_type->element_count();
- } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
- number_of_elements = matrix_type->element_count();
- }
+ uint32_t number_of_elements = GetNumberOfMembers(result_type, context());
if (number_of_elements == 0) {
return nullptr;
@@ -800,23 +813,8 @@
std::vector<uint32_t> access_indices = GetAccessIds();
type = type_mgr->GetMemberType(type, access_indices);
- if (const analysis::Struct* struct_type = type->AsStruct()) {
- return static_cast<uint32_t>(struct_type->element_types().size());
- } else if (const analysis::Array* array_type = type->AsArray()) {
- const analysis::Constant* length_const =
- context->get_constant_mgr()->FindDeclaredConstant(
- array_type->LengthId());
- assert(length_const->type()->AsInteger());
- return length_const->GetU32();
- } else if (const analysis::Vector* vector_type = type->AsVector()) {
- return vector_type->element_count();
- } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
- return matrix_type->element_count();
- } else {
- return 0;
- }
+ return opt::GetNumberOfMembers(type, context);
}
-
template <class iterator>
CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
iterator begin, iterator end)
diff --git a/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.h b/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.h
index 7486f80..c6ca7d2 100644
--- a/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.h
+++ b/third_party/SPIRV-Tools/source/opt/copy_prop_arrays.h
@@ -101,7 +101,8 @@
bool IsMember() const { return !access_chain_.empty(); }
// Returns the number of members in the object represented by |this|. If
- // |this| does not represent a composite type, the return value will be 0.
+ // |this| does not represent a composite type or the number of components is
+ // not known at compile time, the return value will be 0.
uint32_t GetNumberOfMembers();
// Returns the owning variable that the memory object is contained in.
@@ -207,7 +208,7 @@
// Returns the memory object that at some point was equivalent to the result
// of |insert_inst|. If a memory object cannot be identified, the return
- // value is |nullptr\. The opcode of |insert_inst| must be
+ // value is |nullptr|. The opcode of |insert_inst| must be
// |OpCompositeInsert|. This function looks for a series of
// |OpCompositeInsert| instructions that insert the elements one at a time in
// order from beginning to end.
diff --git a/third_party/SPIRV-Tools/source/opt/dead_insert_elim_pass.cpp b/third_party/SPIRV-Tools/source/opt/dead_insert_elim_pass.cpp
index a486903..f985e4c 100644
--- a/third_party/SPIRV-Tools/source/opt/dead_insert_elim_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/dead_insert_elim_pass.cpp
@@ -213,7 +213,8 @@
} break;
default: {
// Mark inserts in chain for all components
- MarkInsertChain(&*ii, nullptr, 0, nullptr);
+ std::unordered_set<uint32_t> visited_phis;
+ MarkInsertChain(&*ii, nullptr, 0, &visited_phis);
} break;
}
});
diff --git a/third_party/SPIRV-Tools/source/opt/decoration_manager.cpp b/third_party/SPIRV-Tools/source/opt/decoration_manager.cpp
index 1393d48..3e95dbc 100644
--- a/third_party/SPIRV-Tools/source/opt/decoration_manager.cpp
+++ b/third_party/SPIRV-Tools/source/opt/decoration_manager.cpp
@@ -461,7 +461,7 @@
bool DecorationManager::WhileEachDecoration(
uint32_t id, uint32_t decoration,
- std::function<bool(const Instruction&)> f) {
+ std::function<bool(const Instruction&)> f) const {
for (const Instruction* inst : GetDecorationsFor(id, true)) {
switch (inst->opcode()) {
case spv::Op::OpMemberDecorate:
@@ -485,14 +485,19 @@
void DecorationManager::ForEachDecoration(
uint32_t id, uint32_t decoration,
- std::function<void(const Instruction&)> f) {
+ std::function<void(const Instruction&)> f) const {
WhileEachDecoration(id, decoration, [&f](const Instruction& inst) {
f(inst);
return true;
});
}
-bool DecorationManager::HasDecoration(uint32_t id, uint32_t decoration) {
+bool DecorationManager::HasDecoration(uint32_t id,
+ spv::Decoration decoration) const {
+ return HasDecoration(id, static_cast<uint32_t>(decoration));
+}
+
+bool DecorationManager::HasDecoration(uint32_t id, uint32_t decoration) const {
bool has_decoration = false;
ForEachDecoration(id, decoration, [&has_decoration](const Instruction&) {
has_decoration = true;
diff --git a/third_party/SPIRV-Tools/source/opt/decoration_manager.h b/third_party/SPIRV-Tools/source/opt/decoration_manager.h
index 1a0d1b1..2be016a 100644
--- a/third_party/SPIRV-Tools/source/opt/decoration_manager.h
+++ b/third_party/SPIRV-Tools/source/opt/decoration_manager.h
@@ -92,20 +92,21 @@
// Returns whether a decoration instruction for |id| with decoration
// |decoration| exists or not.
- bool HasDecoration(uint32_t id, uint32_t decoration);
+ bool HasDecoration(uint32_t id, uint32_t decoration) const;
+ bool HasDecoration(uint32_t id, spv::Decoration decoration) const;
// |f| is run on each decoration instruction for |id| with decoration
// |decoration|. Processed are all decorations which target |id| either
// directly or indirectly by Decoration Groups.
void ForEachDecoration(uint32_t id, uint32_t decoration,
- std::function<void(const Instruction&)> f);
+ std::function<void(const Instruction&)> f) const;
// |f| is run on each decoration instruction for |id| with decoration
// |decoration|. Processes all decoration which target |id| either directly or
// indirectly through decoration groups. If |f| returns false, iteration is
// terminated and this function returns false.
bool WhileEachDecoration(uint32_t id, uint32_t decoration,
- std::function<bool(const Instruction&)> f);
+ std::function<bool(const Instruction&)> f) const;
// |f| is run on each decoration instruction for |id| with decoration
// |decoration|. Processes all decoration which target |id| either directly or
@@ -141,7 +142,7 @@
uint32_t decoration_value);
// Add |decoration, decoration_value| of |inst_id, member| to module.
- void AddMemberDecoration(uint32_t member, uint32_t inst_id,
+ void AddMemberDecoration(uint32_t inst_id, uint32_t member,
uint32_t decoration, uint32_t decoration_value);
friend bool operator==(const DecorationManager&, const DecorationManager&);
diff --git a/third_party/SPIRV-Tools/source/opt/def_use_manager.h b/third_party/SPIRV-Tools/source/opt/def_use_manager.h
index a8dbbc6..13cf9bd 100644
--- a/third_party/SPIRV-Tools/source/opt/def_use_manager.h
+++ b/third_party/SPIRV-Tools/source/opt/def_use_manager.h
@@ -27,28 +27,6 @@
namespace opt {
namespace analysis {
-// Class for representing a use of id. Note that:
-// * Result type id is a use.
-// * Ids referenced in OpSectionMerge & OpLoopMerge are considered as use.
-// * Ids referenced in OpPhi's in operands are considered as use.
-struct Use {
- Instruction* inst; // Instruction using the id.
- uint32_t operand_index; // logical operand index of the id use. This can be
- // the index of result type id.
-};
-
-inline bool operator==(const Use& lhs, const Use& rhs) {
- return lhs.inst == rhs.inst && lhs.operand_index == rhs.operand_index;
-}
-
-inline bool operator!=(const Use& lhs, const Use& rhs) { return !(lhs == rhs); }
-
-inline bool operator<(const Use& lhs, const Use& rhs) {
- if (lhs.inst < rhs.inst) return true;
- if (lhs.inst > rhs.inst) return false;
- return lhs.operand_index < rhs.operand_index;
-}
-
// Definition should never be null. User can be null, however, such an entry
// should be used only for searching (e.g. all users of a particular definition)
// and never stored in a container.
diff --git a/third_party/SPIRV-Tools/source/opt/desc_sroa.cpp b/third_party/SPIRV-Tools/source/opt/desc_sroa.cpp
index 8da0c86..2c0f482 100644
--- a/third_party/SPIRV-Tools/source/opt/desc_sroa.cpp
+++ b/third_party/SPIRV-Tools/source/opt/desc_sroa.cpp
@@ -54,9 +54,10 @@
bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
std::vector<Instruction*> access_chain_work_list;
std::vector<Instruction*> load_work_list;
+ std::vector<Instruction*> entry_point_work_list;
bool failed = !get_def_use_mgr()->WhileEachUser(
- var->result_id(),
- [this, &access_chain_work_list, &load_work_list](Instruction* use) {
+ var->result_id(), [this, &access_chain_work_list, &load_work_list,
+ &entry_point_work_list](Instruction* use) {
if (use->opcode() == spv::Op::OpName) {
return true;
}
@@ -73,6 +74,9 @@
case spv::Op::OpLoad:
load_work_list.push_back(use);
return true;
+ case spv::Op::OpEntryPoint:
+ entry_point_work_list.push_back(use);
+ return true;
default:
context()->EmitErrorMessage(
"Variable cannot be replaced: invalid instruction", use);
@@ -95,6 +99,11 @@
return false;
}
}
+ for (Instruction* use : entry_point_work_list) {
+ if (!ReplaceEntryPoint(var, use)) {
+ return false;
+ }
+ }
return true;
}
@@ -147,6 +156,42 @@
return true;
}
+bool DescriptorScalarReplacement::ReplaceEntryPoint(Instruction* var,
+ Instruction* use) {
+ // Build a new |OperandList| for |use| that removes |var| and adds its
+ // replacement variables.
+ Instruction::OperandList new_operands;
+
+ // Copy all operands except |var|.
+ bool found = false;
+ for (uint32_t idx = 0; idx < use->NumOperands(); idx++) {
+ Operand& op = use->GetOperand(idx);
+ if (op.type == SPV_OPERAND_TYPE_ID && op.words[0] == var->result_id()) {
+ found = true;
+ } else {
+ new_operands.emplace_back(op);
+ }
+ }
+
+ if (!found) {
+ context()->EmitErrorMessage(
+ "Variable cannot be replaced: invalid instruction", use);
+ return false;
+ }
+
+ // Add all new replacement variables.
+ uint32_t num_replacement_vars =
+ descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
+ for (uint32_t i = 0; i < num_replacement_vars; i++) {
+ new_operands.push_back(
+ {SPV_OPERAND_TYPE_ID, {GetReplacementVariable(var, i)}});
+ }
+
+ use->ReplaceOperands(new_operands);
+ context()->UpdateDefUse(use);
+ return true;
+}
+
uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
uint32_t idx) {
auto replacement_vars = replacement_variables_.find(var);
diff --git a/third_party/SPIRV-Tools/source/opt/desc_sroa.h b/third_party/SPIRV-Tools/source/opt/desc_sroa.h
index 6a24fd8..901be3e 100644
--- a/third_party/SPIRV-Tools/source/opt/desc_sroa.h
+++ b/third_party/SPIRV-Tools/source/opt/desc_sroa.h
@@ -64,6 +64,11 @@
// otherwise.
bool ReplaceLoadedValue(Instruction* var, Instruction* value);
+ // Replaces the given composite variable |var| in the OpEntryPoint with the
+ // new replacement variables, one for each element of the array |var|. Returns
+ // |true| if successful, and |false| otherwise.
+ bool ReplaceEntryPoint(Instruction* var, Instruction* use);
+
// Replaces the given OpCompositeExtract |extract| and all of its references
// with an OpLoad of a replacement variable. |var| is the variable with
// composite type whose value is being used by |extract|. Assumes that
diff --git a/third_party/SPIRV-Tools/source/opt/feature_manager.cpp b/third_party/SPIRV-Tools/source/opt/feature_manager.cpp
index 07e053b..5188370 100644
--- a/third_party/SPIRV-Tools/source/opt/feature_manager.cpp
+++ b/third_party/SPIRV-Tools/source/opt/feature_manager.cpp
@@ -40,31 +40,33 @@
const std::string name = ext->GetInOperand(0u).AsString();
Extension extension;
if (GetExtensionFromString(name.c_str(), &extension)) {
- extensions_.Add(extension);
+ extensions_.insert(extension);
}
}
void FeatureManager::RemoveExtension(Extension ext) {
- if (!extensions_.Contains(ext)) return;
- extensions_.Remove(ext);
+ if (!extensions_.contains(ext)) return;
+ extensions_.erase(ext);
}
void FeatureManager::AddCapability(spv::Capability cap) {
- if (capabilities_.Contains(cap)) return;
+ if (capabilities_.contains(cap)) return;
- capabilities_.Add(cap);
+ capabilities_.insert(cap);
spv_operand_desc desc = {};
if (SPV_SUCCESS == grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
uint32_t(cap), &desc)) {
- CapabilitySet(desc->numCapabilities, desc->capabilities)
- .ForEach([this](spv::Capability c) { AddCapability(c); });
+ for (auto capability :
+ CapabilitySet(desc->numCapabilities, desc->capabilities)) {
+ AddCapability(capability);
+ }
}
}
void FeatureManager::RemoveCapability(spv::Capability cap) {
- if (!capabilities_.Contains(cap)) return;
- capabilities_.Remove(cap);
+ if (!capabilities_.contains(cap)) return;
+ capabilities_.erase(cap);
}
void FeatureManager::AddCapabilities(Module* module) {
diff --git a/third_party/SPIRV-Tools/source/opt/feature_manager.h b/third_party/SPIRV-Tools/source/opt/feature_manager.h
index b96988d..d150a2f 100644
--- a/third_party/SPIRV-Tools/source/opt/feature_manager.h
+++ b/third_party/SPIRV-Tools/source/opt/feature_manager.h
@@ -25,27 +25,19 @@
// Tracks features enabled by a module. The IRContext has a FeatureManager.
class FeatureManager {
public:
- explicit FeatureManager(const AssemblyGrammar& grammar) : grammar_(grammar) {}
-
// Returns true if |ext| is an enabled extension in the module.
- bool HasExtension(Extension ext) const { return extensions_.Contains(ext); }
-
- // Removes the given |extension| from the current FeatureManager.
- void RemoveExtension(Extension extension);
+ bool HasExtension(Extension ext) const { return extensions_.contains(ext); }
// Returns true if |cap| is an enabled capability in the module.
bool HasCapability(spv::Capability cap) const {
- return capabilities_.Contains(cap);
+ return capabilities_.contains(cap);
}
- // Removes the given |capability| from the current FeatureManager.
- void RemoveCapability(spv::Capability capability);
+ // Returns the capabilities the module declares.
+ inline const CapabilitySet& GetCapabilities() const { return capabilities_; }
- // Analyzes |module| and records enabled extensions and capabilities.
- void Analyze(Module* module);
-
- CapabilitySet* GetCapabilities() { return &capabilities_; }
- const CapabilitySet* GetCapabilities() const { return &capabilities_; }
+ // Returns the extensions the module imports.
+ inline const ExtensionSet& GetExtensions() const { return extensions_; }
uint32_t GetExtInstImportId_GLSLstd450() const {
return extinst_importid_GLSLstd450_;
@@ -64,23 +56,34 @@
return !(a == b);
}
- // Adds the given |capability| and all implied capabilities into the current
- // FeatureManager.
- void AddCapability(spv::Capability capability);
+ private:
+ explicit FeatureManager(const AssemblyGrammar& grammar) : grammar_(grammar) {}
+
+ // Analyzes |module| and records enabled extensions and capabilities.
+ void Analyze(Module* module);
// Add the extension |ext| to the feature manager.
void AddExtension(Instruction* ext);
- // Analyzes |module| and records imported external instruction sets.
- void AddExtInstImportIds(Module* module);
-
- private:
// Analyzes |module| and records enabled extensions.
void AddExtensions(Module* module);
+ // Removes the given |extension| from the current FeatureManager.
+ void RemoveExtension(Extension extension);
+
+ // Adds the given |capability| and all implied capabilities into the current
+ // FeatureManager.
+ void AddCapability(spv::Capability capability);
+
// Analyzes |module| and records enabled capabilities.
void AddCapabilities(Module* module);
+ // Removes the given |capability| from the current FeatureManager.
+ void RemoveCapability(spv::Capability capability);
+
+ // Analyzes |module| and records imported external instruction sets.
+ void AddExtInstImportIds(Module* module);
+
// Auxiliary object for querying SPIR-V grammar facts.
const AssemblyGrammar& grammar_;
@@ -100,6 +103,8 @@
// Common NonSemanticShader100DebugInfo external instruction import ids,
// cached for performance.
uint32_t extinst_importid_Shader100DebugInfo_ = 0;
+
+ friend class IRContext;
};
} // namespace opt
diff --git a/third_party/SPIRV-Tools/source/opt/fix_storage_class.cpp b/third_party/SPIRV-Tools/source/opt/fix_storage_class.cpp
index 5597e82..564cd1b 100644
--- a/third_party/SPIRV-Tools/source/opt/fix_storage_class.cpp
+++ b/third_party/SPIRV-Tools/source/opt/fix_storage_class.cpp
@@ -318,7 +318,13 @@
const analysis::Constant* index_const =
context()->get_constant_mgr()->FindDeclaredConstant(
inst->GetSingleWordInOperand(i));
- uint32_t index = index_const->GetU32();
+ // It is highly unlikely that any type would have more fields than could
+ // be indexed by a 32-bit integer, and GetSingleWordInOperand only takes
+ // a 32-bit value, so we would not be able to handle it anyway. But the
+ // specification does allow any scalar integer type, treated as signed,
+ // so we simply downcast the index to 32-bits.
+ uint32_t index =
+ static_cast<uint32_t>(index_const->GetSignExtendedValue());
id = type_inst->GetSingleWordInOperand(index);
break;
}
diff --git a/third_party/SPIRV-Tools/source/opt/fold.cpp b/third_party/SPIRV-Tools/source/opt/fold.cpp
index 453756f..942da68 100644
--- a/third_party/SPIRV-Tools/source/opt/fold.cpp
+++ b/third_party/SPIRV-Tools/source/opt/fold.cpp
@@ -70,58 +70,6 @@
uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,
uint32_t b) const {
switch (opcode) {
- // Arthimetics
- case spv::Op::OpIAdd:
- return a + b;
- case spv::Op::OpISub:
- return a - b;
- case spv::Op::OpIMul:
- return a * b;
- case spv::Op::OpUDiv:
- if (b != 0) {
- return a / b;
- } else {
- // Dividing by 0 is undefined, so we will just pick 0.
- return 0;
- }
- case spv::Op::OpSDiv:
- if (b != 0u) {
- return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
- } else {
- // Dividing by 0 is undefined, so we will just pick 0.
- return 0;
- }
- case spv::Op::OpSRem: {
- // The sign of non-zero result comes from the first operand: a. This is
- // guaranteed by C++11 rules for integer division operator. The division
- // result is rounded toward zero, so the result of '%' has the sign of
- // the first operand.
- if (b != 0u) {
- return static_cast<int32_t>(a) % static_cast<int32_t>(b);
- } else {
- // Remainder when dividing with 0 is undefined, so we will just pick 0.
- return 0;
- }
- }
- case spv::Op::OpSMod: {
- // The sign of non-zero result comes from the second operand: b
- if (b != 0u) {
- int32_t rem = BinaryOperate(spv::Op::OpSRem, a, b);
- int32_t b_prim = static_cast<int32_t>(b);
- return (rem + b_prim) % b_prim;
- } else {
- // Mod with 0 is undefined, so we will just pick 0.
- return 0;
- }
- }
- case spv::Op::OpUMod:
- if (b != 0u) {
- return (a % b);
- } else {
- // Mod with 0 is undefined, so we will just pick 0.
- return 0;
- }
-
// Shifting
case spv::Op::OpShiftRightLogical:
if (b >= 32) {
@@ -627,7 +575,8 @@
Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
- if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) {
+ if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
+ !GetConstantFoldingRules().HasFoldingRule(inst)) {
return nullptr;
}
// Collect the values of the constant parameters.
@@ -661,29 +610,58 @@
}
}
- uint32_t result_val = 0;
bool successful = false;
+
// If all parameters are constant, fold the instruction to a constant.
- if (!missing_constants && inst->IsFoldableByFoldScalar()) {
- result_val = FoldScalars(inst->opcode(), constants);
- successful = true;
+ if (inst->IsFoldableByFoldScalar()) {
+ uint32_t result_val = 0;
+
+ if (!missing_constants) {
+ result_val = FoldScalars(inst->opcode(), constants);
+ successful = true;
+ }
+
+ if (!successful) {
+ successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
+ }
+
+ if (successful) {
+ const analysis::Constant* result_const =
+ const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
+ Instruction* folded_inst =
+ const_mgr->GetDefiningInstruction(result_const, inst->type_id());
+ return folded_inst;
+ }
+ } else if (inst->IsFoldableByFoldVector()) {
+ std::vector<uint32_t> result_val;
+
+ if (!missing_constants) {
+ if (Instruction* inst_type =
+ context_->get_def_use_mgr()->GetDef(inst->type_id())) {
+ result_val = FoldVectors(
+ inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
+ successful = true;
+ }
+ }
+
+ if (successful) {
+ const analysis::Constant* result_const =
+ const_mgr->GetNumericVectorConstantWithWords(
+ const_mgr->GetType(inst)->AsVector(), result_val);
+ Instruction* folded_inst =
+ const_mgr->GetDefiningInstruction(result_const, inst->type_id());
+ return folded_inst;
+ }
}
- if (!successful && inst->IsFoldableByFoldScalar()) {
- successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
- }
-
- if (successful) {
- const analysis::Constant* result_const =
- const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
- Instruction* folded_inst =
- const_mgr->GetDefiningInstruction(result_const, inst->type_id());
- return folded_inst;
- }
return nullptr;
}
bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
+ return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
+}
+
+bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
// Support 32-bit integers.
if (type_inst->opcode() == spv::Op::OpTypeInt) {
return type_inst->GetSingleWordInOperand(0) == 32;
@@ -696,6 +674,19 @@
return false;
}
+bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
+ // Support vectors with foldable components
+ if (type_inst->opcode() == spv::Op::OpTypeVector) {
+ uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
+ Instruction* def_component_type =
+ context_->get_def_use_mgr()->GetDef(component_type_id);
+ return def_component_type != nullptr &&
+ IsFoldableScalarType(def_component_type);
+ }
+ // Nothing else yet.
+ return false;
+}
+
bool InstructionFolder::FoldInstruction(Instruction* inst) const {
bool modified = false;
Instruction* folded_inst(inst);
diff --git a/third_party/SPIRV-Tools/source/opt/fold.h b/third_party/SPIRV-Tools/source/opt/fold.h
index 9a131d0..42da65e 100644
--- a/third_party/SPIRV-Tools/source/opt/fold.h
+++ b/third_party/SPIRV-Tools/source/opt/fold.h
@@ -86,6 +86,14 @@
// result type is |type_inst|.
bool IsFoldableType(Instruction* type_inst) const;
+ // Returns true if |FoldInstructionToConstant| could fold an instruction whose
+ // result type is |type_inst|.
+ bool IsFoldableScalarType(Instruction* type_inst) const;
+
+ // Returns true if |FoldInstructionToConstant| could fold an instruction whose
+ // result type is |type_inst|.
+ bool IsFoldableVectorType(Instruction* type_inst) const;
+
// Tries to fold |inst| to a single constant, when the input ids to |inst|
// have been substituted using |id_map|. Returns a pointer to the OpConstant*
// instruction if successful. If necessary, a new constant instruction is
diff --git a/third_party/SPIRV-Tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/third_party/SPIRV-Tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp
index f6d6155..c568027 100644
--- a/third_party/SPIRV-Tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp
@@ -115,20 +115,9 @@
"The first in-operand of OpSpecConstantOp instruction must be of "
"SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
- switch (static_cast<spv::Op>(inst->GetSingleWordInOperand(0))) {
- case spv::Op::OpCompositeExtract:
- case spv::Op::OpVectorShuffle:
- case spv::Op::OpCompositeInsert:
- case spv::Op::OpQuantizeToF16:
- folded_inst = FoldWithInstructionFolder(pos);
- break;
- default:
- // TODO: This should use the instruction folder as well, but some folding
- // rules are missing.
-
- // Component-wise operations.
- folded_inst = DoComponentWiseOperation(pos);
- break;
+ folded_inst = FoldWithInstructionFolder(pos);
+ if (!folded_inst) {
+ folded_inst = DoComponentWiseOperation(pos);
}
if (!folded_inst) return false;
@@ -176,8 +165,9 @@
Instruction* new_const_inst =
context()->get_instruction_folder().FoldInstructionToConstant(
inst.get(), identity_map);
- assert(new_const_inst != nullptr &&
- "Failed to fold instruction that must be folded.");
+
+ // new_const_inst == null indicates we cannot fold this spec constant
+ if (!new_const_inst) return nullptr;
// Get the instruction before |pos| to insert after. |pos| cannot be the
// first instruction in the list because its type has to come first.
diff --git a/third_party/SPIRV-Tools/source/opt/folding_rules.cpp b/third_party/SPIRV-Tools/source/opt/folding_rules.cpp
index 7730ac1..5c68e29 100644
--- a/third_party/SPIRV-Tools/source/opt/folding_rules.cpp
+++ b/third_party/SPIRV-Tools/source/opt/folding_rules.cpp
@@ -2067,7 +2067,8 @@
}
// Returns the number of elements in the composite type |type|. Returns 0 if
-// |type| is a scalar value.
+// |type| is a scalar value. Return UINT32_MAX when the size is unknown at
+// compile time.
uint32_t GetNumberOfElements(const analysis::Type* type) {
if (auto* vector_type = type->AsVector()) {
return vector_type->element_count();
@@ -2079,21 +2080,27 @@
return static_cast<uint32_t>(struct_type->element_types().size());
}
if (auto* array_type = type->AsArray()) {
- return array_type->length_info().words[0];
+ if (array_type->length_info().words[0] ==
+ analysis::Array::LengthInfo::kConstant &&
+ array_type->length_info().words.size() == 2) {
+ return array_type->length_info().words[1];
+ }
+ return UINT32_MAX;
}
return 0;
}
// Returns a map with the set of values that were inserted into an object by
// the chain of OpCompositeInsertInstruction starting with |inst|.
-// The map will map the index to the value inserted at that index.
+// The map will map the index to the value inserted at that index. An empty map
+// will be returned if the map could not be properly generated.
std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
std::map<uint32_t, uint32_t> values_inserted;
Instruction* current_inst = inst;
while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
if (current_inst->NumInOperands() > inst->NumInOperands()) {
- // This is the catch the case
+ // This is to catch the case
// %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
// %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
// %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
@@ -2884,8 +2891,12 @@
"Offset and ConstOffset may not be used together");
if (offset_operand_index < inst->NumOperands()) {
if (constants[offset_operand_index]) {
- image_operands =
- image_operands | uint32_t(spv::ImageOperandsMask::ConstOffset);
+ if (constants[offset_operand_index]->IsZero()) {
+ inst->RemoveInOperand(offset_operand_index);
+ } else {
+ image_operands = image_operands |
+ uint32_t(spv::ImageOperandsMask::ConstOffset);
+ }
image_operands =
image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
inst->SetInOperand(operand_index, {image_operands});
diff --git a/third_party/SPIRV-Tools/source/opt/graphics_robust_access_pass.cpp b/third_party/SPIRV-Tools/source/opt/graphics_robust_access_pass.cpp
index 8fff8a0..e765c39 100644
--- a/third_party/SPIRV-Tools/source/opt/graphics_robust_access_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/graphics_robust_access_pass.cpp
@@ -573,9 +573,9 @@
context()->module()->AddExtInstImport(std::move(import_inst));
module_status_.modified = true;
context()->AnalyzeDefUse(inst);
- // Reanalyze the feature list, since we added an extended instruction
- // set improt.
- context()->get_feature_mgr()->Analyze(context()->module());
+ // Invalidates the feature manager, since we added an extended instruction
+ // set import.
+ context()->ResetFeatureManager();
}
}
return module_status_.glsl_insts_id;
diff --git a/third_party/SPIRV-Tools/source/opt/inline_pass.cpp b/third_party/SPIRV-Tools/source/opt/inline_pass.cpp
index 3f160b2..3186433 100644
--- a/third_party/SPIRV-Tools/source/opt/inline_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/inline_pass.cpp
@@ -213,6 +213,19 @@
{(uint32_t)spv::StorageClass::Function}}}));
new_vars->push_back(std::move(var_inst));
get_decoration_mgr()->CloneDecorations(calleeFn->result_id(), returnVarId);
+
+ // Decorate the return var with AliasedPointer if the storage class of the
+ // pointee type is PhysicalStorageBuffer.
+ auto const pointee_type =
+ type_mgr->GetType(returnVarTypeId)->AsPointer()->pointee_type();
+ if (pointee_type->AsPointer() != nullptr) {
+ if (pointee_type->AsPointer()->storage_class() ==
+ spv::StorageClass::PhysicalStorageBuffer) {
+ get_decoration_mgr()->AddDecoration(
+ returnVarId, uint32_t(spv::Decoration::AliasedPointer));
+ }
+ }
+
return returnVarId;
}
diff --git a/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.cpp b/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.cpp
index e8c412f..8e7d4f8 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.cpp
@@ -31,532 +31,88 @@
constexpr int kSpvAccessChainBaseIdInIdx = 0;
constexpr int kSpvAccessChainIndex0IdInIdx = 1;
constexpr int kSpvTypeArrayTypeIdInIdx = 0;
-constexpr int kSpvTypeArrayLengthIdInIdx = 1;
-constexpr int kSpvConstantValueInIdx = 0;
constexpr int kSpvVariableStorageClassInIdx = 0;
constexpr int kSpvTypePtrTypeIdInIdx = 1;
constexpr int kSpvTypeImageDim = 1;
constexpr int kSpvTypeImageDepth = 2;
constexpr int kSpvTypeImageArrayed = 3;
constexpr int kSpvTypeImageMS = 4;
-constexpr int kSpvTypeImageSampled = 5;
} // namespace
-void InstBindlessCheckPass::SetupInputBufferIds() {
- if (input_buffer_id_ != 0) {
- return;
+// This is a stub function for use with Import linkage
+// clang-format off
+// GLSL:
+//bool inst_bindless_check_desc(const uint shader_id, const uint inst_num, const uvec4 stage_info, const uint desc_set,
+// const uint binding, const uint desc_index, const uint byte_offset) {
+//}
+// clang-format on
+uint32_t InstBindlessCheckPass::GenDescCheckFunctionId() {
+ enum {
+ kShaderId = 0,
+ kInstructionIndex = 1,
+ kStageInfo = 2,
+ kDescSet = 3,
+ kDescBinding = 4,
+ kDescIndex = 5,
+ kByteOffset = 6,
+ kNumArgs
+ };
+ if (check_desc_func_id_ != 0) {
+ return check_desc_func_id_;
}
- AddStorageBufferExt();
- if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
- context()->AddExtension("SPV_KHR_physical_storage_buffer");
- }
- context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
- Instruction* memory_model = get_module()->GetMemoryModel();
- // TODO should this be just Physical64?
- memory_model->SetInOperand(
- 0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
- analysis::DecorationManager* deco_mgr = get_decoration_mgr();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
- constexpr uint32_t width = 32u;
+ const analysis::Integer* uint_type = GetInteger(32, false);
+ const analysis::Vector v4uint(uint_type, 4);
+ const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
+ std::vector<const analysis::Type*> param_types(kNumArgs, uint_type);
+ param_types[2] = v4uint_type;
- // declare the DescriptorSetData struct
- analysis::Struct* desc_set_struct =
- GetStruct({type_mgr->GetUIntType(), GetUintRuntimeArrayType(width)});
- desc_set_type_id_ = type_mgr->GetTypeInstruction(desc_set_struct);
- // By the Vulkan spec, a pre-existing struct containing a RuntimeArray
- // must be a block, and will therefore be decorated with Block. Therefore
- // the undecorated type returned here will not be pre-existing and can
- // safely be decorated. Since this type is now decorated, it is out of
- // sync with the TypeManager and therefore the TypeManager must be
- // invalidated after this pass.
- assert(context()->get_def_use_mgr()->NumUses(desc_set_type_id_) == 0 &&
- "used struct type returned");
- deco_mgr->AddDecoration(desc_set_type_id_, uint32_t(spv::Decoration::Block));
- deco_mgr->AddMemberDecoration(desc_set_type_id_, 0,
- uint32_t(spv::Decoration::Offset), 0);
- deco_mgr->AddMemberDecoration(desc_set_type_id_, 1,
- uint32_t(spv::Decoration::Offset), 4);
- context()->AddDebug2Inst(
- NewGlobalName(desc_set_type_id_, "DescriptorSetData"));
- context()->AddDebug2Inst(NewMemberName(desc_set_type_id_, 0, "num_bindings"));
- context()->AddDebug2Inst(NewMemberName(desc_set_type_id_, 1, "data"));
+ const uint32_t func_id = TakeNextId();
+ std::unique_ptr<Function> func =
+ StartFunction(func_id, type_mgr->GetBoolType(), param_types);
- // declare buffer address reference to DescriptorSetData
- desc_set_ptr_id_ = type_mgr->FindPointerToType(
- desc_set_type_id_, spv::StorageClass::PhysicalStorageBuffer);
- // runtime array of buffer addresses
- analysis::Type* rarr_ty = GetArray(type_mgr->GetType(desc_set_ptr_id_),
- kDebugInputBindlessMaxDescSets);
- deco_mgr->AddDecorationVal(type_mgr->GetId(rarr_ty),
- uint32_t(spv::Decoration::ArrayStride), 8u);
+ func->SetFunctionEnd(EndFunction());
- // declare the InputBuffer type, a struct wrapper around the runtime array
- analysis::Struct* input_buffer_struct = GetStruct({rarr_ty});
- input_buffer_struct_id_ = type_mgr->GetTypeInstruction(input_buffer_struct);
- deco_mgr->AddDecoration(input_buffer_struct_id_,
- uint32_t(spv::Decoration::Block));
- deco_mgr->AddMemberDecoration(input_buffer_struct_id_, 0,
- uint32_t(spv::Decoration::Offset), 0);
- context()->AddDebug2Inst(
- NewGlobalName(input_buffer_struct_id_, "InputBuffer"));
- context()->AddDebug2Inst(
- NewMemberName(input_buffer_struct_id_, 0, "desc_sets"));
-
- input_buffer_ptr_id_ = type_mgr->FindPointerToType(
- input_buffer_struct_id_, spv::StorageClass::StorageBuffer);
-
- // declare the input_buffer global variable
- input_buffer_id_ = TakeNextId();
-
- const std::vector<Operand> var_operands = {
+ static const std::string func_name{"inst_bindless_check_desc"};
+ context()->AddFunctionDeclaration(std::move(func));
+ context()->AddDebug2Inst(NewName(func_id, func_name));
+ std::vector<Operand> operands{
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {func_id}},
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {uint32_t(spv::StorageClass::StorageBuffer)}},
+ {uint32_t(spv::Decoration::LinkageAttributes)}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_STRING,
+ utils::MakeVector(func_name.c_str())},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LINKAGE_TYPE,
+ {uint32_t(spv::LinkageType::Import)}},
};
- auto new_var_op = spvtools::MakeUnique<Instruction>(
- context(), spv::Op::OpVariable, input_buffer_ptr_id_, input_buffer_id_,
- var_operands);
+ get_decoration_mgr()->AddDecoration(spv::Op::OpDecorate, operands);
- context()->AddGlobalValue(std::move(new_var_op));
- context()->AddDebug2Inst(NewGlobalName(input_buffer_id_, "input_buffer"));
- deco_mgr->AddDecorationVal(
- input_buffer_id_, uint32_t(spv::Decoration::DescriptorSet), desc_set_);
- deco_mgr->AddDecorationVal(input_buffer_id_,
- uint32_t(spv::Decoration::Binding),
- GetInputBufferBinding());
- if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
- // Add the new buffer to all entry points.
- for (auto& entry : get_module()->entry_points()) {
- entry.AddOperand({SPV_OPERAND_TYPE_ID, {input_buffer_id_}});
- context()->AnalyzeUses(&entry);
- }
- }
-}
-
-// clang-format off
-// GLSL:
-// uint inst_bindless_read_binding_length(uint desc_set_idx, uint binding_idx)
-// {
-// if (desc_set_idx >= inst_bindless_input_buffer.desc_sets.length()) {
-// return 0;
-// }
-//
-// DescriptorSetData set_data = inst_bindless_input_buffer.desc_sets[desc_set_idx];
-// uvec2 ptr_as_vec = uvec2(set_data);
-// if ((ptr_as_vec.x == 0u) && (_ptr_as_vec.y == 0u))
-// {
-// return 0u;
-// }
-// uint num_bindings = set_data.num_bindings;
-// if (binding_idx >= num_bindings) {
-// return 0;
-// }
-// return set_data.data[binding_idx];
-// }
-// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadLengthFunctionId() {
- if (read_length_func_id_ != 0) {
- return read_length_func_id_;
- }
- SetupInputBufferIds();
- const analysis::Integer* uint_type = GetInteger(32, false);
- const std::vector<const analysis::Type*> param_types(2, uint_type);
-
- const uint32_t func_id = TakeNextId();
- std::unique_ptr<Function> func =
- StartFunction(func_id, uint_type, param_types);
-
- const std::vector<uint32_t> param_ids = AddParameters(*func, param_types);
-
- // Create block
- auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
- InstructionBuilder builder(
- context(), new_blk_ptr.get(),
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- Instruction* inst;
-
- inst = builder.AddBinaryOp(
- GetBoolId(), spv::Op::OpUGreaterThanEqual, param_ids[0],
- builder.GetUintConstantId(kDebugInputBindlessMaxDescSets));
- const uint32_t desc_cmp_id = inst->result_id();
-
- uint32_t error_blk_id = TakeNextId();
- uint32_t merge_blk_id = TakeNextId();
- std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
- std::unique_ptr<Instruction> error_label(NewLabel(error_blk_id));
- (void)builder.AddConditionalBranch(desc_cmp_id, error_blk_id, merge_blk_id,
- merge_blk_id);
-
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // check descriptor set table entry is non-null
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- const uint32_t desc_set_ptr_ptr = type_mgr->FindPointerToType(
- desc_set_ptr_id_, spv::StorageClass::StorageBuffer);
-
- inst = builder.AddAccessChain(desc_set_ptr_ptr, input_buffer_id_,
- {builder.GetUintConstantId(0), param_ids[0]});
- const uint32_t set_access_chain_id = inst->result_id();
-
- inst = builder.AddLoad(desc_set_ptr_id_, set_access_chain_id);
- const uint32_t desc_set_ptr_id = inst->result_id();
-
- inst =
- builder.AddUnaryOp(GetVecUintId(2), spv::Op::OpBitcast, desc_set_ptr_id);
- const uint32_t ptr_as_uvec_id = inst->result_id();
-
- inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {0});
- const uint32_t uvec_x = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_x,
- builder.GetUintConstantId(0));
- const uint32_t x_is_zero_id = inst->result_id();
-
- inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {1});
- const uint32_t uvec_y = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_y,
- builder.GetUintConstantId(0));
- const uint32_t y_is_zero_id = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpLogicalAnd, x_is_zero_id,
- y_is_zero_id);
- const uint32_t is_null_id = inst->result_id();
-
- error_blk_id = TakeNextId();
- merge_blk_id = TakeNextId();
- merge_label = NewLabel(merge_blk_id);
- error_label = NewLabel(error_blk_id);
- (void)builder.AddConditionalBranch(is_null_id, error_blk_id, merge_blk_id,
- merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // check binding is in range
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- const uint32_t uint_ptr = type_mgr->FindPointerToType(
- GetUintId(), spv::StorageClass::PhysicalStorageBuffer);
-
- inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {builder.GetUintConstantId(0)});
- const uint32_t binding_access_chain_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), binding_access_chain_id, 8);
- const uint32_t num_bindings_id = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
- param_ids[1], num_bindings_id);
- const uint32_t bindings_cmp_id = inst->result_id();
-
- error_blk_id = TakeNextId();
- merge_blk_id = TakeNextId();
- merge_label = NewLabel(merge_blk_id);
- error_label = NewLabel(error_blk_id);
- (void)builder.AddConditionalBranch(bindings_cmp_id, error_blk_id,
- merge_blk_id, merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // read binding length
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {{builder.GetUintConstantId(1), param_ids[1]}});
- const uint32_t length_ac_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), length_ac_id, sizeof(uint32_t));
- const uint32_t length_id = inst->result_id();
-
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, length_id);
-
- func->AddBasicBlock(std::move(new_blk_ptr));
- func->SetFunctionEnd(EndFunction());
-
- context()->AddFunction(std::move(func));
- context()->AddDebug2Inst(NewGlobalName(func_id, "read_binding_length"));
-
- read_length_func_id_ = func_id;
- // Make sure this function doesn't get processed by
- // InstrumentPass::InstProcessCallTreeFromRoots()
- param2output_func_id_[2] = func_id;
- return read_length_func_id_;
-}
-
-// clang-format off
-// GLSL:
-// result = inst_bindless_read_binding_length(desc_set_id, binding_id);
-// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadLength(
- uint32_t var_id, InstructionBuilder* builder) {
- const uint32_t func_id = GenDebugReadLengthFunctionId();
-
- const std::vector<uint32_t> args = {
- builder->GetUintConstantId(var2desc_set_[var_id]),
- builder->GetUintConstantId(var2binding_[var_id]),
- };
- return GenReadFunctionCall(func_id, args, builder);
-}
-
-// clang-format off
-// GLSL:
-// uint inst_bindless_read_desc_init(uint desc_set_idx, uint binding_idx, uint desc_idx)
-// {
-// if (desc_set_idx >= uint(inst_bindless_input_buffer.desc_sets.length()))
-// {
-// return 0u;
-// }
-// DescriptorSetData set_data = inst_bindless_input_buffer.desc_sets[desc_set_idx];
-// uvec2 ptr_as_vec = uvec2(set_data)
-// if ((ptr_as_vec .x == 0u) && (ptr_as_vec.y == 0u))
-// {
-// return 0u;
-// }
-// if (binding_idx >= set_data.num_bindings)
-// {
-// return 0u;
-// }
-// if (desc_idx >= set_data.data[binding_idx])
-// {
-// return 0u;
-// }
-// uint desc_records_start = set_data.data[set_data.num_bindings + binding_idx];
-// return set_data.data[desc_records_start + desc_idx];
-// }
-// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
- if (read_init_func_id_ != 0) {
- return read_init_func_id_;
- }
- SetupInputBufferIds();
- const analysis::Integer* uint_type = GetInteger(32, false);
- const std::vector<const analysis::Type*> param_types(3, uint_type);
-
- const uint32_t func_id = TakeNextId();
- std::unique_ptr<Function> func =
- StartFunction(func_id, uint_type, param_types);
-
- const std::vector<uint32_t> param_ids = AddParameters(*func, param_types);
-
- // Create block
- auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
- InstructionBuilder builder(
- context(), new_blk_ptr.get(),
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- Instruction* inst;
-
- inst = builder.AddBinaryOp(
- GetBoolId(), spv::Op::OpUGreaterThanEqual, param_ids[0],
- builder.GetUintConstantId(kDebugInputBindlessMaxDescSets));
- const uint32_t desc_cmp_id = inst->result_id();
-
- uint32_t error_blk_id = TakeNextId();
- uint32_t merge_blk_id = TakeNextId();
- std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
- std::unique_ptr<Instruction> error_label(NewLabel(error_blk_id));
- (void)builder.AddConditionalBranch(desc_cmp_id, error_blk_id, merge_blk_id,
- merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // check descriptor set table entry is non-null
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- const uint32_t desc_set_ptr_ptr = type_mgr->FindPointerToType(
- desc_set_ptr_id_, spv::StorageClass::StorageBuffer);
-
- inst = builder.AddAccessChain(desc_set_ptr_ptr, input_buffer_id_,
- {builder.GetUintConstantId(0), param_ids[0]});
- const uint32_t set_access_chain_id = inst->result_id();
-
- inst = builder.AddLoad(desc_set_ptr_id_, set_access_chain_id);
- const uint32_t desc_set_ptr_id = inst->result_id();
-
- inst =
- builder.AddUnaryOp(GetVecUintId(2), spv::Op::OpBitcast, desc_set_ptr_id);
- const uint32_t ptr_as_uvec_id = inst->result_id();
-
- inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {0});
- const uint32_t uvec_x = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_x,
- builder.GetUintConstantId(0));
- const uint32_t x_is_zero_id = inst->result_id();
-
- inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {1});
- const uint32_t uvec_y = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_y,
- builder.GetUintConstantId(0));
- const uint32_t y_is_zero_id = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpLogicalAnd, x_is_zero_id,
- y_is_zero_id);
- const uint32_t is_null_id = inst->result_id();
-
- error_blk_id = TakeNextId();
- merge_blk_id = TakeNextId();
- merge_label = NewLabel(merge_blk_id);
- error_label = NewLabel(error_blk_id);
- (void)builder.AddConditionalBranch(is_null_id, error_blk_id, merge_blk_id,
- merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // check binding is in range
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- const uint32_t uint_ptr = type_mgr->FindPointerToType(
- GetUintId(), spv::StorageClass::PhysicalStorageBuffer);
-
- inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {builder.GetUintConstantId(0)});
- const uint32_t binding_access_chain_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), binding_access_chain_id, 8);
- const uint32_t num_bindings_id = inst->result_id();
-
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
- param_ids[1], num_bindings_id);
- const uint32_t bindings_cmp_id = inst->result_id();
-
- error_blk_id = TakeNextId();
- merge_blk_id = TakeNextId();
- merge_label = NewLabel(merge_blk_id);
- error_label = NewLabel(error_blk_id);
- (void)builder.AddConditionalBranch(bindings_cmp_id, error_blk_id,
- merge_blk_id, merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
- // error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // read binding length
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {{builder.GetUintConstantId(1), param_ids[1]}});
- const uint32_t length_ac_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), length_ac_id, sizeof(uint32_t));
- const uint32_t length_id = inst->result_id();
-
- // Check descriptor index in bounds
- inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
- param_ids[2], length_id);
- const uint32_t desc_idx_range_id = inst->result_id();
-
- error_blk_id = TakeNextId();
- merge_blk_id = TakeNextId();
- merge_label = NewLabel(merge_blk_id);
- error_label = NewLabel(error_blk_id);
- (void)builder.AddConditionalBranch(desc_idx_range_id, error_blk_id,
- merge_blk_id, merge_blk_id);
- func->AddBasicBlock(std::move(new_blk_ptr));
- // Error return
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- builder.GetUintConstantId(0));
- func->AddBasicBlock(std::move(new_blk_ptr));
-
- // Read descriptor init status
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
-
- inst = builder.AddIAdd(GetUintId(), num_bindings_id, param_ids[1]);
- const uint32_t state_offset_id = inst->result_id();
-
- inst =
- builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {{builder.GetUintConstantId(1), state_offset_id}});
- const uint32_t state_start_ac_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), state_start_ac_id, sizeof(uint32_t));
- const uint32_t state_start_id = inst->result_id();
-
- inst = builder.AddIAdd(GetUintId(), state_start_id, param_ids[2]);
- const uint32_t state_entry_id = inst->result_id();
-
- // Note: length starts from the beginning of the buffer, not the beginning of
- // the data array
- inst =
- builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
- {{builder.GetUintConstantId(1), state_entry_id}});
- const uint32_t init_ac_id = inst->result_id();
-
- inst = builder.AddLoad(GetUintId(), init_ac_id, sizeof(uint32_t));
- const uint32_t init_status_id = inst->result_id();
-
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, init_status_id);
-
- func->AddBasicBlock(std::move(new_blk_ptr));
- func->SetFunctionEnd(EndFunction());
-
- context()->AddFunction(std::move(func));
- context()->AddDebug2Inst(NewGlobalName(func_id, "read_desc_init"));
-
- read_init_func_id_ = func_id;
+ check_desc_func_id_ = func_id;
// Make sure function doesn't get processed by
// InstrumentPass::InstProcessCallTreeFromRoots()
param2output_func_id_[3] = func_id;
- return read_init_func_id_;
+ return check_desc_func_id_;
}
// clang-format off
// GLSL:
-// result = inst_bindless_read_desc_init(desc_set_id, binding_id, desc_idx_id);
+// result = inst_bindless_check_desc(shader_id, inst_idx, stage_info, desc_set, binding, desc_idx, offset);
//
// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
- uint32_t desc_idx_id,
- InstructionBuilder* builder) {
- const uint32_t func_id = GenDebugReadInitFunctionId();
+uint32_t InstBindlessCheckPass::GenDescCheckCall(
+ uint32_t inst_idx, uint32_t stage_idx, uint32_t var_id,
+ uint32_t desc_idx_id, uint32_t offset_id, InstructionBuilder* builder) {
+ const uint32_t func_id = GenDescCheckFunctionId();
const std::vector<uint32_t> args = {
+ builder->GetUintConstantId(shader_id_),
+ builder->GetUintConstantId(inst_idx),
+ GenStageInfo(stage_idx, builder),
builder->GetUintConstantId(var2desc_set_[var_id]),
builder->GetUintConstantId(var2binding_[var_id]),
- GenUintCastCode(desc_idx_id, builder)};
- return GenReadFunctionCall(func_id, args, builder);
+ GenUintCastCode(desc_idx_id, builder),
+ offset_id};
+ return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
}
uint32_t InstBindlessCheckPass::CloneOriginalImage(
@@ -1017,8 +573,7 @@
}
void InstBindlessCheckPass::GenCheckCode(
- uint32_t check_id, uint32_t error_id, uint32_t offset_id,
- uint32_t length_id, uint32_t stage_idx, RefAnalysis* ref,
+ uint32_t check_id, RefAnalysis* ref,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
BasicBlock* back_blk_ptr = &*new_blocks->back();
InstructionBuilder builder(
@@ -1047,30 +602,7 @@
// Gen invalid block
new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
builder.SetInsertPoint(&*new_blk_ptr);
- const uint32_t u_set_id = builder.GetUintConstantId(ref->set);
- const uint32_t u_binding_id = builder.GetUintConstantId(ref->binding);
- const uint32_t u_index_id = GenUintCastCode(ref->desc_idx_id, &builder);
- const uint32_t u_length_id = GenUintCastCode(length_id, &builder);
- if (offset_id != 0) {
- const uint32_t u_offset_id = GenUintCastCode(offset_id, &builder);
- // Buffer OOB
- GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
- {error_id, u_set_id, u_binding_id, u_index_id,
- u_offset_id, u_length_id},
- &builder);
- } else if (buffer_bounds_enabled_ || texel_buffer_enabled_) {
- // Uninitialized Descriptor - Return additional unused zero so all error
- // modes will use same debug stream write function
- GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
- {error_id, u_set_id, u_binding_id, u_index_id,
- u_length_id, builder.GetUintConstantId(0)},
- &builder);
- } else {
- // Uninitialized Descriptor - Normal error return
- GenDebugStreamWrite(
- uid2offset_[ref->ref_inst->unique_id()], stage_idx,
- {error_id, u_set_id, u_binding_id, u_index_id, u_length_id}, &builder);
- }
+
// Generate a ConstantNull, converting to uint64 if the type cannot be a null.
if (new_ref_id != 0) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
@@ -1106,77 +638,42 @@
context()->KillInst(ref->ref_inst);
}
-void InstBindlessCheckPass::GenDescIdxCheckCode(
- BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- // Look for reference through indexed descriptor. If found, analyze and
- // save components. If not, return.
- RefAnalysis ref;
- if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
- Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
- if (ptr_inst->opcode() != spv::Op::OpAccessChain) return;
- // If index and bound both compile-time constants and index < bound,
- // return without changing
- Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
- Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
- uint32_t length_id = 0;
- if (desc_type_inst->opcode() == spv::Op::OpTypeArray) {
- length_id =
- desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
- Instruction* index_inst = get_def_use_mgr()->GetDef(ref.desc_idx_id);
- Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
- if (index_inst->opcode() == spv::Op::OpConstant &&
- length_inst->opcode() == spv::Op::OpConstant &&
- index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
- length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
- return;
- } else if (!desc_idx_enabled_ ||
- desc_type_inst->opcode() != spv::Op::OpTypeRuntimeArray) {
- return;
- }
- // Move original block's preceding instructions into first new block
- std::unique_ptr<BasicBlock> new_blk_ptr;
- MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
- InstructionBuilder builder(
- context(), &*new_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- new_blocks->push_back(std::move(new_blk_ptr));
- uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
- // If length id not yet set, descriptor array is runtime size so
- // generate load of length from stage's debug input buffer.
- if (length_id == 0) {
- assert(desc_type_inst->opcode() == spv::Op::OpTypeRuntimeArray &&
- "unexpected bindless type");
- length_id = GenDebugReadLength(ref.var_id, &builder);
- }
- // Generate full runtime bounds test code with true branch
- // being full reference and false branch being debug output and zero
- // for the referenced value.
- uint32_t desc_idx_32b_id = Gen32BitCvtCode(ref.desc_idx_id, &builder);
- uint32_t length_32b_id = Gen32BitCvtCode(length_id, &builder);
- Instruction* ult_inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan,
- desc_idx_32b_id, length_32b_id);
- ref.desc_idx_id = desc_idx_32b_id;
- GenCheckCode(ult_inst->result_id(), error_id, 0u, length_id, stage_idx, &ref,
- new_blocks);
- // Move original block's remaining code into remainder/merge block and add
- // to new blocks
- BasicBlock* back_blk_ptr = &*new_blocks->back();
- MovePostludeCode(ref_block_itr, back_blk_ptr);
-}
-
-void InstBindlessCheckPass::GenDescInitCheckCode(
+void InstBindlessCheckPass::GenDescCheckCode(
BasicBlock::iterator ref_inst_itr,
UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
// Look for reference through descriptor. If not, return.
RefAnalysis ref;
if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
+ std::unique_ptr<BasicBlock> new_blk_ptr;
+ // Move original block's preceding instructions into first new block
+ MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
+ InstructionBuilder builder(
+ context(), &*new_blk_ptr,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ new_blocks->push_back(std::move(new_blk_ptr));
// Determine if we can only do initialization check
- bool init_check = false;
- if (ref.desc_load_id != 0 || !buffer_bounds_enabled_) {
- init_check = true;
+ uint32_t ref_id = builder.GetUintConstantId(0u);
+ spv::Op op = ref.ref_inst->opcode();
+ if (ref.desc_load_id != 0) {
+ uint32_t num_in_oprnds = ref.ref_inst->NumInOperands();
+ if ((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
+ (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
+ (op == spv::Op::OpImageWrite && num_in_oprnds == 3)) {
+ Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
+ uint32_t image_ty_id = image_inst->type_id();
+ Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
+ if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) ==
+ spv::Dim::Buffer) {
+ if ((image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) == 0) &&
+ (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) ==
+ 0) &&
+ (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) == 0)) {
+ ref_id = GenUintCastCode(ref.ref_inst->GetSingleWordInOperand(1),
+ &builder);
+ }
+ }
+ }
} else {
// For now, only do bounds check for non-aggregate types. Otherwise
// just do descriptor initialization check.
@@ -1184,106 +681,24 @@
Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
spv::Op pte_type_op = pte_type_inst->opcode();
- if (pte_type_op == spv::Op::OpTypeArray ||
- pte_type_op == spv::Op::OpTypeRuntimeArray ||
- pte_type_op == spv::Op::OpTypeStruct)
- init_check = true;
+ if (pte_type_op != spv::Op::OpTypeArray &&
+ pte_type_op != spv::Op::OpTypeRuntimeArray &&
+ pte_type_op != spv::Op::OpTypeStruct) {
+ ref_id = GenLastByteIdx(&ref, &builder);
+ }
}
- // If initialization check and not enabled, return
- if (init_check && !desc_init_enabled_) return;
- // Move original block's preceding instructions into first new block
- std::unique_ptr<BasicBlock> new_blk_ptr;
- MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
- InstructionBuilder builder(
- context(), &*new_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- new_blocks->push_back(std::move(new_blk_ptr));
- // If initialization check, use reference value of zero.
- // Else use the index of the last byte referenced.
- uint32_t ref_id = init_check ? builder.GetUintConstantId(0u)
- : GenLastByteIdx(&ref, &builder);
// Read initialization/bounds from debug input buffer. If index id not yet
// set, binding is single descriptor, so set index to constant 0.
if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
- uint32_t init_id = GenDebugReadInit(ref.var_id, ref.desc_idx_id, &builder);
- // Generate runtime initialization/bounds test code with true branch
- // being full reference and false branch being debug output and zero
- // for the referenced value.
- Instruction* ult_inst =
- builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan, ref_id, init_id);
- uint32_t error =
- init_check
- ? kInstErrorBindlessUninit
- : (spv::StorageClass(ref.strg_class) == spv::StorageClass::Uniform
- ? kInstErrorBuffOOBUniform
- : kInstErrorBuffOOBStorage);
- uint32_t error_id = builder.GetUintConstantId(error);
- GenCheckCode(ult_inst->result_id(), error_id, init_check ? 0 : ref_id,
- init_check ? builder.GetUintConstantId(0u) : init_id, stage_idx,
- &ref, new_blocks);
- // Move original block's remaining code into remainder/merge block and add
- // to new blocks
- BasicBlock* back_blk_ptr = &*new_blocks->back();
- MovePostludeCode(ref_block_itr, back_blk_ptr);
-}
+ uint32_t check_id =
+ GenDescCheckCall(ref.ref_inst->unique_id(), stage_idx, ref.var_id,
+ ref.desc_idx_id, ref_id, &builder);
-void InstBindlessCheckPass::GenTexBuffCheckCode(
- BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- // Only process OpImageRead and OpImageWrite with no optional operands
- Instruction* ref_inst = &*ref_inst_itr;
- spv::Op op = ref_inst->opcode();
- uint32_t num_in_oprnds = ref_inst->NumInOperands();
- if (!((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
- (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
- (op == spv::Op::OpImageWrite && num_in_oprnds == 3)))
- return;
- // Pull components from descriptor reference
- RefAnalysis ref;
- if (!AnalyzeDescriptorReference(ref_inst, &ref)) return;
- // Only process if image is texel buffer
- Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
- uint32_t image_ty_id = image_inst->type_id();
- Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
- if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) !=
- spv::Dim::Buffer) {
- return;
- }
- if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) != 0) return;
- if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) != 0) return;
- if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) != 0) return;
- // Enable ImageQuery Capability if not yet enabled
- context()->AddCapability(spv::Capability::ImageQuery);
- // Move original block's preceding instructions into first new block
- std::unique_ptr<BasicBlock> new_blk_ptr;
- MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
- InstructionBuilder builder(
- context(), &*new_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- new_blocks->push_back(std::move(new_blk_ptr));
- // Get texel coordinate
- uint32_t coord_id =
- GenUintCastCode(ref_inst->GetSingleWordInOperand(1), &builder);
- // If index id not yet set, binding is single descriptor, so set index to
- // constant 0.
- if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
- // Get texel buffer size.
- Instruction* size_inst =
- builder.AddUnaryOp(GetUintId(), spv::Op::OpImageQuerySize, ref.image_id);
- uint32_t size_id = size_inst->result_id();
// Generate runtime initialization/bounds test code with true branch
- // being full reference and false branch being debug output and zero
+ // being full reference and false branch being zero
// for the referenced value.
- Instruction* ult_inst =
- builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan, coord_id, size_id);
- uint32_t error =
- (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageSampled) == 2)
- ? kInstErrorBuffOOBStorageTexel
- : kInstErrorBuffOOBUniformTexel;
- uint32_t error_id = builder.GetUintConstantId(error);
- GenCheckCode(ult_inst->result_id(), error_id, coord_id, size_id, stage_idx,
- &ref, new_blocks);
+ GenCheckCode(check_id, &ref, new_blocks);
+
// Move original block's remaining code into remainder/merge block and add
// to new blocks
BasicBlock* back_blk_ptr = &*new_blocks->back();
@@ -1293,59 +708,48 @@
void InstBindlessCheckPass::InitializeInstBindlessCheck() {
// Initialize base class
InitializeInstrument();
- // If runtime array length support or buffer bounds checking are enabled,
- // create variable mappings. Length support is always enabled if descriptor
- // init check is enabled.
- if (desc_idx_enabled_ || buffer_bounds_enabled_ || texel_buffer_enabled_)
- for (auto& anno : get_module()->annotations())
- if (anno.opcode() == spv::Op::OpDecorate) {
- if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
- spv::Decoration::DescriptorSet) {
- var2desc_set_[anno.GetSingleWordInOperand(0u)] =
- anno.GetSingleWordInOperand(2u);
- } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
- spv::Decoration::Binding) {
- var2binding_[anno.GetSingleWordInOperand(0u)] =
- anno.GetSingleWordInOperand(2u);
- }
+ for (auto& anno : get_module()->annotations()) {
+ if (anno.opcode() == spv::Op::OpDecorate) {
+ if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
+ spv::Decoration::DescriptorSet) {
+ var2desc_set_[anno.GetSingleWordInOperand(0u)] =
+ anno.GetSingleWordInOperand(2u);
+ } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
+ spv::Decoration::Binding) {
+ var2binding_[anno.GetSingleWordInOperand(0u)] =
+ anno.GetSingleWordInOperand(2u);
}
+ }
+ }
}
Pass::Status InstBindlessCheckPass::ProcessImpl() {
- // Perform bindless bounds check on each entry point function in module
+ // The memory model and linkage must always be updated for spirv-link to work
+ // correctly.
+ AddStorageBufferExt();
+ if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
+ context()->AddExtension("SPV_KHR_physical_storage_buffer");
+ }
+
+ context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
+ Instruction* memory_model = get_module()->GetMemoryModel();
+ memory_model->SetInOperand(
+ 0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
+
+ context()->AddCapability(spv::Capability::Linkage);
+
InstProcessFunction pfn =
[this](BasicBlock::iterator ref_inst_itr,
UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- return GenDescIdxCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
- new_blocks);
+ return GenDescCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
+ new_blocks);
};
- bool modified = InstProcessEntryPointCallTree(pfn);
- if (desc_init_enabled_ || buffer_bounds_enabled_) {
- // Perform descriptor initialization and/or buffer bounds check on each
- // entry point function in module
- pfn = [this](BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr,
- uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- return GenDescInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
- new_blocks);
- };
- modified |= InstProcessEntryPointCallTree(pfn);
- }
- if (texel_buffer_enabled_) {
- // Perform texel buffer bounds check on each entry point function in
- // module. Generate after descriptor bounds and initialization checks.
- pfn = [this](BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr,
- uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- return GenTexBuffCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
- new_blocks);
- };
- modified |= InstProcessEntryPointCallTree(pfn);
- }
- return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+
+ InstProcessEntryPointCallTree(pfn);
+ // This pass always changes the memory model, so that linking will work
+ // properly.
+ return Status::SuccessWithChange;
}
Pass::Status InstBindlessCheckPass::Process() {
diff --git a/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.h b/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.h
index f89af02..243cba7 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.h
+++ b/third_party/SPIRV-Tools/source/opt/inst_bindless_check_pass.h
@@ -28,16 +28,8 @@
// external design may change as the layer evolves.
class InstBindlessCheckPass : public InstrumentPass {
public:
- InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id,
- bool desc_idx_enable, bool desc_init_enable,
- bool buffer_bounds_enable, bool texel_buffer_enable,
- bool opt_direct_reads)
- : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless,
- opt_direct_reads),
- desc_idx_enabled_(desc_idx_enable),
- desc_init_enabled_(desc_init_enable),
- buffer_bounds_enabled_(buffer_bounds_enable),
- texel_buffer_enabled_(texel_buffer_enable) {}
+ InstBindlessCheckPass(uint32_t shader_id)
+ : InstrumentPass(0, shader_id, true, true) {}
~InstBindlessCheckPass() override = default;
@@ -47,82 +39,16 @@
const char* name() const override { return "inst-bindless-check-pass"; }
private:
- // These functions do bindless checking instrumentation on a single
- // instruction which references through a descriptor (ie references into an
- // image or buffer). Refer to Vulkan API for further information on
- // descriptors. GenDescIdxCheckCode checks that an index into a descriptor
- // array (array of images or buffers) is in-bounds. GenDescInitCheckCode
- // checks that the referenced descriptor has been initialized, if the
- // SPV_EXT_descriptor_indexing extension is enabled, and initialized large
- // enough to handle the reference, if RobustBufferAccess is disabled.
- // GenDescInitCheckCode checks for uniform and storage buffer overrun.
- // GenTexBuffCheckCode checks for texel buffer overrun and should be
- // run after GenDescInitCheckCode to first make sure that the descriptor
- // is initialized because it uses OpImageQuerySize on the descriptor.
- //
- // The functions are designed to be passed to
- // InstrumentPass::InstProcessEntryPointCallTree(), which applies the
- // function to each instruction in a module and replaces the instruction
- // if warranted.
- //
- // If |ref_inst_itr| is a bindless reference, return in |new_blocks| the
- // result of instrumenting it with validation code within its block at
- // |ref_block_itr|. The validation code first executes a check for the
- // specific condition called for. If the check passes, it executes
- // the remainder of the reference, otherwise writes a record to the debug
- // output buffer stream including |function_idx, instruction_idx, stage_idx|
- // and replaces the reference with the null value of the original type. The
- // block at |ref_block_itr| can just be replaced with the blocks in
- // |new_blocks|, which will contain at least two blocks. The last block will
- // comprise all instructions following |ref_inst_itr|,
- // preceded by a phi instruction.
- //
- // These instrumentation functions utilize GenDebugDirectRead() to read data
- // from the debug input buffer, specifically the lengths of variable length
- // descriptor arrays, and the initialization status of each descriptor.
- // The format of the debug input buffer is documented in instrument.hpp.
- //
- // These instrumentation functions utilize GenDebugStreamWrite() to write its
- // error records. The validation-specific part of the error record will
- // have the format:
- //
- // Validation Error Code (=kInstErrorBindlessBounds)
- // Descriptor Index
- // Descriptor Array Size
- //
- // The Descriptor Index is the index which has been determined to be
- // out-of-bounds.
- //
- // The Descriptor Array Size is the size of the descriptor array which was
- // indexed.
- void GenDescIdxCheckCode(
- BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+ void GenDescCheckCode(BasicBlock::iterator ref_inst_itr,
+ UptrVectorIterator<BasicBlock> ref_block_itr,
+ uint32_t stage_idx,
+ std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
- void GenDescInitCheckCode(
- BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+ uint32_t GenDescCheckFunctionId();
- void GenTexBuffCheckCode(
- BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
- std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
-
- void SetupInputBufferIds();
- uint32_t GenDebugReadLengthFunctionId();
-
- // Generate instructions into |builder| to read length of runtime descriptor
- // array |var_id| from debug input buffer and return id of value.
- uint32_t GenDebugReadLength(uint32_t var_id, InstructionBuilder* builder);
-
- uint32_t GenDebugReadInitFunctionId();
- // Generate instructions into |builder| to read initialization status of
- // descriptor array |image_id| at |index_id| from debug input buffer and
- // return id of value.
- uint32_t GenDebugReadInit(uint32_t image_id, uint32_t index_id,
- InstructionBuilder* builder);
+ uint32_t GenDescCheckCall(uint32_t inst_idx, uint32_t stage_idx,
+ uint32_t var_id, uint32_t index_id,
+ uint32_t byte_offset, InstructionBuilder* builder);
// Analysis data for descriptor reference components, generated by
// AnalyzeDescriptorReference. It is necessary and sufficient for further
@@ -179,8 +105,7 @@
// writes debug error output utilizing |ref|, |error_id|, |length_id| and
// |stage_idx|. Generate merge block for valid and invalid branches. Kill
// original reference.
- void GenCheckCode(uint32_t check_id, uint32_t error_id, uint32_t offset_id,
- uint32_t length_id, uint32_t stage_idx, RefAnalysis* ref,
+ void GenCheckCode(uint32_t check_id, RefAnalysis* ref,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
// Initialize state for instrumenting bindless checking
@@ -190,30 +115,13 @@
// GenDescInitCheckCode to every instruction in module.
Pass::Status ProcessImpl();
- // Enable instrumentation of runtime array length checking
- bool desc_idx_enabled_;
-
- // Enable instrumentation of descriptor initialization checking
- bool desc_init_enabled_;
-
- // Enable instrumentation of uniform and storage buffer overrun checking
- bool buffer_bounds_enabled_;
-
- // Enable instrumentation of texel buffer overrun checking
- bool texel_buffer_enabled_;
-
// Mapping from variable to descriptor set
std::unordered_map<uint32_t, uint32_t> var2desc_set_;
// Mapping from variable to binding
std::unordered_map<uint32_t, uint32_t> var2binding_;
- uint32_t read_length_func_id_{0};
- uint32_t read_init_func_id_{0};
- uint32_t desc_set_type_id_{0};
- uint32_t desc_set_ptr_id_{0};
- uint32_t input_buffer_struct_id_{0};
- uint32_t input_buffer_ptr_id_{0};
+ uint32_t check_desc_func_id_{0};
};
} // namespace opt
diff --git a/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.cpp b/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.cpp
index 4954706..e6c5508 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.cpp
@@ -19,24 +19,6 @@
namespace spvtools {
namespace opt {
-bool InstBuffAddrCheckPass::InstrumentFunction(Function* func,
- uint32_t stage_idx,
- InstProcessFunction& pfn) {
- // The bindless instrumentation pass adds functions that use
- // BufferDeviceAddress They should not be instrumented by this pass.
- Instruction* func_name_inst =
- context()->GetNames(func->DefInst().result_id()).begin()->second;
- if (func_name_inst) {
- static const std::string kPrefix{"inst_bindless_"};
- std::string func_name = func_name_inst->GetOperand(1).AsString();
- if (func_name.size() >= kPrefix.size() &&
- func_name.compare(0, kPrefix.size(), kPrefix) == 0) {
- return false;
- }
- }
- return InstrumentPass::InstrumentFunction(func, stage_idx, pfn);
-}
-
uint32_t InstBuffAddrCheckPass::CloneOriginalReference(
Instruction* ref_inst, InstructionBuilder* builder) {
// Clone original ref with new result id (if load)
@@ -76,8 +58,7 @@
// TODO(greg-lunarg): Refactor with InstBindlessCheckPass::GenCheckCode() ??
void InstBuffAddrCheckPass::GenCheckCode(
- uint32_t check_id, uint32_t error_id, uint32_t ref_uptr_id,
- uint32_t stage_idx, Instruction* ref_inst,
+ uint32_t check_id, Instruction* ref_inst,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
BasicBlock* back_blk_ptr = &*new_blocks->back();
InstructionBuilder builder(
@@ -104,18 +85,6 @@
// Gen invalid block
new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
builder.SetInsertPoint(&*new_blk_ptr);
- // Convert uptr from uint64 to 2 uint32
- Instruction* lo_uptr_inst =
- builder.AddUnaryOp(GetUintId(), spv::Op::OpUConvert, ref_uptr_id);
- Instruction* rshift_uptr_inst =
- builder.AddBinaryOp(GetUint64Id(), spv::Op::OpShiftRightLogical,
- ref_uptr_id, builder.GetUintConstantId(32));
- Instruction* hi_uptr_inst = builder.AddUnaryOp(
- GetUintId(), spv::Op::OpUConvert, rshift_uptr_inst->result_id());
- GenDebugStreamWrite(
- uid2offset_[ref_inst->unique_id()], stage_idx,
- {error_id, lo_uptr_inst->result_id(), hi_uptr_inst->result_id()},
- &builder);
// Gen zero for invalid load. If pointer type, need to convert uint64
// zero to pointer; cannot create ConstantNull of pointer type.
uint32_t null_id = 0;
@@ -150,48 +119,13 @@
context()->KillInst(ref_inst);
}
-uint32_t InstBuffAddrCheckPass::GetTypeAlignment(uint32_t type_id) {
- Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
- switch (type_inst->opcode()) {
- case spv::Op::OpTypeFloat:
- case spv::Op::OpTypeInt:
- case spv::Op::OpTypeVector:
- return GetTypeLength(type_id);
- case spv::Op::OpTypeMatrix:
- return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
- case spv::Op::OpTypeArray:
- case spv::Op::OpTypeRuntimeArray:
- return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
- case spv::Op::OpTypeStruct: {
- uint32_t max = 0;
- type_inst->ForEachInId([&max, this](const uint32_t* iid) {
- uint32_t alignment = GetTypeAlignment(*iid);
- max = (alignment > max) ? alignment : max;
- });
- return max;
- }
- case spv::Op::OpTypePointer:
- assert(spv::StorageClass(type_inst->GetSingleWordInOperand(0)) ==
- spv::StorageClass::PhysicalStorageBufferEXT &&
- "unexpected pointer type");
- return 8u;
- default:
- assert(false && "unexpected type");
- return 0;
- }
-}
-
uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
switch (type_inst->opcode()) {
case spv::Op::OpTypeFloat:
case spv::Op::OpTypeInt:
return type_inst->GetSingleWordInOperand(0) / 8u;
- case spv::Op::OpTypeVector: {
- uint32_t raw_cnt = type_inst->GetSingleWordInOperand(1);
- uint32_t adj_cnt = (raw_cnt == 3u) ? 4u : raw_cnt;
- return adj_cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
- }
+ case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
return type_inst->GetSingleWordInOperand(1) *
GetTypeLength(type_inst->GetSingleWordInOperand(0));
@@ -207,18 +141,19 @@
return cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
}
case spv::Op::OpTypeStruct: {
- uint32_t len = 0;
- type_inst->ForEachInId([&len, this](const uint32_t* iid) {
- // Align struct length
- uint32_t alignment = GetTypeAlignment(*iid);
- uint32_t mod = len % alignment;
- uint32_t diff = (mod != 0) ? alignment - mod : 0;
- len += diff;
- // Increment struct length by component length
- uint32_t comp_len = GetTypeLength(*iid);
- len += comp_len;
+ // Figure out the location of the last byte of the last member of the
+ // structure.
+ uint32_t last_offset = 0, last_len = 0;
+
+ get_decoration_mgr()->ForEachDecoration(
+ type_id, uint32_t(spv::Decoration::Offset),
+ [&last_offset](const Instruction& deco_inst) {
+ last_offset = deco_inst.GetSingleWordInOperand(3);
+ });
+ type_inst->ForEachInId([&last_len, this](const uint32_t* iid) {
+ last_len = GetTypeLength(*iid);
});
- return len;
+ return last_offset + last_len;
}
case spv::Op::OpTypeRuntimeArray:
default:
@@ -238,201 +173,86 @@
(*input_func)->AddParameter(std::move(param_inst));
}
+// This is a stub function for use with Import linkage
+// clang-format off
+// GLSL:
+//bool inst_bindless_search_and_test(const uint shader_id, const uint inst_num, const uvec4 stage_info,
+// const uint64 ref_ptr, const uint length) {
+//}
+// clang-format on
uint32_t InstBuffAddrCheckPass::GetSearchAndTestFuncId() {
- if (search_test_func_id_ == 0) {
- // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
- // which searches input buffer for buffer which most likely contains the
- // pointer value |ref_ptr| and verifies that the entire reference of
- // length |len| bytes is contained in the buffer.
- search_test_func_id_ = TakeNextId();
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- std::vector<const analysis::Type*> param_types = {
- type_mgr->GetType(GetUint64Id()), type_mgr->GetType(GetUintId())};
- analysis::Function func_ty(type_mgr->GetType(GetBoolId()), param_types);
- analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
- std::unique_ptr<Instruction> func_inst(
- new Instruction(get_module()->context(), spv::Op::OpFunction,
- GetBoolId(), search_test_func_id_,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {uint32_t(spv::FunctionControlMask::MaskNone)}},
- {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
- {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
- get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
- std::unique_ptr<Function> input_func =
- MakeUnique<Function>(std::move(func_inst));
- std::vector<uint32_t> param_vec;
- // Add ref_ptr and length parameters
- AddParam(GetUint64Id(), ¶m_vec, &input_func);
- AddParam(GetUintId(), ¶m_vec, &input_func);
- // Empty first block.
- uint32_t first_blk_id = TakeNextId();
- std::unique_ptr<Instruction> first_blk_label(NewLabel(first_blk_id));
- std::unique_ptr<BasicBlock> first_blk_ptr =
- MakeUnique<BasicBlock>(std::move(first_blk_label));
- InstructionBuilder builder(
- context(), &*first_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- uint32_t hdr_blk_id = TakeNextId();
- // Branch to search loop header
- std::unique_ptr<Instruction> hdr_blk_label(NewLabel(hdr_blk_id));
- (void)builder.AddBranch(hdr_blk_id);
- input_func->AddBasicBlock(std::move(first_blk_ptr));
- // Linear search loop header block
- // TODO(greg-lunarg): Implement binary search
- std::unique_ptr<BasicBlock> hdr_blk_ptr =
- MakeUnique<BasicBlock>(std::move(hdr_blk_label));
- builder.SetInsertPoint(&*hdr_blk_ptr);
- // Phi for search index. Starts with 1.
- uint32_t cont_blk_id = TakeNextId();
- std::unique_ptr<Instruction> cont_blk_label(NewLabel(cont_blk_id));
- // Deal with def-use cycle caused by search loop index computation.
- // Create Add and Phi instructions first, then do Def analysis on Add.
- // Add Phi and Add instructions and do Use analysis later.
- uint32_t idx_phi_id = TakeNextId();
- uint32_t idx_inc_id = TakeNextId();
- std::unique_ptr<Instruction> idx_inc_inst(new Instruction(
- context(), spv::Op::OpIAdd, GetUintId(), idx_inc_id,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_phi_id}},
- {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
- {builder.GetUintConstantId(1u)}}}));
- std::unique_ptr<Instruction> idx_phi_inst(new Instruction(
- context(), spv::Op::OpPhi, GetUintId(), idx_phi_id,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
- {builder.GetUintConstantId(1u)}},
- {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {first_blk_id}},
- {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_inc_id}},
- {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
- get_def_use_mgr()->AnalyzeInstDef(&*idx_inc_inst);
- // Add (previously created) search index phi
- (void)builder.AddInstruction(std::move(idx_phi_inst));
- // LoopMerge
- uint32_t bound_test_blk_id = TakeNextId();
- std::unique_ptr<Instruction> bound_test_blk_label(
- NewLabel(bound_test_blk_id));
- (void)builder.AddLoopMerge(bound_test_blk_id, cont_blk_id,
- uint32_t(spv::LoopControlMask::MaskNone));
- // Branch to continue/work block
- (void)builder.AddBranch(cont_blk_id);
- input_func->AddBasicBlock(std::move(hdr_blk_ptr));
- // Continue/Work Block. Read next buffer pointer and break if greater
- // than ref_ptr arg.
- std::unique_ptr<BasicBlock> cont_blk_ptr =
- MakeUnique<BasicBlock>(std::move(cont_blk_label));
- builder.SetInsertPoint(&*cont_blk_ptr);
- // Add (previously created) search index increment now.
- (void)builder.AddInstruction(std::move(idx_inc_inst));
- // Load next buffer address from debug input buffer
- uint32_t ibuf_id = GetInputBufferId();
- uint32_t ibuf_ptr_id = GetInputBufferPtrId();
- Instruction* uptr_ac_inst = builder.AddTernaryOp(
- ibuf_ptr_id, spv::Op::OpAccessChain, ibuf_id,
- builder.GetUintConstantId(kDebugInputDataOffset), idx_inc_id);
- uint32_t ibuf_type_id = GetInputBufferTypeId();
- Instruction* uptr_load_inst = builder.AddUnaryOp(
- ibuf_type_id, spv::Op::OpLoad, uptr_ac_inst->result_id());
- // If loaded address greater than ref_ptr arg, break, else branch back to
- // loop header
- Instruction* uptr_test_inst =
- builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThan,
- uptr_load_inst->result_id(), param_vec[0]);
- (void)builder.AddConditionalBranch(
- uptr_test_inst->result_id(), bound_test_blk_id, hdr_blk_id, kInvalidId,
- uint32_t(spv::SelectionControlMask::MaskNone));
- input_func->AddBasicBlock(std::move(cont_blk_ptr));
- // Bounds test block. Read length of selected buffer and test that
- // all len arg bytes are in buffer.
- std::unique_ptr<BasicBlock> bound_test_blk_ptr =
- MakeUnique<BasicBlock>(std::move(bound_test_blk_label));
- builder.SetInsertPoint(&*bound_test_blk_ptr);
- // Decrement index to point to previous/candidate buffer address
- Instruction* cand_idx_inst =
- builder.AddBinaryOp(GetUintId(), spv::Op::OpISub, idx_inc_id,
- builder.GetUintConstantId(1u));
- // Load candidate buffer address
- Instruction* cand_ac_inst =
- builder.AddTernaryOp(ibuf_ptr_id, spv::Op::OpAccessChain, ibuf_id,
- builder.GetUintConstantId(kDebugInputDataOffset),
- cand_idx_inst->result_id());
- Instruction* cand_load_inst = builder.AddUnaryOp(
- ibuf_type_id, spv::Op::OpLoad, cand_ac_inst->result_id());
- // Compute offset of ref_ptr from candidate buffer address
- Instruction* offset_inst =
- builder.AddBinaryOp(ibuf_type_id, spv::Op::OpISub, param_vec[0],
- cand_load_inst->result_id());
- // Convert ref length to uint64
- Instruction* ref_len_64_inst =
- builder.AddUnaryOp(ibuf_type_id, spv::Op::OpUConvert, param_vec[1]);
- // Add ref length to ref offset to compute end of reference
- Instruction* ref_end_inst = builder.AddBinaryOp(
- ibuf_type_id, spv::Op::OpIAdd, offset_inst->result_id(),
- ref_len_64_inst->result_id());
- // Load starting index of lengths in input buffer and convert to uint32
- Instruction* len_start_ac_inst =
- builder.AddTernaryOp(ibuf_ptr_id, spv::Op::OpAccessChain, ibuf_id,
- builder.GetUintConstantId(kDebugInputDataOffset),
- builder.GetUintConstantId(0u));
- Instruction* len_start_load_inst = builder.AddUnaryOp(
- ibuf_type_id, spv::Op::OpLoad, len_start_ac_inst->result_id());
- Instruction* len_start_32_inst = builder.AddUnaryOp(
- GetUintId(), spv::Op::OpUConvert, len_start_load_inst->result_id());
- // Decrement search index to get candidate buffer length index
- Instruction* cand_len_idx_inst = builder.AddBinaryOp(
- GetUintId(), spv::Op::OpISub, cand_idx_inst->result_id(),
- builder.GetUintConstantId(1u));
- // Add candidate length index to start index
- Instruction* len_idx_inst = builder.AddBinaryOp(
- GetUintId(), spv::Op::OpIAdd, cand_len_idx_inst->result_id(),
- len_start_32_inst->result_id());
- // Load candidate buffer length
- Instruction* len_ac_inst =
- builder.AddTernaryOp(ibuf_ptr_id, spv::Op::OpAccessChain, ibuf_id,
- builder.GetUintConstantId(kDebugInputDataOffset),
- len_idx_inst->result_id());
- Instruction* len_load_inst = builder.AddUnaryOp(
- ibuf_type_id, spv::Op::OpLoad, len_ac_inst->result_id());
- // Test if reference end within candidate buffer length
- Instruction* len_test_inst = builder.AddBinaryOp(
- GetBoolId(), spv::Op::OpULessThanEqual, ref_end_inst->result_id(),
- len_load_inst->result_id());
- // Return test result
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
- len_test_inst->result_id());
- // Close block
- input_func->AddBasicBlock(std::move(bound_test_blk_ptr));
- // Close function and add function to module
- std::unique_ptr<Instruction> func_end_inst(new Instruction(
- get_module()->context(), spv::Op::OpFunctionEnd, 0, 0, {}));
- get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst);
- input_func->SetFunctionEnd(std::move(func_end_inst));
- context()->AddFunction(std::move(input_func));
- context()->AddDebug2Inst(
- NewGlobalName(search_test_func_id_, "search_and_test"));
+ enum {
+ kShaderId = 0,
+ kInstructionIndex = 1,
+ kStageInfo = 2,
+ kRefPtr = 3,
+ kLength = 4,
+ kNumArgs
+ };
+ if (search_test_func_id_ != 0) {
+ return search_test_func_id_;
}
+ // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
+ // which searches input buffer for buffer which most likely contains the
+ // pointer value |ref_ptr| and verifies that the entire reference of
+ // length |len| bytes is contained in the buffer.
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ const analysis::Integer* uint_type = GetInteger(32, false);
+ const analysis::Vector v4uint(uint_type, 4);
+ const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
+
+ std::vector<const analysis::Type*> param_types = {
+ uint_type, uint_type, v4uint_type, type_mgr->GetType(GetUint64Id()),
+ uint_type};
+
+ const std::string func_name{"inst_buff_addr_search_and_test"};
+ const uint32_t func_id = TakeNextId();
+ std::unique_ptr<Function> func =
+ StartFunction(func_id, type_mgr->GetBoolType(), param_types);
+ func->SetFunctionEnd(EndFunction());
+ context()->AddFunctionDeclaration(std::move(func));
+ context()->AddDebug2Inst(NewName(func_id, func_name));
+
+ std::vector<Operand> operands{
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {func_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {uint32_t(spv::Decoration::LinkageAttributes)}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_STRING,
+ utils::MakeVector(func_name.c_str())},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LINKAGE_TYPE,
+ {uint32_t(spv::LinkageType::Import)}},
+ };
+ get_decoration_mgr()->AddDecoration(spv::Op::OpDecorate, operands);
+
+ search_test_func_id_ = func_id;
return search_test_func_id_;
}
uint32_t InstBuffAddrCheckPass::GenSearchAndTest(Instruction* ref_inst,
InstructionBuilder* builder,
- uint32_t* ref_uptr_id) {
+ uint32_t* ref_uptr_id,
+ uint32_t stage_idx) {
// Enable Int64 if necessary
- context()->AddCapability(spv::Capability::Int64);
// Convert reference pointer to uint64
- uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
+ const uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
Instruction* ref_uptr_inst =
builder->AddUnaryOp(GetUint64Id(), spv::Op::OpConvertPtrToU, ref_ptr_id);
*ref_uptr_id = ref_uptr_inst->result_id();
// Compute reference length in bytes
analysis::DefUseManager* du_mgr = get_def_use_mgr();
Instruction* ref_ptr_inst = du_mgr->GetDef(ref_ptr_id);
- uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
+ const uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
Instruction* ref_ptr_ty_inst = du_mgr->GetDef(ref_ptr_ty_id);
- uint32_t ref_len = GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
- uint32_t ref_len_id = builder->GetUintConstantId(ref_len);
+ const uint32_t ref_len =
+ GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
// Gen call to search and test function
- Instruction* call_inst = builder->AddFunctionCall(
- GetBoolId(), GetSearchAndTestFuncId(), {*ref_uptr_id, ref_len_id});
- uint32_t retval = call_inst->result_id();
- return retval;
+ const uint32_t func_id = GetSearchAndTestFuncId();
+ const std::vector<uint32_t> args = {
+ builder->GetUintConstantId(shader_id_),
+ builder->GetUintConstantId(ref_inst->unique_id()),
+ GenStageInfo(stage_idx, builder), *ref_uptr_id,
+ builder->GetUintConstantId(ref_len)};
+ return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
}
void InstBuffAddrCheckPass::GenBuffAddrCheckCode(
@@ -450,16 +270,16 @@
context(), &*new_blk_ptr,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
new_blocks->push_back(std::move(new_blk_ptr));
- uint32_t error_id = builder.GetUintConstantId(kInstErrorBuffAddrUnallocRef);
// Generate code to do search and test if all bytes of reference
// are within a listed buffer. Return reference pointer converted to uint64.
uint32_t ref_uptr_id;
- uint32_t valid_id = GenSearchAndTest(ref_inst, &builder, &ref_uptr_id);
+ uint32_t valid_id =
+ GenSearchAndTest(ref_inst, &builder, &ref_uptr_id, stage_idx);
// Generate test of search results with true branch
// being full reference and false branch being debug output and zero
// for the referenced value.
- GenCheckCode(valid_id, error_id, ref_uptr_id, stage_idx, ref_inst,
- new_blocks);
+ GenCheckCode(valid_id, ref_inst, new_blocks);
+
// Move original block's remaining code into remainder/merge block and add
// to new blocks
BasicBlock* back_blk_ptr = &*new_blocks->back();
@@ -474,6 +294,20 @@
}
Pass::Status InstBuffAddrCheckPass::ProcessImpl() {
+ // The memory model and linkage must always be updated for spirv-link to work
+ // correctly.
+ AddStorageBufferExt();
+ if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
+ context()->AddExtension("SPV_KHR_physical_storage_buffer");
+ }
+
+ context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
+ Instruction* memory_model = get_module()->GetMemoryModel();
+ memory_model->SetInOperand(
+ 0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
+
+ context()->AddCapability(spv::Capability::Int64);
+ context()->AddCapability(spv::Capability::Linkage);
// Perform bindless bounds check on each entry point function in module
InstProcessFunction pfn =
[this](BasicBlock::iterator ref_inst_itr,
@@ -482,14 +316,13 @@
return GenBuffAddrCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
new_blocks);
};
- bool modified = InstProcessEntryPointCallTree(pfn);
- return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+ InstProcessEntryPointCallTree(pfn);
+ // This pass always changes the memory model, so that linking will work
+ // properly.
+ return Status::SuccessWithChange;
}
Pass::Status InstBuffAddrCheckPass::Process() {
- if (!get_feature_mgr()->HasCapability(
- spv::Capability::PhysicalStorageBufferAddressesEXT))
- return Status::SuccessWithoutChange;
InitInstBuffAddrCheck();
return ProcessImpl();
}
diff --git a/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.h b/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.h
index 2ec212b..f07f98a 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.h
+++ b/third_party/SPIRV-Tools/source/opt/inst_buff_addr_check_pass.h
@@ -29,10 +29,10 @@
class InstBuffAddrCheckPass : public InstrumentPass {
public:
// For test harness only
- InstBuffAddrCheckPass() : InstrumentPass(7, 23, kInstValidationIdBuffAddr) {}
+ InstBuffAddrCheckPass() : InstrumentPass(0, 23, false, true) {}
// For all other interfaces
- InstBuffAddrCheckPass(uint32_t desc_set, uint32_t shader_id)
- : InstrumentPass(desc_set, shader_id, kInstValidationIdBuffAddr) {}
+ InstBuffAddrCheckPass(uint32_t shader_id)
+ : InstrumentPass(0, shader_id, false, true) {}
~InstBuffAddrCheckPass() override = default;
@@ -41,14 +41,7 @@
const char* name() const override { return "inst-buff-addr-check-pass"; }
- bool InstrumentFunction(Function* func, uint32_t stage_idx,
- InstProcessFunction& pfn) override;
-
private:
- // Return byte alignment of type |type_id|. Must be int, float, vector,
- // matrix, struct, array or physical pointer. Uses std430 alignment.
- uint32_t GetTypeAlignment(uint32_t type_id);
-
// Return byte length of type |type_id|. Must be int, float, vector, matrix,
// struct, array or physical pointer. Uses std430 alignment and sizes.
uint32_t GetTypeLength(uint32_t type_id);
@@ -65,7 +58,7 @@
// are within the buffer. Returns id of boolean value which is true if
// search and test is successful, false otherwise.
uint32_t GenSearchAndTest(Instruction* ref_inst, InstructionBuilder* builder,
- uint32_t* ref_uptr_id);
+ uint32_t* ref_uptr_id, uint32_t stage_idx);
// This function does checking instrumentation on a single
// instruction which references through a physical storage buffer address.
@@ -118,8 +111,7 @@
// writes debug error output utilizing |ref_inst|, |error_id| and
// |stage_idx|. Generate merge block for valid and invalid reference blocks.
// Kill original reference.
- void GenCheckCode(uint32_t check_id, uint32_t error_id, uint32_t length_id,
- uint32_t stage_idx, Instruction* ref_inst,
+ void GenCheckCode(uint32_t check_id, Instruction* ref_inst,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
// Initialize state for instrumenting physical buffer address checking
diff --git a/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.cpp b/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.cpp
index 4f97277..abd25e9 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.cpp
@@ -16,6 +16,7 @@
#include "inst_debug_printf_pass.h"
+#include "source/spirv_constant.h"
#include "source/util/string_utils.h"
#include "spirv/unified1/NonSemanticDebugPrintf.h"
@@ -137,7 +138,7 @@
}
void InstDebugPrintfPass::GenOutputCode(
- Instruction* printf_inst, uint32_t stage_idx,
+ Instruction* printf_inst,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
BasicBlock* back_blk_ptr = &*new_blocks->back();
InstructionBuilder builder(
@@ -165,14 +166,16 @@
GenOutputValues(opnd_inst, &val_ids, &builder);
}
});
- GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
- &builder);
+ GenDebugStreamWrite(
+ builder.GetUintConstantId(shader_id_),
+ builder.GetUintConstantId(uid2offset_[printf_inst->unique_id()]), val_ids,
+ &builder);
context()->KillInst(printf_inst);
}
void InstDebugPrintfPass::GenDebugPrintfCode(
BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+ UptrVectorIterator<BasicBlock> ref_block_itr,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
// If not DebugPrintf OpExtInst, return.
Instruction* printf_inst = &*ref_inst_itr;
@@ -188,7 +191,7 @@
MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
new_blocks->push_back(std::move(new_blk_ptr));
// Generate instructions to output printf args to printf buffer
- GenOutputCode(printf_inst, stage_idx, new_blocks);
+ GenOutputCode(printf_inst, new_blocks);
// Caller expects at least two blocks with last block containing remaining
// code, so end block after instrumentation, create remainder block, and
// branch to it
@@ -208,19 +211,243 @@
new_blocks->push_back(std::move(new_blk_ptr));
}
+// Return id for output buffer
+uint32_t InstDebugPrintfPass::GetOutputBufferId() {
+ if (output_buffer_id_ == 0) {
+ // If not created yet, create one
+ analysis::DecorationManager* deco_mgr = get_decoration_mgr();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::RuntimeArray* reg_uint_rarr_ty = GetUintRuntimeArrayType(32);
+ analysis::Integer* reg_uint_ty = GetInteger(32, false);
+ analysis::Type* reg_buf_ty =
+ GetStruct({reg_uint_ty, reg_uint_ty, reg_uint_rarr_ty});
+ uint32_t obufTyId = type_mgr->GetTypeInstruction(reg_buf_ty);
+ // By the Vulkan spec, a pre-existing struct containing a RuntimeArray
+ // must be a block, and will therefore be decorated with Block. Therefore
+ // the undecorated type returned here will not be pre-existing and can
+ // safely be decorated. Since this type is now decorated, it is out of
+ // sync with the TypeManager and therefore the TypeManager must be
+ // invalidated after this pass.
+ assert(context()->get_def_use_mgr()->NumUses(obufTyId) == 0 &&
+ "used struct type returned");
+ deco_mgr->AddDecoration(obufTyId, uint32_t(spv::Decoration::Block));
+ deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputFlagsOffset,
+ uint32_t(spv::Decoration::Offset), 0);
+ deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputSizeOffset,
+ uint32_t(spv::Decoration::Offset), 4);
+ deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputDataOffset,
+ uint32_t(spv::Decoration::Offset), 8);
+ uint32_t obufTyPtrId_ =
+ type_mgr->FindPointerToType(obufTyId, spv::StorageClass::StorageBuffer);
+ output_buffer_id_ = TakeNextId();
+ std::unique_ptr<Instruction> newVarOp(new Instruction(
+ context(), spv::Op::OpVariable, obufTyPtrId_, output_buffer_id_,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {uint32_t(spv::StorageClass::StorageBuffer)}}}));
+ context()->AddGlobalValue(std::move(newVarOp));
+ context()->AddDebug2Inst(NewGlobalName(obufTyId, "OutputBuffer"));
+ context()->AddDebug2Inst(NewMemberName(obufTyId, 0, "flags"));
+ context()->AddDebug2Inst(NewMemberName(obufTyId, 1, "written_count"));
+ context()->AddDebug2Inst(NewMemberName(obufTyId, 2, "data"));
+ context()->AddDebug2Inst(NewGlobalName(output_buffer_id_, "output_buffer"));
+ deco_mgr->AddDecorationVal(
+ output_buffer_id_, uint32_t(spv::Decoration::DescriptorSet), desc_set_);
+ deco_mgr->AddDecorationVal(output_buffer_id_,
+ uint32_t(spv::Decoration::Binding),
+ GetOutputBufferBinding());
+ AddStorageBufferExt();
+ if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
+ // Add the new buffer to all entry points.
+ for (auto& entry : get_module()->entry_points()) {
+ entry.AddOperand({SPV_OPERAND_TYPE_ID, {output_buffer_id_}});
+ context()->AnalyzeUses(&entry);
+ }
+ }
+ }
+ return output_buffer_id_;
+}
+
+uint32_t InstDebugPrintfPass::GetOutputBufferPtrId() {
+ if (output_buffer_ptr_id_ == 0) {
+ output_buffer_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
+ GetUintId(), spv::StorageClass::StorageBuffer);
+ }
+ return output_buffer_ptr_id_;
+}
+
+uint32_t InstDebugPrintfPass::GetOutputBufferBinding() {
+ return kDebugOutputPrintfStream;
+}
+
+void InstDebugPrintfPass::GenDebugOutputFieldCode(uint32_t base_offset_id,
+ uint32_t field_offset,
+ uint32_t field_value_id,
+ InstructionBuilder* builder) {
+ // Cast value to 32-bit unsigned if necessary
+ uint32_t val_id = GenUintCastCode(field_value_id, builder);
+ // Store value
+ Instruction* data_idx_inst = builder->AddIAdd(
+ GetUintId(), base_offset_id, builder->GetUintConstantId(field_offset));
+ uint32_t buf_id = GetOutputBufferId();
+ uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
+ Instruction* achain_inst = builder->AddAccessChain(
+ buf_uint_ptr_id, buf_id,
+ {builder->GetUintConstantId(kDebugOutputDataOffset),
+ data_idx_inst->result_id()});
+ (void)builder->AddStore(achain_inst->result_id(), val_id);
+}
+
+uint32_t InstDebugPrintfPass::GetStreamWriteFunctionId(uint32_t param_cnt) {
+ enum {
+ kShaderId = 0,
+ kInstructionIndex = 1,
+ kFirstParam = 2,
+ };
+ // Total param count is common params plus validation-specific
+ // params
+ if (param2output_func_id_[param_cnt] == 0) {
+ // Create function
+ param2output_func_id_[param_cnt] = TakeNextId();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+
+ const analysis::Type* uint_type = GetInteger(32, false);
+
+ std::vector<const analysis::Type*> param_types(kFirstParam + param_cnt,
+ uint_type);
+ std::unique_ptr<Function> output_func = StartFunction(
+ param2output_func_id_[param_cnt], type_mgr->GetVoidType(), param_types);
+
+ std::vector<uint32_t> param_ids = AddParameters(*output_func, param_types);
+
+ // Create first block
+ auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
+
+ InstructionBuilder builder(
+ context(), &*new_blk_ptr,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ // Gen test if debug output buffer size will not be exceeded.
+ const uint32_t first_param_offset = kInstCommonOutInstructionIdx + 1;
+ const uint32_t obuf_record_sz = first_param_offset + param_cnt;
+ const uint32_t buf_id = GetOutputBufferId();
+ const uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
+ Instruction* obuf_curr_sz_ac_inst = builder.AddAccessChain(
+ buf_uint_ptr_id, buf_id,
+ {builder.GetUintConstantId(kDebugOutputSizeOffset)});
+ // Fetch the current debug buffer written size atomically, adding the
+ // size of the record to be written.
+ uint32_t obuf_record_sz_id = builder.GetUintConstantId(obuf_record_sz);
+ uint32_t mask_none_id =
+ builder.GetUintConstantId(uint32_t(spv::MemoryAccessMask::MaskNone));
+ uint32_t scope_invok_id =
+ builder.GetUintConstantId(uint32_t(spv::Scope::Invocation));
+ Instruction* obuf_curr_sz_inst = builder.AddQuadOp(
+ GetUintId(), spv::Op::OpAtomicIAdd, obuf_curr_sz_ac_inst->result_id(),
+ scope_invok_id, mask_none_id, obuf_record_sz_id);
+ uint32_t obuf_curr_sz_id = obuf_curr_sz_inst->result_id();
+ // Compute new written size
+ Instruction* obuf_new_sz_inst =
+ builder.AddIAdd(GetUintId(), obuf_curr_sz_id,
+ builder.GetUintConstantId(obuf_record_sz));
+ // Fetch the data bound
+ Instruction* obuf_bnd_inst =
+ builder.AddIdLiteralOp(GetUintId(), spv::Op::OpArrayLength,
+ GetOutputBufferId(), kDebugOutputDataOffset);
+ // Test that new written size is less than or equal to debug output
+ // data bound
+ Instruction* obuf_safe_inst = builder.AddBinaryOp(
+ GetBoolId(), spv::Op::OpULessThanEqual, obuf_new_sz_inst->result_id(),
+ obuf_bnd_inst->result_id());
+ uint32_t merge_blk_id = TakeNextId();
+ uint32_t write_blk_id = TakeNextId();
+ std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
+ std::unique_ptr<Instruction> write_label(NewLabel(write_blk_id));
+ (void)builder.AddConditionalBranch(
+ obuf_safe_inst->result_id(), write_blk_id, merge_blk_id, merge_blk_id,
+ uint32_t(spv::SelectionControlMask::MaskNone));
+ // Close safety test block and gen write block
+ output_func->AddBasicBlock(std::move(new_blk_ptr));
+ new_blk_ptr = MakeUnique<BasicBlock>(std::move(write_label));
+ builder.SetInsertPoint(&*new_blk_ptr);
+ // Generate common and stage-specific debug record members
+ GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutSize,
+ builder.GetUintConstantId(obuf_record_sz),
+ &builder);
+ // Store Shader Id
+ GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutShaderId,
+ param_ids[kShaderId], &builder);
+ // Store Instruction Idx
+ GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutInstructionIdx,
+ param_ids[kInstructionIndex], &builder);
+ // Gen writes of validation specific data
+ for (uint32_t i = 0; i < param_cnt; ++i) {
+ GenDebugOutputFieldCode(obuf_curr_sz_id, first_param_offset + i,
+ param_ids[kFirstParam + i], &builder);
+ }
+ // Close write block and gen merge block
+ (void)builder.AddBranch(merge_blk_id);
+ output_func->AddBasicBlock(std::move(new_blk_ptr));
+ new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
+ builder.SetInsertPoint(&*new_blk_ptr);
+ // Close merge block and function and add function to module
+ (void)builder.AddNullaryOp(0, spv::Op::OpReturn);
+
+ output_func->AddBasicBlock(std::move(new_blk_ptr));
+ output_func->SetFunctionEnd(EndFunction());
+ context()->AddFunction(std::move(output_func));
+
+ std::string name("stream_write_");
+ name += std::to_string(param_cnt);
+
+ context()->AddDebug2Inst(
+ NewGlobalName(param2output_func_id_[param_cnt], name));
+ }
+ return param2output_func_id_[param_cnt];
+}
+
+void InstDebugPrintfPass::GenDebugStreamWrite(
+ uint32_t shader_id, uint32_t instruction_idx_id,
+ const std::vector<uint32_t>& validation_ids, InstructionBuilder* builder) {
+ // Call debug output function. Pass func_idx, instruction_idx and
+ // validation ids as args.
+ uint32_t val_id_cnt = static_cast<uint32_t>(validation_ids.size());
+ std::vector<uint32_t> args = {shader_id, instruction_idx_id};
+ (void)args.insert(args.end(), validation_ids.begin(), validation_ids.end());
+ (void)builder->AddFunctionCall(GetVoidId(),
+ GetStreamWriteFunctionId(val_id_cnt), args);
+}
+
+std::unique_ptr<Instruction> InstDebugPrintfPass::NewGlobalName(
+ uint32_t id, const std::string& name_str) {
+ std::string prefixed_name{"inst_printf_"};
+ prefixed_name += name_str;
+ return NewName(id, prefixed_name);
+}
+
+std::unique_ptr<Instruction> InstDebugPrintfPass::NewMemberName(
+ uint32_t id, uint32_t member_index, const std::string& name_str) {
+ return MakeUnique<Instruction>(
+ context(), spv::Op::OpMemberName, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {id}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index}},
+ {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}});
+}
+
void InstDebugPrintfPass::InitializeInstDebugPrintf() {
// Initialize base class
InitializeInstrument();
+ output_buffer_id_ = 0;
+ output_buffer_ptr_id_ = 0;
}
Pass::Status InstDebugPrintfPass::ProcessImpl() {
// Perform printf instrumentation on each entry point function in module
InstProcessFunction pfn =
[this](BasicBlock::iterator ref_inst_itr,
- UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+ UptrVectorIterator<BasicBlock> ref_block_itr,
+ [[maybe_unused]] uint32_t stage_idx,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
- return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx,
- new_blocks);
+ return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, new_blocks);
};
(void)InstProcessEntryPointCallTree(pfn);
// Remove DebugPrintf OpExtInstImport instruction
@@ -239,15 +466,7 @@
}
}
if (!non_sem_set_seen) {
- for (auto c_itr = context()->module()->extension_begin();
- c_itr != context()->module()->extension_end(); ++c_itr) {
- const std::string ext_name = c_itr->GetInOperand(0).AsString();
- if (ext_name == "SPV_KHR_non_semantic_info") {
- context()->KillInst(&*c_itr);
- break;
- }
- }
- context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info);
+ context()->RemoveExtension(kSPV_KHR_non_semantic_info);
}
return Status::SuccessWithChange;
}
diff --git a/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.h b/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.h
index 70b0a72..5688d38 100644
--- a/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.h
+++ b/third_party/SPIRV-Tools/source/opt/inst_debug_printf_pass.h
@@ -28,10 +28,10 @@
class InstDebugPrintfPass : public InstrumentPass {
public:
// For test harness only
- InstDebugPrintfPass() : InstrumentPass(7, 23, kInstValidationIdDebugPrintf) {}
+ InstDebugPrintfPass() : InstrumentPass(7, 23, false, false) {}
// For all other interfaces
InstDebugPrintfPass(uint32_t desc_set, uint32_t shader_id)
- : InstrumentPass(desc_set, shader_id, kInstValidationIdDebugPrintf) {}
+ : InstrumentPass(desc_set, shader_id, false, false) {}
~InstDebugPrintfPass() override = default;
@@ -41,12 +41,92 @@
const char* name() const override { return "inst-printf-pass"; }
private:
+ // Gen code into |builder| to write |field_value_id| into debug output
+ // buffer at |base_offset_id| + |field_offset|.
+ void GenDebugOutputFieldCode(uint32_t base_offset_id, uint32_t field_offset,
+ uint32_t field_value_id,
+ InstructionBuilder* builder);
+
+ // Generate instructions in |builder| which will atomically fetch and
+ // increment the size of the debug output buffer stream of the current
+ // validation and write a record to the end of the stream, if enough space
+ // in the buffer remains. The record will contain the index of the function
+ // and instruction within that function |func_idx, instruction_idx| which
+ // generated the record. Finally, the record will contain validation-specific
+ // data contained in |validation_ids| which will identify the validation
+ // error as well as the values involved in the error.
+ //
+ // The output buffer binding written to by the code generated by the function
+ // is determined by the validation id specified when each specific
+ // instrumentation pass is created.
+ //
+ // The output buffer is a sequence of 32-bit values with the following
+ // format (where all elements are unsigned 32-bit unless otherwise noted):
+ //
+ // Size
+ // Record0
+ // Record1
+ // Record2
+ // ...
+ //
+ // Size is the number of 32-bit values that have been written or
+ // attempted to be written to the output buffer, excluding the Size. It is
+ // initialized to 0. If the size of attempts to write the buffer exceeds
+ // the actual size of the buffer, it is possible that this field can exceed
+ // the actual size of the buffer.
+ //
+ // Each Record* is a variable-length sequence of 32-bit values with the
+ // following format defined using static const offsets in the .cpp file:
+ //
+ // Record Size
+ // Shader ID
+ // Instruction Index
+ // ...
+ // Validation Error Code
+ // Validation-specific Word 0
+ // Validation-specific Word 1
+ // Validation-specific Word 2
+ // ...
+ //
+ // Each record consists of two subsections: members common across all
+ // validation and members specific to a
+ // validation.
+ //
+ // The Record Size is the number of 32-bit words in the record, including
+ // the Record Size word.
+ //
+ // Shader ID is a value that identifies which shader has generated the
+ // validation error. It is passed when the instrumentation pass is created.
+ //
+ // The Instruction Index is the position of the instruction within the
+ // SPIR-V file which is in error.
+ //
+ // The Validation Error Code specifies the exact error which has occurred.
+ // These are enumerated with the kInstError* static consts. This allows
+ // multiple validation layers to use the same, single output buffer.
+ //
+ // The Validation-specific Words are a validation-specific number of 32-bit
+ // words which give further information on the validation error that
+ // occurred. These are documented further in each file containing the
+ // validation-specific class which derives from this base class.
+ //
+ // Because the code that is generated checks against the size of the buffer
+ // before writing, the size of the debug out buffer can be used by the
+ // validation layer to control the number of error records that are written.
+ void GenDebugStreamWrite(uint32_t shader_id, uint32_t instruction_idx_id,
+ const std::vector<uint32_t>& validation_ids,
+ InstructionBuilder* builder);
+
+ // Return id for output function. Define if it doesn't exist with
+ // |val_spec_param_cnt| validation-specific uint32 parameters.
+ uint32_t GetStreamWriteFunctionId(uint32_t val_spec_param_cnt);
+
// Generate instructions for OpDebugPrintf.
//
// If |ref_inst_itr| is an OpDebugPrintf, return in |new_blocks| the result
// of replacing it with buffer write instructions within its block at
// |ref_block_itr|. The instructions write a record to the printf
- // output buffer stream including |function_idx, instruction_idx, stage_idx|
+ // output buffer stream including |function_idx, instruction_idx|
// and removes the OpDebugPrintf. The block at |ref_block_itr| can just be
// replaced with the block in |new_blocks|. Besides the buffer writes, this
// block will comprise all instructions preceding and following
@@ -64,7 +144,6 @@
// DebugPrintf.
void GenDebugPrintfCode(BasicBlock::iterator ref_inst_itr,
UptrVectorIterator<BasicBlock> ref_block_itr,
- uint32_t stage_idx,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
// Generate a sequence of uint32 instructions in |builder| (if necessary)
@@ -77,16 +156,40 @@
// Generate instructions to write a record containing the operands of
// |printf_inst| arguments to printf buffer, adding new code to the end of
// the last block in |new_blocks|. Kill OpDebugPrintf instruction.
- void GenOutputCode(Instruction* printf_inst, uint32_t stage_idx,
+ void GenOutputCode(Instruction* printf_inst,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+ // Set the name for a function or global variable, names will be
+ // prefixed to identify which instrumentation pass generated them.
+ std::unique_ptr<Instruction> NewGlobalName(uint32_t id,
+ const std::string& name_str);
+
+ // Set the name for a structure member
+ std::unique_ptr<Instruction> NewMemberName(uint32_t id, uint32_t member_index,
+ const std::string& name_str);
+
+ // Return id for debug output buffer
+ uint32_t GetOutputBufferId();
+
+ // Return id for buffer uint type
+ uint32_t GetOutputBufferPtrId();
+
+ // Return binding for output buffer for current validation.
+ uint32_t GetOutputBufferBinding();
+
// Initialize state for instrumenting bindless checking
void InitializeInstDebugPrintf();
// Apply GenDebugPrintfCode to every instruction in module.
Pass::Status ProcessImpl();
- uint32_t ext_inst_printf_id_;
+ uint32_t ext_inst_printf_id_{0};
+
+ // id for output buffer variable
+ uint32_t output_buffer_id_{0};
+
+ // ptr type id for output buffer element
+ uint32_t output_buffer_ptr_id_{0};
};
} // namespace opt
diff --git a/third_party/SPIRV-Tools/source/opt/instruction.cpp b/third_party/SPIRV-Tools/source/opt/instruction.cpp
index ece6baf..aa4ae26 100644
--- a/third_party/SPIRV-Tools/source/opt/instruction.cpp
+++ b/third_party/SPIRV-Tools/source/opt/instruction.cpp
@@ -751,7 +751,7 @@
}
bool Instruction::IsFoldable() const {
- return IsFoldableByFoldScalar() ||
+ return IsFoldableByFoldScalar() || IsFoldableByFoldVector() ||
context()->get_instruction_folder().HasConstFoldingRule(this);
}
@@ -762,7 +762,7 @@
}
Instruction* type = context()->get_def_use_mgr()->GetDef(type_id());
- if (!folder.IsFoldableType(type)) {
+ if (!folder.IsFoldableScalarType(type)) {
return false;
}
@@ -773,7 +773,29 @@
Instruction* def_inst = context()->get_def_use_mgr()->GetDef(*op_id);
Instruction* def_inst_type =
context()->get_def_use_mgr()->GetDef(def_inst->type_id());
- return folder.IsFoldableType(def_inst_type);
+ return folder.IsFoldableScalarType(def_inst_type);
+ });
+}
+
+bool Instruction::IsFoldableByFoldVector() const {
+ const InstructionFolder& folder = context()->get_instruction_folder();
+ if (!folder.IsFoldableOpcode(opcode())) {
+ return false;
+ }
+
+ Instruction* type = context()->get_def_use_mgr()->GetDef(type_id());
+ if (!folder.IsFoldableVectorType(type)) {
+ return false;
+ }
+
+ // Even if the type of the instruction is foldable, its operands may not be
+ // foldable (e.g., comparisons of 64bit types). Check that all operand types
+ // are foldable before accepting the instruction.
+ return WhileEachInOperand([&folder, this](const uint32_t* op_id) {
+ Instruction* def_inst = context()->get_def_use_mgr()->GetDef(*op_id);
+ Instruction* def_inst_type =
+ context()->get_def_use_mgr()->GetDef(def_inst->type_id());
+ return folder.IsFoldableVectorType(def_inst_type);
});
}
diff --git a/third_party/SPIRV-Tools/source/opt/instruction.h b/third_party/SPIRV-Tools/source/opt/instruction.h
index d50e625..c2617fb 100644
--- a/third_party/SPIRV-Tools/source/opt/instruction.h
+++ b/third_party/SPIRV-Tools/source/opt/instruction.h
@@ -294,6 +294,8 @@
// It is the responsibility of the caller to make sure
// that the instruction remains valid.
inline void AddOperand(Operand&& operand);
+ // Adds a copy of |operand| to the list of operands of this instruction.
+ inline void AddOperand(const Operand& operand);
// Gets the |index|-th logical operand as a single SPIR-V word. This method is
// not expected to be used with logical operands consisting of multiple SPIR-V
// words.
@@ -522,6 +524,10 @@
// constant value by |FoldScalar|.
bool IsFoldableByFoldScalar() const;
+ // Returns true if |this| is an instruction which could be folded into a
+ // constant value by |FoldVector|.
+ bool IsFoldableByFoldVector() const;
+
// Returns true if we are allowed to fold or otherwise manipulate the
// instruction that defines |id| in the given context. This includes not
// handling NaN values.
@@ -676,6 +682,10 @@
operands_.push_back(std::move(operand));
}
+inline void Instruction::AddOperand(const Operand& operand) {
+ operands_.push_back(operand);
+}
+
inline void Instruction::SetInOperand(uint32_t index,
Operand::OperandData&& data) {
SetOperand(index + TypeResultIdCount(), std::move(data));
diff --git a/third_party/SPIRV-Tools/source/opt/instrument_pass.cpp b/third_party/SPIRV-Tools/source/opt/instrument_pass.cpp
index 9233ffd..b6845a5 100644
--- a/third_party/SPIRV-Tools/source/opt/instrument_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/instrument_pass.cpp
@@ -22,9 +22,6 @@
namespace spvtools {
namespace opt {
namespace {
-// Common Parameter Positions
-constexpr int kInstCommonParamInstIdx = 0;
-constexpr int kInstCommonParamCnt = 1;
// Indices of operands in SPIR-V instructions
constexpr int kEntryPointFunctionIdInIdx = 1;
} // namespace
@@ -134,38 +131,6 @@
{SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}});
}
-std::unique_ptr<Instruction> InstrumentPass::NewGlobalName(
- uint32_t id, const std::string& name_str) {
- std::string prefixed_name;
- switch (validation_id_) {
- case kInstValidationIdBindless:
- prefixed_name = "inst_bindless_";
- break;
- case kInstValidationIdBuffAddr:
- prefixed_name = "inst_buff_addr_";
- break;
- case kInstValidationIdDebugPrintf:
- prefixed_name = "inst_printf_";
- break;
- default:
- assert(false); // add new instrumentation pass here
- prefixed_name = "inst_pass_";
- break;
- }
- prefixed_name += name_str;
- return NewName(id, prefixed_name);
-}
-
-std::unique_ptr<Instruction> InstrumentPass::NewMemberName(
- uint32_t id, uint32_t member_index, const std::string& name_str) {
- return MakeUnique<Instruction>(
- context(), spv::Op::OpMemberName, 0, 0,
- std::initializer_list<Operand>{
- {SPV_OPERAND_TYPE_ID, {id}},
- {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index}},
- {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}});
-}
-
uint32_t InstrumentPass::Gen32BitCvtCode(uint32_t val_id,
InstructionBuilder* builder) {
// Convert integer value to 32-bit if necessary
@@ -198,52 +163,6 @@
->result_id();
}
-void InstrumentPass::GenDebugOutputFieldCode(uint32_t base_offset_id,
- uint32_t field_offset,
- uint32_t field_value_id,
- InstructionBuilder* builder) {
- // Cast value to 32-bit unsigned if necessary
- uint32_t val_id = GenUintCastCode(field_value_id, builder);
- // Store value
- Instruction* data_idx_inst = builder->AddIAdd(
- GetUintId(), base_offset_id, builder->GetUintConstantId(field_offset));
- uint32_t buf_id = GetOutputBufferId();
- uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
- Instruction* achain_inst = builder->AddAccessChain(
- buf_uint_ptr_id, buf_id,
- {builder->GetUintConstantId(kDebugOutputDataOffset),
- data_idx_inst->result_id()});
- (void)builder->AddStore(achain_inst->result_id(), val_id);
-}
-
-void InstrumentPass::GenCommonStreamWriteCode(uint32_t record_sz,
- uint32_t inst_id,
- uint32_t stage_idx,
- uint32_t base_offset_id,
- InstructionBuilder* builder) {
- // Store record size
- GenDebugOutputFieldCode(base_offset_id, kInstCommonOutSize,
- builder->GetUintConstantId(record_sz), builder);
- // Store Shader Id
- GenDebugOutputFieldCode(base_offset_id, kInstCommonOutShaderId,
- builder->GetUintConstantId(shader_id_), builder);
- // Store Instruction Idx
- GenDebugOutputFieldCode(base_offset_id, kInstCommonOutInstructionIdx, inst_id,
- builder);
- // Store Stage Idx
- GenDebugOutputFieldCode(base_offset_id, kInstCommonOutStageIdx,
- builder->GetUintConstantId(stage_idx), builder);
-}
-
-void InstrumentPass::GenFragCoordEltDebugOutputCode(
- uint32_t base_offset_id, uint32_t uint_frag_coord_id, uint32_t element,
- InstructionBuilder* builder) {
- Instruction* element_val_inst =
- builder->AddCompositeExtract(GetUintId(), uint_frag_coord_id, {element});
- GenDebugOutputFieldCode(base_offset_id, kInstFragOutFragCoordX + element,
- element_val_inst->result_id(), builder);
-}
-
uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
InstructionBuilder* builder) {
Instruction* var_inst = get_def_use_mgr()->GetDef(var_id);
@@ -252,28 +171,24 @@
return load_inst->result_id();
}
-void InstrumentPass::GenBuiltinOutputCode(uint32_t builtin_id,
- uint32_t builtin_off,
- uint32_t base_offset_id,
- InstructionBuilder* builder) {
- // Load and store builtin
- uint32_t load_id = GenVarLoad(builtin_id, builder);
- GenDebugOutputFieldCode(base_offset_id, builtin_off, load_id, builder);
-}
-
-void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
- uint32_t base_offset_id,
- InstructionBuilder* builder) {
+uint32_t InstrumentPass::GenStageInfo(uint32_t stage_idx,
+ InstructionBuilder* builder) {
+ std::vector<uint32_t> ids(4, builder->GetUintConstantId(0));
+ ids[0] = builder->GetUintConstantId(stage_idx);
+ // %289 = OpCompositeConstruct %v4uint %uint_0 %285 %288 %uint_0
// TODO(greg-lunarg): Add support for all stages
switch (spv::ExecutionModel(stage_idx)) {
case spv::ExecutionModel::Vertex: {
// Load and store VertexId and InstanceId
- GenBuiltinOutputCode(
+ uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::VertexIndex)),
- kInstVertOutVertexIndex, base_offset_id, builder);
- GenBuiltinOutputCode(context()->GetBuiltinInputVarId(
+ builder);
+ ids[1] = GenUintCastCode(load_id, builder);
+
+ load_id = GenVarLoad(context()->GetBuiltinInputVarId(
uint32_t(spv::BuiltIn::InstanceIndex)),
- kInstVertOutInstanceIndex, base_offset_id, builder);
+ builder);
+ ids[2] = GenUintCastCode(load_id, builder);
} break;
case spv::ExecutionModel::GLCompute:
case spv::ExecutionModel::TaskNV:
@@ -284,56 +199,50 @@
uint32_t load_id = GenVarLoad(context()->GetBuiltinInputVarId(uint32_t(
spv::BuiltIn::GlobalInvocationId)),
builder);
- Instruction* x_inst =
- builder->AddCompositeExtract(GetUintId(), load_id, {0});
- Instruction* y_inst =
- builder->AddCompositeExtract(GetUintId(), load_id, {1});
- Instruction* z_inst =
- builder->AddCompositeExtract(GetUintId(), load_id, {2});
- GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdX,
- x_inst->result_id(), builder);
- GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdY,
- y_inst->result_id(), builder);
- GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdZ,
- z_inst->result_id(), builder);
+ for (uint32_t u = 0; u < 3u; ++u) {
+ ids[u + 1] = builder->AddCompositeExtract(GetUintId(), load_id, {u})
+ ->result_id();
+ }
} break;
case spv::ExecutionModel::Geometry: {
// Load and store PrimitiveId and InvocationId.
- GenBuiltinOutputCode(
+ uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
- kInstGeomOutPrimitiveId, base_offset_id, builder);
- GenBuiltinOutputCode(
+ builder);
+ ids[1] = load_id;
+ load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
- kInstGeomOutInvocationId, base_offset_id, builder);
+ builder);
+ ids[2] = GenUintCastCode(load_id, builder);
} break;
case spv::ExecutionModel::TessellationControl: {
// Load and store InvocationId and PrimitiveId
- GenBuiltinOutputCode(
+ uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
- kInstTessCtlOutInvocationId, base_offset_id, builder);
- GenBuiltinOutputCode(
+ builder);
+ ids[1] = GenUintCastCode(load_id, builder);
+ load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
- kInstTessCtlOutPrimitiveId, base_offset_id, builder);
+ builder);
+ ids[2] = load_id;
} break;
case spv::ExecutionModel::TessellationEvaluation: {
// Load and store PrimitiveId and TessCoord.uv
- GenBuiltinOutputCode(
- context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
- kInstTessEvalOutPrimitiveId, base_offset_id, builder);
uint32_t load_id = GenVarLoad(
+ context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
+ builder);
+ ids[1] = load_id;
+ load_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::TessCoord)),
builder);
Instruction* uvec3_cast_inst =
builder->AddUnaryOp(GetVec3UintId(), spv::Op::OpBitcast, load_id);
uint32_t uvec3_cast_id = uvec3_cast_inst->result_id();
- Instruction* u_inst =
- builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {0});
- Instruction* v_inst =
- builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {1});
- GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordU,
- u_inst->result_id(), builder);
- GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordV,
- v_inst->result_id(), builder);
+ for (uint32_t u = 0; u < 2u; ++u) {
+ ids[u + 2] =
+ builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {u})
+ ->result_id();
+ }
} break;
case spv::ExecutionModel::Fragment: {
// Load FragCoord and convert to Uint
@@ -342,9 +251,13 @@
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::FragCoord)));
Instruction* uint_frag_coord_inst = builder->AddUnaryOp(
GetVec4UintId(), spv::Op::OpBitcast, frag_coord_inst->result_id());
- for (uint32_t u = 0; u < 2u; ++u)
- GenFragCoordEltDebugOutputCode(
- base_offset_id, uint_frag_coord_inst->result_id(), u, builder);
+ for (uint32_t u = 0; u < 2u; ++u) {
+ ids[u + 1] =
+ builder
+ ->AddCompositeExtract(GetUintId(),
+ uint_frag_coord_inst->result_id(), {u})
+ ->result_id();
+ }
} break;
case spv::ExecutionModel::RayGenerationNV:
case spv::ExecutionModel::IntersectionNV:
@@ -356,33 +269,14 @@
uint32_t launch_id = GenVarLoad(
context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::LaunchIdNV)),
builder);
- Instruction* x_launch_inst =
- builder->AddCompositeExtract(GetUintId(), launch_id, {0});
- Instruction* y_launch_inst =
- builder->AddCompositeExtract(GetUintId(), launch_id, {1});
- Instruction* z_launch_inst =
- builder->AddCompositeExtract(GetUintId(), launch_id, {2});
- GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdX,
- x_launch_inst->result_id(), builder);
- GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdY,
- y_launch_inst->result_id(), builder);
- GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdZ,
- z_launch_inst->result_id(), builder);
+ for (uint32_t u = 0; u < 3u; ++u) {
+ ids[u + 1] = builder->AddCompositeExtract(GetUintId(), launch_id, {u})
+ ->result_id();
+ }
} break;
default: { assert(false && "unsupported stage"); } break;
}
-}
-
-void InstrumentPass::GenDebugStreamWrite(
- uint32_t instruction_idx, uint32_t stage_idx,
- const std::vector<uint32_t>& validation_ids, InstructionBuilder* builder) {
- // Call debug output function. Pass func_idx, instruction_idx and
- // validation ids as args.
- uint32_t val_id_cnt = static_cast<uint32_t>(validation_ids.size());
- std::vector<uint32_t> args = {builder->GetUintConstantId(instruction_idx)};
- (void)args.insert(args.end(), validation_ids.begin(), validation_ids.end());
- (void)builder->AddFunctionCall(
- GetVoidId(), GetStreamWriteFunctionId(stage_idx, val_id_cnt), args);
+ return builder->AddCompositeConstruct(GetVec4UintId(), ids)->result_id();
}
bool InstrumentPass::AllConstant(const std::vector<uint32_t>& ids) {
@@ -393,16 +287,9 @@
return true;
}
-uint32_t InstrumentPass::GenDebugDirectRead(
- const std::vector<uint32_t>& offset_ids, InstructionBuilder* builder) {
- // Call debug input function. Pass func_idx and offset ids as args.
- const uint32_t off_id_cnt = static_cast<uint32_t>(offset_ids.size());
- const uint32_t input_func_id = GetDirectReadFunctionId(off_id_cnt);
- return GenReadFunctionCall(input_func_id, offset_ids, builder);
-}
-
uint32_t InstrumentPass::GenReadFunctionCall(
- uint32_t func_id, const std::vector<uint32_t>& func_call_args,
+ uint32_t return_id, uint32_t func_id,
+ const std::vector<uint32_t>& func_call_args,
InstructionBuilder* ref_builder) {
// If optimizing direct reads and the call has already been generated,
// use its result
@@ -423,8 +310,7 @@
builder.SetInsertPoint(insert_before);
}
uint32_t res_id =
- builder.AddFunctionCall(GetUintId(), func_id, func_call_args)
- ->result_id();
+ builder.AddFunctionCall(return_id, func_id, func_call_args)->result_id();
if (insert_in_first_block) call2id_[func_call_args] = res_id;
return res_id;
}
@@ -494,53 +380,6 @@
});
}
-uint32_t InstrumentPass::GetOutputBufferPtrId() {
- if (output_buffer_ptr_id_ == 0) {
- output_buffer_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
- GetUintId(), spv::StorageClass::StorageBuffer);
- }
- return output_buffer_ptr_id_;
-}
-
-uint32_t InstrumentPass::GetInputBufferTypeId() {
- return (validation_id_ == kInstValidationIdBuffAddr) ? GetUint64Id()
- : GetUintId();
-}
-
-uint32_t InstrumentPass::GetInputBufferPtrId() {
- if (input_buffer_ptr_id_ == 0) {
- input_buffer_ptr_id_ = context()->get_type_mgr()->FindPointerToType(
- GetInputBufferTypeId(), spv::StorageClass::StorageBuffer);
- }
- return input_buffer_ptr_id_;
-}
-
-uint32_t InstrumentPass::GetOutputBufferBinding() {
- switch (validation_id_) {
- case kInstValidationIdBindless:
- return kDebugOutputBindingStream;
- case kInstValidationIdBuffAddr:
- return kDebugOutputBindingStream;
- case kInstValidationIdDebugPrintf:
- return kDebugOutputPrintfStream;
- default:
- assert(false && "unexpected validation id");
- }
- return 0;
-}
-
-uint32_t InstrumentPass::GetInputBufferBinding() {
- switch (validation_id_) {
- case kInstValidationIdBindless:
- return kDebugInputBindingBindless;
- case kInstValidationIdBuffAddr:
- return kDebugInputBindingBuffAddr;
- default:
- assert(false && "unexpected validation id");
- }
- return 0;
-}
-
analysis::Integer* InstrumentPass::GetInteger(uint32_t width, bool is_signed) {
analysis::Integer i(width, is_signed);
analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&i);
@@ -621,110 +460,6 @@
storage_buffer_ext_defined_ = true;
}
-// Return id for output buffer
-uint32_t InstrumentPass::GetOutputBufferId() {
- if (output_buffer_id_ == 0) {
- // If not created yet, create one
- analysis::DecorationManager* deco_mgr = get_decoration_mgr();
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::RuntimeArray* reg_uint_rarr_ty = GetUintRuntimeArrayType(32);
- analysis::Integer* reg_uint_ty = GetInteger(32, false);
- analysis::Type* reg_buf_ty =
- GetStruct({reg_uint_ty, reg_uint_ty, reg_uint_rarr_ty});
- uint32_t obufTyId = type_mgr->GetTypeInstruction(reg_buf_ty);
- // By the Vulkan spec, a pre-existing struct containing a RuntimeArray
- // must be a block, and will therefore be decorated with Block. Therefore
- // the undecorated type returned here will not be pre-existing and can
- // safely be decorated. Since this type is now decorated, it is out of
- // sync with the TypeManager and therefore the TypeManager must be
- // invalidated after this pass.
- assert(context()->get_def_use_mgr()->NumUses(obufTyId) == 0 &&
- "used struct type returned");
- deco_mgr->AddDecoration(obufTyId, uint32_t(spv::Decoration::Block));
- deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputFlagsOffset,
- uint32_t(spv::Decoration::Offset), 0);
- deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputSizeOffset,
- uint32_t(spv::Decoration::Offset), 4);
- deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputDataOffset,
- uint32_t(spv::Decoration::Offset), 8);
- uint32_t obufTyPtrId_ =
- type_mgr->FindPointerToType(obufTyId, spv::StorageClass::StorageBuffer);
- output_buffer_id_ = TakeNextId();
- std::unique_ptr<Instruction> newVarOp(new Instruction(
- context(), spv::Op::OpVariable, obufTyPtrId_, output_buffer_id_,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {uint32_t(spv::StorageClass::StorageBuffer)}}}));
- context()->AddGlobalValue(std::move(newVarOp));
- context()->AddDebug2Inst(NewGlobalName(obufTyId, "OutputBuffer"));
- context()->AddDebug2Inst(NewMemberName(obufTyId, 0, "flags"));
- context()->AddDebug2Inst(NewMemberName(obufTyId, 1, "written_count"));
- context()->AddDebug2Inst(NewMemberName(obufTyId, 2, "data"));
- context()->AddDebug2Inst(NewGlobalName(output_buffer_id_, "output_buffer"));
- deco_mgr->AddDecorationVal(
- output_buffer_id_, uint32_t(spv::Decoration::DescriptorSet), desc_set_);
- deco_mgr->AddDecorationVal(output_buffer_id_,
- uint32_t(spv::Decoration::Binding),
- GetOutputBufferBinding());
- AddStorageBufferExt();
- if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
- // Add the new buffer to all entry points.
- for (auto& entry : get_module()->entry_points()) {
- entry.AddOperand({SPV_OPERAND_TYPE_ID, {output_buffer_id_}});
- context()->AnalyzeUses(&entry);
- }
- }
- }
- return output_buffer_id_;
-}
-
-uint32_t InstrumentPass::GetInputBufferId() {
- if (input_buffer_id_ == 0) {
- // If not created yet, create one
- analysis::DecorationManager* deco_mgr = get_decoration_mgr();
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- uint32_t width = (validation_id_ == kInstValidationIdBuffAddr) ? 64u : 32u;
- analysis::Type* reg_uint_rarr_ty = GetUintRuntimeArrayType(width);
- analysis::Struct* reg_buf_ty = GetStruct({reg_uint_rarr_ty});
- uint32_t ibufTyId = type_mgr->GetTypeInstruction(reg_buf_ty);
- // By the Vulkan spec, a pre-existing struct containing a RuntimeArray
- // must be a block, and will therefore be decorated with Block. Therefore
- // the undecorated type returned here will not be pre-existing and can
- // safely be decorated. Since this type is now decorated, it is out of
- // sync with the TypeManager and therefore the TypeManager must be
- // invalidated after this pass.
- assert(context()->get_def_use_mgr()->NumUses(ibufTyId) == 0 &&
- "used struct type returned");
- deco_mgr->AddDecoration(ibufTyId, uint32_t(spv::Decoration::Block));
- deco_mgr->AddMemberDecoration(ibufTyId, 0,
- uint32_t(spv::Decoration::Offset), 0);
- uint32_t ibufTyPtrId_ =
- type_mgr->FindPointerToType(ibufTyId, spv::StorageClass::StorageBuffer);
- input_buffer_id_ = TakeNextId();
- std::unique_ptr<Instruction> newVarOp(new Instruction(
- context(), spv::Op::OpVariable, ibufTyPtrId_, input_buffer_id_,
- {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {uint32_t(spv::StorageClass::StorageBuffer)}}}));
- context()->AddGlobalValue(std::move(newVarOp));
- context()->AddDebug2Inst(NewGlobalName(ibufTyId, "InputBuffer"));
- context()->AddDebug2Inst(NewMemberName(ibufTyId, 0, "data"));
- context()->AddDebug2Inst(NewGlobalName(input_buffer_id_, "input_buffer"));
- deco_mgr->AddDecorationVal(
- input_buffer_id_, uint32_t(spv::Decoration::DescriptorSet), desc_set_);
- deco_mgr->AddDecorationVal(input_buffer_id_,
- uint32_t(spv::Decoration::Binding),
- GetInputBufferBinding());
- AddStorageBufferExt();
- if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
- // Add the new buffer to all entry points.
- for (auto& entry : get_module()->entry_points()) {
- entry.AddOperand({SPV_OPERAND_TYPE_ID, {input_buffer_id_}});
- context()->AnalyzeUses(&entry);
- }
- }
- }
- return input_buffer_id_;
-}
-
uint32_t InstrumentPass::GetFloatId() {
if (float_id_ == 0) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
@@ -817,159 +552,6 @@
return void_id_;
}
-uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
- uint32_t val_spec_param_cnt) {
- // Total param count is common params plus validation-specific
- // params
- uint32_t param_cnt = kInstCommonParamCnt + val_spec_param_cnt;
- if (param2output_func_id_[param_cnt] == 0) {
- // Create function
- param2output_func_id_[param_cnt] = TakeNextId();
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
-
- const std::vector<const analysis::Type*> param_types(param_cnt,
- GetInteger(32, false));
- std::unique_ptr<Function> output_func = StartFunction(
- param2output_func_id_[param_cnt], type_mgr->GetVoidType(), param_types);
-
- std::vector<uint32_t> param_ids = AddParameters(*output_func, param_types);
-
- // Create first block
- auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
-
- InstructionBuilder builder(
- context(), &*new_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- // Gen test if debug output buffer size will not be exceeded.
- uint32_t val_spec_offset = kInstStageOutCnt;
- uint32_t obuf_record_sz = val_spec_offset + val_spec_param_cnt;
- uint32_t buf_id = GetOutputBufferId();
- uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
- Instruction* obuf_curr_sz_ac_inst = builder.AddAccessChain(
- buf_uint_ptr_id, buf_id,
- {builder.GetUintConstantId(kDebugOutputSizeOffset)});
- // Fetch the current debug buffer written size atomically, adding the
- // size of the record to be written.
- uint32_t obuf_record_sz_id = builder.GetUintConstantId(obuf_record_sz);
- uint32_t mask_none_id =
- builder.GetUintConstantId(uint32_t(spv::MemoryAccessMask::MaskNone));
- uint32_t scope_invok_id =
- builder.GetUintConstantId(uint32_t(spv::Scope::Invocation));
- Instruction* obuf_curr_sz_inst = builder.AddQuadOp(
- GetUintId(), spv::Op::OpAtomicIAdd, obuf_curr_sz_ac_inst->result_id(),
- scope_invok_id, mask_none_id, obuf_record_sz_id);
- uint32_t obuf_curr_sz_id = obuf_curr_sz_inst->result_id();
- // Compute new written size
- Instruction* obuf_new_sz_inst =
- builder.AddIAdd(GetUintId(), obuf_curr_sz_id,
- builder.GetUintConstantId(obuf_record_sz));
- // Fetch the data bound
- Instruction* obuf_bnd_inst =
- builder.AddIdLiteralOp(GetUintId(), spv::Op::OpArrayLength,
- GetOutputBufferId(), kDebugOutputDataOffset);
- // Test that new written size is less than or equal to debug output
- // data bound
- Instruction* obuf_safe_inst = builder.AddBinaryOp(
- GetBoolId(), spv::Op::OpULessThanEqual, obuf_new_sz_inst->result_id(),
- obuf_bnd_inst->result_id());
- uint32_t merge_blk_id = TakeNextId();
- uint32_t write_blk_id = TakeNextId();
- std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
- std::unique_ptr<Instruction> write_label(NewLabel(write_blk_id));
- (void)builder.AddConditionalBranch(
- obuf_safe_inst->result_id(), write_blk_id, merge_blk_id, merge_blk_id,
- uint32_t(spv::SelectionControlMask::MaskNone));
- // Close safety test block and gen write block
- output_func->AddBasicBlock(std::move(new_blk_ptr));
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(write_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- // Generate common and stage-specific debug record members
- GenCommonStreamWriteCode(obuf_record_sz, param_ids[kInstCommonParamInstIdx],
- stage_idx, obuf_curr_sz_id, &builder);
- GenStageStreamWriteCode(stage_idx, obuf_curr_sz_id, &builder);
- // Gen writes of validation specific data
- for (uint32_t i = 0; i < val_spec_param_cnt; ++i) {
- GenDebugOutputFieldCode(obuf_curr_sz_id, val_spec_offset + i,
- param_ids[kInstCommonParamCnt + i], &builder);
- }
- // Close write block and gen merge block
- (void)builder.AddBranch(merge_blk_id);
- output_func->AddBasicBlock(std::move(new_blk_ptr));
- new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
- builder.SetInsertPoint(&*new_blk_ptr);
- // Close merge block and function and add function to module
- (void)builder.AddNullaryOp(0, spv::Op::OpReturn);
-
- output_func->AddBasicBlock(std::move(new_blk_ptr));
- output_func->SetFunctionEnd(EndFunction());
- context()->AddFunction(std::move(output_func));
-
- std::string name("stream_write_");
- name += std::to_string(param_cnt);
-
- context()->AddDebug2Inst(
- NewGlobalName(param2output_func_id_[param_cnt], name));
- }
- return param2output_func_id_[param_cnt];
-}
-
-uint32_t InstrumentPass::GetDirectReadFunctionId(uint32_t param_cnt) {
- uint32_t func_id = param2input_func_id_[param_cnt];
- if (func_id != 0) return func_id;
- // Create input function for param_cnt.
- func_id = TakeNextId();
- analysis::Integer* uint_type = GetInteger(32, false);
- std::vector<const analysis::Type*> param_types(param_cnt, uint_type);
-
- std::unique_ptr<Function> input_func =
- StartFunction(func_id, uint_type, param_types);
- std::vector<uint32_t> param_ids = AddParameters(*input_func, param_types);
-
- // Create block
- auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
- InstructionBuilder builder(
- context(), &*new_blk_ptr,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- // For each offset parameter, generate new offset with parameter, adding last
- // loaded value if it exists, and load value from input buffer at new offset.
- // Return last loaded value.
- uint32_t ibuf_type_id = GetInputBufferTypeId();
- uint32_t buf_id = GetInputBufferId();
- uint32_t buf_ptr_id = GetInputBufferPtrId();
- uint32_t last_value_id = 0;
- for (uint32_t p = 0; p < param_cnt; ++p) {
- uint32_t offset_id;
- if (p == 0) {
- offset_id = param_ids[0];
- } else {
- if (ibuf_type_id != GetUintId()) {
- last_value_id =
- builder.AddUnaryOp(GetUintId(), spv::Op::OpUConvert, last_value_id)
- ->result_id();
- }
- offset_id = builder.AddIAdd(GetUintId(), last_value_id, param_ids[p])
- ->result_id();
- }
- Instruction* ac_inst = builder.AddAccessChain(
- buf_ptr_id, buf_id,
- {builder.GetUintConstantId(kDebugInputDataOffset), offset_id});
- last_value_id =
- builder.AddLoad(ibuf_type_id, ac_inst->result_id())->result_id();
- }
- (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, last_value_id);
- // Close block and function and add function to module
- input_func->AddBasicBlock(std::move(new_blk_ptr));
- input_func->SetFunctionEnd(EndFunction());
- context()->AddFunction(std::move(input_func));
-
- std::string name("direct_read_");
- name += std::to_string(param_cnt);
- context()->AddDebug2Inst(NewGlobalName(func_id, name));
-
- param2input_func_id_[param_cnt] = func_id;
- return func_id;
-}
-
void InstrumentPass::SplitBlock(
BasicBlock::iterator inst_itr, UptrVectorIterator<BasicBlock> block_itr,
std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
@@ -1071,52 +653,54 @@
}
bool InstrumentPass::InstProcessEntryPointCallTree(InstProcessFunction& pfn) {
- // Make sure all entry points have the same execution model. Do not
- // instrument if they do not.
- // TODO(greg-lunarg): Handle mixed stages. Technically, a shader module
- // can contain entry points with different execution models, although
- // such modules will likely be rare as GLSL and HLSL are geared toward
- // one model per module. In such cases we will need
- // to clone any functions which are in the call trees of entrypoints
- // with differing execution models.
- spv::ExecutionModel stage = context()->GetStage();
- // Check for supported stages
- if (stage != spv::ExecutionModel::Vertex &&
- stage != spv::ExecutionModel::Fragment &&
- stage != spv::ExecutionModel::Geometry &&
- stage != spv::ExecutionModel::GLCompute &&
- stage != spv::ExecutionModel::TessellationControl &&
- stage != spv::ExecutionModel::TessellationEvaluation &&
- stage != spv::ExecutionModel::TaskNV &&
- stage != spv::ExecutionModel::MeshNV &&
- stage != spv::ExecutionModel::RayGenerationNV &&
- stage != spv::ExecutionModel::IntersectionNV &&
- stage != spv::ExecutionModel::AnyHitNV &&
- stage != spv::ExecutionModel::ClosestHitNV &&
- stage != spv::ExecutionModel::MissNV &&
- stage != spv::ExecutionModel::CallableNV &&
- stage != spv::ExecutionModel::TaskEXT &&
- stage != spv::ExecutionModel::MeshEXT) {
- if (consumer()) {
- std::string message = "Stage not supported by instrumentation";
- consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str());
+ uint32_t stage_id;
+ if (use_stage_info_) {
+ // Make sure all entry points have the same execution model. Do not
+ // instrument if they do not.
+ // TODO(greg-lunarg): Handle mixed stages. Technically, a shader module
+ // can contain entry points with different execution models, although
+ // such modules will likely be rare as GLSL and HLSL are geared toward
+ // one model per module. In such cases we will need
+ // to clone any functions which are in the call trees of entrypoints
+ // with differing execution models.
+ spv::ExecutionModel stage = context()->GetStage();
+ // Check for supported stages
+ if (stage != spv::ExecutionModel::Vertex &&
+ stage != spv::ExecutionModel::Fragment &&
+ stage != spv::ExecutionModel::Geometry &&
+ stage != spv::ExecutionModel::GLCompute &&
+ stage != spv::ExecutionModel::TessellationControl &&
+ stage != spv::ExecutionModel::TessellationEvaluation &&
+ stage != spv::ExecutionModel::TaskNV &&
+ stage != spv::ExecutionModel::MeshNV &&
+ stage != spv::ExecutionModel::RayGenerationNV &&
+ stage != spv::ExecutionModel::IntersectionNV &&
+ stage != spv::ExecutionModel::AnyHitNV &&
+ stage != spv::ExecutionModel::ClosestHitNV &&
+ stage != spv::ExecutionModel::MissNV &&
+ stage != spv::ExecutionModel::CallableNV &&
+ stage != spv::ExecutionModel::TaskEXT &&
+ stage != spv::ExecutionModel::MeshEXT) {
+ if (consumer()) {
+ std::string message = "Stage not supported by instrumentation";
+ consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str());
+ }
+ return false;
}
- return false;
+ stage_id = static_cast<uint32_t>(stage);
+ } else {
+ stage_id = 0;
}
// Add together the roots of all entry points
std::queue<uint32_t> roots;
for (auto& e : get_module()->entry_points()) {
roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
}
- bool modified = InstProcessCallTreeFromRoots(pfn, &roots, uint32_t(stage));
+ bool modified = InstProcessCallTreeFromRoots(pfn, &roots, stage_id);
return modified;
}
void InstrumentPass::InitializeInstrument() {
- output_buffer_id_ = 0;
- output_buffer_ptr_id_ = 0;
- input_buffer_ptr_id_ = 0;
- input_buffer_id_ = 0;
float_id_ = 0;
v4float_id_ = 0;
uint_id_ = 0;
diff --git a/third_party/SPIRV-Tools/source/opt/instrument_pass.h b/third_party/SPIRV-Tools/source/opt/instrument_pass.h
index 4bbbb09..e4408c9 100644
--- a/third_party/SPIRV-Tools/source/opt/instrument_pass.h
+++ b/third_party/SPIRV-Tools/source/opt/instrument_pass.h
@@ -55,14 +55,6 @@
namespace spvtools {
namespace opt {
-namespace {
-// Validation Ids
-// These are used to identify the general validation being done and map to
-// its output buffers.
-constexpr uint32_t kInstValidationIdBindless = 0;
-constexpr uint32_t kInstValidationIdBuffAddr = 1;
-constexpr uint32_t kInstValidationIdDebugPrintf = 2;
-} // namespace
class InstrumentPass : public Pass {
using cbb_ptr = const BasicBlock*;
@@ -85,13 +77,13 @@
// set |desc_set| for debug input and output buffers and writes |shader_id|
// into debug output records. |opt_direct_reads| indicates that the pass
// will see direct input buffer reads and should prepare to optimize them.
- InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id,
- bool opt_direct_reads = false)
+ InstrumentPass(uint32_t desc_set, uint32_t shader_id, bool opt_direct_reads,
+ bool use_stage_info)
: Pass(),
desc_set_(desc_set),
shader_id_(shader_id),
- validation_id_(validation_id),
- opt_direct_reads_(opt_direct_reads) {}
+ opt_direct_reads_(opt_direct_reads),
+ use_stage_info_(use_stage_info) {}
// Initialize state for instrumentation of module.
void InitializeInstrument();
@@ -113,108 +105,10 @@
void MovePostludeCode(UptrVectorIterator<BasicBlock> ref_block_itr,
BasicBlock* new_blk_ptr);
- // Generate instructions in |builder| which will atomically fetch and
- // increment the size of the debug output buffer stream of the current
- // validation and write a record to the end of the stream, if enough space
- // in the buffer remains. The record will contain the index of the function
- // and instruction within that function |func_idx, instruction_idx| which
- // generated the record. It will also contain additional information to
- // identify the instance of the shader, depending on the stage |stage_idx|
- // of the shader. Finally, the record will contain validation-specific
- // data contained in |validation_ids| which will identify the validation
- // error as well as the values involved in the error.
- //
- // The output buffer binding written to by the code generated by the function
- // is determined by the validation id specified when each specific
- // instrumentation pass is created.
- //
- // The output buffer is a sequence of 32-bit values with the following
- // format (where all elements are unsigned 32-bit unless otherwise noted):
- //
- // Size
- // Record0
- // Record1
- // Record2
- // ...
- //
- // Size is the number of 32-bit values that have been written or
- // attempted to be written to the output buffer, excluding the Size. It is
- // initialized to 0. If the size of attempts to write the buffer exceeds
- // the actual size of the buffer, it is possible that this field can exceed
- // the actual size of the buffer.
- //
- // Each Record* is a variable-length sequence of 32-bit values with the
- // following format defined using static const offsets in the .cpp file:
- //
- // Record Size
- // Shader ID
- // Instruction Index
- // Stage
- // Stage-specific Word 0
- // Stage-specific Word 1
- // ...
- // Validation Error Code
- // Validation-specific Word 0
- // Validation-specific Word 1
- // Validation-specific Word 2
- // ...
- //
- // Each record consists of three subsections: members common across all
- // validation, members specific to the stage, and members specific to a
- // validation.
- //
- // The Record Size is the number of 32-bit words in the record, including
- // the Record Size word.
- //
- // Shader ID is a value that identifies which shader has generated the
- // validation error. It is passed when the instrumentation pass is created.
- //
- // The Instruction Index is the position of the instruction within the
- // SPIR-V file which is in error.
- //
- // The Stage is the pipeline stage which has generated the error as defined
- // by the SpvExecutionModel_ enumeration. This is used to interpret the
- // following Stage-specific words.
- //
- // The Stage-specific Words identify which invocation of the shader generated
- // the error. Every stage will write a fixed number of words. Vertex shaders
- // will write the Vertex and Instance ID. Fragment shaders will write
- // FragCoord.xy. Compute shaders will write the GlobalInvocation ID.
- // The tessellation eval shader will write the Primitive ID and TessCoords.uv.
- // The tessellation control shader and geometry shader will write the
- // Primitive ID and Invocation ID.
- //
- // The Validation Error Code specifies the exact error which has occurred.
- // These are enumerated with the kInstError* static consts. This allows
- // multiple validation layers to use the same, single output buffer.
- //
- // The Validation-specific Words are a validation-specific number of 32-bit
- // words which give further information on the validation error that
- // occurred. These are documented further in each file containing the
- // validation-specific class which derives from this base class.
- //
- // Because the code that is generated checks against the size of the buffer
- // before writing, the size of the debug out buffer can be used by the
- // validation layer to control the number of error records that are written.
- void GenDebugStreamWrite(uint32_t instruction_idx, uint32_t stage_idx,
- const std::vector<uint32_t>& validation_ids,
- InstructionBuilder* builder);
-
// Return true if all instructions in |ids| are constants or spec constants.
bool AllConstant(const std::vector<uint32_t>& ids);
- // Generate in |builder| instructions to read the unsigned integer from the
- // input buffer specified by the offsets in |offset_ids|. Given offsets
- // o0, o1, ... oN, and input buffer ibuf, return the id for the value:
- //
- // ibuf[...ibuf[ibuf[o0]+o1]...+oN]
- //
- // The binding and the format of the input buffer is determined by each
- // specific validation, which is specified at the creation of the pass.
- uint32_t GenDebugDirectRead(const std::vector<uint32_t>& offset_ids,
- InstructionBuilder* builder);
-
- uint32_t GenReadFunctionCall(uint32_t func_id,
+ uint32_t GenReadFunctionCall(uint32_t return_id, uint32_t func_id,
const std::vector<uint32_t>& args,
InstructionBuilder* builder);
@@ -242,15 +136,6 @@
std::unique_ptr<Instruction> NewName(uint32_t id,
const std::string& name_str);
- // Set the name for a function or global variable, names will be
- // prefixed to identify which instrumentation pass generated them.
- std::unique_ptr<Instruction> NewGlobalName(uint32_t id,
- const std::string& name_str);
-
- // Set the name for a structure member
- std::unique_ptr<Instruction> NewMemberName(uint32_t id, uint32_t member_index,
- const std::string& name_str);
-
// Return id for 32-bit unsigned type
uint32_t GetUintId();
@@ -282,30 +167,9 @@
// Return pointer to type for runtime array of uint
analysis::RuntimeArray* GetUintRuntimeArrayType(uint32_t width);
- // Return id for buffer uint type
- uint32_t GetOutputBufferPtrId();
-
- // Return id for buffer uint type
- uint32_t GetInputBufferTypeId();
-
- // Return id for buffer uint type
- uint32_t GetInputBufferPtrId();
-
- // Return binding for output buffer for current validation.
- uint32_t GetOutputBufferBinding();
-
- // Return binding for input buffer for current validation.
- uint32_t GetInputBufferBinding();
-
// Add storage buffer extension if needed
void AddStorageBufferExt();
- // Return id for debug output buffer
- uint32_t GetOutputBufferId();
-
- // Return id for debug input buffer
- uint32_t GetInputBufferId();
-
// Return id for 32-bit float type
uint32_t GetFloatId();
@@ -321,15 +185,6 @@
// Return id for v3uint type
uint32_t GetVec3UintId();
- // Return id for output function. Define if it doesn't exist with
- // |val_spec_param_cnt| validation-specific uint32 parameters.
- uint32_t GetStreamWriteFunctionId(uint32_t stage_idx,
- uint32_t val_spec_param_cnt);
-
- // Return id for input function taking |param_cnt| uint32 parameters. Define
- // if it doesn't exist.
- uint32_t GetDirectReadFunctionId(uint32_t param_cnt);
-
// Split block |block_itr| into two new blocks where the second block
// contains |inst_itr| and place in |new_blocks|.
void SplitBlock(BasicBlock::iterator inst_itr,
@@ -349,40 +204,11 @@
std::queue<uint32_t>* roots,
uint32_t stage_idx);
- // Gen code into |builder| to write |field_value_id| into debug output
- // buffer at |base_offset_id| + |field_offset|.
- void GenDebugOutputFieldCode(uint32_t base_offset_id, uint32_t field_offset,
- uint32_t field_value_id,
- InstructionBuilder* builder);
-
- // Generate instructions into |builder| which will write the members
- // of the debug output record common for all stages and validations at
- // |base_off|.
- void GenCommonStreamWriteCode(uint32_t record_sz, uint32_t instruction_idx,
- uint32_t stage_idx, uint32_t base_off,
- InstructionBuilder* builder);
-
- // Generate instructions into |builder| which will write
- // |uint_frag_coord_id| at |component| of the record at |base_offset_id| of
- // the debug output buffer .
- void GenFragCoordEltDebugOutputCode(uint32_t base_offset_id,
- uint32_t uint_frag_coord_id,
- uint32_t component,
- InstructionBuilder* builder);
-
// Generate instructions into |builder| which will load |var_id| and return
// its result id.
uint32_t GenVarLoad(uint32_t var_id, InstructionBuilder* builder);
- // Generate instructions into |builder| which will load the uint |builtin_id|
- // and write it into the debug output buffer at |base_off| + |builtin_off|.
- void GenBuiltinOutputCode(uint32_t builtin_id, uint32_t builtin_off,
- uint32_t base_off, InstructionBuilder* builder);
-
- // Generate instructions into |builder| which will write the |stage_idx|-
- // specific members of the debug output stream at |base_off|.
- void GenStageStreamWriteCode(uint32_t stage_idx, uint32_t base_off,
- InstructionBuilder* builder);
+ uint32_t GenStageInfo(uint32_t stage_idx, InstructionBuilder* builder);
// Return true if instruction must be in the same block that its result
// is used.
@@ -418,62 +244,47 @@
// Map from instruction's unique id to offset in original file.
std::unordered_map<uint32_t, uint32_t> uid2offset_;
- // result id for OpConstantFalse
- uint32_t validation_id_;
-
- // id for output buffer variable
- uint32_t output_buffer_id_;
-
- // ptr type id for output buffer element
- uint32_t output_buffer_ptr_id_;
-
- // ptr type id for input buffer element
- uint32_t input_buffer_ptr_id_;
-
// id for debug output function
std::unordered_map<uint32_t, uint32_t> param2output_func_id_;
// ids for debug input functions
std::unordered_map<uint32_t, uint32_t> param2input_func_id_;
- // id for input buffer variable
- uint32_t input_buffer_id_;
-
// id for 32-bit float type
- uint32_t float_id_;
+ uint32_t float_id_{0};
// id for v4float type
- uint32_t v4float_id_;
+ uint32_t v4float_id_{0};
// id for v4uint type
- uint32_t v4uint_id_;
+ uint32_t v4uint_id_{0};
// id for v3uint type
- uint32_t v3uint_id_;
+ uint32_t v3uint_id_{0};
// id for 32-bit unsigned type
- uint32_t uint_id_;
+ uint32_t uint_id_{0};
// id for 64-bit unsigned type
- uint32_t uint64_id_;
+ uint32_t uint64_id_{0};
// id for 8-bit unsigned type
- uint32_t uint8_id_;
+ uint32_t uint8_id_{0};
// id for bool type
- uint32_t bool_id_;
+ uint32_t bool_id_{0};
// id for void type
- uint32_t void_id_;
+ uint32_t void_id_{0};
// boolean to remember storage buffer extension
- bool storage_buffer_ext_defined_;
+ bool storage_buffer_ext_defined_{false};
// runtime array of uint type
- analysis::RuntimeArray* uint64_rarr_ty_;
+ analysis::RuntimeArray* uint64_rarr_ty_{nullptr};
// runtime array of uint type
- analysis::RuntimeArray* uint32_rarr_ty_;
+ analysis::RuntimeArray* uint32_rarr_ty_{nullptr};
// Pre-instrumentation same-block insts
std::unordered_map<uint32_t, Instruction*> same_block_pre_;
@@ -498,11 +309,15 @@
std::unordered_map<std::vector<uint32_t>, uint32_t, vector_hash_> call2id_;
// Function currently being instrumented
- Function* curr_func_;
+ Function* curr_func_{nullptr};
// Optimize direct debug input buffer reads. Specifically, move all such
// reads with constant args to first block and reuse them.
- bool opt_direct_reads_;
+ const bool opt_direct_reads_;
+
+ // Set true if the instrumentation needs to know the current stage.
+ // Note that this does not work with multi-stage modules.
+ const bool use_stage_info_;
};
} // namespace opt
diff --git a/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.cpp b/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.cpp
new file mode 100644
index 0000000..642e2d2
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.cpp
@@ -0,0 +1,493 @@
+// Copyright (c) 2023 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/invocation_interlock_placement_pass.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <functional>
+#include <optional>
+#include <queue>
+#include <stack>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "source/enum_set.h"
+#include "source/enum_string_mapping.h"
+#include "source/opt/ir_context.h"
+#include "source/opt/reflect.h"
+#include "source/spirv_target_env.h"
+#include "source/util/string_utils.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+constexpr uint32_t kEntryPointExecutionModelInIdx = 0;
+constexpr uint32_t kEntryPointFunctionIdInIdx = 1;
+constexpr uint32_t kFunctionCallFunctionIdInIdx = 0;
+} // namespace
+
+bool InvocationInterlockPlacementPass::hasSingleNextBlock(uint32_t block_id,
+ bool reverse_cfg) {
+ if (reverse_cfg) {
+ // We are traversing forward, so check whether there is a single successor.
+ BasicBlock* block = cfg()->block(block_id);
+
+ switch (block->tail()->opcode()) {
+ case spv::Op::OpBranchConditional:
+ return false;
+ case spv::Op::OpSwitch:
+ return block->tail()->NumInOperandWords() == 1;
+ default:
+ return !block->tail()->IsReturnOrAbort();
+ }
+ } else {
+ // We are traversing backward, so check whether there is a single
+ // predecessor.
+ return cfg()->preds(block_id).size() == 1;
+ }
+}
+
+void InvocationInterlockPlacementPass::forEachNext(
+ uint32_t block_id, bool reverse_cfg, std::function<void(uint32_t)> f) {
+ if (reverse_cfg) {
+ BasicBlock* block = cfg()->block(block_id);
+
+ block->ForEachSuccessorLabel([f](uint32_t succ_id) { f(succ_id); });
+ } else {
+ for (uint32_t pred_id : cfg()->preds(block_id)) {
+ f(pred_id);
+ }
+ }
+}
+
+void InvocationInterlockPlacementPass::addInstructionAtBlockBoundary(
+ BasicBlock* block, spv::Op opcode, bool at_end) {
+ if (at_end) {
+ assert(block->begin()->opcode() != spv::Op::OpPhi &&
+ "addInstructionAtBlockBoundary expects to be called with at_end == "
+ "true only if there is a single successor to block");
+ // Insert a begin instruction at the end of the block.
+ Instruction* begin_inst = new Instruction(context(), opcode);
+ begin_inst->InsertAfter(&*--block->tail());
+ } else {
+ assert(block->begin()->opcode() != spv::Op::OpPhi &&
+ "addInstructionAtBlockBoundary expects to be called with at_end == "
+ "false only if there is a single predecessor to block");
+ // Insert an end instruction at the beginning of the block.
+ Instruction* end_inst = new Instruction(context(), opcode);
+ end_inst->InsertBefore(&*block->begin());
+ }
+}
+
+bool InvocationInterlockPlacementPass::killDuplicateBegin(BasicBlock* block) {
+ bool found = false;
+
+ return context()->KillInstructionIf(
+ block->begin(), block->end(), [&found](Instruction* inst) {
+ if (inst->opcode() == spv::Op::OpBeginInvocationInterlockEXT) {
+ if (found) {
+ return true;
+ }
+ found = true;
+ }
+ return false;
+ });
+}
+
+bool InvocationInterlockPlacementPass::killDuplicateEnd(BasicBlock* block) {
+ std::vector<Instruction*> to_kill;
+ block->ForEachInst([&to_kill](Instruction* inst) {
+ if (inst->opcode() == spv::Op::OpEndInvocationInterlockEXT) {
+ to_kill.push_back(inst);
+ }
+ });
+
+ if (to_kill.size() <= 1) {
+ return false;
+ }
+
+ to_kill.pop_back();
+
+ for (Instruction* inst : to_kill) {
+ context()->KillInst(inst);
+ }
+
+ return true;
+}
+
+void InvocationInterlockPlacementPass::recordBeginOrEndInFunction(
+ Function* func) {
+ if (extracted_functions_.count(func)) {
+ return;
+ }
+
+ bool had_begin = false;
+ bool had_end = false;
+
+ func->ForEachInst([this, &had_begin, &had_end](Instruction* inst) {
+ switch (inst->opcode()) {
+ case spv::Op::OpBeginInvocationInterlockEXT:
+ had_begin = true;
+ break;
+ case spv::Op::OpEndInvocationInterlockEXT:
+ had_end = true;
+ break;
+ case spv::Op::OpFunctionCall: {
+ uint32_t function_id =
+ inst->GetSingleWordInOperand(kFunctionCallFunctionIdInIdx);
+ Function* inner_func = context()->GetFunction(function_id);
+ recordBeginOrEndInFunction(inner_func);
+ ExtractionResult result = extracted_functions_[inner_func];
+ had_begin = had_begin || result.had_begin;
+ had_end = had_end || result.had_end;
+ break;
+ }
+ default:
+ break;
+ }
+ });
+
+ ExtractionResult result = {had_begin, had_end};
+ extracted_functions_[func] = result;
+}
+
+bool InvocationInterlockPlacementPass::
+ removeBeginAndEndInstructionsFromFunction(Function* func) {
+ bool modified = false;
+ func->ForEachInst([this, &modified](Instruction* inst) {
+ switch (inst->opcode()) {
+ case spv::Op::OpBeginInvocationInterlockEXT:
+ context()->KillInst(inst);
+ modified = true;
+ break;
+ case spv::Op::OpEndInvocationInterlockEXT:
+ context()->KillInst(inst);
+ modified = true;
+ break;
+ default:
+ break;
+ }
+ });
+ return modified;
+}
+
+bool InvocationInterlockPlacementPass::extractInstructionsFromCalls(
+ std::vector<BasicBlock*> blocks) {
+ bool modified = false;
+
+ for (BasicBlock* block : blocks) {
+ block->ForEachInst([this, &modified](Instruction* inst) {
+ if (inst->opcode() == spv::Op::OpFunctionCall) {
+ uint32_t function_id =
+ inst->GetSingleWordInOperand(kFunctionCallFunctionIdInIdx);
+ Function* func = context()->GetFunction(function_id);
+ ExtractionResult result = extracted_functions_[func];
+
+ if (result.had_begin) {
+ Instruction* new_inst = new Instruction(
+ context(), spv::Op::OpBeginInvocationInterlockEXT);
+ new_inst->InsertBefore(inst);
+ modified = true;
+ }
+ if (result.had_end) {
+ Instruction* new_inst =
+ new Instruction(context(), spv::Op::OpEndInvocationInterlockEXT);
+ new_inst->InsertAfter(inst);
+ modified = true;
+ }
+ }
+ });
+ }
+ return modified;
+}
+
+void InvocationInterlockPlacementPass::recordExistingBeginAndEndBlock(
+ std::vector<BasicBlock*> blocks) {
+ for (BasicBlock* block : blocks) {
+ block->ForEachInst([this, block](Instruction* inst) {
+ switch (inst->opcode()) {
+ case spv::Op::OpBeginInvocationInterlockEXT:
+ begin_.insert(block->id());
+ break;
+ case spv::Op::OpEndInvocationInterlockEXT:
+ end_.insert(block->id());
+ break;
+ default:
+ break;
+ }
+ });
+ }
+}
+
+InvocationInterlockPlacementPass::BlockSet
+InvocationInterlockPlacementPass::computeReachableBlocks(
+ BlockSet& previous_inside, const BlockSet& starting_nodes,
+ bool reverse_cfg) {
+ BlockSet inside = starting_nodes;
+
+ std::deque<uint32_t> worklist;
+ worklist.insert(worklist.begin(), starting_nodes.begin(),
+ starting_nodes.end());
+
+ while (!worklist.empty()) {
+ uint32_t block_id = worklist.front();
+ worklist.pop_front();
+
+ forEachNext(block_id, reverse_cfg,
+ [&inside, &previous_inside, &worklist](uint32_t next_id) {
+ previous_inside.insert(next_id);
+ if (inside.insert(next_id).second) {
+ worklist.push_back(next_id);
+ }
+ });
+ }
+
+ return inside;
+}
+
+bool InvocationInterlockPlacementPass::removeUnneededInstructions(
+ BasicBlock* block) {
+ bool modified = false;
+ if (!predecessors_after_begin_.count(block->id()) &&
+ after_begin_.count(block->id())) {
+ // None of the previous blocks are in the critical section, but this block
+ // is. This can only happen if this block already has at least one begin
+ // instruction. Leave the first begin instruction, and remove any others.
+ modified |= killDuplicateBegin(block);
+ } else if (predecessors_after_begin_.count(block->id())) {
+ // At least one previous block is in the critical section; remove all
+ // begin instructions in this block.
+ modified |= context()->KillInstructionIf(
+ block->begin(), block->end(), [](Instruction* inst) {
+ return inst->opcode() == spv::Op::OpBeginInvocationInterlockEXT;
+ });
+ }
+
+ if (!successors_before_end_.count(block->id()) &&
+ before_end_.count(block->id())) {
+ // Same as above
+ modified |= killDuplicateEnd(block);
+ } else if (successors_before_end_.count(block->id())) {
+ modified |= context()->KillInstructionIf(
+ block->begin(), block->end(), [](Instruction* inst) {
+ return inst->opcode() == spv::Op::OpEndInvocationInterlockEXT;
+ });
+ }
+ return modified;
+}
+
+BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block,
+ uint32_t succ_id) {
+ // Create a new block to replace the critical edge.
+ auto new_succ_temp = MakeUnique<BasicBlock>(
+ MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, TakeNextId(),
+ std::initializer_list<Operand>{}));
+ auto* new_succ = new_succ_temp.get();
+
+ // Insert the new block into the function.
+ block->GetParent()->InsertBasicBlockAfter(std::move(new_succ_temp), block);
+
+ new_succ->AddInstruction(MakeUnique<Instruction>(
+ context(), spv::Op::OpBranch, 0, 0,
+ std::initializer_list<Operand>{
+ Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {succ_id})}));
+
+ assert(block->tail()->opcode() == spv::Op::OpBranchConditional ||
+ block->tail()->opcode() == spv::Op::OpSwitch);
+
+ // Update the first branch to successor to instead branch to
+ // the new successor. If there are multiple edges, we arbitrarily choose the
+ // first time it appears in the list. The other edges to `succ_id` will have
+ // to be split by another call to `splitEdge`.
+ block->tail()->WhileEachInId([new_succ, succ_id](uint32_t* branch_id) {
+ if (*branch_id == succ_id) {
+ *branch_id = new_succ->id();
+ return false;
+ }
+ return true;
+ });
+
+ return new_succ;
+}
+
+bool InvocationInterlockPlacementPass::placeInstructionsForEdge(
+ BasicBlock* block, uint32_t next_id, BlockSet& inside,
+ BlockSet& previous_inside, spv::Op opcode, bool reverse_cfg) {
+ bool modified = false;
+
+ if (previous_inside.count(next_id) && !inside.count(block->id())) {
+ // This block is not in the critical section but the next has at least one
+ // other previous block that is, so this block should be enter it as well.
+ // We need to add begin or end instructions to the edge.
+
+ modified = true;
+
+ if (hasSingleNextBlock(block->id(), reverse_cfg)) {
+ // This is the only next block.
+
+ // Additionally, because `next_id` is in `previous_inside`, we know that
+ // `next_id` has at least one previous block in `inside`. And because
+ // 'block` is not in `inside`, that means the `next_id` has to have at
+ // least one other previous block in `inside`.
+
+ // This is solely for a debug assertion. It is essentially recomputing the
+ // value of `previous_inside` to verify that it was computed correctly
+ // such that the above statement is true.
+ bool next_has_previous_inside = false;
+ // By passing !reverse_cfg to forEachNext, we are actually iterating over
+ // the previous blocks.
+ forEachNext(next_id, !reverse_cfg,
+ [&next_has_previous_inside, inside](uint32_t previous_id) {
+ if (inside.count(previous_id)) {
+ next_has_previous_inside = true;
+ }
+ });
+ assert(next_has_previous_inside &&
+ "`previous_inside` must be the set of blocks with at least one "
+ "previous block in `inside`");
+
+ addInstructionAtBlockBoundary(block, opcode, reverse_cfg);
+ } else {
+ // This block has multiple next blocks. Split the edge and insert the
+ // instruction in the new next block.
+ BasicBlock* new_branch;
+ if (reverse_cfg) {
+ new_branch = splitEdge(block, next_id);
+ } else {
+ new_branch = splitEdge(cfg()->block(next_id), block->id());
+ }
+
+ auto inst = new Instruction(context(), opcode);
+ inst->InsertBefore(&*new_branch->tail());
+ }
+ }
+
+ return modified;
+}
+
+bool InvocationInterlockPlacementPass::placeInstructions(BasicBlock* block) {
+ bool modified = false;
+
+ block->ForEachSuccessorLabel([this, block, &modified](uint32_t succ_id) {
+ modified |= placeInstructionsForEdge(
+ block, succ_id, after_begin_, predecessors_after_begin_,
+ spv::Op::OpBeginInvocationInterlockEXT, /* reverse_cfg= */ true);
+ modified |= placeInstructionsForEdge(cfg()->block(succ_id), block->id(),
+ before_end_, successors_before_end_,
+ spv::Op::OpEndInvocationInterlockEXT,
+ /* reverse_cfg= */ false);
+ });
+
+ return modified;
+}
+
+bool InvocationInterlockPlacementPass::processFragmentShaderEntry(
+ Function* entry_func) {
+ bool modified = false;
+
+ // Save the original order of blocks in the function, so we don't iterate over
+ // newly-added blocks.
+ std::vector<BasicBlock*> original_blocks;
+ for (auto bi = entry_func->begin(); bi != entry_func->end(); ++bi) {
+ original_blocks.push_back(&*bi);
+ }
+
+ modified |= extractInstructionsFromCalls(original_blocks);
+ recordExistingBeginAndEndBlock(original_blocks);
+
+ after_begin_ = computeReachableBlocks(predecessors_after_begin_, begin_,
+ /* reverse_cfg= */ true);
+ before_end_ = computeReachableBlocks(successors_before_end_, end_,
+ /* reverse_cfg= */ false);
+
+ for (BasicBlock* block : original_blocks) {
+ modified |= removeUnneededInstructions(block);
+ modified |= placeInstructions(block);
+ }
+ return modified;
+}
+
+bool InvocationInterlockPlacementPass::isFragmentShaderInterlockEnabled() {
+ if (!context()->get_feature_mgr()->HasExtension(
+ kSPV_EXT_fragment_shader_interlock)) {
+ return false;
+ }
+
+ if (context()->get_feature_mgr()->HasCapability(
+ spv::Capability::FragmentShaderSampleInterlockEXT)) {
+ return true;
+ }
+
+ if (context()->get_feature_mgr()->HasCapability(
+ spv::Capability::FragmentShaderPixelInterlockEXT)) {
+ return true;
+ }
+
+ if (context()->get_feature_mgr()->HasCapability(
+ spv::Capability::FragmentShaderShadingRateInterlockEXT)) {
+ return true;
+ }
+
+ return false;
+}
+
+Pass::Status InvocationInterlockPlacementPass::Process() {
+ // Skip this pass if the necessary extension or capability is missing
+ if (!isFragmentShaderInterlockEnabled()) {
+ return Status::SuccessWithoutChange;
+ }
+
+ bool modified = false;
+
+ std::unordered_set<Function*> entry_points;
+ for (Instruction& entry_inst : context()->module()->entry_points()) {
+ uint32_t entry_id =
+ entry_inst.GetSingleWordInOperand(kEntryPointFunctionIdInIdx);
+ entry_points.insert(context()->GetFunction(entry_id));
+ }
+
+ for (auto fi = context()->module()->begin(); fi != context()->module()->end();
+ ++fi) {
+ Function* func = &*fi;
+ recordBeginOrEndInFunction(func);
+ if (!entry_points.count(func) && extracted_functions_.count(func)) {
+ modified |= removeBeginAndEndInstructionsFromFunction(func);
+ }
+ }
+
+ for (Instruction& entry_inst : context()->module()->entry_points()) {
+ uint32_t entry_id =
+ entry_inst.GetSingleWordInOperand(kEntryPointFunctionIdInIdx);
+ Function* entry_func = context()->GetFunction(entry_id);
+
+ auto execution_model = spv::ExecutionModel(
+ entry_inst.GetSingleWordInOperand(kEntryPointExecutionModelInIdx));
+
+ if (execution_model != spv::ExecutionModel::Fragment) {
+ continue;
+ }
+
+ modified |= processFragmentShaderEntry(entry_func);
+ }
+
+ return modified ? Pass::Status::SuccessWithChange
+ : Pass::Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.h b/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.h
new file mode 100644
index 0000000..4e85be8
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/invocation_interlock_placement_pass.h
@@ -0,0 +1,158 @@
+// Copyright (c) 2023 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_DEDUPE_INTERLOCK_INVOCATION_PASS_H_
+#define SOURCE_OPT_DEDUPE_INTERLOCK_INVOCATION_PASS_H_
+
+#include <algorithm>
+#include <array>
+#include <functional>
+#include <optional>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "source/enum_set.h"
+#include "source/extensions.h"
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+#include "source/spirv_target_env.h"
+
+namespace spvtools {
+namespace opt {
+
+// This pass will ensure that an entry point will only have at most one
+// OpBeginInterlockInvocationEXT and one OpEndInterlockInvocationEXT, in that
+// order
+class InvocationInterlockPlacementPass : public Pass {
+ public:
+ InvocationInterlockPlacementPass() {}
+ InvocationInterlockPlacementPass(const InvocationInterlockPlacementPass&) =
+ delete;
+ InvocationInterlockPlacementPass(InvocationInterlockPlacementPass&&) = delete;
+
+ const char* name() const override { return "dedupe-interlock-invocation"; }
+ Status Process() override;
+
+ private:
+ using BlockSet = std::unordered_set<uint32_t>;
+
+ // Specifies whether a function originally had a begin or end instruction.
+ struct ExtractionResult {
+ bool had_begin : 1;
+ bool had_end : 2;
+ };
+
+ // Check if a block has only a single next block, depending on the directing
+ // that we are traversing the CFG. If reverse_cfg is true, we are walking
+ // forward through the CFG, and will return if the block has only one
+ // successor. Otherwise, we are walking backward through the CFG, and will
+ // return if the block has only one predecessor.
+ bool hasSingleNextBlock(uint32_t block_id, bool reverse_cfg);
+
+ // Iterate over each of a block's predecessors or successors, depending on
+ // direction. If reverse_cfg is true, we are walking forward through the CFG,
+ // and need to iterate over the successors. Otherwise, we are walking backward
+ // through the CFG, and need to iterate over the predecessors.
+ void forEachNext(uint32_t block_id, bool reverse_cfg,
+ std::function<void(uint32_t)> f);
+
+ // Add either a begin or end instruction to the edge of the basic block. If
+ // at_end is true, add the instruction to the end of the block; otherwise add
+ // the instruction to the beginning of the basic block.
+ void addInstructionAtBlockBoundary(BasicBlock* block, spv::Op opcode,
+ bool at_end);
+
+ // Remove every OpBeginInvocationInterlockEXT instruction in block after the
+ // first. Returns whether any instructions were removed.
+ bool killDuplicateBegin(BasicBlock* block);
+ // Remove every OpBeginInvocationInterlockEXT instruction in block before the
+ // last. Returns whether any instructions were removed.
+ bool killDuplicateEnd(BasicBlock* block);
+
+ // Records whether a function will potentially execute a begin or end
+ // instruction.
+ void recordBeginOrEndInFunction(Function* func);
+
+ // Recursively removes any begin or end instructions from func and any
+ // function func calls. Returns whether any instructions were removed.
+ bool removeBeginAndEndInstructionsFromFunction(Function* func);
+
+ // For every function call in any of the passed blocks, move any begin or end
+ // instructions outside of the function call. Returns whether any extractions
+ // occurred.
+ bool extractInstructionsFromCalls(std::vector<BasicBlock*> blocks);
+
+ // Finds the sets of blocks that contain OpBeginInvocationInterlockEXT and
+ // OpEndInvocationInterlockEXT, storing them in the member variables begin_
+ // and end_ respectively.
+ void recordExistingBeginAndEndBlock(std::vector<BasicBlock*> blocks);
+
+ // Compute the set of blocks including or after the barrier instruction, and
+ // the set of blocks with any previous blocks inside the barrier instruction.
+ // If reverse_cfg is true, move forward through the CFG, computing
+ // after_begin_ and predecessors_after_begin_computing after_begin_ and
+ // predecessors_after_begin_, otherwise, move backward through the CFG,
+ // computing before_end_ and successors_before_end_.
+ BlockSet computeReachableBlocks(BlockSet& in_set,
+ const BlockSet& starting_nodes,
+ bool reverse_cfg);
+
+ // Remove unneeded begin and end instructions in block.
+ bool removeUnneededInstructions(BasicBlock* block);
+
+ // Given a block which branches to multiple successors, and a specific
+ // successor, creates a new empty block, and update the branch instruction to
+ // branch to the new block instead.
+ BasicBlock* splitEdge(BasicBlock* block, uint32_t succ_id);
+
+ // For the edge from block to next_id, places a begin or end instruction on
+ // the edge, based on the direction we are walking the CFG, specified in
+ // reverse_cfg.
+ bool placeInstructionsForEdge(BasicBlock* block, uint32_t next_id,
+ BlockSet& inside, BlockSet& previous_inside,
+ spv::Op opcode, bool reverse_cfg);
+ // Calls placeInstructionsForEdge for each edge in block.
+ bool placeInstructions(BasicBlock* block);
+
+ // Processes a single fragment shader entry function.
+ bool processFragmentShaderEntry(Function* entry_func);
+
+ // Returns whether the module has the SPV_EXT_fragment_shader_interlock
+ // extension and one of the FragmentShader*InterlockEXT capabilities.
+ bool isFragmentShaderInterlockEnabled();
+
+ // Maps a function to whether that function originally held a begin or end
+ // instruction.
+ std::unordered_map<Function*, ExtractionResult> extracted_functions_;
+
+ // The set of blocks which have an OpBeginInvocationInterlockEXT instruction.
+ BlockSet begin_;
+ // The set of blocks which have an OpEndInvocationInterlockEXT instruction.
+ BlockSet end_;
+ // The set of blocks which either have a begin instruction, or have a
+ // predecessor which has a begin instruction.
+ BlockSet after_begin_;
+ // The set of blocks which either have an end instruction, or have a successor
+ // which have an end instruction.
+ BlockSet before_end_;
+ // The set of blocks which have a predecessor in after_begin_.
+ BlockSet predecessors_after_begin_;
+ // The set of blocks which have a successor in before_end_.
+ BlockSet successors_before_end_;
+};
+
+} // namespace opt
+} // namespace spvtools
+#endif // SOURCE_OPT_DEDUPE_INTERLOCK_INVOCATION_PASS_H_
diff --git a/third_party/SPIRV-Tools/source/opt/ir_builder.h b/third_party/SPIRV-Tools/source/opt/ir_builder.h
index 48e08ee..f3e0afc 100644
--- a/third_party/SPIRV-Tools/source/opt/ir_builder.h
+++ b/third_party/SPIRV-Tools/source/opt/ir_builder.h
@@ -440,6 +440,22 @@
return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
}
+ Instruction* GetBoolConstant(bool value) {
+ analysis::Bool type;
+ uint32_t type_id = GetContext()->get_type_mgr()->GetTypeInstruction(&type);
+ analysis::Type* rebuilt_type =
+ GetContext()->get_type_mgr()->GetType(type_id);
+ uint32_t word = value;
+ const analysis::Constant* constant =
+ GetContext()->get_constant_mgr()->GetConstant(rebuilt_type, {word});
+ return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
+ }
+
+ uint32_t GetBoolConstantId(bool value) {
+ Instruction* inst = GetBoolConstant(value);
+ return (inst != nullptr ? inst->result_id() : 0);
+ }
+
Instruction* AddCompositeExtract(uint32_t type, uint32_t id_of_composite,
const std::vector<uint32_t>& index_list) {
std::vector<Operand> operands;
diff --git a/third_party/SPIRV-Tools/source/opt/ir_context.cpp b/third_party/SPIRV-Tools/source/opt/ir_context.cpp
index 26501c2..d864b7c 100644
--- a/third_party/SPIRV-Tools/source/opt/ir_context.cpp
+++ b/third_party/SPIRV-Tools/source/opt/ir_context.cpp
@@ -88,6 +88,9 @@
if (set & kAnalysisDebugInfo) {
BuildDebugInfoManager();
}
+ if (set & kAnalysisLiveness) {
+ BuildLivenessManager();
+ }
}
void IRContext::InvalidateAnalysesExceptFor(
@@ -220,6 +223,28 @@
return next_instruction;
}
+bool IRContext::KillInstructionIf(Module::inst_iterator begin,
+ Module::inst_iterator end,
+ std::function<bool(Instruction*)> condition) {
+ bool removed = false;
+ for (auto it = begin; it != end;) {
+ if (!condition(&*it)) {
+ ++it;
+ continue;
+ }
+
+ removed = true;
+ // `it` is an iterator on an intrusive list. Next is invalidated on the
+ // current node when an instruction is killed. The iterator must be moved
+ // forward before deleting the node.
+ auto instruction = &*it;
+ ++it;
+ KillInst(instruction);
+ }
+
+ return removed;
+}
+
void IRContext::CollectNonSemanticTree(
Instruction* inst, std::unordered_set<Instruction*>* to_kill) {
if (!inst->HasResultId()) return;
@@ -251,6 +276,36 @@
return false;
}
+bool IRContext::RemoveCapability(spv::Capability capability) {
+ const bool removed = KillInstructionIf(
+ module()->capability_begin(), module()->capability_end(),
+ [capability](Instruction* inst) {
+ return static_cast<spv::Capability>(inst->GetSingleWordOperand(0)) ==
+ capability;
+ });
+
+ if (removed && feature_mgr_ != nullptr) {
+ feature_mgr_->RemoveCapability(capability);
+ }
+
+ return removed;
+}
+
+bool IRContext::RemoveExtension(Extension extension) {
+ const std::string_view extensionName = ExtensionToString(extension);
+ const bool removed = KillInstructionIf(
+ module()->extension_begin(), module()->extension_end(),
+ [&extensionName](Instruction* inst) {
+ return inst->GetOperand(0).AsString() == extensionName;
+ });
+
+ if (removed && feature_mgr_ != nullptr) {
+ feature_mgr_->RemoveExtension(extension);
+ }
+
+ return removed;
+}
+
bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) {
return ReplaceAllUsesWithPredicate(before, after,
[](Instruction*) { return true; });
@@ -718,9 +773,9 @@
}
void IRContext::InitializeCombinators() {
- get_feature_mgr()->GetCapabilities()->ForEach([this](spv::Capability cap) {
- AddCombinatorsForCapability(uint32_t(cap));
- });
+ for (auto capability : get_feature_mgr()->GetCapabilities()) {
+ AddCombinatorsForCapability(uint32_t(capability));
+ }
for (auto& extension : module()->ext_inst_imports()) {
AddCombinatorsForExtension(&extension);
diff --git a/third_party/SPIRV-Tools/source/opt/ir_context.h b/third_party/SPIRV-Tools/source/opt/ir_context.h
index 8419ee7..ef7c458 100644
--- a/third_party/SPIRV-Tools/source/opt/ir_context.h
+++ b/third_party/SPIRV-Tools/source/opt/ir_context.h
@@ -27,6 +27,7 @@
#include <vector>
#include "source/assembly_grammar.h"
+#include "source/enum_string_mapping.h"
#include "source/opt/cfg.h"
#include "source/opt/constants.h"
#include "source/opt/debug_info_manager.h"
@@ -83,7 +84,7 @@
kAnalysisTypes = 1 << 15,
kAnalysisDebugInfo = 1 << 16,
kAnalysisLiveness = 1 << 17,
- kAnalysisEnd = 1 << 17
+ kAnalysisEnd = 1 << 18
};
using ProcessFunction = std::function<bool(Function*)>;
@@ -153,13 +154,19 @@
inline IteratorRange<Module::inst_iterator> capabilities();
inline IteratorRange<Module::const_inst_iterator> capabilities() const;
+ // Iterators for extensions instructions contained in this module.
+ inline Module::inst_iterator extension_begin();
+ inline Module::inst_iterator extension_end();
+ inline IteratorRange<Module::inst_iterator> extensions();
+ inline IteratorRange<Module::const_inst_iterator> extensions() const;
+
// Iterators for types, constants and global variables instructions.
inline Module::inst_iterator types_values_begin();
inline Module::inst_iterator types_values_end();
inline IteratorRange<Module::inst_iterator> types_values();
inline IteratorRange<Module::const_inst_iterator> types_values() const;
- // Iterators for extension instructions contained in this module.
+ // Iterators for ext_inst import instructions contained in this module.
inline Module::inst_iterator ext_inst_import_begin();
inline Module::inst_iterator ext_inst_import_end();
inline IteratorRange<Module::inst_iterator> ext_inst_imports();
@@ -204,17 +211,26 @@
// Add |capability| to the module, if it is not already enabled.
inline void AddCapability(spv::Capability capability);
-
// Appends a capability instruction to this module.
inline void AddCapability(std::unique_ptr<Instruction>&& c);
+ // Removes instruction declaring `capability` from this module.
+ // Returns true if the capability was removed, false otherwise.
+ bool RemoveCapability(spv::Capability capability);
+
// Appends an extension instruction to this module.
inline void AddExtension(const std::string& ext_name);
inline void AddExtension(std::unique_ptr<Instruction>&& e);
+ // Removes instruction declaring `extension` from this module.
+ // Returns true if the extension was removed, false otherwise.
+ bool RemoveExtension(Extension extension);
+
// Appends an extended instruction set instruction to this module.
inline void AddExtInstImport(const std::string& name);
inline void AddExtInstImport(std::unique_ptr<Instruction>&& e);
// Set the memory model for this module.
inline void SetMemoryModel(std::unique_ptr<Instruction>&& m);
+ // Get the memory model for this module.
+ inline const Instruction* GetMemoryModel() const;
// Appends an entry point instruction to this module.
inline void AddEntryPoint(std::unique_ptr<Instruction>&& e);
// Appends an execution mode instruction to this module.
@@ -238,6 +254,8 @@
inline void AddType(std::unique_ptr<Instruction>&& t);
// Appends a constant, global variable, or OpUndef instruction to this module.
inline void AddGlobalValue(std::unique_ptr<Instruction>&& v);
+ // Prepends a function declaration to this module.
+ inline void AddFunctionDeclaration(std::unique_ptr<Function>&& f);
// Appends a function to this module.
inline void AddFunction(std::unique_ptr<Function>&& f);
@@ -422,6 +440,15 @@
// instruction exists.
Instruction* KillInst(Instruction* inst);
+ // Deletes all the instruction in the range [`begin`; `end`[, for which the
+ // unary predicate `condition` returned true.
+ // Returns true if at least one instruction was removed, false otherwise.
+ //
+ // Pointer and iterator pointing to the deleted instructions become invalid.
+ // However other pointers and iterators are still valid.
+ bool KillInstructionIf(Module::inst_iterator begin, Module::inst_iterator end,
+ std::function<bool(Instruction*)> condition);
+
// Collects the non-semantic instruction tree that uses |inst|'s result id
// to be killed later.
void CollectNonSemanticTree(Instruction* inst,
@@ -772,7 +799,8 @@
// Analyzes the features in the owned module. Builds the manager if required.
void AnalyzeFeatures() {
- feature_mgr_ = MakeUnique<FeatureManager>(grammar_);
+ feature_mgr_ =
+ std::unique_ptr<FeatureManager>(new FeatureManager(grammar_));
feature_mgr_->Analyze(module());
}
@@ -964,6 +992,22 @@
return ((const Module*)module())->capabilities();
}
+Module::inst_iterator IRContext::extension_begin() {
+ return module()->extension_begin();
+}
+
+Module::inst_iterator IRContext::extension_end() {
+ return module()->extension_end();
+}
+
+IteratorRange<Module::inst_iterator> IRContext::extensions() {
+ return module()->extensions();
+}
+
+IteratorRange<Module::const_inst_iterator> IRContext::extensions() const {
+ return ((const Module*)module())->extensions();
+}
+
Module::inst_iterator IRContext::types_values_begin() {
return module()->types_values_begin();
}
@@ -1114,6 +1158,10 @@
module()->SetMemoryModel(std::move(m));
}
+const Instruction* IRContext::GetMemoryModel() const {
+ return module()->GetMemoryModel();
+}
+
void IRContext::AddEntryPoint(std::unique_ptr<Instruction>&& e) {
module()->AddEntryPoint(std::move(e));
}
@@ -1173,6 +1221,10 @@
module()->AddGlobalValue(std::move(v));
}
+void IRContext::AddFunctionDeclaration(std::unique_ptr<Function>&& f) {
+ module()->AddFunctionDeclaration(std::move(f));
+}
+
void IRContext::AddFunction(std::unique_ptr<Function>&& f) {
module()->AddFunction(std::move(f));
}
diff --git a/third_party/SPIRV-Tools/source/opt/local_access_chain_convert_pass.cpp b/third_party/SPIRV-Tools/source/opt/local_access_chain_convert_pass.cpp
index 6ec0c2d..7ba75cb 100644
--- a/third_party/SPIRV-Tools/source/opt/local_access_chain_convert_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/local_access_chain_convert_pass.cpp
@@ -420,13 +420,16 @@
"SPV_EXT_demote_to_helper_invocation", "SPV_EXT_descriptor_indexing",
"SPV_NV_fragment_shader_barycentric",
"SPV_NV_compute_shader_derivatives", "SPV_NV_shader_image_footprint",
- "SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_NV_ray_tracing",
- "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
+ "SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_EXT_mesh_shader",
+ "SPV_NV_ray_tracing", "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density", "SPV_KHR_terminate_invocation",
"SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
- "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model"});
+ "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
+ "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
+ "SPV_EXT_fragment_shader_interlock",
+ "SPV_NV_compute_shader_derivatives"});
}
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
diff --git a/third_party/SPIRV-Tools/source/opt/local_single_block_elim_pass.cpp b/third_party/SPIRV-Tools/source/opt/local_single_block_elim_pass.cpp
index 063d1b9..d7a9295 100644
--- a/third_party/SPIRV-Tools/source/opt/local_single_block_elim_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/local_single_block_elim_pass.cpp
@@ -273,11 +273,13 @@
"SPV_NV_shader_image_footprint",
"SPV_NV_shading_rate",
"SPV_NV_mesh_shader",
+ "SPV_EXT_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_KHR_ray_tracing",
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_physical_storage_buffer",
"SPV_KHR_terminate_invocation",
"SPV_KHR_subgroup_uniform_control_flow",
"SPV_KHR_integer_dot_product",
@@ -285,7 +287,11 @@
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric",
- "SPV_KHR_vulkan_memory_model"});
+ "SPV_KHR_vulkan_memory_model",
+ "SPV_NV_bindless_texture",
+ "SPV_EXT_shader_atomic_float_add",
+ "SPV_EXT_fragment_shader_interlock",
+ "SPV_NV_compute_shader_derivatives"});
}
} // namespace opt
diff --git a/third_party/SPIRV-Tools/source/opt/local_single_store_elim_pass.cpp b/third_party/SPIRV-Tools/source/opt/local_single_store_elim_pass.cpp
index a0de44c..7cd6b0e 100644
--- a/third_party/SPIRV-Tools/source/opt/local_single_store_elim_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/local_single_store_elim_pass.cpp
@@ -124,10 +124,12 @@
"SPV_NV_shader_image_footprint",
"SPV_NV_shading_rate",
"SPV_NV_mesh_shader",
+ "SPV_EXT_mesh_shader",
"SPV_NV_ray_tracing",
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_physical_storage_buffer",
"SPV_KHR_terminate_invocation",
"SPV_KHR_subgroup_uniform_control_flow",
"SPV_KHR_integer_dot_product",
@@ -135,7 +137,11 @@
"SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric",
- "SPV_KHR_vulkan_memory_model"});
+ "SPV_KHR_vulkan_memory_model",
+ "SPV_NV_bindless_texture",
+ "SPV_EXT_shader_atomic_float_add",
+ "SPV_EXT_fragment_shader_interlock",
+ "SPV_NV_compute_shader_derivatives"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
std::vector<Instruction*> users;
diff --git a/third_party/SPIRV-Tools/source/opt/log.h b/third_party/SPIRV-Tools/source/opt/log.h
index 6805100..4fb66fd 100644
--- a/third_party/SPIRV-Tools/source/opt/log.h
+++ b/third_party/SPIRV-Tools/source/opt/log.h
@@ -23,7 +23,7 @@
#include "spirv-tools/libspirv.hpp"
// Asserts the given condition is true. Otherwise, sends a message to the
-// consumer and exits the problem with failure code. Accepts the following
+// consumer and exits the program with failure code. Accepts the following
// formats:
//
// SPIRV_ASSERT(<message-consumer>, <condition-expression>);
@@ -36,7 +36,9 @@
#if !defined(NDEBUG)
#define SPIRV_ASSERT(consumer, ...) SPIRV_ASSERT_IMPL(consumer, __VA_ARGS__)
#else
-#define SPIRV_ASSERT(consumer, ...)
+// Adding a use to avoid errors in the release build related to unused
+// consumers.
+#define SPIRV_ASSERT(consumer, ...) (void)(consumer)
#endif
// Logs a debug message to the consumer. Accepts the following formats:
@@ -49,26 +51,11 @@
#if !defined(NDEBUG) && defined(SPIRV_LOG_DEBUG)
#define SPIRV_DEBUG(consumer, ...) SPIRV_DEBUG_IMPL(consumer, __VA_ARGS__)
#else
-#define SPIRV_DEBUG(consumer, ...)
+// Adding a use to avoid errors in the release build related to unused
+// consumers.
+#define SPIRV_DEBUG(consumer, ...) (void)(consumer)
#endif
-// Logs an error message to the consumer saying the given feature is
-// unimplemented.
-#define SPIRV_UNIMPLEMENTED(consumer, feature) \
- do { \
- spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \
- {static_cast<size_t>(__LINE__), 0, 0}, \
- "unimplemented: " feature); \
- } while (0)
-
-// Logs an error message to the consumer saying the code location
-// should be unreachable.
-#define SPIRV_UNREACHABLE(consumer) \
- do { \
- spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \
- {static_cast<size_t>(__LINE__), 0, 0}, "unreachable"); \
- } while (0)
-
// Helper macros for concatenating arguments.
#define SPIRV_CONCATENATE(a, b) SPIRV_CONCATENATE_(a, b)
#define SPIRV_CONCATENATE_(a, b) a##b
diff --git a/third_party/SPIRV-Tools/source/opt/mem_pass.cpp b/third_party/SPIRV-Tools/source/opt/mem_pass.cpp
index 9f95785..9972c4f 100644
--- a/third_party/SPIRV-Tools/source/opt/mem_pass.cpp
+++ b/third_party/SPIRV-Tools/source/opt/mem_pass.cpp
@@ -76,6 +76,11 @@
bool MemPass::IsPtr(uint32_t ptrId) {
uint32_t varId = ptrId;
Instruction* ptrInst = get_def_use_mgr()->GetDef(varId);
+ if (ptrInst->opcode() == spv::Op::OpFunction) {
+ // A function is not a pointer, but it's return type could be, which will
+ // erroneously lead to this function returning true later on
+ return false;
+ }
while (ptrInst->opcode() == spv::Op::OpCopyObject) {
varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = get_def_use_mgr()->GetDef(varId);
diff --git a/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.cpp b/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.cpp
new file mode 100644
index 0000000..dd79b62
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.cpp
@@ -0,0 +1,103 @@
+// Copyright (c) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "modify_maximal_reconvergence.h"
+
+#include "source/opt/ir_context.h"
+#include "source/util/make_unique.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status ModifyMaximalReconvergence::Process() {
+ bool changed = false;
+ if (add_) {
+ changed = AddMaximalReconvergence();
+ } else {
+ changed = RemoveMaximalReconvergence();
+ }
+ return changed ? Pass::Status::SuccessWithChange
+ : Pass::Status::SuccessWithoutChange;
+}
+
+bool ModifyMaximalReconvergence::AddMaximalReconvergence() {
+ bool changed = false;
+ bool has_extension = false;
+ bool has_shader =
+ context()->get_feature_mgr()->HasCapability(spv::Capability::Shader);
+ for (auto extension : context()->extensions()) {
+ if (extension.GetOperand(0).AsString() == "SPV_KHR_maximal_reconvergence") {
+ has_extension = true;
+ break;
+ }
+ }
+
+ std::unordered_set<uint32_t> entry_points_with_mode;
+ for (auto mode : get_module()->execution_modes()) {
+ if (spv::ExecutionMode(mode.GetSingleWordInOperand(1)) ==
+ spv::ExecutionMode::MaximallyReconvergesKHR) {
+ entry_points_with_mode.insert(mode.GetSingleWordInOperand(0));
+ }
+ }
+
+ for (auto entry_point : get_module()->entry_points()) {
+ const uint32_t id = entry_point.GetSingleWordInOperand(1);
+ if (!entry_points_with_mode.count(id)) {
+ changed = true;
+ if (!has_extension) {
+ context()->AddExtension("SPV_KHR_maximal_reconvergence");
+ has_extension = true;
+ }
+ if (!has_shader) {
+ context()->AddCapability(spv::Capability::Shader);
+ has_shader = true;
+ }
+ context()->AddExecutionMode(MakeUnique<Instruction>(
+ context(), spv::Op::OpExecutionMode, 0, 0,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {id}},
+ {SPV_OPERAND_TYPE_EXECUTION_MODE,
+ {static_cast<uint32_t>(
+ spv::ExecutionMode::MaximallyReconvergesKHR)}}}));
+ entry_points_with_mode.insert(id);
+ }
+ }
+
+ return changed;
+}
+
+bool ModifyMaximalReconvergence::RemoveMaximalReconvergence() {
+ bool changed = false;
+ std::vector<Instruction*> to_remove;
+ Instruction* mode = &*get_module()->execution_mode_begin();
+ while (mode) {
+ if (mode->opcode() != spv::Op::OpExecutionMode &&
+ mode->opcode() != spv::Op::OpExecutionModeId) {
+ break;
+ }
+ if (spv::ExecutionMode(mode->GetSingleWordInOperand(1)) ==
+ spv::ExecutionMode::MaximallyReconvergesKHR) {
+ mode = context()->KillInst(mode);
+ changed = true;
+ } else {
+ mode = mode->NextNode();
+ }
+ }
+
+ changed |=
+ context()->RemoveExtension(Extension::kSPV_KHR_maximal_reconvergence);
+ return changed;
+}
+} // namespace opt
+} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.h b/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.h
new file mode 100644
index 0000000..8d9a698
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/modify_maximal_reconvergence.h
@@ -0,0 +1,53 @@
+// Copyright (c) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_
+#define LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_
+
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Modifies entry points to either add or remove MaximallyReconvergesKHR
+//
+// This pass will either add or remove MaximallyReconvergesKHR to all entry
+// points in the module. When adding the execution mode, it does not attempt to
+// determine whether any ray tracing invocation repack instructions might be
+// executed because it is a runtime restriction. That is left to the user.
+class ModifyMaximalReconvergence : public Pass {
+ public:
+ const char* name() const override { return "modify-maximal-reconvergence"; }
+ Status Process() override;
+
+ explicit ModifyMaximalReconvergence(bool add = true) : Pass(), add_(add) {}
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+ IRContext::kAnalysisCFG | IRContext::kAnalysisNameMap |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+ }
+
+ private:
+ bool AddMaximalReconvergence();
+ bool RemoveMaximalReconvergence();
+
+ bool add_;
+};
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_
diff --git a/third_party/SPIRV-Tools/source/opt/module.h b/third_party/SPIRV-Tools/source/opt/module.h
index ed2f345..98c16dc 100644
--- a/third_party/SPIRV-Tools/source/opt/module.h
+++ b/third_party/SPIRV-Tools/source/opt/module.h
@@ -17,6 +17,7 @@
#include <functional>
#include <memory>
+#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -119,6 +120,9 @@
// Appends a constant, global variable, or OpUndef instruction to this module.
inline void AddGlobalValue(std::unique_ptr<Instruction> v);
+ // Prepends a function declaration to this module.
+ inline void AddFunctionDeclaration(std::unique_ptr<Function> f);
+
// Appends a function to this module.
inline void AddFunction(std::unique_ptr<Function> f);
@@ -379,6 +383,11 @@
types_values_.push_back(std::move(v));
}
+inline void Module::AddFunctionDeclaration(std::unique_ptr<Function> f) {
+ // function declarations must come before function definitions.
+ functions_.emplace(functions_.begin(), std::move(f));
+}
+
inline void Module::AddFunction(std::unique_ptr<Function> f) {
functions_.emplace_back(std::move(f));
}
diff --git a/third_party/SPIRV-Tools/source/opt/optimizer.cpp b/third_party/SPIRV-Tools/source/opt/optimizer.cpp
index 46a92dd..c4c2b0f 100644
--- a/third_party/SPIRV-Tools/source/opt/optimizer.cpp
+++ b/third_party/SPIRV-Tools/source/opt/optimizer.cpp
@@ -15,6 +15,7 @@
#include "spirv-tools/optimizer.hpp"
#include <cassert>
+#include <charconv>
#include <memory>
#include <string>
#include <unordered_map>
@@ -32,6 +33,15 @@
namespace spvtools {
+std::vector<std::string> GetVectorOfStrings(const char** strings,
+ const size_t string_count) {
+ std::vector<std::string> result;
+ for (uint32_t i = 0; i < string_count; i++) {
+ result.emplace_back(strings[i]);
+ }
+ return result;
+}
+
struct Optimizer::PassToken::Impl {
Impl(std::unique_ptr<opt::Pass> p) : pass(std::move(p)) {}
@@ -109,7 +119,7 @@
// The legalization problem is essentially a very general copy propagation
// problem. The optimization we use are all used to either do copy propagation
// or enable more copy propagation.
-Optimizer& Optimizer::RegisterLegalizationPasses() {
+Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
return
// Wrap OpKill instructions so all other code can be inlined.
RegisterPass(CreateWrapOpKillPass())
@@ -129,16 +139,16 @@
// Propagate the value stored to the loads in very simple cases.
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
// Split up aggregates so they are easier to deal with.
.RegisterPass(CreateScalarReplacementPass(0))
// Remove loads and stores so everything is in intermediate values.
// Takes care of copy propagation of non-members.
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateLocalMultiStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
// Propagate constants to get as many constant conditions on branches
// as possible.
.RegisterPass(CreateCCPPass())
@@ -147,7 +157,7 @@
// Copy propagate members. Cleans up code sequences generated by
// scalar replacement. Also important for removing OpPhi nodes.
.RegisterPass(CreateSimplificationPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateCopyPropagateArraysPass())
// May need loop unrolling here see
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
@@ -156,30 +166,36 @@
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
.RegisterPass(CreateReduceLoadSizePass())
- .RegisterPass(CreateAggressiveDCEPass())
- .RegisterPass(CreateInterpolateFixupPass());
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
+ .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
+ .RegisterPass(CreateInterpolateFixupPass())
+ .RegisterPass(CreateInvocationInterlockPlacementPass());
}
-Optimizer& Optimizer::RegisterPerformancePasses() {
+Optimizer& Optimizer::RegisterLegalizationPasses() {
+ return RegisterLegalizationPasses(false);
+}
+
+Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
return RegisterPass(CreateWrapOpKillPass())
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateEliminateDeadFunctionsPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreatePrivateToLocalPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateScalarReplacementPass())
.RegisterPass(CreateLocalAccessChainConvertPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateLocalMultiStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateCCPPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateLoopUnrollPass(true))
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateRedundancyEliminationPass())
@@ -189,9 +205,9 @@
.RegisterPass(CreateLocalAccessChainConvertPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateSSARewritePass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
.RegisterPass(CreateDeadBranchElimPass())
@@ -199,7 +215,7 @@
.RegisterPass(CreateIfConversionPass())
.RegisterPass(CreateCopyPropagateArraysPass())
.RegisterPass(CreateReduceLoadSizePass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateBlockMergePass())
.RegisterPass(CreateRedundancyEliminationPass())
.RegisterPass(CreateDeadBranchElimPass())
@@ -207,7 +223,11 @@
.RegisterPass(CreateSimplificationPass());
}
-Optimizer& Optimizer::RegisterSizePasses() {
+Optimizer& Optimizer::RegisterPerformancePasses() {
+ return RegisterPerformancePasses(false);
+}
+
+Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) {
return RegisterPass(CreateWrapOpKillPass())
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
@@ -224,12 +244,12 @@
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateIfConversionPass())
.RegisterPass(CreateSimplificationPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateBlockMergePass())
.RegisterPass(CreateLocalAccessChainConvertPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateCopyPropagateArraysPass())
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
@@ -239,13 +259,20 @@
.RegisterPass(CreateLocalMultiStoreElimPass())
.RegisterPass(CreateRedundancyEliminationPass())
.RegisterPass(CreateSimplificationPass())
- .RegisterPass(CreateAggressiveDCEPass())
+ .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateCFGCleanupPass());
}
+Optimizer& Optimizer::RegisterSizePasses() { return RegisterSizePasses(false); }
+
bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags) {
+ return RegisterPassesFromFlags(flags, false);
+}
+
+bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags,
+ bool preserve_interface) {
for (const auto& flag : flags) {
- if (!RegisterPassFromFlag(flag)) {
+ if (!RegisterPassFromFlag(flag, preserve_interface)) {
return false;
}
}
@@ -269,6 +296,11 @@
}
bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
+ return RegisterPassFromFlag(flag, false);
+}
+
+bool Optimizer::RegisterPassFromFlag(const std::string& flag,
+ bool preserve_interface) {
if (!FlagHasValidForm(flag)) {
return false;
}
@@ -330,7 +362,7 @@
} else if (pass_name == "descriptor-scalar-replacement") {
RegisterPass(CreateDescriptorScalarReplacementPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
- RegisterPass(CreateAggressiveDCEPass());
+ RegisterPass(CreateAggressiveDCEPass(preserve_interface));
} else if (pass_name == "eliminate-insert-extract") {
RegisterPass(CreateInsertExtractElimPass());
} else if (pass_name == "eliminate-local-single-block") {
@@ -419,32 +451,26 @@
RegisterPass(CreateWorkaround1209Pass());
} else if (pass_name == "replace-invalid-opcode") {
RegisterPass(CreateReplaceInvalidOpcodePass());
- } else if (pass_name == "inst-bindless-check") {
- RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false));
+ } else if (pass_name == "inst-bindless-check" ||
+ pass_name == "inst-desc-idx-check" ||
+ pass_name == "inst-buff-oob-check") {
+ // preserve legacy names
+ RegisterPass(CreateInstBindlessCheckPass(23));
RegisterPass(CreateSimplificationPass());
RegisterPass(CreateDeadBranchElimPass());
RegisterPass(CreateBlockMergePass());
- RegisterPass(CreateAggressiveDCEPass(true));
- } else if (pass_name == "inst-desc-idx-check") {
- RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true));
- RegisterPass(CreateSimplificationPass());
- RegisterPass(CreateDeadBranchElimPass());
- RegisterPass(CreateBlockMergePass());
- RegisterPass(CreateAggressiveDCEPass(true));
- } else if (pass_name == "inst-buff-oob-check") {
- RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false, true, true));
- RegisterPass(CreateSimplificationPass());
- RegisterPass(CreateDeadBranchElimPass());
- RegisterPass(CreateBlockMergePass());
- RegisterPass(CreateAggressiveDCEPass(true));
} else if (pass_name == "inst-buff-addr-check") {
- RegisterPass(CreateInstBuffAddrCheckPass(7, 23));
- RegisterPass(CreateAggressiveDCEPass(true));
+ RegisterPass(CreateInstBuffAddrCheckPass(23));
} else if (pass_name == "convert-relaxed-to-half") {
RegisterPass(CreateConvertRelaxedToHalfPass());
} else if (pass_name == "relax-float-ops") {
RegisterPass(CreateRelaxFloatOpsPass());
} else if (pass_name == "inst-debug-printf") {
+ // This private option is not for user consumption.
+ // It is here to assist in debugging and fixing the debug printf
+ // instrumentation pass.
+ // For users who wish to utilize debug printf, see the white paper at
+ // https://www.lunarg.com/wp-content/uploads/2021/08/Using-Debug-Printf-02August2021.pdf
RegisterPass(CreateInstDebugPrintfPass(7, 23));
} else if (pass_name == "simplify-instructions") {
RegisterPass(CreateSimplificationPass());
@@ -507,11 +533,11 @@
} else if (pass_name == "fix-storage-class") {
RegisterPass(CreateFixStorageClassPass());
} else if (pass_name == "O") {
- RegisterPerformancePasses();
+ RegisterPerformancePasses(preserve_interface);
} else if (pass_name == "Os") {
- RegisterSizePasses();
+ RegisterSizePasses(preserve_interface);
} else if (pass_name == "legalize-hlsl") {
- RegisterLegalizationPasses();
+ RegisterLegalizationPasses(preserve_interface);
} else if (pass_name == "remove-unused-interface-variables") {
RegisterPass(CreateRemoveUnusedInterfaceVariablesPass());
} else if (pass_name == "graphics-robust-access") {
@@ -548,6 +574,58 @@
pass_args.c_str());
return false;
}
+ } else if (pass_name == "switch-descriptorset") {
+ if (pass_args.size() == 0) {
+ Error(consumer(), nullptr, {},
+ "--switch-descriptorset requires a from:to argument.");
+ return false;
+ }
+ uint32_t from_set = 0, to_set = 0;
+ const char* start = pass_args.data();
+ const char* end = pass_args.data() + pass_args.size();
+
+ auto result = std::from_chars(start, end, from_set);
+ if (result.ec != std::errc()) {
+ Errorf(consumer(), nullptr, {},
+ "Invalid argument for --switch-descriptorset: %s",
+ pass_args.c_str());
+ return false;
+ }
+ start = result.ptr;
+ if (start[0] != ':') {
+ Errorf(consumer(), nullptr, {},
+ "Invalid argument for --switch-descriptorset: %s",
+ pass_args.c_str());
+ return false;
+ }
+ start++;
+ result = std::from_chars(start, end, to_set);
+ if (result.ec != std::errc() || result.ptr != end) {
+ Errorf(consumer(), nullptr, {},
+ "Invalid argument for --switch-descriptorset: %s",
+ pass_args.c_str());
+ return false;
+ }
+ RegisterPass(CreateSwitchDescriptorSetPass(from_set, to_set));
+ } else if (pass_name == "modify-maximal-reconvergence") {
+ if (pass_args.size() == 0) {
+ Error(consumer(), nullptr, {},
+ "--modify-maximal-reconvergence requires an argument");
+ return false;
+ }
+ if (pass_args == "add") {
+ RegisterPass(CreateModifyMaximalReconvergencePass(true));
+ } else if (pass_args == "remove") {
+ RegisterPass(CreateModifyMaximalReconvergencePass(false));
+ } else {
+ Errorf(consumer(), nullptr, {},
+ "Invalid argument for --modify-maximal-reconvergence: %s (must be "
+ "'add' or 'remove')",
+ pass_args.c_str());
+ return false;
+ }
+ } else if (pass_name == "trim-capabilities") {
+ RegisterPass(CreateTrimCapabilitiesPass());
} else {
Errorf(consumer(), nullptr, {},
"Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -945,14 +1023,9 @@
MakeUnique<opt::UpgradeMemoryModel>());
}
-Optimizer::PassToken CreateInstBindlessCheckPass(
- uint32_t desc_set, uint32_t shader_id, bool desc_length_enable,
- bool desc_init_enable, bool buff_oob_enable, bool texbuff_oob_enable) {
+Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t shader_id) {
return MakeUnique<Optimizer::PassToken::Impl>(
- MakeUnique<opt::InstBindlessCheckPass>(
- desc_set, shader_id, desc_length_enable, desc_init_enable,
- buff_oob_enable, texbuff_oob_enable,
- desc_length_enable || desc_init_enable || buff_oob_enable));
+ MakeUnique<opt::InstBindlessCheckPass>(shader_id));
}
Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,
@@ -961,10 +1034,9 @@
MakeUnique<opt::InstDebugPrintfPass>(desc_set, shader_id));
}
-Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set,
- uint32_t shader_id) {
+Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t shader_id) {
return MakeUnique<Optimizer::PassToken::Impl>(
- MakeUnique<opt::InstBuffAddrCheckPass>(desc_set, shader_id));
+ MakeUnique<opt::InstBuffAddrCheckPass>(shader_id));
}
Optimizer::PassToken CreateConvertRelaxedToHalfPass() {
@@ -1074,6 +1146,26 @@
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::FixFuncCallArgumentsPass>());
}
+
+Optimizer::PassToken CreateTrimCapabilitiesPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::TrimCapabilitiesPass>());
+}
+
+Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t from, uint32_t to) {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::SwitchDescriptorSetPass>(from, to));
+}
+
+Optimizer::PassToken CreateInvocationInterlockPlacementPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::InvocationInterlockPlacementPass>());
+}
+
+Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add) {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::ModifyMaximalReconvergence>(add));
+}
} // namespace spvtools
extern "C" {
@@ -1122,13 +1214,19 @@
SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlags(
spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
- std::vector<std::string> opt_flags;
- for (uint32_t i = 0; i < flag_count; i++) {
- opt_flags.emplace_back(flags[i]);
- }
+ std::vector<std::string> opt_flags =
+ spvtools::GetVectorOfStrings(flags, flag_count);
+ return reinterpret_cast<spvtools::Optimizer*>(optimizer)
+ ->RegisterPassesFromFlags(opt_flags, false);
+}
- return reinterpret_cast<spvtools::Optimizer*>(optimizer)->
- RegisterPassesFromFlags(opt_flags);
+SPIRV_TOOLS_EXPORT bool
+spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(
+ spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
+ std::vector<std::string> opt_flags =
+ spvtools::GetVectorOfStrings(flags, flag_count);
+ return reinterpret_cast<spvtools::Optimizer*>(optimizer)
+ ->RegisterPassesFromFlags(opt_flags, true);
}
SPIRV_TOOLS_EXPORT
diff --git a/third_party/SPIRV-Tools/source/opt/passes.h b/third_party/SPIRV-Tools/source/opt/passes.h
index eb3b1e5..9d027fb 100644
--- a/third_party/SPIRV-Tools/source/opt/passes.h
+++ b/third_party/SPIRV-Tools/source/opt/passes.h
@@ -53,6 +53,7 @@
#include "source/opt/inst_debug_printf_pass.h"
#include "source/opt/interface_var_sroa.h"
#include "source/opt/interp_fixup_pass.h"
+#include "source/opt/invocation_interlock_placement_pass.h"
#include "source/opt/licm_pass.h"
#include "source/opt/local_access_chain_convert_pass.h"
#include "source/opt/local_redundancy_elimination.h"
@@ -64,6 +65,7 @@
#include "source/opt/loop_unroller.h"
#include "source/opt/loop_unswitch_pass.h"
#include "source/opt/merge_return_pass.h"
+#include "source/opt/modify_maximal_reconvergence.h"
#include "source/opt/null_pass.h"
#include "source/opt/private_to_local_pass.h"
#include "source/opt/reduce_load_size.h"
@@ -82,6 +84,8 @@
#include "source/opt/strength_reduction_pass.h"
#include "source/opt/strip_debug_info_pass.h"
#include "source/opt/strip_nonsemantic_info_pass.h"
+#include "source/opt/switch_descriptorset_pass.h"
+#include "source/opt/trim_capabilities_pass.h"
#include "source/opt/unify_const_pass.h"
#include "source/opt/upgrade_memory_model.h"
#include "source/opt/vector_dce.h"
diff --git a/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.cpp b/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.cpp
new file mode 100644
index 0000000..f07c917
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.cpp
@@ -0,0 +1,46 @@
+// Copyright (c) 2023 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/switch_descriptorset_pass.h"
+
+#include "source/opt/ir_builder.h"
+#include "source/util/string_utils.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status SwitchDescriptorSetPass::Process() {
+ Status status = Status::SuccessWithoutChange;
+ auto* deco_mgr = context()->get_decoration_mgr();
+
+ for (Instruction& var : context()->types_values()) {
+ if (var.opcode() != spv::Op::OpVariable) {
+ continue;
+ }
+ auto decos = deco_mgr->GetDecorationsFor(var.result_id(), false);
+ for (const auto& deco : decos) {
+ spv::Decoration d = spv::Decoration(deco->GetSingleWordInOperand(1u));
+ if (d == spv::Decoration::DescriptorSet &&
+ deco->GetSingleWordInOperand(2u) == ds_from_) {
+ deco->SetInOperand(2u, {ds_to_});
+ status = Status::SuccessWithChange;
+ break;
+ }
+ }
+ }
+ return status;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.h b/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.h
new file mode 100644
index 0000000..2084e9c
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/switch_descriptorset_pass.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2023 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <cstdio>
+#include <memory>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class SwitchDescriptorSetPass : public Pass {
+ public:
+ SwitchDescriptorSetPass(uint32_t ds_from, uint32_t ds_to)
+ : ds_from_(ds_from), ds_to_(ds_to) {}
+
+ const char* name() const override { return "switch-descriptorset"; }
+
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ // this pass preserves everything except decorations
+ uint32_t mask = ((IRContext::kAnalysisEnd << 1) - 1);
+ mask &= ~static_cast<uint32_t>(IRContext::kAnalysisDecorations);
+ return static_cast<IRContext::Analysis>(mask);
+ }
+
+ private:
+ uint32_t ds_from_;
+ uint32_t ds_to_;
+};
+
+} // namespace opt
+} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.cpp b/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.cpp
new file mode 100644
index 0000000..24f9e46
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.cpp
@@ -0,0 +1,649 @@
+// Copyright (c) 2023 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/trim_capabilities_pass.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <functional>
+#include <optional>
+#include <queue>
+#include <stack>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "source/enum_set.h"
+#include "source/enum_string_mapping.h"
+#include "source/opt/ir_context.h"
+#include "source/opt/reflect.h"
+#include "source/spirv_target_env.h"
+#include "source/util/string_utils.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+constexpr uint32_t kOpTypeFloatSizeIndex = 0;
+constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
+constexpr uint32_t kTypeArrayTypeIndex = 0;
+constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
+constexpr uint32_t kTypePointerTypeIdInIndex = 1;
+constexpr uint32_t kOpTypeIntSizeIndex = 0;
+constexpr uint32_t kOpTypeImageDimIndex = 1;
+constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
+constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
+constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
+constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
+constexpr uint32_t kOpImageReadImageIndex = 0;
+constexpr uint32_t kOpImageSparseReadImageIndex = 0;
+
+// DFS visit of the type defined by `instruction`.
+// If `condition` is true, children of the current node are visited.
+// If `condition` is false, the children of the current node are ignored.
+template <class UnaryPredicate>
+static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
+ std::stack<uint32_t> instructions_to_visit;
+ instructions_to_visit.push(instruction->result_id());
+ const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
+
+ while (!instructions_to_visit.empty()) {
+ const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
+ instructions_to_visit.pop();
+
+ if (!condition(item)) {
+ continue;
+ }
+
+ if (item->opcode() == spv::Op::OpTypePointer) {
+ instructions_to_visit.push(
+ item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
+ continue;
+ }
+
+ if (item->opcode() == spv::Op::OpTypeMatrix ||
+ item->opcode() == spv::Op::OpTypeVector ||
+ item->opcode() == spv::Op::OpTypeArray ||
+ item->opcode() == spv::Op::OpTypeRuntimeArray) {
+ instructions_to_visit.push(
+ item->GetSingleWordInOperand(kTypeArrayTypeIndex));
+ continue;
+ }
+
+ if (item->opcode() == spv::Op::OpTypeStruct) {
+ item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
+ instructions_to_visit.push(*op_id);
+ });
+ continue;
+ }
+ }
+}
+
+// Walks the type defined by `instruction` (OpType* only).
+// Returns `true` if any call to `predicate` with the type/subtype returns true.
+template <class UnaryPredicate>
+static bool AnyTypeOf(const Instruction* instruction,
+ UnaryPredicate predicate) {
+ assert(IsTypeInst(instruction->opcode()) &&
+ "AnyTypeOf called with a non-type instruction.");
+
+ bool found_one = false;
+ DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
+ if (found_one || predicate(node)) {
+ found_one = true;
+ return false;
+ }
+
+ return true;
+ });
+ return found_one;
+}
+
+static bool is16bitType(const Instruction* instruction) {
+ if (instruction->opcode() != spv::Op::OpTypeInt &&
+ instruction->opcode() != spv::Op::OpTypeFloat) {
+ return false;
+ }
+
+ return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
+}
+
+static bool Has16BitCapability(const FeatureManager* feature_manager) {
+ const CapabilitySet& capabilities = feature_manager->GetCapabilities();
+ return capabilities.contains(spv::Capability::Float16) ||
+ capabilities.contains(spv::Capability::Int16);
+}
+
+} // namespace
+
+// ============== Begin opcode handler implementations. =======================
+//
+// Adding support for a new capability should only require adding a new handler,
+// and updating the
+// kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
+//
+// Handler names follow the following convention:
+// Handler_<Opcode>_<Capability>()
+
+static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeFloat &&
+ "This handler only support OpTypeFloat opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
+ return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
+}
+
+static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeFloat &&
+ "This handler only support OpTypeFloat opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
+ return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
+}
+
+static std::optional<spv::Capability>
+Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypePointer &&
+ "This handler only support OpTypePointer opcodes.");
+
+ // This capability is only required if the variable has an Input/Output
+ // storage class.
+ spv::StorageClass storage_class = spv::StorageClass(
+ instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
+ if (storage_class != spv::StorageClass::Input &&
+ storage_class != spv::StorageClass::Output) {
+ return std::nullopt;
+ }
+
+ if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
+ return std::nullopt;
+ }
+
+ return AnyTypeOf(instruction, is16bitType)
+ ? std::optional(spv::Capability::StorageInputOutput16)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability>
+Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypePointer &&
+ "This handler only support OpTypePointer opcodes.");
+
+ // This capability is only required if the variable has a PushConstant storage
+ // class.
+ spv::StorageClass storage_class = spv::StorageClass(
+ instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
+ if (storage_class != spv::StorageClass::PushConstant) {
+ return std::nullopt;
+ }
+
+ if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
+ return std::nullopt;
+ }
+
+ return AnyTypeOf(instruction, is16bitType)
+ ? std::optional(spv::Capability::StoragePushConstant16)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability>
+Handler_OpTypePointer_StorageUniformBufferBlock16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypePointer &&
+ "This handler only support OpTypePointer opcodes.");
+
+ // This capability is only required if the variable has a Uniform storage
+ // class.
+ spv::StorageClass storage_class = spv::StorageClass(
+ instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
+ if (storage_class != spv::StorageClass::Uniform) {
+ return std::nullopt;
+ }
+
+ if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
+ return std::nullopt;
+ }
+
+ const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
+ const bool matchesCondition =
+ AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
+ if (!decoration_mgr->HasDecoration(item->result_id(),
+ spv::Decoration::BufferBlock)) {
+ return false;
+ }
+
+ return AnyTypeOf(item, is16bitType);
+ });
+
+ return matchesCondition
+ ? std::optional(spv::Capability::StorageUniformBufferBlock16)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypePointer &&
+ "This handler only support OpTypePointer opcodes.");
+
+ // This capability is only required if the variable has a Uniform storage
+ // class.
+ spv::StorageClass storage_class = spv::StorageClass(
+ instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
+ if (storage_class != spv::StorageClass::Uniform) {
+ return std::nullopt;
+ }
+
+ const auto* feature_manager = instruction->context()->get_feature_mgr();
+ if (!Has16BitCapability(feature_manager)) {
+ return std::nullopt;
+ }
+
+ const bool hasBufferBlockCapability =
+ feature_manager->GetCapabilities().contains(
+ spv::Capability::StorageUniformBufferBlock16);
+ const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
+ bool found16bitType = false;
+
+ DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
+ &found16bitType](const Instruction* item) {
+ if (found16bitType) {
+ return false;
+ }
+
+ if (hasBufferBlockCapability &&
+ decoration_mgr->HasDecoration(item->result_id(),
+ spv::Decoration::BufferBlock)) {
+ return false;
+ }
+
+ if (is16bitType(item)) {
+ found16bitType = true;
+ return false;
+ }
+
+ return true;
+ });
+
+ return found16bitType ? std::optional(spv::Capability::StorageUniform16)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeInt &&
+ "This handler only support OpTypeInt opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
+ return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
+}
+
+static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeInt &&
+ "This handler only support OpTypeInt opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
+ return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
+}
+
+static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeImage &&
+ "This handler only support OpTypeImage opcodes.");
+
+ const uint32_t arrayed =
+ instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
+ const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
+ const uint32_t sampled =
+ instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
+
+ return arrayed == 1 && sampled == 2 && ms == 1
+ ? std::optional(spv::Capability::ImageMSArray)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability>
+Handler_OpImageRead_StorageImageReadWithoutFormat(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpImageRead &&
+ "This handler only support OpImageRead opcodes.");
+ const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
+
+ const uint32_t image_index =
+ instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
+ const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
+ const Instruction* type = def_use_mgr->GetDef(type_index);
+ const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
+ const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
+
+ const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
+ const bool requires_capability_for_unknown =
+ spv::Dim(dim) != spv::Dim::SubpassData;
+ return is_unknown && requires_capability_for_unknown
+ ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
+ : std::nullopt;
+}
+
+static std::optional<spv::Capability>
+Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
+ "This handler only support OpImageSparseRead opcodes.");
+ const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
+
+ const uint32_t image_index =
+ instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
+ const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
+ const Instruction* type = def_use_mgr->GetDef(type_index);
+ const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
+
+ return spv::ImageFormat(format) == spv::ImageFormat::Unknown
+ ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
+ : std::nullopt;
+}
+
+// Opcode of interest to determine capabilities requirements.
+constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
+ // clang-format off
+ {spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
+ {spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
+ {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
+ {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
+ {spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
+ {spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
+ {spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
+ {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
+ {spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
+ {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
+ {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
+ {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
+ // clang-format on
+}};
+
+// ============== End opcode handler implementations. =======================
+
+namespace {
+ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
+ const AssemblyGrammar& grammar) {
+ ExtensionSet output;
+ const spv_operand_desc_t* desc = nullptr;
+ for (auto capability : capabilities) {
+ if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
+ static_cast<uint32_t>(capability),
+ &desc)) {
+ continue;
+ }
+
+ for (uint32_t i = 0; i < desc->numExtensions; ++i) {
+ output.insert(desc->extensions[i]);
+ }
+ }
+
+ return output;
+}
+} // namespace
+
+TrimCapabilitiesPass::TrimCapabilitiesPass()
+ : supportedCapabilities_(
+ TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
+ TrimCapabilitiesPass::kSupportedCapabilities.cend()),
+ forbiddenCapabilities_(
+ TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
+ TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
+ untouchableCapabilities_(
+ TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
+ TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
+ opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
+
+void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
+ spv::Op opcode, CapabilitySet* capabilities,
+ ExtensionSet* extensions) const {
+ // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
+ // because they have three possible capabilities, only one of which is needed
+ if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
+ opcode == spv::Op::OpEndInvocationInterlockEXT) {
+ return;
+ }
+
+ const spv_opcode_desc_t* desc = {};
+ auto result = context()->grammar().lookupOpcode(opcode, &desc);
+ if (result != SPV_SUCCESS) {
+ return;
+ }
+
+ addSupportedCapabilitiesToSet(desc, capabilities);
+ addSupportedExtensionsToSet(desc, extensions);
+}
+
+void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
+ const Operand& operand, CapabilitySet* capabilities,
+ ExtensionSet* extensions) const {
+ // No supported capability relies on a 2+-word operand.
+ if (operand.words.size() != 1) {
+ return;
+ }
+
+ // No supported capability relies on a literal string operand or an ID.
+ if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
+ operand.type == SPV_OPERAND_TYPE_ID ||
+ operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
+ return;
+ }
+
+ // If the Vulkan memory model is declared and any instruction uses Device
+ // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
+ // rule cannot be covered by the grammar, so must be checked explicitly.
+ if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
+ const Instruction* memory_model = context()->GetMemoryModel();
+ if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
+ uint32_t(spv::MemoryModel::Vulkan)) {
+ capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
+ }
+ }
+
+ // case 1: Operand is a single value, can directly lookup.
+ if (!spvOperandIsConcreteMask(operand.type)) {
+ const spv_operand_desc_t* desc = {};
+ auto result = context()->grammar().lookupOperand(operand.type,
+ operand.words[0], &desc);
+ if (result != SPV_SUCCESS) {
+ return;
+ }
+ addSupportedCapabilitiesToSet(desc, capabilities);
+ addSupportedExtensionsToSet(desc, extensions);
+ return;
+ }
+
+ // case 2: operand can be a bitmask, we need to decompose the lookup.
+ for (uint32_t i = 0; i < 32; i++) {
+ const uint32_t mask = (1 << i) & operand.words[0];
+ if (!mask) {
+ continue;
+ }
+
+ const spv_operand_desc_t* desc = {};
+ auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
+ if (result != SPV_SUCCESS) {
+ continue;
+ }
+
+ addSupportedCapabilitiesToSet(desc, capabilities);
+ addSupportedExtensionsToSet(desc, extensions);
+ }
+}
+
+void TrimCapabilitiesPass::addInstructionRequirements(
+ Instruction* instruction, CapabilitySet* capabilities,
+ ExtensionSet* extensions) const {
+ // Ignoring OpCapability and OpExtension instructions.
+ if (instruction->opcode() == spv::Op::OpCapability ||
+ instruction->opcode() == spv::Op::OpExtension) {
+ return;
+ }
+
+ addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
+ extensions);
+
+ // Second case: one of the opcode operand is gated by a capability.
+ const uint32_t operandCount = instruction->NumOperands();
+ for (uint32_t i = 0; i < operandCount; i++) {
+ addInstructionRequirementsForOperand(instruction->GetOperand(i),
+ capabilities, extensions);
+ }
+
+ // Last case: some complex logic needs to be run to determine capabilities.
+ auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
+ for (auto it = begin; it != end; it++) {
+ const OpcodeHandler handler = it->second;
+ auto result = handler(instruction);
+ if (!result.has_value()) {
+ continue;
+ }
+
+ capabilities->insert(*result);
+ }
+}
+
+void TrimCapabilitiesPass::AddExtensionsForOperand(
+ const spv_operand_type_t type, const uint32_t value,
+ ExtensionSet* extensions) const {
+ const spv_operand_desc_t* desc = nullptr;
+ spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
+ if (result != SPV_SUCCESS) {
+ return;
+ }
+ addSupportedExtensionsToSet(desc, extensions);
+}
+
+std::pair<CapabilitySet, ExtensionSet>
+TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
+ CapabilitySet required_capabilities;
+ ExtensionSet required_extensions;
+
+ get_module()->ForEachInst([&](Instruction* instruction) {
+ addInstructionRequirements(instruction, &required_capabilities,
+ &required_extensions);
+ });
+
+ for (auto capability : required_capabilities) {
+ AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
+ static_cast<uint32_t>(capability),
+ &required_extensions);
+ }
+
+#if !defined(NDEBUG)
+ // Debug only. We check the outputted required capabilities against the
+ // supported capabilities list. The supported capabilities list is useful for
+ // API users to quickly determine if they can use the pass or not. But this
+ // list has to remain up-to-date with the pass code. If we can detect a
+ // capability as required, but it's not listed, it means the list is
+ // out-of-sync. This method is not ideal, but should cover most cases.
+ {
+ for (auto capability : required_capabilities) {
+ assert(supportedCapabilities_.contains(capability) &&
+ "Module is using a capability that is not listed as supported.");
+ }
+ }
+#endif
+
+ return std::make_pair(std::move(required_capabilities),
+ std::move(required_extensions));
+}
+
+Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
+ const CapabilitySet& required_capabilities) const {
+ const FeatureManager* feature_manager = context()->get_feature_mgr();
+ CapabilitySet capabilities_to_trim;
+ for (auto capability : feature_manager->GetCapabilities()) {
+ // Some capabilities cannot be safely removed. Leaving them untouched.
+ if (untouchableCapabilities_.contains(capability)) {
+ continue;
+ }
+
+ // If the capability is unsupported, don't trim it.
+ if (!supportedCapabilities_.contains(capability)) {
+ continue;
+ }
+
+ if (required_capabilities.contains(capability)) {
+ continue;
+ }
+
+ capabilities_to_trim.insert(capability);
+ }
+
+ for (auto capability : capabilities_to_trim) {
+ context()->RemoveCapability(capability);
+ }
+
+ return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
+ : Pass::Status::SuccessWithChange;
+}
+
+Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
+ const ExtensionSet& required_extensions) const {
+ const auto supported_extensions =
+ getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
+
+ bool modified_module = false;
+ for (auto extension : supported_extensions) {
+ if (required_extensions.contains(extension)) {
+ continue;
+ }
+
+ if (context()->RemoveExtension(extension)) {
+ modified_module = true;
+ }
+ }
+
+ return modified_module ? Pass::Status::SuccessWithChange
+ : Pass::Status::SuccessWithoutChange;
+}
+
+bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
+ // EnumSet.HasAnyOf returns `true` if the given set is empty.
+ if (forbiddenCapabilities_.size() == 0) {
+ return false;
+ }
+
+ const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
+ return capabilities.HasAnyOf(forbiddenCapabilities_);
+}
+
+Pass::Status TrimCapabilitiesPass::Process() {
+ if (HasForbiddenCapabilities()) {
+ return Status::SuccessWithoutChange;
+ }
+
+ auto[required_capabilities, required_extensions] =
+ DetermineRequiredCapabilitiesAndExtensions();
+
+ Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
+ Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
+
+ return capStatus == Pass::Status::SuccessWithChange ||
+ extStatus == Pass::Status::SuccessWithChange
+ ? Pass::Status::SuccessWithChange
+ : Pass::Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.h b/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.h
new file mode 100644
index 0000000..81c07b8
--- /dev/null
+++ b/third_party/SPIRV-Tools/source/opt/trim_capabilities_pass.h
@@ -0,0 +1,205 @@
+// Copyright (c) 2023 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
+#define SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
+
+#include <algorithm>
+#include <array>
+#include <functional>
+#include <optional>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "source/enum_set.h"
+#include "source/extensions.h"
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+#include "source/spirv_target_env.h"
+
+namespace spvtools {
+namespace opt {
+
+// This is required for NDK build. The unordered_set/unordered_map
+// implementation don't work with class enums.
+struct ClassEnumHash {
+ std::size_t operator()(spv::Capability value) const {
+ using StoringType = typename std::underlying_type_t<spv::Capability>;
+ return std::hash<StoringType>{}(static_cast<StoringType>(value));
+ }
+
+ std::size_t operator()(spv::Op value) const {
+ using StoringType = typename std::underlying_type_t<spv::Op>;
+ return std::hash<StoringType>{}(static_cast<StoringType>(value));
+ }
+};
+
+// An opcode handler is a function which, given an instruction, returns either
+// the required capability, or nothing.
+// Each handler checks one case for a capability requirement.
+//
+// Example:
+// - `OpTypeImage` can have operand `A` operand which requires capability 1
+// - `OpTypeImage` can also have operand `B` which requires capability 2.
+// -> We have 2 handlers: `Handler_OpTypeImage_1` and
+// `Handler_OpTypeImage_2`.
+using OpcodeHandler =
+ std::optional<spv::Capability> (*)(const Instruction* instruction);
+
+// This pass tried to remove superfluous capabilities declared in the module.
+// - If all the capabilities listed by an extension are removed, the extension
+// is also trimmed.
+// - If the module countains any capability listed in `kForbiddenCapabilities`,
+// the module is left untouched.
+// - No capabilities listed in `kUntouchableCapabilities` are trimmed, even when
+// not used.
+// - Only capabilitied listed in `kSupportedCapabilities` are supported.
+// - If the module contains unsupported capabilities, results might be
+// incorrect.
+class TrimCapabilitiesPass : public Pass {
+ private:
+ // All the capabilities supported by this optimization pass. If your module
+ // contains unsupported instruction, the pass could yield bad results.
+ static constexpr std::array kSupportedCapabilities{
+ // clang-format off
+ spv::Capability::ComputeDerivativeGroupLinearNV,
+ spv::Capability::ComputeDerivativeGroupQuadsNV,
+ spv::Capability::Float16,
+ spv::Capability::Float64,
+ spv::Capability::FragmentShaderPixelInterlockEXT,
+ spv::Capability::FragmentShaderSampleInterlockEXT,
+ spv::Capability::FragmentShaderShadingRateInterlockEXT,
+ spv::Capability::Groups,
+ spv::Capability::ImageMSArray,
+ spv::Capability::Int16,
+ spv::Capability::Int64,
+ spv::Capability::Linkage,
+ spv::Capability::MinLod,
+ spv::Capability::PhysicalStorageBufferAddresses,
+ spv::Capability::RayQueryKHR,
+ spv::Capability::RayTracingKHR,
+ spv::Capability::RayTraversalPrimitiveCullingKHR,
+ spv::Capability::Shader,
+ spv::Capability::ShaderClockKHR,
+ spv::Capability::StorageImageReadWithoutFormat,
+ spv::Capability::StorageInputOutput16,
+ spv::Capability::StoragePushConstant16,
+ spv::Capability::StorageUniform16,
+ spv::Capability::StorageUniformBufferBlock16,
+ spv::Capability::VulkanMemoryModelDeviceScope,
+ spv::Capability::GroupNonUniformPartitionedNV
+ // clang-format on
+ };
+
+ // Those capabilities disable all transformation of the module.
+ static constexpr std::array kForbiddenCapabilities{
+ spv::Capability::Linkage,
+ };
+
+ // Those capabilities are never removed from a module because we cannot
+ // guess from the SPIR-V only if they are required or not.
+ static constexpr std::array kUntouchableCapabilities{
+ spv::Capability::Shader,
+ };
+
+ public:
+ TrimCapabilitiesPass();
+ TrimCapabilitiesPass(const TrimCapabilitiesPass&) = delete;
+ TrimCapabilitiesPass(TrimCapabilitiesPass&&) = delete;
+
+ private:
+ // Inserts every capability listed by `descriptor` this pass supports into
+ // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
+ // `spv_operand_desc_t`.
+ template <class Descriptor>
+ inline void addSupportedCapabilitiesToSet(const Descriptor* const descriptor,
+ CapabilitySet* output) const {
+ const uint32_t capabilityCount = descriptor->numCapabilities;
+ for (uint32_t i = 0; i < capabilityCount; ++i) {
+ const auto capability = descriptor->capabilities[i];
+ if (supportedCapabilities_.contains(capability)) {
+ output->insert(capability);
+ }
+ }
+ }
+
+ // Inserts every extension listed by `descriptor` required by the module into
+ // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
+ // `spv_operand_desc_t`.
+ template <class Descriptor>
+ inline void addSupportedExtensionsToSet(const Descriptor* const descriptor,
+ ExtensionSet* output) const {
+ if (descriptor->minVersion <=
+ spvVersionForTargetEnv(context()->GetTargetEnv())) {
+ return;
+ }
+ output->insert(descriptor->extensions,
+ descriptor->extensions + descriptor->numExtensions);
+ }
+
+ void addInstructionRequirementsForOpcode(spv::Op opcode,
+ CapabilitySet* capabilities,
+ ExtensionSet* extensions) const;
+ void addInstructionRequirementsForOperand(const Operand& operand,
+ CapabilitySet* capabilities,
+ ExtensionSet* extensions) const;
+
+ // Given an `instruction`, determines the capabilities it requires, and output
+ // them in `capabilities`. The returned capabilities form a subset of
+ // kSupportedCapabilities.
+ void addInstructionRequirements(Instruction* instruction,
+ CapabilitySet* capabilities,
+ ExtensionSet* extensions) const;
+
+ // Given an operand `type` and `value`, adds the extensions it would require
+ // to `extensions`.
+ void AddExtensionsForOperand(const spv_operand_type_t type,
+ const uint32_t value,
+ ExtensionSet* extensions) const;
+
+ // Returns the list of required capabilities and extensions for the module.
+ // The returned capabilities form a subset of kSupportedCapabilities.
+ std::pair<CapabilitySet, ExtensionSet>
+ DetermineRequiredCapabilitiesAndExtensions() const;
+
+ // Trims capabilities not listed in `required_capabilities` if possible.
+ // Returns whether or not the module was modified.
+ Pass::Status TrimUnrequiredCapabilities(
+ const CapabilitySet& required_capabilities) const;
+
+ // Trims extensions not listed in `required_extensions` if supported by this
+ // pass. An extensions is considered supported as soon as one capability this
+ // pass support requires it.
+ Pass::Status TrimUnrequiredExtensions(
+ const ExtensionSet& required_extensions) const;
+
+ // Returns if the analyzed module contains any forbidden capability.
+ bool HasForbiddenCapabilities() const;
+
+ public:
+ const char* name() const override { return "trim-capabilities"; }
+ Status Process() override;
+
+ private:
+ const CapabilitySet supportedCapabilities_;
+ const CapabilitySet forbiddenCapabilities_;
+ const CapabilitySet untouchableCapabilities_;
+ const std::unordered_multimap<spv::Op, OpcodeHandler, ClassEnumHash>
+ opcodeHandlers_;
+};
+
+} // namespace opt
+} // namespace spvtools
+#endif // SOURCE_OPT_TRIM_CAPABILITIES_H_
diff --git a/third_party/SPIRV-Tools/source/opt/type_manager.cpp b/third_party/SPIRV-Tools/source/opt/type_manager.cpp
index 1b1aead..7b609bc 100644
--- a/third_party/SPIRV-Tools/source/opt/type_manager.cpp
+++ b/third_party/SPIRV-Tools/source/opt/type_manager.cpp
@@ -423,6 +423,23 @@
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
break;
}
+ case Type::kCooperativeMatrixKHR: {
+ auto coop_mat = type->AsCooperativeMatrixKHR();
+ uint32_t const component_type =
+ GetTypeInstruction(coop_mat->component_type());
+ if (component_type == 0) {
+ return 0;
+ }
+ typeInst = MakeUnique<Instruction>(
+ context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {component_type}},
+ {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
+ break;
+ }
default:
assert(false && "Unexpected type");
break;
@@ -500,13 +517,24 @@
context()->get_def_use_mgr()->AnalyzeInstUse(inst);
}
-Type* TypeManager::RebuildType(const Type& type) {
+Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
+ assert(type_id != 0);
+
// The comparison and hash on the type pool will avoid inserting the rebuilt
// type if an equivalent type already exists. The rebuilt type will be deleted
// when it goes out of scope at the end of the function in that case. Repeated
// insertions of the same Type will, at most, keep one corresponding object in
// the type pool.
std::unique_ptr<Type> rebuilt_ty;
+
+ // If |type_id| is already present in the type pool, return the existing type.
+ // This saves extra work in the type builder and prevents running into
+ // circular issues (https://github.com/KhronosGroup/SPIRV-Tools/issues/5623).
+ Type* pool_ty = GetType(type_id);
+ if (pool_ty != nullptr) {
+ return pool_ty;
+ }
+
switch (type.kind()) {
#define DefineNoSubtypeCase(kind) \
case Type::k##kind: \
@@ -533,43 +561,46 @@
case Type::kVector: {
const Vector* vec_ty = type.AsVector();
const Type* ele_ty = vec_ty->element_type();
- rebuilt_ty =
- MakeUnique<Vector>(RebuildType(*ele_ty), vec_ty->element_count());
+ rebuilt_ty = MakeUnique<Vector>(RebuildType(GetId(ele_ty), *ele_ty),
+ vec_ty->element_count());
break;
}
case Type::kMatrix: {
const Matrix* mat_ty = type.AsMatrix();
const Type* ele_ty = mat_ty->element_type();
- rebuilt_ty =
- MakeUnique<Matrix>(RebuildType(*ele_ty), mat_ty->element_count());
+ rebuilt_ty = MakeUnique<Matrix>(RebuildType(GetId(ele_ty), *ele_ty),
+ mat_ty->element_count());
break;
}
case Type::kImage: {
const Image* image_ty = type.AsImage();
const Type* ele_ty = image_ty->sampled_type();
- rebuilt_ty =
- MakeUnique<Image>(RebuildType(*ele_ty), image_ty->dim(),
- image_ty->depth(), image_ty->is_arrayed(),
- image_ty->is_multisampled(), image_ty->sampled(),
- image_ty->format(), image_ty->access_qualifier());
+ rebuilt_ty = MakeUnique<Image>(
+ RebuildType(GetId(ele_ty), *ele_ty), image_ty->dim(),
+ image_ty->depth(), image_ty->is_arrayed(),
+ image_ty->is_multisampled(), image_ty->sampled(), image_ty->format(),
+ image_ty->access_qualifier());
break;
}
case Type::kSampledImage: {
const SampledImage* image_ty = type.AsSampledImage();
const Type* ele_ty = image_ty->image_type();
- rebuilt_ty = MakeUnique<SampledImage>(RebuildType(*ele_ty));
+ rebuilt_ty =
+ MakeUnique<SampledImage>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kArray: {
const Array* array_ty = type.AsArray();
- rebuilt_ty =
- MakeUnique<Array>(array_ty->element_type(), array_ty->length_info());
+ const Type* ele_ty = array_ty->element_type();
+ rebuilt_ty = MakeUnique<Array>(RebuildType(GetId(ele_ty), *ele_ty),
+ array_ty->length_info());
break;
}
case Type::kRuntimeArray: {
const RuntimeArray* array_ty = type.AsRuntimeArray();
const Type* ele_ty = array_ty->element_type();
- rebuilt_ty = MakeUnique<RuntimeArray>(RebuildType(*ele_ty));
+ rebuilt_ty =
+ MakeUnique<RuntimeArray>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kStruct: {
@@ -577,7 +608,7 @@
std::vector<const Type*> subtypes;
subtypes.reserve(struct_ty->element_types().size());
for (const auto* ele_ty : struct_ty->element_types()) {
- subtypes.push_back(RebuildType(*ele_ty));
+ subtypes.push_back(RebuildType(GetId(ele_ty), *ele_ty));
}
rebuilt_ty = MakeUnique<Struct>(subtypes);
Struct* rebuilt_struct = rebuilt_ty->AsStruct();
@@ -594,7 +625,7 @@
case Type::kPointer: {
const Pointer* pointer_ty = type.AsPointer();
const Type* ele_ty = pointer_ty->pointee_type();
- rebuilt_ty = MakeUnique<Pointer>(RebuildType(*ele_ty),
+ rebuilt_ty = MakeUnique<Pointer>(RebuildType(GetId(ele_ty), *ele_ty),
pointer_ty->storage_class());
break;
}
@@ -604,9 +635,10 @@
std::vector<const Type*> param_types;
param_types.reserve(function_ty->param_types().size());
for (const auto* param_ty : function_ty->param_types()) {
- param_types.push_back(RebuildType(*param_ty));
+ param_types.push_back(RebuildType(GetId(param_ty), *param_ty));
}
- rebuilt_ty = MakeUnique<Function>(RebuildType(*ret_ty), param_types);
+ rebuilt_ty = MakeUnique<Function>(RebuildType(GetId(ret_ty), *ret_ty),
+ param_types);
break;
}
case Type::kForwardPointer: {
@@ -616,7 +648,7 @@
const Pointer* target_ptr = forward_ptr_ty->target_pointer();
if (target_ptr) {
rebuilt_ty->AsForwardPointer()->SetTargetPointer(
- RebuildType(*target_ptr)->AsPointer());
+ RebuildType(GetId(target_ptr), *target_ptr)->AsPointer());
}
break;
}
@@ -624,8 +656,17 @@
const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixNV>(
- RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
- cm_type->columns_id());
+ RebuildType(GetId(component_type), *component_type),
+ cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id());
+ break;
+ }
+ case Type::kCooperativeMatrixKHR: {
+ const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
+ const Type* component_type = cm_type->component_type();
+ rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
+ RebuildType(GetId(component_type), *component_type),
+ cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id(),
+ cm_type->use_id());
break;
}
default:
@@ -644,7 +685,7 @@
void TypeManager::RegisterType(uint32_t id, const Type& type) {
// Rebuild |type| so it and all its constituent types are owned by the type
// pool.
- Type* rebuilt = RebuildType(type);
+ Type* rebuilt = RebuildType(id, type);
assert(rebuilt->IsSame(&type));
id_to_type_[id] = rebuilt;
if (GetId(rebuilt) == 0) {
@@ -863,6 +904,12 @@
inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3));
break;
+ case spv::Op::OpTypeCooperativeMatrixKHR:
+ type = new CooperativeMatrixKHR(
+ GetType(inst.GetSingleWordInOperand(0)),
+ inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
+ inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
+ break;
case spv::Op::OpTypeRayQueryKHR:
type = new RayQueryKHR();
break;
@@ -870,7 +917,7 @@
type = new HitObjectNV();
break;
default:
- SPIRV_UNIMPLEMENTED(consumer_, "unhandled type");
+ assert(false && "Type not handled by the type manager.");
break;
}
@@ -912,12 +959,10 @@
}
if (Struct* st = type->AsStruct()) {
st->AddMemberDecoration(index, std::move(data));
- } else {
- SPIRV_UNIMPLEMENTED(consumer_, "OpMemberDecorate non-struct type");
}
} break;
default:
- SPIRV_UNREACHABLE(consumer_);
+ assert(false && "Unexpected opcode for a decoration instruction.");
break;
}
}
diff --git a/third_party/SPIRV-Tools/source/opt/type_manager.h b/third_party/SPIRV-Tools/source/opt/type_manager.h
index c49e193..948b691 100644
--- a/third_party/SPIRV-Tools/source/opt/type_manager.h
+++ b/third_party/SPIRV-Tools/source/opt/type_manager.h
@@ -144,18 +144,17 @@
// |type| (e.g. should be called in loop of |type|'s decorations).
void AttachDecoration(const Instruction& inst, Type* type);
- Type* GetUIntType() {
- Integer int_type(32, false);
- return GetRegisteredType(&int_type);
- }
+ Type* GetUIntType() { return GetIntType(32, false); }
uint32_t GetUIntTypeId() { return GetTypeInstruction(GetUIntType()); }
- Type* GetSIntType() {
- Integer int_type(32, true);
+ Type* GetIntType(int32_t bitWidth, bool isSigned) {
+ Integer int_type(bitWidth, isSigned);
return GetRegisteredType(&int_type);
}
+ Type* GetSIntType() { return GetIntType(32, true); }
+
uint32_t GetSIntTypeId() { return GetTypeInstruction(GetSIntType()); }
Type* GetFloatType() {
@@ -261,7 +260,9 @@
// Returns an equivalent pointer to |type| built in terms of pointers owned by
// |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt
// replacing the bool subtype with one owned by |type_pool_|.
- Type* RebuildType(const Type& type);
+ //
+ // The re-built type will have ID |type_id|.
+ Type* RebuildType(uint32_t type_id, const Type& type);
// Completes the incomplete type |type|, by replaces all references to
// ForwardPointer by the defining Pointer.
diff --git a/third_party/SPIRV-Tools/source/opt/types.cpp b/third_party/SPIRV-Tools/source/opt/types.cpp
index 49eec9b..b18b8cb 100644
--- a/third_party/SPIRV-Tools/source/opt/types.cpp
+++ b/third_party/SPIRV-Tools/source/opt/types.cpp
@@ -128,6 +128,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -175,6 +176,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -230,6 +232,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -708,6 +711,45 @@
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
}
+CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
+ const uint32_t scope,
+ const uint32_t rows,
+ const uint32_t columns,
+ const uint32_t use)
+ : Type(kCooperativeMatrixKHR),
+ component_type_(type),
+ scope_id_(scope),
+ rows_id_(rows),
+ columns_id_(columns),
+ use_id_(use) {
+ assert(type != nullptr);
+ assert(scope != 0);
+ assert(rows != 0);
+ assert(columns != 0);
+}
+
+std::string CooperativeMatrixKHR::str() const {
+ std::ostringstream oss;
+ oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
+ << ", " << columns_id_ << ", " << use_id_ << ">";
+ return oss.str();
+}
+
+size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
+ SeenTypes* seen) const {
+ hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
+ return component_type_->ComputeHashValue(hash, seen);
+}
+
+bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
+ IsSameCache* seen) const {
+ const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
+ if (!mt) return false;
+ return component_type_->IsSameImpl(mt->component_type_, seen) &&
+ scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
+ columns_id_ == mt->columns_id_ && HasSameDecorations(that);
+}
+
} // namespace analysis
} // namespace opt
} // namespace spvtools
diff --git a/third_party/SPIRV-Tools/source/opt/types.h b/third_party/SPIRV-Tools/source/opt/types.h
index 26c058c..16a948c 100644
--- a/third_party/SPIRV-Tools/source/opt/types.h
+++ b/third_party/SPIRV-Tools/source/opt/types.h
@@ -60,6 +60,7 @@
class NamedBarrier;
class AccelerationStructureNV;
class CooperativeMatrixNV;
+class CooperativeMatrixKHR;
class RayQueryKHR;
class HitObjectNV;
@@ -100,6 +101,7 @@
kNamedBarrier,
kAccelerationStructureNV,
kCooperativeMatrixNV,
+ kCooperativeMatrixKHR,
kRayQueryKHR,
kHitObjectNV,
kLast
@@ -201,6 +203,7 @@
DeclareCastMethod(NamedBarrier)
DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV)
+ DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
#undef DeclareCastMethod
@@ -624,6 +627,38 @@
const uint32_t columns_id_;
};
+class CooperativeMatrixKHR : public Type {
+ public:
+ CooperativeMatrixKHR(const Type* type, const uint32_t scope,
+ const uint32_t rows, const uint32_t columns,
+ const uint32_t use);
+ CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
+
+ std::string str() const override;
+
+ CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
+ const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
+ return this;
+ }
+
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+ const Type* component_type() const { return component_type_; }
+ uint32_t scope_id() const { return scope_id_; }
+ uint32_t rows_id() const { return rows_id_; }
+ uint32_t columns_id() const { return columns_id_; }
+ uint32_t use_id() const { return use_id_; }
+
+ private:
+ bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+ const Type* component_type_;
+ const uint32_t scope_id_;
+ const uint32_t rows_id_;
+ const uint32_t columns_id_;
+ const uint32_t use_id_;
+};
+
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
diff --git a/third_party/SPIRV-Tools/source/parsed_operand.cpp b/third_party/SPIRV-Tools/source/parsed_operand.cpp
index 5f8e94d..cc33f8b 100644
--- a/third_party/SPIRV-Tools/source/parsed_operand.cpp
+++ b/third_party/SPIRV-Tools/source/parsed_operand.cpp
@@ -24,6 +24,7 @@
void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst,
const spv_parsed_operand_t& operand) {
if (operand.type != SPV_OPERAND_TYPE_LITERAL_INTEGER &&
+ operand.type != SPV_OPERAND_TYPE_LITERAL_FLOAT &&
operand.type != SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER &&
operand.type != SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER &&
operand.type != SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER)
diff --git a/third_party/SPIRV-Tools/source/print.cpp b/third_party/SPIRV-Tools/source/print.cpp
index 6c94e2b..f36812e 100644
--- a/third_party/SPIRV-Tools/source/print.cpp
+++ b/third_party/SPIRV-Tools/source/print.cpp
@@ -17,7 +17,7 @@
#if defined(SPIRV_ANDROID) || defined(SPIRV_LINUX) || defined(SPIRV_MAC) || \
defined(SPIRV_IOS) || defined(SPIRV_TVOS) || defined(SPIRV_FREEBSD) || \
defined(SPIRV_OPENBSD) || defined(SPIRV_EMSCRIPTEN) || \
- defined(SPIRV_FUCHSIA) || defined(SPIRV_GNU)
+ defined(SPIRV_FUCHSIA) || defined(SPIRV_GNU) || defined(SPIRV_QNX)
namespace spvtools {
clr::reset::operator const char*() { return "\x1b[0m"; }
diff --git a/third_party/SPIRV-Tools/source/table.h b/third_party/SPIRV-Tools/source/table.h
index 8097f13..4f1dc1f 100644
--- a/third_party/SPIRV-Tools/source/table.h
+++ b/third_party/SPIRV-Tools/source/table.h
@@ -74,7 +74,7 @@
const uint32_t ext_inst;
const uint32_t numCapabilities;
const spv::Capability* capabilities;
- const spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger?
+ const spv_operand_type_t operandTypes[40]; // vksp needs at least 40
} spv_ext_inst_desc_t;
typedef struct spv_ext_inst_group_t {
diff --git a/third_party/SPIRV-Tools/source/text.cpp b/third_party/SPIRV-Tools/source/text.cpp
index 8f77d62..263bacd 100644
--- a/third_party/SPIRV-Tools/source/text.cpp
+++ b/third_party/SPIRV-Tools/source/text.cpp
@@ -312,6 +312,17 @@
}
} break;
+ case SPV_OPERAND_TYPE_LITERAL_FLOAT: {
+ // The current operand is a 32-bit float.
+ // That's just how the grammar works.
+ spvtools::IdType expected_type = {
+ 32, false, spvtools::IdTypeClass::kScalarFloatType};
+ if (auto error = context->binaryEncodeNumericLiteral(
+ textValue, error_code_for_literals, expected_type, pInst)) {
+ return error;
+ }
+ } break;
+
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER:
// This is a context-independent literal number which can be a 32-bit
// number of floating point value.
@@ -400,9 +411,11 @@
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
- case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: {
+ case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error)
@@ -544,7 +557,8 @@
std::string equal_sign;
error = context->getWord(&equal_sign, &nextPosition);
if ("=" != equal_sign)
- return context->diagnostic() << "'=' expected after result id.";
+ return context->diagnostic() << "'=' expected after result id but found '"
+ << equal_sign << "'.";
// The <opcode> after the '=' sign.
context->setPosition(nextPosition);
diff --git a/third_party/SPIRV-Tools/source/val/validate.cpp b/third_party/SPIRV-Tools/source/val/validate.cpp
index e73e466..3236807 100644
--- a/third_party/SPIRV-Tools/source/val/validate.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate.cpp
@@ -141,6 +141,13 @@
}
}
+ if (auto error = ValidateFloatControls2(_)) {
+ return error;
+ }
+ if (auto error = ValidateDuplicateExecutionModes(_)) {
+ return error;
+ }
+
return SPV_SUCCESS;
}
@@ -381,6 +388,8 @@
for (const auto& inst : vstate->ordered_instructions()) {
if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error;
if (auto error = ValidateSmallTypeUses(*vstate, &inst)) return error;
+ if (auto error = ValidateQCOMImageProcessingTextureUsages(*vstate, &inst))
+ return error;
}
return SPV_SUCCESS;
diff --git a/third_party/SPIRV-Tools/source/val/validate.h b/third_party/SPIRV-Tools/source/val/validate.h
index 2cd229f..78093ce 100644
--- a/third_party/SPIRV-Tools/source/val/validate.h
+++ b/third_party/SPIRV-Tools/source/val/validate.h
@@ -82,6 +82,25 @@
/// @return SPV_SUCCESS if no errors are found.
spv_result_t ValidateInterfaces(ValidationState_t& _);
+/// @brief Validates entry point call tree requirements of
+/// SPV_KHR_float_controls2
+///
+/// Checks that no entry point using FPFastMathDefault uses:
+/// * FPFastMathMode Fast
+/// * NoContraction
+///
+/// @param[in] _ the validation state of the module
+///
+/// @return SPV_SUCCESS if no errors are found.
+spv_result_t ValidateFloatControls2(ValidationState_t& _);
+
+/// @brief Validates duplicated execution modes for each entry point.
+///
+/// @param[in] _ the validation state of the module
+///
+/// @return SPV_SUCCESS if no errors are found.
+spv_result_t ValidateDuplicateExecutionModes(ValidationState_t& _);
+
/// @brief Validates memory instructions
///
/// @param[in] _ the validation state of the module
@@ -220,6 +239,14 @@
spv_result_t ValidateSmallTypeUses(ValidationState_t& _,
const Instruction* inst);
+/// Validates restricted uses of QCOM decorated textures
+///
+/// The textures that are decorated with some of QCOM image processing
+/// decorations must be used in the specified QCOM image processing built-in
+/// functions and not used in any other image functions.
+spv_result_t ValidateQCOMImageProcessingTextureUsages(ValidationState_t& _,
+ const Instruction* inst);
+
/// @brief Validate the ID's within a SPIR-V binary
///
/// @param[in] pInstructions array of instructions
diff --git a/third_party/SPIRV-Tools/source/val/validate_annotation.cpp b/third_party/SPIRV-Tools/source/val/validate_annotation.cpp
index 73d0285..dac3585 100644
--- a/third_party/SPIRV-Tools/source/val/validate_annotation.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_annotation.cpp
@@ -161,7 +161,8 @@
case spv::Decoration::RestrictPointer:
case spv::Decoration::AliasedPointer:
if (target->opcode() != spv::Op::OpVariable &&
- target->opcode() != spv::Op::OpFunctionParameter) {
+ target->opcode() != spv::Op::OpFunctionParameter &&
+ target->opcode() != spv::Op::OpRawAccessChainNV) {
return fail(0) << "must be a memory object declaration";
}
if (_.GetIdOpcode(target->type_id()) != spv::Op::OpTypePointer) {
@@ -267,6 +268,34 @@
}
}
+ if (decoration == spv::Decoration::FPFastMathMode) {
+ if (_.HasDecoration(target_id, spv::Decoration::NoContraction)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "FPFastMathMode and NoContraction cannot decorate the same "
+ "target";
+ }
+ auto mask = inst->GetOperandAs<spv::FPFastMathModeMask>(2);
+ if ((mask & spv::FPFastMathModeMask::AllowTransform) !=
+ spv::FPFastMathModeMask::MaskNone &&
+ ((mask & (spv::FPFastMathModeMask::AllowContract |
+ spv::FPFastMathModeMask::AllowReassoc)) !=
+ (spv::FPFastMathModeMask::AllowContract |
+ spv::FPFastMathModeMask::AllowReassoc))) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "AllowReassoc and AllowContract must be specified when "
+ "AllowTransform is specified";
+ }
+ }
+
+ // This is checked from both sides since we register decorations as we go.
+ if (decoration == spv::Decoration::NoContraction) {
+ if (_.HasDecoration(target_id, spv::Decoration::FPFastMathMode)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "FPFastMathMode and NoContraction cannot decorate the same "
+ "target";
+ }
+ }
+
if (DecorationTakesIdParameters(decoration)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Decorations taking ID parameters may not be used with "
diff --git a/third_party/SPIRV-Tools/source/val/validate_arithmetics.cpp b/third_party/SPIRV-Tools/source/val/validate_arithmetics.cpp
index 4e7dd5e..b608a85 100644
--- a/third_party/SPIRV-Tools/source/val/validate_arithmetics.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_arithmetics.cpp
@@ -42,14 +42,29 @@
opcode != spv::Op::OpFMod);
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type) &&
- !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
+ !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpFMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsFloatCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
- if (_.GetOperandTypeId(inst, operand_index) != result_type)
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsFloatCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
@@ -71,7 +86,19 @@
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
- if (_.GetOperandTypeId(inst, operand_index) != result_type)
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
@@ -91,7 +118,10 @@
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
opcode != spv::Op::OpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
- !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpIMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@@ -102,9 +132,26 @@
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsIntCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ }
+
if (!type_id ||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
- !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpIMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsIntCooperativeMatrixType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
@@ -187,7 +234,7 @@
case spv::Op::OpMatrixTimesScalar: {
if (!_.IsFloatMatrixType(result_type) &&
- !_.IsCooperativeMatrixType(result_type))
+ !(_.IsCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float matrix type as Result Type: "
<< spvOpcodeString(opcode);
@@ -459,22 +506,108 @@
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
- if (!_.IsCooperativeMatrixType(A_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(A_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as A Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(B_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(B_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as B Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(C_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(C_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as C Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(D_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(D_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as Result Type: "
+ << spvOpcodeString(opcode);
+ }
+
+ const auto A = _.FindDef(A_type_id);
+ const auto B = _.FindDef(B_type_id);
+ const auto C = _.FindDef(C_type_id);
+ const auto D = _.FindDef(D_type_id);
+
+ std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+ A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+ A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+ B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+ C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+ D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+ A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+ B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+ C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+ D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+ A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+ B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+ C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+ D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+ const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+ std::tuple<bool, bool, uint32_t> Y) {
+ return (std::get<1>(X) && std::get<1>(Y) &&
+ std::get<2>(X) != std::get<2>(Y));
+ };
+
+ if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+ notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+ notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix scopes must match: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+ notEqual(C_rows, D_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'M' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+ notEqual(C_cols, D_cols)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'N' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_cols, B_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'K' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+ break;
+ }
+
+ case spv::Op::OpCooperativeMatrixMulAddKHR: {
+ const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+ const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+ const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+ const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+ if (!_.IsCooperativeMatrixAType(A_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be A Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixBType(B_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be B Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixAccType(C_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be Accumulator Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as Result Type: "
<< spvOpcodeString(opcode);
diff --git a/third_party/SPIRV-Tools/source/val/validate_atomics.cpp b/third_party/SPIRV-Tools/source/val/validate_atomics.cpp
index b745a9e..8ddef17 100644
--- a/third_party/SPIRV-Tools/source/val/validate_atomics.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_atomics.cpp
@@ -144,12 +144,13 @@
case spv::Op::OpAtomicFlagClear: {
const uint32_t result_type = inst->type_id();
- // All current atomics only are scalar result
// Validate return type first so can just check if pointer type is same
// (if applicable)
if (HasReturnType(opcode)) {
if (HasOnlyFloatReturnType(opcode) &&
- !_.IsFloatScalarType(result_type)) {
+ (!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(result_type)) &&
+ !_.IsFloatScalarType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": expected Result Type to be float scalar type";
@@ -160,6 +161,9 @@
<< ": expected Result Type to be integer scalar type";
} else if (HasIntOrFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type) &&
+ !(opcode == spv::Op::OpAtomicExchange &&
+ _.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsIntScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
@@ -222,12 +226,21 @@
if (opcode == spv::Op::OpAtomicFAddEXT) {
// result type being float checked already
- if ((_.GetBitWidth(result_type) == 16) &&
- (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << spvOpcodeString(opcode)
- << ": float add atomics require the AtomicFloat32AddEXT "
- "capability";
+ if (_.GetBitWidth(result_type) == 16) {
+ if (_.IsFloat16Vector2Or4Type(result_type)) {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float vector atomics require the "
+ "AtomicFloat16VectorNV capability";
+ } else {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float add atomics require the AtomicFloat32AddEXT "
+ "capability";
+ }
+ }
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
@@ -245,12 +258,21 @@
}
} else if (opcode == spv::Op::OpAtomicFMinEXT ||
opcode == spv::Op::OpAtomicFMaxEXT) {
- if ((_.GetBitWidth(result_type) == 16) &&
- (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << spvOpcodeString(opcode)
- << ": float min/max atomics require the "
- "AtomicFloat16MinMaxEXT capability";
+ if (_.GetBitWidth(result_type) == 16) {
+ if (_.IsFloat16Vector2Or4Type(result_type)) {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float vector atomics require the "
+ "AtomicFloat16VectorNV capability";
+ } else {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float min/max atomics require the "
+ "AtomicFloat16MinMaxEXT capability";
+ }
+ }
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
diff --git a/third_party/SPIRV-Tools/source/val/validate_builtins.cpp b/third_party/SPIRV-Tools/source/val/validate_builtins.cpp
index 3e81712..a7e9942 100644
--- a/third_party/SPIRV-Tools/source/val/validate_builtins.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_builtins.cpp
@@ -118,13 +118,15 @@
VUIDErrorMax,
} VUIDError;
-const static uint32_t NumVUIDBuiltins = 36;
+const static uint32_t NumVUIDBuiltins = 39;
typedef struct {
spv::BuiltIn builtIn;
uint32_t vuid[VUIDErrorMax]; // execution mode, storage class, type VUIDs
} BuiltinVUIDMapping;
+// Many built-ins have the same checks (Storage Class, Type, etc)
+// This table provides a nice LUT for the VUIDs
std::array<BuiltinVUIDMapping, NumVUIDBuiltins> builtinVUIDInfo = {{
// clang-format off
{spv::BuiltIn::SubgroupEqMask, {0, 4370, 4371}},
@@ -163,8 +165,11 @@
{spv::BuiltIn::CullMaskKHR, {6735, 6736, 6737}},
{spv::BuiltIn::BaryCoordKHR, {4154, 4155, 4156}},
{spv::BuiltIn::BaryCoordNoPerspKHR, {4160, 4161, 4162}},
- // clang-format off
-} };
+ {spv::BuiltIn::PrimitivePointIndicesEXT, {7041, 7043, 7044}},
+ {spv::BuiltIn::PrimitiveLineIndicesEXT, {7047, 7049, 7050}},
+ {spv::BuiltIn::PrimitiveTriangleIndicesEXT, {7053, 7055, 7056}},
+ // clang-format on
+}};
uint32_t GetVUIDForBuiltin(spv::BuiltIn builtIn, VUIDError type) {
uint32_t vuid = 0;
@@ -356,6 +361,9 @@
spv_result_t ValidateRayTracingBuiltinsAtDefinition(
const Decoration& decoration, const Instruction& inst);
+ spv_result_t ValidateMeshShadingEXTBuiltinsAtDefinition(
+ const Decoration& decoration, const Instruction& inst);
+
// The following section contains functions which are called when id defined
// by |referenced_inst| is
// 1. referenced by |referenced_from_inst|
@@ -546,6 +554,11 @@
const Instruction& referenced_inst,
const Instruction& referenced_from_inst);
+ spv_result_t ValidateMeshShadingEXTBuiltinsAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst);
+
// Validates that |built_in_inst| is not (even indirectly) referenced from
// within a function which can be called with |execution_model|.
//
@@ -581,6 +594,10 @@
spv_result_t ValidateI32Arr(
const Decoration& decoration, const Instruction& inst,
const std::function<spv_result_t(const std::string& message)>& diag);
+ spv_result_t ValidateArrayedI32Vec(
+ const Decoration& decoration, const Instruction& inst,
+ uint32_t num_components,
+ const std::function<spv_result_t(const std::string& message)>& diag);
spv_result_t ValidateOptionalArrayedI32(
const Decoration& decoration, const Instruction& inst,
const std::function<spv_result_t(const std::string& message)>& diag);
@@ -909,6 +926,45 @@
return SPV_SUCCESS;
}
+spv_result_t BuiltInsValidator::ValidateArrayedI32Vec(
+ const Decoration& decoration, const Instruction& inst,
+ uint32_t num_components,
+ const std::function<spv_result_t(const std::string& message)>& diag) {
+ uint32_t underlying_type = 0;
+ if (spv_result_t error =
+ GetUnderlyingType(_, decoration, inst, &underlying_type)) {
+ return error;
+ }
+
+ const Instruction* const type_inst = _.FindDef(underlying_type);
+ if (type_inst->opcode() != spv::Op::OpTypeArray) {
+ return diag(GetDefinitionDesc(decoration, inst) + " is not an array.");
+ }
+
+ const uint32_t component_type = type_inst->word(2);
+ if (!_.IsIntVectorType(component_type)) {
+ return diag(GetDefinitionDesc(decoration, inst) + " is not an int vector.");
+ }
+
+ const uint32_t actual_num_components = _.GetDimension(component_type);
+ if (_.GetDimension(component_type) != num_components) {
+ std::ostringstream ss;
+ ss << GetDefinitionDesc(decoration, inst) << " has "
+ << actual_num_components << " components.";
+ return diag(ss.str());
+ }
+
+ const uint32_t bit_width = _.GetBitWidth(component_type);
+ if (bit_width != 32) {
+ std::ostringstream ss;
+ ss << GetDefinitionDesc(decoration, inst)
+ << " has components with bit width " << bit_width << ".";
+ return diag(ss.str());
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32Vec(
const Decoration& decoration, const Instruction& inst,
uint32_t num_components,
@@ -1064,7 +1120,7 @@
if (num_components != 0) {
uint64_t actual_num_components = 0;
- if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &actual_num_components)) {
assert(0 && "Array type definition is corrupt");
}
if (actual_num_components != num_components) {
@@ -4108,6 +4164,119 @@
return SPV_SUCCESS;
}
+spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
+ const Decoration& decoration, const Instruction& inst) {
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ const spv::BuiltIn builtin = spv::BuiltIn(decoration.params()[0]);
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorType);
+ if (builtin == spv::BuiltIn::PrimitivePointIndicesEXT) {
+ if (spv_result_t error = ValidateI32Arr(
+ decoration, inst,
+ [this, &inst, &decoration,
+ &vuid](const std::string& message) -> spv_result_t {
+ return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+ << _.VkErrorID(vuid) << "According to the "
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec BuiltIn "
+ << _.grammar().lookupOperandName(
+ SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0])
+ << " variable needs to be a 32-bit int array."
+ << message;
+ })) {
+ return error;
+ }
+ }
+ if (builtin == spv::BuiltIn::PrimitiveLineIndicesEXT) {
+ if (spv_result_t error = ValidateArrayedI32Vec(
+ decoration, inst, 2,
+ [this, &inst, &decoration,
+ &vuid](const std::string& message) -> spv_result_t {
+ return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+ << _.VkErrorID(vuid) << "According to the "
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec BuiltIn "
+ << _.grammar().lookupOperandName(
+ SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0])
+ << " variable needs to be a 2-component 32-bit int "
+ "array."
+ << message;
+ })) {
+ return error;
+ }
+ }
+ if (builtin == spv::BuiltIn::PrimitiveTriangleIndicesEXT) {
+ if (spv_result_t error = ValidateArrayedI32Vec(
+ decoration, inst, 3,
+ [this, &inst, &decoration,
+ &vuid](const std::string& message) -> spv_result_t {
+ return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+ << _.VkErrorID(vuid) << "According to the "
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec BuiltIn "
+ << _.grammar().lookupOperandName(
+ SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0])
+ << " variable needs to be a 3-component 32-bit int "
+ "array."
+ << message;
+ })) {
+ return error;
+ }
+ }
+ }
+ // Seed at reference checks with this built-in.
+ return ValidateMeshShadingEXTBuiltinsAtReference(decoration, inst, inst,
+ inst);
+}
+
+spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst) {
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ const spv::BuiltIn builtin = spv::BuiltIn(decoration.params()[0]);
+ const spv::StorageClass storage_class =
+ GetStorageClass(referenced_from_inst);
+ if (storage_class != spv::StorageClass::Max &&
+ storage_class != spv::StorageClass::Output) {
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorStorageClass);
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << _.VkErrorID(vuid) << spvLogStringForEnv(_.context()->target_env)
+ << " spec allows BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+ uint32_t(builtin))
+ << " to be only used for variables with Output storage class. "
+ << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+ referenced_from_inst)
+ << " " << GetStorageClassDesc(referenced_from_inst);
+ }
+
+ for (const spv::ExecutionModel execution_model : execution_models_) {
+ if (execution_model != spv::ExecutionModel::MeshEXT) {
+ uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorExecutionModel);
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << _.VkErrorID(vuid)
+ << spvLogStringForEnv(_.context()->target_env)
+ << " spec allows BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+ uint32_t(builtin))
+ << " to be used only with MeshEXT execution model. "
+ << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+ referenced_from_inst, execution_model);
+ }
+ }
+ }
+
+ if (function_id_ == 0) {
+ // Propagate this rule to all dependant ids in the global scope.
+ id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+ std::bind(&BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtReference,
+ this, decoration, built_in_inst, referenced_from_inst,
+ std::placeholders::_1));
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition(
const Decoration& decoration, const Instruction& inst) {
const spv::BuiltIn label = spv::BuiltIn(decoration.params()[0]);
@@ -4283,6 +4452,11 @@
case spv::BuiltIn::CullMaskKHR: {
return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
}
+ case spv::BuiltIn::PrimitivePointIndicesEXT:
+ case spv::BuiltIn::PrimitiveLineIndicesEXT:
+ case spv::BuiltIn::PrimitiveTriangleIndicesEXT: {
+ return ValidateMeshShadingEXTBuiltinsAtDefinition(decoration, inst);
+ }
case spv::BuiltIn::PrimitiveShadingRateKHR: {
return ValidatePrimitiveShadingRateAtDefinition(decoration, inst);
}
diff --git a/third_party/SPIRV-Tools/source/val/validate_capability.cpp b/third_party/SPIRV-Tools/source/val/validate_capability.cpp
index 98dab42..81d2ad5 100644
--- a/third_party/SPIRV-Tools/source/val/validate_capability.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_capability.cpp
@@ -240,7 +240,7 @@
ExtensionSet operand_exts(operand_desc->numExtensions,
operand_desc->extensions);
- if (operand_exts.IsEmpty()) return false;
+ if (operand_exts.empty()) return false;
return _.HasAnyOfExtensions(operand_exts);
}
diff --git a/third_party/SPIRV-Tools/source/val/validate_cfg.cpp b/third_party/SPIRV-Tools/source/val/validate_cfg.cpp
index 7b3f748..9b7161f 100644
--- a/third_party/SPIRV-Tools/source/val/validate_cfg.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_cfg.cpp
@@ -190,6 +190,8 @@
"ID of an OpLabel instruction";
}
+ // A similar requirement for SPV_KHR_maximal_reconvergence is deferred until
+ // entry point call trees have been reconrded.
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 6) && true_id == false_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "In SPIR-V 1.6 or later, True Label and False Label must be "
@@ -875,6 +877,95 @@
return SPV_SUCCESS;
}
+spv_result_t MaximalReconvergenceChecks(ValidationState_t& _) {
+ // Find all the entry points with the MaximallyReconvergencesKHR execution
+ // mode.
+ std::unordered_set<uint32_t> maximal_funcs;
+ std::unordered_set<uint32_t> maximal_entry_points;
+ for (auto entry_point : _.entry_points()) {
+ const auto* exec_modes = _.GetExecutionModes(entry_point);
+ if (exec_modes &&
+ exec_modes->count(spv::ExecutionMode::MaximallyReconvergesKHR)) {
+ maximal_entry_points.insert(entry_point);
+ maximal_funcs.insert(entry_point);
+ }
+ }
+
+ if (maximal_entry_points.empty()) {
+ return SPV_SUCCESS;
+ }
+
+ // Find all the functions reachable from a maximal reconvergence entry point.
+ for (const auto& func : _.functions()) {
+ const auto& entry_points = _.EntryPointReferences(func.id());
+ for (auto id : entry_points) {
+ if (maximal_entry_points.count(id)) {
+ maximal_funcs.insert(func.id());
+ break;
+ }
+ }
+ }
+
+ // Check for conditional branches with the same true and false targets.
+ for (const auto& inst : _.ordered_instructions()) {
+ if (inst.opcode() == spv::Op::OpBranchConditional) {
+ const auto true_id = inst.GetOperandAs<uint32_t>(1);
+ const auto false_id = inst.GetOperandAs<uint32_t>(2);
+ if (true_id == false_id && maximal_funcs.count(inst.function()->id())) {
+ return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ << "In entry points using the MaximallyReconvergesKHR execution "
+ "mode, True Label and False Label must be different labels";
+ }
+ }
+ }
+
+ // Check for invalid multiple predecessors. Only loop headers, continue
+ // targets, merge targets or switch targets or defaults may have multiple
+ // unique predecessors.
+ for (const auto& func : _.functions()) {
+ if (!maximal_funcs.count(func.id())) continue;
+
+ for (const auto* block : func.ordered_blocks()) {
+ std::unordered_set<uint32_t> unique_preds;
+ const auto* preds = block->predecessors();
+ if (!preds) continue;
+
+ for (const auto* pred : *preds) {
+ unique_preds.insert(pred->id());
+ }
+ if (unique_preds.size() < 2) continue;
+
+ const auto* terminator = block->terminator();
+ const auto index = terminator - &_.ordered_instructions()[0];
+ const auto* pre_terminator = &_.ordered_instructions()[index - 1];
+ if (pre_terminator->opcode() == spv::Op::OpLoopMerge) continue;
+
+ const auto* label = _.FindDef(block->id());
+ bool ok = false;
+ for (const auto& pair : label->uses()) {
+ const auto* use_inst = pair.first;
+ switch (use_inst->opcode()) {
+ case spv::Op::OpSelectionMerge:
+ case spv::Op::OpLoopMerge:
+ case spv::Op::OpSwitch:
+ ok = true;
+ break;
+ default:
+ break;
+ }
+ }
+ if (!ok) {
+ return _.diag(SPV_ERROR_INVALID_CFG, label)
+ << "In entry points using the MaximallyReconvergesKHR "
+ "execution mode, this basic block must not have multiple "
+ "unique predecessors";
+ }
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t PerformCfgChecks(ValidationState_t& _) {
for (auto& function : _.functions()) {
// Check all referenced blocks are defined within a function
@@ -999,6 +1090,11 @@
return error;
}
}
+
+ if (auto error = MaximalReconvergenceChecks(_)) {
+ return error;
+ }
+
return SPV_SUCCESS;
}
diff --git a/third_party/SPIRV-Tools/source/val/validate_composites.cpp b/third_party/SPIRV-Tools/source/val/validate_composites.cpp
index 2b83c63..26486da 100644
--- a/third_party/SPIRV-Tools/source/val/validate_composites.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_composites.cpp
@@ -94,7 +94,7 @@
break;
}
- if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
if (component_index >= array_size) {
@@ -122,6 +122,7 @@
*member_type = type_inst->word(component_index + 2);
break;
}
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: {
*member_type = type_inst->word(2);
break;
@@ -288,7 +289,7 @@
}
uint64_t array_size = 0;
- if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
@@ -335,6 +336,25 @@
break;
}
+ case spv::Op::OpTypeCooperativeMatrixKHR: {
+ const auto result_type_inst = _.FindDef(result_type);
+ assert(result_type_inst);
+ const auto component_type_id =
+ result_type_inst->GetOperandAs<uint32_t>(1);
+
+ if (3 != num_operands) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Must be only one constituent";
+ }
+
+ const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+ if (operand_type_id != component_type_id) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Constituent type to be equal to the component type";
+ }
+ break;
+ }
case spv::Op::OpTypeCooperativeMatrixNV: {
const auto result_type_inst = _.FindDef(result_type);
assert(result_type_inst);
diff --git a/third_party/SPIRV-Tools/source/val/validate_constants.cpp b/third_party/SPIRV-Tools/source/val/validate_constants.cpp
index 006e504..4deaa49 100644
--- a/third_party/SPIRV-Tools/source/val/validate_constants.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_constants.cpp
@@ -243,6 +243,7 @@
}
}
} break;
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: {
if (1 != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -310,6 +311,7 @@
case spv::Op::OpTypeArray:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
diff --git a/third_party/SPIRV-Tools/source/val/validate_conversion.cpp b/third_party/SPIRV-Tools/source/val/validate_conversion.cpp
index 476c1fe..b2892a8 100644
--- a/third_party/SPIRV-Tools/source/val/validate_conversion.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_conversion.cpp
@@ -473,7 +473,10 @@
const bool input_is_pointer = _.IsPointerType(input_type);
const bool input_is_int_scalar = _.IsIntScalarType(input_type);
- if (!result_is_pointer && !result_is_int_scalar &&
+ const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
+ const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
+
+ if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
!_.IsIntVectorType(result_type) &&
!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type))
@@ -481,13 +484,24 @@
<< "Expected Result Type to be a pointer or int or float vector "
<< "or scalar type: " << spvOpcodeString(opcode);
- if (!input_is_pointer && !input_is_int_scalar &&
+ if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
!_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be a pointer or int or float vector "
<< "or scalar: " << spvOpcodeString(opcode);
+ if (result_is_coopmat != input_is_coopmat)
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix can only be cast to another cooperative "
+ << "matrix: " << spvOpcodeString(opcode);
+
+ if (result_is_coopmat) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ }
+
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
_.HasExtension(kSPV_KHR_physical_storage_buffer)) {
const bool result_is_int_vector = _.IsIntVectorType(result_type);
diff --git a/third_party/SPIRV-Tools/source/val/validate_decorations.cpp b/third_party/SPIRV-Tools/source/val/validate_decorations.cpp
index c1fca45..0a7df65 100644
--- a/third_party/SPIRV-Tools/source/val/validate_decorations.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_decorations.cpp
@@ -452,7 +452,16 @@
return ds;
};
- const auto& members = getStructMembers(struct_id, vstate);
+ // If we are checking physical storage buffer pointers, we may not actually
+ // have a struct here. Instead, pretend we have a struct with a single member
+ // at offset 0.
+ const auto& struct_type = vstate.FindDef(struct_id);
+ std::vector<uint32_t> members;
+ if (struct_type->opcode() == spv::Op::OpTypeStruct) {
+ members = getStructMembers(struct_id, vstate);
+ } else {
+ members.push_back(struct_id);
+ }
// To check for member overlaps, we want to traverse the members in
// offset order.
@@ -461,31 +470,38 @@
uint32_t offset;
};
std::vector<MemberOffsetPair> member_offsets;
- member_offsets.reserve(members.size());
- for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size());
- memberIdx < numMembers; memberIdx++) {
- uint32_t offset = 0xffffffff;
- auto member_decorations =
- vstate.id_member_decorations(struct_id, memberIdx);
- for (auto decoration = member_decorations.begin;
- decoration != member_decorations.end; ++decoration) {
- assert(decoration->struct_member_index() == (int)memberIdx);
- switch (decoration->dec_type()) {
- case spv::Decoration::Offset:
- offset = decoration->params()[0];
- break;
- default:
- break;
+
+ // With physical storage buffers, we might be checking layouts that do not
+ // originate from a structure.
+ if (struct_type->opcode() == spv::Op::OpTypeStruct) {
+ member_offsets.reserve(members.size());
+ for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size());
+ memberIdx < numMembers; memberIdx++) {
+ uint32_t offset = 0xffffffff;
+ auto member_decorations =
+ vstate.id_member_decorations(struct_id, memberIdx);
+ for (auto decoration = member_decorations.begin;
+ decoration != member_decorations.end; ++decoration) {
+ assert(decoration->struct_member_index() == (int)memberIdx);
+ switch (decoration->dec_type()) {
+ case spv::Decoration::Offset:
+ offset = decoration->params()[0];
+ break;
+ default:
+ break;
+ }
}
+ member_offsets.push_back(
+ MemberOffsetPair{memberIdx, incoming_offset + offset});
}
- member_offsets.push_back(
- MemberOffsetPair{memberIdx, incoming_offset + offset});
+ std::stable_sort(
+ member_offsets.begin(), member_offsets.end(),
+ [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) {
+ return lhs.offset < rhs.offset;
+ });
+ } else {
+ member_offsets.push_back({0, 0});
}
- std::stable_sort(
- member_offsets.begin(), member_offsets.end(),
- [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) {
- return lhs.offset < rhs.offset;
- });
// Now scan from lowest offset to highest offset.
uint32_t nextValidOffset = 0;
@@ -906,9 +922,9 @@
}
}
- if (vstate.HasCapability(
- spv::Capability::WorkgroupMemoryExplicitLayoutKHR) &&
- num_workgroup_variables > 0 &&
+ const bool workgroup_blocks_allowed = vstate.HasCapability(
+ spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
+ if (workgroup_blocks_allowed && num_workgroup_variables > 0 &&
num_workgroup_variables_with_block > 0) {
if (num_workgroup_variables != num_workgroup_variables_with_block) {
return vstate.diag(SPV_ERROR_INVALID_BINARY, vstate.FindDef(entry_point))
@@ -929,6 +945,13 @@
"Entry point id "
<< entry_point << " does not meet this requirement.";
}
+ } else if (!workgroup_blocks_allowed &&
+ num_workgroup_variables_with_block > 0) {
+ return vstate.diag(SPV_ERROR_INVALID_BINARY,
+ vstate.FindDef(entry_point))
+ << "Workgroup Storage Class variables can't be decorated with "
+ "Block unless declaring the WorkgroupMemoryExplicitLayoutKHR "
+ "capability.";
}
}
}
@@ -1023,6 +1046,8 @@
std::unordered_set<uint32_t> uses_push_constant;
for (const auto& inst : vstate.ordered_instructions()) {
const auto& words = inst.words();
+ auto type_id = inst.type_id();
+ const Instruction* type_inst = vstate.FindDef(type_id);
if (spv::Op::OpVariable == inst.opcode()) {
const auto var_id = inst.id();
// For storage class / decoration combinations, see Vulkan 14.5.4 "Offset
@@ -1276,6 +1301,23 @@
}
}
}
+ } else if (type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
+ type_inst->GetOperandAs<spv::StorageClass>(1u) ==
+ spv::StorageClass::PhysicalStorageBuffer) {
+ const bool scalar_block_layout = vstate.options()->scalar_block_layout;
+ MemberConstraints constraints;
+ const bool buffer = true;
+ const auto data_type_id = type_inst->GetOperandAs<uint32_t>(2u);
+ const auto* data_type_inst = vstate.FindDef(data_type_id);
+ if (data_type_inst->opcode() == spv::Op::OpTypeStruct) {
+ ComputeMemberConstraintsForStruct(&constraints, data_type_id,
+ LayoutConstraints(), vstate);
+ }
+ if (auto res = checkLayout(data_type_id, "PhysicalStorageBuffer", "Block",
+ !buffer, scalar_block_layout, 0, constraints,
+ vstate)) {
+ return res;
+ }
}
}
return SPV_SUCCESS;
@@ -1283,21 +1325,14 @@
// Returns true if |decoration| cannot be applied to the same id more than once.
bool AtMostOncePerId(spv::Decoration decoration) {
- return decoration == spv::Decoration::ArrayStride;
+ return decoration != spv::Decoration::UserSemantic &&
+ decoration != spv::Decoration::FuncParamAttr;
}
// Returns true if |decoration| cannot be applied to the same member more than
// once.
bool AtMostOncePerMember(spv::Decoration decoration) {
- switch (decoration) {
- case spv::Decoration::Offset:
- case spv::Decoration::MatrixStride:
- case spv::Decoration::RowMajor:
- case spv::Decoration::ColMajor:
- return true;
- default:
- return false;
- }
+ return decoration != spv::Decoration::UserSemantic;
}
spv_result_t CheckDecorationsCompatibility(ValidationState_t& vstate) {
@@ -1514,7 +1549,8 @@
const auto opcode = inst.opcode();
const auto type_id = inst.type_id();
if (opcode != spv::Op::OpVariable &&
- opcode != spv::Op::OpFunctionParameter) {
+ opcode != spv::Op::OpFunctionParameter &&
+ opcode != spv::Op::OpRawAccessChainNV) {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration must be a memory object "
"declaration (a variable or a function parameter)";
@@ -1527,10 +1563,11 @@
vstate.features().nonwritable_var_in_function_or_private) {
// New permitted feature in SPIR-V 1.4.
} else if (
- // It may point to a UBO, SSBO, or storage image.
+ // It may point to a UBO, SSBO, storage image, or raw access chain.
vstate.IsPointerToUniformBlock(type_id) ||
vstate.IsPointerToStorageBuffer(type_id) ||
- vstate.IsPointerToStorageImage(type_id)) {
+ vstate.IsPointerToStorageImage(type_id) ||
+ opcode == spv::Op::OpRawAccessChainNV) {
} else {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration is invalid: must point to a "
diff --git a/third_party/SPIRV-Tools/source/val/validate_extensions.cpp b/third_party/SPIRV-Tools/source/val/validate_extensions.cpp
index 0ac62bf..7b73c9c 100644
--- a/third_party/SPIRV-Tools/source/val/validate_extensions.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_extensions.cpp
@@ -482,8 +482,8 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateClspvReflectionArgumentOffsetBuffer(ValidationState_t& _,
- const Instruction* inst) {
+spv_result_t ValidateClspvReflectionArgumentOffsetBuffer(
+ ValidationState_t& _, const Instruction* inst) {
const auto num_operands = inst->operands().size();
if (auto error = ValidateKernelDecl(_, inst)) {
return error;
@@ -802,7 +802,7 @@
}
spv_result_t ValidateClspvReflectionPrintfInfo(ValidationState_t& _,
- const Instruction* inst) {
+ const Instruction* inst) {
if (!IsUint32Constant(_, inst->GetOperandAs<uint32_t>(4))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "PrintfID must be a 32-bit unsigned integer OpConstant";
@@ -823,8 +823,8 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateClspvReflectionPrintfStorageBuffer(ValidationState_t& _,
- const Instruction* inst) {
+spv_result_t ValidateClspvReflectionPrintfStorageBuffer(
+ ValidationState_t& _, const Instruction* inst) {
if (!IsUint32Constant(_, inst->GetOperandAs<uint32_t>(4))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "DescriptorSet must be a 32-bit unsigned integer OpConstant";
@@ -843,8 +843,8 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateClspvReflectionPrintfPushConstant(ValidationState_t& _,
- const Instruction* inst) {
+spv_result_t ValidateClspvReflectionPrintfPushConstant(
+ ValidationState_t& _, const Instruction* inst) {
if (!IsUint32Constant(_, inst->GetOperandAs<uint32_t>(4))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Offset must be a 32-bit unsigned integer OpConstant";
@@ -3100,7 +3100,7 @@
uint32_t vector_count = inst->word(6);
uint64_t const_val;
- if (!_.GetConstantValUint64(vector_count, &const_val)) {
+ if (!_.EvalConstantValUint64(vector_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Vector Count must be 32-bit integer OpConstant";
@@ -3168,16 +3168,16 @@
break;
}
case CommonDebugInfoDebugTypePointer: {
- auto validate_base_type =
- ValidateOperandBaseType(_, inst, 5, ext_inst_name);
+ auto validate_base_type = ValidateOperandDebugType(
+ _, "Base Type", inst, 5, ext_inst_name, false);
if (validate_base_type != SPV_SUCCESS) return validate_base_type;
CHECK_CONST_UINT_OPERAND("Storage Class", 6);
CHECK_CONST_UINT_OPERAND("Flags", 7);
break;
}
case CommonDebugInfoDebugTypeQualifier: {
- auto validate_base_type =
- ValidateOperandBaseType(_, inst, 5, ext_inst_name);
+ auto validate_base_type = ValidateOperandDebugType(
+ _, "Base Type", inst, 5, ext_inst_name, false);
if (validate_base_type != SPV_SUCCESS) return validate_base_type;
CHECK_CONST_UINT_OPERAND("Type Qualifier", 6);
break;
@@ -3191,7 +3191,7 @@
uint32_t component_count = inst->word(6);
if (vulkanDebugInfo) {
uint64_t const_val;
- if (!_.GetConstantValUint64(component_count, &const_val)) {
+ if (!_.EvalConstantValUint64(component_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Component Count must be 32-bit integer OpConstant";
diff --git a/third_party/SPIRV-Tools/source/val/validate_id.cpp b/third_party/SPIRV-Tools/source/val/validate_id.cpp
index 92a4e8e..bcfeb59 100644
--- a/third_party/SPIRV-Tools/source/val/validate_id.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_id.cpp
@@ -163,9 +163,12 @@
!inst->IsDebugInfo() && !inst->IsNonSemantic() &&
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+ opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp &&
- spv::Op(inst->word(3)) ==
- spv::Op::OpCooperativeMatrixLengthNV)) {
+ (spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthNV ||
+ spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " cannot be a type";
@@ -179,9 +182,12 @@
opcode != spv::Op::OpLoopMerge &&
opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+ opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp &&
- spv::Op(inst->word(3)) ==
- spv::Op::OpCooperativeMatrixLengthNV)) {
+ (spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthNV ||
+ spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " requires a type";
diff --git a/third_party/SPIRV-Tools/source/val/validate_image.cpp b/third_party/SPIRV-Tools/source/val/validate_image.cpp
index 733556b..a1a76ea 100644
--- a/third_party/SPIRV-Tools/source/val/validate_image.cpp
+++ b/third_party/SPIRV-Tools/source/val/validate_image.cpp
@@ -1,4 +1,4 @@
-// Copyright (c) 2017 Google Inc.
+// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
// reserved.
//
@@ -297,7 +297,6 @@
spv::ImageOperandsMask::ConstOffsets |
spv::ImageOperandsMask::Offsets)) > 1) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << _.VkErrorID(4662)
<< "Image Operands Offset, ConstOffset, ConstOffsets, Offsets "
"cannot be used together";
}
@@ -496,7 +495,7 @@
}
uint64_t array_size = 0;
- if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
@@ -694,16 +693,11 @@
<< "storage image";
}
- if (info.multisampled == 1 &&
+ if (info.multisampled == 1 && info.arrayed == 1 && info.sampled == 2 &&
!_.HasCapability(spv::Capability::ImageMSArray)) {
-#if 0
- // TODO(atgoo@github.com) The description of this rule in the spec
- // is unclear and Glslang doesn't declare ImageMSArray. Need to clarify
- // and reenable.
return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Capability ImageMSArray is required to access storage "
- << "image";
-#endif
+ << "Capability ImageMSArray is required to access storage "
+ << "image";
}
} else if (info.sampled != 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -920,7 +914,15 @@
if (info.dim == spv::Dim::SubpassData && info.arrayed != 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << _.VkErrorID(6214) << "Dim SubpassData requires Arrayed to be 0";
+ << _.VkErrorID(6214)
+ << "Dim SubpassData requires Arrayed to be 0 in the Vulkan "
+ "environment";
+ }
+
+ if (info.dim == spv::Dim::Rect) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << _.VkErrorID(9638)
+ << "Dim must not be Rect in the Vulkan environment";
}
}
@@ -984,6 +986,14 @@
case spv::Op::OpImageSparseGather:
case spv::Op::OpImageSparseDrefGather:
case spv::Op::OpCopyObject:
+ case spv::Op::OpImageSampleWeightedQCOM:
+ case spv::Op::OpImageBoxFilterQCOM:
+ case spv::Op::OpImageBlockMatchSSDQCOM:
+ case spv::Op::OpImageBlockMatchSADQCOM:
+ case spv::Op::OpImageBlockMatchWindowSADQCOM:
+ case spv::Op::OpImageBlockMatchWindowSSDQCOM:
+ case spv::Op::OpImageBlockMatchGatherSADQCOM:
+ case spv::Op::OpImageBlockMatchGatherSSDQCOM:
return true;
case spv::Op::OpStore:
if (_.HasCapability(spv::Capability::BindlessTextureNV)) return true;
@@ -1087,6 +1097,18 @@
}
}
}
+
+ const Instruction* ld_inst;
+ {
+ int t_idx = inst->GetOperandAs<int>(2);
+ ld_inst = _.FindDef(t_idx);
+ }
+
+ if (ld_inst->opcode() == spv::Op::OpLoad) {