shaderSampledImageArrayNonUniformIndexing support

This CL implements the shaderSampledImageArrayNonUniformIndexing
feature, which allows arrays of samplers or sampled images to be
indexed by non-uniform integer expressions in shader code.

This is a brute force implementation, which simply repeats the
sampling operation for every lane. This leaves room for future
optimization when some of them or all of them don't diverge.

Bug: b/236957718
Change-Id: Ic4d8b58b1bd8281a07c28a43d9a96f3ae2ea9bf4
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/67810
Tested-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Commit-Queue: Alexis Hétu <sugoi@google.com>
Presubmit-Ready: Alexis Hétu <sugoi@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index eac3fef..ee3e1a5 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -459,6 +459,7 @@
 				case spv::CapabilityUniformTexelBufferArrayDynamicIndexing: capabilities.UniformTexelBufferArrayDynamicIndexing = true; break;
 				case spv::CapabilityStorageTexelBufferArrayDynamicIndexing: capabilities.StorageTexelBufferArrayDynamicIndexing = true; break;
 				case spv::CapabilityUniformBufferArrayNonUniformIndexing: capabilities.UniformBufferArrayNonUniformIndex = true; break;
+				case spv::CapabilitySampledImageArrayNonUniformIndexing: capabilities.SampledImageArrayNonUniformIndexing = true; break;
 				case spv::CapabilityPhysicalStorageBufferAddresses: capabilities.PhysicalStorageBufferAddresses = true; break;
 				default:
 					UNSUPPORTED("Unsupported capability %u", insn.word(1));
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 51df887..74dc1c6 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -393,10 +393,10 @@
 		Kind kind = Kind::Unknown;
 	};
 
-	class SampledImage : public SIMD::Pointer
+	class SampledImagePointer : public SIMD::Pointer
 	{
 	public:
-		SampledImage(SIMD::Pointer image, Object::ID sampler)
+		SampledImagePointer(SIMD::Pointer image, Object::ID sampler)
 		    : SIMD::Pointer(image)
 		    , samplerId(sampler)
 		{}
@@ -749,6 +749,7 @@
 		bool UniformTexelBufferArrayNonUniformIndexing : 1;
 		bool UniformTexelBufferArrayDynamicIndexing : 1;
 		bool UniformBufferArrayNonUniformIndex : 1;
+		bool SampledImageArrayNonUniformIndexing : 1;
 		bool PhysicalStorageBufferAddresses : 1;
 	};
 
@@ -1185,13 +1186,13 @@
 			return it->second;
 		}
 
-		void createSampledImage(Object::ID id, SampledImage ptr)
+		void createSampledImage(Object::ID id, SampledImagePointer ptr)
 		{
 			bool added = sampledImages.emplace(id, ptr).second;
 			ASSERT_MSG(added, "Sampled image %d created twice", id.value());
 		}
 
-		SampledImage const &getSampledImage(Object::ID id) const
+		SampledImagePointer const &getSampledImage(Object::ID id) const
 		{
 			auto it = sampledImages.find(id);
 			ASSERT_MSG(it != sampledImages.end(), "Unknown sampled image %d", id.value());
@@ -1211,7 +1212,7 @@
 	private:
 		std::unordered_map<Object::ID, Intermediate> intermediates;
 		std::unordered_map<Object::ID, SIMD::Pointer> pointers;
-		std::unordered_map<Object::ID, SampledImage> sampledImages;
+		std::unordered_map<Object::ID, SampledImagePointer> sampledImages;
 
 		const unsigned int multiSampleCount;
 	};
@@ -1278,7 +1279,7 @@
 			return (pointer != nullptr);
 		}
 
-		const SampledImage &SampledImage(uint32_t i) const
+		const SampledImagePointer &SampledImage(uint32_t i) const
 		{
 			ASSERT(intermediate == nullptr);
 
@@ -1299,7 +1300,7 @@
 		const uint32_t *constant;
 		const Intermediate *intermediate;
 		const SIMD::Pointer *pointer;
-		const SpirvShader::SampledImage *sampledImage;
+		const SampledImagePointer *sampledImage;
 
 	public:
 		const uint32_t componentCount;
@@ -1448,7 +1449,9 @@
 	// Emits code to sample an image, regardless of whether any SIMD lanes are active.
 	void EmitImageSampleUnconditional(Array<SIMD::Float> &out, const ImageInstruction &instruction, EmitState *state) const;
 
-	Pointer<Byte> lookupSamplerFunction(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, EmitState *state) const;
+	Pointer<Byte> getSamplerDescriptor(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, EmitState *state) const;
+	Pointer<Byte> getSamplerDescriptor(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, int laneIdx, EmitState *state) const;
+	Pointer<Byte> lookupSamplerFunction(Pointer<Byte> imageDescriptor, Pointer<Byte> samplerDescriptor, const ImageInstruction &instruction, EmitState *state) const;
 	void callSamplerFunction(Pointer<Byte> samplerFunction, Array<SIMD::Float> &out, Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, EmitState *state) const;
 
 	void GetImageDimensions(EmitState const *state, Type const &resultTy, Object::ID imageId, Object::ID lodId, Intermediate &dst) const;
diff --git a/src/Pipeline/SpirvShaderImage.cpp b/src/Pipeline/SpirvShaderImage.cpp
index 4f28a39..1e7c97b 100644
--- a/src/Pipeline/SpirvShaderImage.cpp
+++ b/src/Pipeline/SpirvShaderImage.cpp
@@ -101,7 +101,7 @@
 
 			if(state->isSampledImage(sampledImageId))  // Result of an OpSampledImage instruction
 			{
-				const SampledImage &sampledImage = state->getSampledImage(sampledImageId);
+				const SampledImagePointer &sampledImage = state->getSampledImage(sampledImageId);
 				imageId = spirv.getObject(sampledImageId).definition.word(3);
 				samplerId = sampledImage.samplerId;
 			}
@@ -350,23 +350,56 @@
 
 void SpirvShader::EmitImageSampleUnconditional(Array<SIMD::Float> &out, const ImageInstruction &instruction, EmitState *state) const
 {
-	Pointer<Byte> imageDescriptor = state->getImage(instruction.imageId).getUniformPointer();  // vk::SampledImageDescriptor*
+	auto decorations = GetDecorationsForId(instruction.imageId);
 
-	Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, instruction, state);
+	if(decorations.NonUniform)
+	{
+		SIMD::Int activeLaneMask = state->activeLaneMask();
+		SIMD::Pointer imagePointer = state->getImage(instruction.imageId);
+		// PerLane output
+		for(int laneIdx = 0; laneIdx < SIMD::Width; laneIdx++)
+		{
+			Array<SIMD::Float> laneOut(out.getArraySize());
+			If(Extract(activeLaneMask, laneIdx) != 0)
+			{
+				Pointer<Byte> imageDescriptor = imagePointer.getPointerForLane(laneIdx);  // vk::SampledImageDescriptor*
+				Pointer<Byte> samplerDescriptor = getSamplerDescriptor(imageDescriptor, instruction, laneIdx, state);
 
-	callSamplerFunction(samplerFunction, out, imageDescriptor, instruction, state);
+				Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, samplerDescriptor, instruction, state);
+
+				callSamplerFunction(samplerFunction, laneOut, imageDescriptor, instruction, state);
+			}
+
+			for(int outIdx = 0; outIdx < out.getArraySize(); outIdx++)
+			{
+				out[outIdx] = Insert(out[outIdx], Extract(laneOut[outIdx], laneIdx), laneIdx);
+			}
+		}
+	}
+	else
+	{
+		Pointer<Byte> imageDescriptor = state->getImage(instruction.imageId).getUniformPointer();  // vk::SampledImageDescriptor*
+		Pointer<Byte> samplerDescriptor = getSamplerDescriptor(imageDescriptor, instruction, state);
+
+		Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, samplerDescriptor, instruction, state);
+
+		callSamplerFunction(samplerFunction, out, imageDescriptor, instruction, state);
+	}
 }
 
-Pointer<Byte> SpirvShader::lookupSamplerFunction(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, EmitState *state) const
+Pointer<Byte> SpirvShader::getSamplerDescriptor(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, EmitState *state) const
 {
-	Int samplerId = 0;
+	return ((instruction.samplerId == instruction.imageId) || (instruction.samplerId == 0)) ? imageDescriptor : state->getImage(instruction.samplerId).getUniformPointer();
+}
 
-	if(instruction.samplerId != 0)
-	{
-		Pointer<Byte> samplerDescriptor = state->getPointer(instruction.samplerId).getUniformPointer();  // vk::SampledImageDescriptor*
+Pointer<Byte> SpirvShader::getSamplerDescriptor(Pointer<Byte> imageDescriptor, const ImageInstruction &instruction, int laneIdx, EmitState *state) const
+{
+	return ((instruction.samplerId == instruction.imageId) || (instruction.samplerId == 0)) ? imageDescriptor : state->getImage(instruction.samplerId).getPointerForLane(laneIdx);
+}
 
-		samplerId = *Pointer<rr::Int>(samplerDescriptor + OFFSET(vk::SampledImageDescriptor, samplerId));  // vk::Sampler::id
-	}
+Pointer<Byte> SpirvShader::lookupSamplerFunction(Pointer<Byte> imageDescriptor, Pointer<Byte> samplerDescriptor, const ImageInstruction &instruction, EmitState *state) const
+{
+	Int samplerId = (instruction.samplerId != 0) ? *Pointer<rr::Int>(samplerDescriptor + OFFSET(vk::SampledImageDescriptor, samplerId)) : Int(0);
 
 	auto &cache = state->routine->samplerCache.at(instruction.position);
 	Bool cacheHit = (cache.imageDescriptor == imageDescriptor) && (cache.samplerId == samplerId);  // TODO(b/205566405): Skip sampler ID check for samplerless instructions.
@@ -1252,17 +1285,18 @@
 	SIMD::Pointer ptr = state->getPointer(instruction.imageId);
 	if(ptr.isBasePlusOffset)
 	{
-		Pointer<Byte> descriptor = ptr.getUniformPointer();  // vk::StorageImageDescriptor*
+		Pointer<Byte> imageDescriptor = ptr.getUniformPointer();  // vk::StorageImageDescriptor* or vk::SampledImageDescriptor*
+		Pointer<Byte> samplerDescriptor = getSamplerDescriptor(imageDescriptor, instruction, state);
 
 		if(imageFormat == VK_FORMAT_UNDEFINED)  // spv::ImageFormatUnknown
 		{
-			Pointer<Byte> samplerFunction = lookupSamplerFunction(descriptor, instruction, state);
+			Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, samplerDescriptor, instruction, state);
 
-			Call<ImageSampler>(samplerFunction, descriptor, &coord, &texelAndMask, state->routine->constants);
+			Call<ImageSampler>(samplerFunction, imageDescriptor, &coord, &texelAndMask, state->routine->constants);
 		}
 		else
 		{
-			WriteImage(instruction, descriptor, &coord, &texelAndMask, imageFormat);
+			WriteImage(instruction, imageDescriptor, &coord, &texelAndMask, imageFormat);
 		}
 	}
 	else
@@ -1272,16 +1306,17 @@
 			SIMD::Int singleLaneMask = 0;
 			singleLaneMask = Insert(singleLaneMask, 0xffffffff, j);
 			texelAndMask[4] = state->activeStoresAndAtomicsMask() & singleLaneMask;
-			Pointer<Byte> descriptor = ptr.getPointerForLane(j);
+			Pointer<Byte> imageDescriptor = ptr.getPointerForLane(j);
+			Pointer<Byte> samplerDescriptor = getSamplerDescriptor(imageDescriptor, instruction, j, state);
 			if(imageFormat == VK_FORMAT_UNDEFINED)  // spv::ImageFormatUnknown
 			{
-				Pointer<Byte> samplerFunction = lookupSamplerFunction(descriptor, instruction, state);
+				Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, samplerDescriptor, instruction, state);
 
-				Call<ImageSampler>(samplerFunction, descriptor, &coord, &texelAndMask, state->routine->constants);
+				Call<ImageSampler>(samplerFunction, imageDescriptor, &coord, &texelAndMask, state->routine->constants);
 			}
 			else
 			{
-				WriteImage(instruction, descriptor, &coord, &texelAndMask, imageFormat);
+				WriteImage(instruction, imageDescriptor, &coord, &texelAndMask, imageFormat);
 			}
 		}
 	}
diff --git a/src/Vulkan/VkPhysicalDevice.cpp b/src/Vulkan/VkPhysicalDevice.cpp
index 0f8d200..59eef6b 100644
--- a/src/Vulkan/VkPhysicalDevice.cpp
+++ b/src/Vulkan/VkPhysicalDevice.cpp
@@ -247,7 +247,7 @@
 	features->shaderUniformTexelBufferArrayDynamicIndexing = VK_TRUE;
 	features->shaderStorageTexelBufferArrayDynamicIndexing = VK_TRUE;
 	features->shaderUniformBufferArrayNonUniformIndexing = VK_TRUE;
-	features->shaderSampledImageArrayNonUniformIndexing = VK_FALSE;
+	features->shaderSampledImageArrayNonUniformIndexing = VK_TRUE;
 	features->shaderStorageBufferArrayNonUniformIndexing = VK_TRUE;
 	features->shaderStorageImageArrayNonUniformIndexing = VK_FALSE;
 	features->shaderInputAttachmentArrayNonUniformIndexing = VK_FALSE;