OpSampledImage propagation fix

This CL fixes issues when the result of an OpSampledImage instruction
is copied. We can no longer rely on the following instructions being
able to query if the previous instruction was OpSampledImage, since
it's also acceptable to use copy operations on the result of
OpSampledImage operations.

If a sampler is associated with an image, either when performing an
OpImage, OpSampledImage or OpCopy* operation, the sampler is now
properly propagated and queried in the ImageInstruction constructor
in order to retrieve the proper sampler, even if it doesn't come
directly from an OpSampledImage instruction.

Bug: b/236957718
Change-Id: I7e6a7834f955fb0be5a50f4f5b1717a471a7aaf6
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/67809
Tested-by: Alexis Hétu <sugoi@google.com>
Presubmit-Ready: Alexis Hétu <sugoi@google.com>
Commit-Queue: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 875d534..eac3fef 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -1710,9 +1710,12 @@
 
 	switch(getType(typeId).opcode())
 	{
+	case spv::OpTypeSampledImage:
+		object.kind = Object::Kind::SampledImage;
+		break;
+
 	case spv::OpTypePointer:
 	case spv::OpTypeImage:
-	case spv::OpTypeSampledImage:
 	case spv::OpTypeSampler:
 		object.kind = Object::Kind::Pointer;
 		break;
@@ -2183,7 +2186,7 @@
 	case spv::OpImageDrefGather:
 	case spv::OpImageFetch:
 	case spv::OpImageQueryLod:
-		return EmitImageSample(ImageInstruction(insn, *this), state);
+		return EmitImageSample(ImageInstruction(insn, *this, state), state);
 
 	case spv::OpImageQuerySizeLod:
 		return EmitImageQuerySizeLod(insn, state);
@@ -2198,17 +2201,19 @@
 		return EmitImageQuerySamples(insn, state);
 
 	case spv::OpImageRead:
-		return EmitImageRead(ImageInstruction(insn, *this), state);
+		return EmitImageRead(ImageInstruction(insn, *this, state), state);
 
 	case spv::OpImageWrite:
-		return EmitImageWrite(ImageInstruction(insn, *this), state);
+		return EmitImageWrite(ImageInstruction(insn, *this, state), state);
 
 	case spv::OpImageTexelPointer:
-		return EmitImageTexelPointer(ImageInstruction(insn, *this), state);
+		return EmitImageTexelPointer(ImageInstruction(insn, *this, state), state);
 
 	case spv::OpSampledImage:
+		return EmitSampledImage(insn, state);
+
 	case spv::OpImage:
-		return EmitSampledImageCombineOrSplit(insn, state);
+		return EmitImage(insn, state);
 
 	case spv::OpCopyObject:
 	case spv::OpCopyLogical:
@@ -2650,6 +2655,10 @@
 	{
 		state->createPointer(insn.resultId(), src.Pointer(0));
 	}
+	else if(src.isSampledImage())
+	{
+		state->createSampledImage(insn.resultId(), src.SampledImage(0));
+	}
 	else
 	{
 		auto type = getType(insn.resultTypeId());
@@ -2797,15 +2806,17 @@
     : constant(object.kind == SpirvShader::Object::Kind::Constant ? object.constantValue.data() : nullptr)
     , intermediate(object.kind == SpirvShader::Object::Kind::Intermediate ? &state->getIntermediate(object.id()) : nullptr)
     , pointer(object.kind == SpirvShader::Object::Kind::Pointer ? &state->getPointer(object.id()) : nullptr)
+    , sampledImage(object.kind == SpirvShader::Object::Kind::SampledImage ? &state->getSampledImage(object.id()) : nullptr)
     , componentCount(intermediate ? intermediate->componentCount : object.constantValue.size())
 {
-	ASSERT(intermediate || constant || pointer);
+	ASSERT(intermediate || constant || pointer || sampledImage);
 }
 
 SpirvShader::Operand::Operand(const Intermediate &value)
     : constant(nullptr)
     , intermediate(&value)
     , pointer(nullptr)
+    , sampledImage(nullptr)
     , componentCount(value.componentCount)
 {
 }
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index b0288b2..51df887 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -153,6 +153,8 @@
 
 class SpirvShader
 {
+	class EmitState;
+
 public:
 	SpirvBinary insns;
 
@@ -380,6 +382,9 @@
 			// Pointer held by SpirvRoutine::pointers
 			Pointer,
 
+			// Combination of an image pointer and a sampler ID
+			SampledImage,
+
 			// A pointer to a vk::DescriptorSet*.
 			// Pointer held by SpirvRoutine::pointers.
 			DescriptorSet,
@@ -388,6 +393,16 @@
 		Kind kind = Kind::Unknown;
 	};
 
+	class SampledImage : public SIMD::Pointer
+	{
+	public:
+		SampledImage(SIMD::Pointer image, Object::ID sampler)
+		    : SIMD::Pointer(image)
+		    , samplerId(sampler)
+		{}
+		Object::ID samplerId;
+	};
+
 	// Block is an interval of SPIR-V instructions, starting with the
 	// opening OpLabel, and ending with a termination instruction.
 	class Block
@@ -606,7 +621,7 @@
 
 	struct ImageInstruction : public ImageInstructionSignature
 	{
-		ImageInstruction(InsnIterator insn, const SpirvShader &spirv);
+		ImageInstruction(InsnIterator insn, const SpirvShader &spirv, EmitState *state);
 
 		const uint32_t position;
 
@@ -1170,9 +1185,33 @@
 			return it->second;
 		}
 
+		void createSampledImage(Object::ID id, SampledImage ptr)
+		{
+			bool added = sampledImages.emplace(id, ptr).second;
+			ASSERT_MSG(added, "Sampled image %d created twice", id.value());
+		}
+
+		SampledImage const &getSampledImage(Object::ID id) const
+		{
+			auto it = sampledImages.find(id);
+			ASSERT_MSG(it != sampledImages.end(), "Unknown sampled image %d", id.value());
+			return it->second;
+		}
+
+		bool isSampledImage(Object::ID id) const
+		{
+			return sampledImages.find(id) != sampledImages.end();
+		}
+
+		SIMD::Pointer const &getImage(Object::ID id) const
+		{
+			return isSampledImage(id) ? getSampledImage(id) : getPointer(id);
+		}
+
 	private:
 		std::unordered_map<Object::ID, Intermediate> intermediates;
 		std::unordered_map<Object::ID, SIMD::Pointer> pointers;
+		std::unordered_map<Object::ID, SampledImage> sampledImages;
 
 		const unsigned int multiSampleCount;
 	};
@@ -1239,6 +1278,18 @@
 			return (pointer != nullptr);
 		}
 
+		const SampledImage &SampledImage(uint32_t i) const
+		{
+			ASSERT(intermediate == nullptr);
+
+			return sampledImage[i];
+		}
+
+		bool isSampledImage() const
+		{
+			return (sampledImage != nullptr);
+		}
+
 	private:
 		RR_PRINT_ONLY(friend struct rr::PrintValue::Ty<Operand>;)
 
@@ -1248,6 +1299,7 @@
 		const uint32_t *constant;
 		const Intermediate *intermediate;
 		const SIMD::Pointer *pointer;
+		const SpirvShader::SampledImage *sampledImage;
 
 	public:
 		const uint32_t componentCount;
@@ -1383,7 +1435,8 @@
 	EmitResult EmitImageTexelPointer(const ImageInstruction &instruction, EmitState *state) const;
 	EmitResult EmitAtomicOp(InsnIterator insn, EmitState *state) const;
 	EmitResult EmitAtomicCompareExchange(InsnIterator insn, EmitState *state) const;
-	EmitResult EmitSampledImageCombineOrSplit(InsnIterator insn, EmitState *state) const;
+	EmitResult EmitSampledImage(InsnIterator insn, EmitState *state) const;
+	EmitResult EmitImage(InsnIterator insn, EmitState *state) const;
 	EmitResult EmitCopyObject(InsnIterator insn, EmitState *state) const;
 	EmitResult EmitCopyMemory(InsnIterator insn, EmitState *state) const;
 	EmitResult EmitControlBarrier(InsnIterator insn, EmitState *state) const;
diff --git a/src/Pipeline/SpirvShaderImage.cpp b/src/Pipeline/SpirvShaderImage.cpp
index 764ad3e..4f28a39 100644
--- a/src/Pipeline/SpirvShaderImage.cpp
+++ b/src/Pipeline/SpirvShaderImage.cpp
@@ -74,7 +74,7 @@
 	}
 }
 
-SpirvShader::ImageInstruction::ImageInstruction(InsnIterator insn, const SpirvShader &spirv)
+SpirvShader::ImageInstruction::ImageInstruction(InsnIterator insn, const SpirvShader &spirv, EmitState *state)
     : ImageInstructionSignature(parseVariantAndMethod(insn))
     , position(insn.distanceFrom(spirv.begin()))
 {
@@ -95,13 +95,15 @@
 		}
 		else
 		{
+			// sampledImageId is either the result of an OpSampledImage instruction or
+			// an externally combined sampler and image.
 			Object::ID sampledImageId = insn.word(3);
-			const Object &sampledImage = spirv.getObject(sampledImageId);
 
-			if(sampledImage.opcode() == spv::OpSampledImage)
+			if(state->isSampledImage(sampledImageId))  // Result of an OpSampledImage instruction
 			{
-				imageId = sampledImage.definition.word(3);
-				samplerId = sampledImage.definition.word(4);
+				const SampledImage &sampledImage = state->getSampledImage(sampledImageId);
+				imageId = spirv.getObject(sampledImageId).definition.word(3);
+				samplerId = sampledImage.samplerId;
 			}
 			else  // Combined image/sampler
 			{
@@ -348,7 +350,7 @@
 
 void SpirvShader::EmitImageSampleUnconditional(Array<SIMD::Float> &out, const ImageInstruction &instruction, EmitState *state) const
 {
-	Pointer<Byte> imageDescriptor = state->getPointer(instruction.imageId).getUniformPointer();  // vk::SampledImageDescriptor*
+	Pointer<Byte> imageDescriptor = state->getImage(instruction.imageId).getUniformPointer();  // vk::SampledImageDescriptor*
 
 	Pointer<Byte> samplerFunction = lookupSamplerFunction(imageDescriptor, instruction, state);
 
@@ -1544,15 +1546,25 @@
 	return EmitResult::Continue;
 }
 
-SpirvShader::EmitResult SpirvShader::EmitSampledImageCombineOrSplit(InsnIterator insn, EmitState *state) const
+SpirvShader::EmitResult SpirvShader::EmitSampledImage(InsnIterator insn, EmitState *state) const
 {
-	// Propagate the image pointer in both cases.
-	// Consumers of OpSampledImage will look through to find the sampler pointer.
+	Object::ID resultId = insn.word(2);
+	Object::ID imageId = insn.word(3);
+	Object::ID samplerId = insn.word(4);
 
+	// Create a sampled image, containing both a sampler and an image
+	state->createSampledImage(resultId, { state->getPointer(imageId), samplerId });
+
+	return EmitResult::Continue;
+}
+
+SpirvShader::EmitResult SpirvShader::EmitImage(InsnIterator insn, EmitState *state) const
+{
 	Object::ID resultId = insn.word(2);
 	Object::ID imageId = insn.word(3);
 
-	state->createPointer(resultId, state->getPointer(imageId));
+	// Extract the image from a sampled image.
+	state->createPointer(resultId, state->getImage(imageId));
 
 	return EmitResult::Continue;
 }