Obtain all sampler parameters through SamplingRoutineCache::Key

This change ensures that the descriptor state identifiers used to
perform lookups in the sampling routine cache are all that is used to
obtain the state itself which is used for specializing sampling
routine generation.

The createSamplingRoutine lambda function is made to capture only the
'device' variable, instead of allowing access to all local variables.
The device is required to obtain the sampler state from the sampler
identifier.

Bug: b/152227757
Change-Id: Id7f5e18e09f078589a1a1edc12622ed40126cd32
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/53068
Commit-Queue: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 69335b8..e52ddf3 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -46,11 +46,13 @@
 
 namespace vk {
 
+class Device;
 class PipelineLayout;
 class ImageView;
 class Sampler;
 class RenderPass;
 struct SampledImageDescriptor;
+struct SamplerState;
 
 namespace dbg {
 class Context;
@@ -1291,13 +1293,13 @@
 	// Returns the pair <significand, exponent>
 	std::pair<SIMD::Float, SIMD::Int> Frexp(RValue<SIMD::Float> val) const;
 
-	static ImageSampler *getImageSampler(uint32_t instruction, vk::SampledImageDescriptor const *imageDescriptor, const vk::Sampler *sampler);
+	static ImageSampler *getImageSampler(uint32_t instruction, uint32_t imageViewId, uint32_t samplerId, const vk::Device *device);
 	static std::shared_ptr<rr::Routine> emitSamplerRoutine(ImageInstruction instruction, const Sampler &samplerState);
 
 	// TODO(b/129523279): Eliminate conversion and use vk::Sampler members directly.
-	static sw::FilterType convertFilterMode(const vk::Sampler *sampler, VkImageViewType imageViewType, SamplerMethod samplerMethod);
-	static sw::MipmapType convertMipmapMode(const vk::Sampler *sampler);
-	static sw::AddressingMode convertAddressingMode(int coordinateIndex, const vk::Sampler *sampler, VkImageViewType imageViewType);
+	static sw::FilterType convertFilterMode(const vk::SamplerState *samplerState, VkImageViewType imageViewType, SamplerMethod samplerMethod);
+	static sw::MipmapType convertMipmapMode(const vk::SamplerState *samplerState);
+	static sw::AddressingMode convertAddressingMode(int coordinateIndex, const vk::SamplerState *samplerState, VkImageViewType imageViewType);
 
 	// Returns 0 when invalid.
 	static VkShaderStageFlagBits executionModelToStage(spv::ExecutionModel model);
@@ -1366,7 +1368,7 @@
 	struct SamplerCache
 	{
 		Pointer<Byte> imageDescriptor = nullptr;
-		Pointer<Byte> sampler;
+		Int samplerId;
 		Pointer<Byte> function;
 	};
 
diff --git a/src/Pipeline/SpirvShaderImage.cpp b/src/Pipeline/SpirvShaderImage.cpp
index d0ea81a..3be1fa0 100644
--- a/src/Pipeline/SpirvShaderImage.cpp
+++ b/src/Pipeline/SpirvShaderImage.cpp
@@ -157,15 +157,15 @@
 
 	auto coordinate = Operand(this, state, coordinateId);
 
-	Pointer<Byte> sampler = samplerDescriptor + OFFSET(vk::SampledImageDescriptor, sampler);  // vk::Sampler*
-	Pointer<Byte> texture = imageDescriptor + OFFSET(vk::SampledImageDescriptor, texture);    // sw::Texture*
+	rr::Int samplerId = *Pointer<rr::Int>(samplerDescriptor + OFFSET(vk::SampledImageDescriptor, sampler) + OFFSET(vk::Sampler, id));  // vk::Sampler::id
+	Pointer<Byte> texture = imageDescriptor + OFFSET(vk::SampledImageDescriptor, texture);                                             // sw::Texture*
 
 	// Above we assumed that if the SampledImage operand is not the result of an OpSampledImage,
 	// it must be a combined image sampler loaded straight from the descriptor set. For OpImageFetch
 	// it's just an Image operand, so there's no sampler descriptor data.
 	if(getType(sampledImage).opcode() != spv::OpTypeSampledImage)
 	{
-		sampler = Pointer<Byte>(nullptr);
+		samplerId = Int(0);
 	}
 
 	uint32_t imageOperands = spv::ImageOperandsMaskNone;
@@ -325,13 +325,15 @@
 	auto cacheIt = state->routine->samplerCache.find(insn.resultId());
 	ASSERT(cacheIt != state->routine->samplerCache.end());
 	auto &cache = cacheIt->second;
-	auto cacheHit = cache.imageDescriptor == imageDescriptor && cache.sampler == sampler;
+	auto cacheHit = cache.imageDescriptor == imageDescriptor && cache.samplerId == samplerId;
 
 	If(!cacheHit)
 	{
-		cache.function = Call(getImageSampler, instruction.parameters, imageDescriptor, sampler);
+		rr::Int imageViewId = *Pointer<rr::Int>(imageDescriptor + OFFSET(vk::SampledImageDescriptor, imageViewId));
+		Pointer<Byte> device = *Pointer<Pointer<Byte>>(imageDescriptor + OFFSET(vk::SampledImageDescriptor, device));
+		cache.function = Call(getImageSampler, instruction.parameters, imageViewId, samplerId, device);
 		cache.imageDescriptor = imageDescriptor;
-		cache.sampler = sampler;
+		cache.samplerId = samplerId;
 	}
 
 	Call<ImageSampler>(cache.function, texture, &in[0], &out[0], state->routine->constants);
diff --git a/src/Pipeline/SpirvShaderSampling.cpp b/src/Pipeline/SpirvShaderSampling.cpp
index ecd2f74..c36e43c 100644
--- a/src/Pipeline/SpirvShaderSampling.cpp
+++ b/src/Pipeline/SpirvShaderSampling.cpp
@@ -30,19 +30,20 @@
 
 namespace sw {
 
-SpirvShader::ImageSampler *SpirvShader::getImageSampler(uint32_t inst, vk::SampledImageDescriptor const *imageDescriptor, const vk::Sampler *sampler)
+SpirvShader::ImageSampler *SpirvShader::getImageSampler(uint32_t inst, uint32_t imageViewId, uint32_t samplerId, const vk::Device *device)
 {
 	ImageInstruction instruction(inst);
-	const auto samplerId = sampler ? sampler->id : 0;
-	ASSERT(imageDescriptor->imageViewId != 0 && (samplerId != 0 || instruction.samplerMethod == Fetch));
-	ASSERT(imageDescriptor->device);
+	ASSERT(imageViewId != 0 && (samplerId != 0 || instruction.samplerMethod == Fetch));
+	ASSERT(device);
 
-	vk::Device::SamplingRoutineCache::Key key = { inst, imageDescriptor->imageViewId, samplerId };
+	vk::Device::SamplingRoutineCache::Key key = { inst, samplerId, imageViewId };
 
-	vk::Device::SamplingRoutineCache *cache = imageDescriptor->device->getSamplingRoutineCache();
+	vk::Device::SamplingRoutineCache *cache = device->getSamplingRoutineCache();
 
-	auto createSamplingRoutine = [&](const vk::Device::SamplingRoutineCache::Key &key) {
-		const vk::Identifier::State imageViewState = vk::Identifier(imageDescriptor->imageViewId).getState();
+	auto createSamplingRoutine = [&device](const vk::Device::SamplingRoutineCache::Key &key) {
+		ImageInstruction instruction(key.instruction);
+		const vk::Identifier::State imageViewState = vk::Identifier(key.imageView).getState();
+		const vk::SamplerState *vkSamplerState = (key.sampler != 0) ? device->findSampler(key.sampler) : nullptr;
 
 		auto type = imageViewState.imageViewType;
 		auto samplerMethod = static_cast<SamplerMethod>(instruction.samplerMethod);
@@ -51,34 +52,34 @@
 		samplerState.textureType = type;
 		samplerState.textureFormat = imageViewState.format;
 
-		samplerState.addressingModeU = convertAddressingMode(0, sampler, type);
-		samplerState.addressingModeV = convertAddressingMode(1, sampler, type);
-		samplerState.addressingModeW = convertAddressingMode(2, sampler, type);
+		samplerState.addressingModeU = convertAddressingMode(0, vkSamplerState, type);
+		samplerState.addressingModeV = convertAddressingMode(1, vkSamplerState, type);
+		samplerState.addressingModeW = convertAddressingMode(2, vkSamplerState, type);
 
-		samplerState.mipmapFilter = convertMipmapMode(sampler);
+		samplerState.mipmapFilter = convertMipmapMode(vkSamplerState);
 		samplerState.swizzle = imageViewState.mapping;
 		samplerState.gatherComponent = instruction.gatherComponent;
 
-		if(sampler)
+		if(vkSamplerState)
 		{
-			samplerState.textureFilter = convertFilterMode(sampler, type, samplerMethod);
-			samplerState.border = sampler->borderColor;
+			samplerState.textureFilter = convertFilterMode(vkSamplerState, type, samplerMethod);
+			samplerState.border = vkSamplerState->borderColor;
 
-			samplerState.mipmapFilter = convertMipmapMode(sampler);
-			samplerState.highPrecisionFiltering = (sampler->filteringPrecision == VK_SAMPLER_FILTERING_PRECISION_MODE_HIGH_GOOGLE);
+			samplerState.mipmapFilter = convertMipmapMode(vkSamplerState);
+			samplerState.highPrecisionFiltering = (vkSamplerState->filteringPrecision == VK_SAMPLER_FILTERING_PRECISION_MODE_HIGH_GOOGLE);
 
-			samplerState.compareEnable = (sampler->compareEnable != VK_FALSE);
-			samplerState.compareOp = sampler->compareOp;
-			samplerState.unnormalizedCoordinates = (sampler->unnormalizedCoordinates != VK_FALSE);
+			samplerState.compareEnable = (vkSamplerState->compareEnable != VK_FALSE);
+			samplerState.compareOp = vkSamplerState->compareOp;
+			samplerState.unnormalizedCoordinates = (vkSamplerState->unnormalizedCoordinates != VK_FALSE);
 
-			samplerState.ycbcrModel = sampler->ycbcrModel;
-			samplerState.studioSwing = sampler->studioSwing;
-			samplerState.swappedChroma = sampler->swappedChroma;
+			samplerState.ycbcrModel = vkSamplerState->ycbcrModel;
+			samplerState.studioSwing = vkSamplerState->studioSwing;
+			samplerState.swappedChroma = vkSamplerState->swappedChroma;
 
-			samplerState.mipLodBias = sampler->mipLodBias;
-			samplerState.maxAnisotropy = sampler->maxAnisotropy;
-			samplerState.minLod = sampler->minLod;
-			samplerState.maxLod = sampler->maxLod;
+			samplerState.mipLodBias = vkSamplerState->mipLodBias;
+			samplerState.maxAnisotropy = vkSamplerState->maxAnisotropy;
+			samplerState.minLod = vkSamplerState->minLod;
+			samplerState.maxLod = vkSamplerState->maxLod;
 		}
 		else
 		{
@@ -200,7 +201,7 @@
 	return function("sampler");
 }
 
-sw::FilterType SpirvShader::convertFilterMode(const vk::Sampler *sampler, VkImageViewType imageViewType, SamplerMethod samplerMethod)
+sw::FilterType SpirvShader::convertFilterMode(const vk::SamplerState *samplerState, VkImageViewType imageViewType, SamplerMethod samplerMethod)
 {
 	if(samplerMethod == Gather)
 	{
@@ -212,7 +213,7 @@
 		return FILTER_POINT;
 	}
 
-	if(sampler->anisotropyEnable != VK_FALSE)
+	if(samplerState->anisotropyEnable != VK_FALSE)
 	{
 		if(imageViewType == VK_IMAGE_VIEW_TYPE_2D || imageViewType == VK_IMAGE_VIEW_TYPE_2D_ARRAY)
 		{
@@ -223,25 +224,25 @@
 		}
 	}
 
-	switch(sampler->magFilter)
+	switch(samplerState->magFilter)
 	{
 	case VK_FILTER_NEAREST:
-		switch(sampler->minFilter)
+		switch(samplerState->minFilter)
 		{
 		case VK_FILTER_NEAREST: return FILTER_POINT;
 		case VK_FILTER_LINEAR: return FILTER_MIN_LINEAR_MAG_POINT;
 		default:
-			UNSUPPORTED("minFilter %d", sampler->minFilter);
+			UNSUPPORTED("minFilter %d", samplerState->minFilter);
 			return FILTER_POINT;
 		}
 		break;
 	case VK_FILTER_LINEAR:
-		switch(sampler->minFilter)
+		switch(samplerState->minFilter)
 		{
 		case VK_FILTER_NEAREST: return FILTER_MIN_POINT_MAG_LINEAR;
 		case VK_FILTER_LINEAR: return FILTER_LINEAR;
 		default:
-			UNSUPPORTED("minFilter %d", sampler->minFilter);
+			UNSUPPORTED("minFilter %d", samplerState->minFilter);
 			return FILTER_POINT;
 		}
 		break;
@@ -249,34 +250,34 @@
 		break;
 	}
 
-	UNSUPPORTED("magFilter %d", sampler->magFilter);
+	UNSUPPORTED("magFilter %d", samplerState->magFilter);
 	return FILTER_POINT;
 }
 
-sw::MipmapType SpirvShader::convertMipmapMode(const vk::Sampler *sampler)
+sw::MipmapType SpirvShader::convertMipmapMode(const vk::SamplerState *samplerState)
 {
-	if(!sampler)
+	if(!samplerState)
 	{
 		return MIPMAP_POINT;  // Samplerless operations (OpImageFetch) can take an integer Lod operand.
 	}
 
-	if(sampler->ycbcrModel != VK_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
+	if(samplerState->ycbcrModel != VK_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
 	{
 		// TODO(b/151263485): Check image view level count instead.
 		return MIPMAP_NONE;
 	}
 
-	switch(sampler->mipmapMode)
+	switch(samplerState->mipmapMode)
 	{
 	case VK_SAMPLER_MIPMAP_MODE_NEAREST: return MIPMAP_POINT;
 	case VK_SAMPLER_MIPMAP_MODE_LINEAR: return MIPMAP_LINEAR;
 	default:
-		UNSUPPORTED("mipmapMode %d", sampler->mipmapMode);
+		UNSUPPORTED("mipmapMode %d", samplerState->mipmapMode);
 		return MIPMAP_POINT;
 	}
 }
 
-sw::AddressingMode SpirvShader::convertAddressingMode(int coordinateIndex, const vk::Sampler *sampler, VkImageViewType imageViewType)
+sw::AddressingMode SpirvShader::convertAddressingMode(int coordinateIndex, const vk::SamplerState *samplerState, VkImageViewType imageViewType)
 {
 	switch(imageViewType)
 	{
@@ -321,7 +322,7 @@
 		return ADDRESSING_WRAP;
 	}
 
-	if(!sampler)
+	if(!samplerState)
 	{
 		// OpImageFetch does not take a sampler descriptor, but still needs a valid
 		// addressing mode that prevents out-of-bounds accesses:
@@ -339,9 +340,9 @@
 	VkSamplerAddressMode addressMode = VK_SAMPLER_ADDRESS_MODE_REPEAT;
 	switch(coordinateIndex)
 	{
-	case 0: addressMode = sampler->addressModeU; break;
-	case 1: addressMode = sampler->addressModeV; break;
-	case 2: addressMode = sampler->addressModeW; break;
+	case 0: addressMode = samplerState->addressModeU; break;
+	case 1: addressMode = samplerState->addressModeV; break;
+	case 2: addressMode = samplerState->addressModeW; break;
 	default: UNSUPPORTED("coordinateIndex: %d", coordinateIndex);
 	}
 
diff --git a/src/Vulkan/VkDevice.cpp b/src/Vulkan/VkDevice.cpp
index e420921..4a742cf 100644
--- a/src/Vulkan/VkDevice.cpp
+++ b/src/Vulkan/VkDevice.cpp
@@ -105,6 +105,16 @@
 	}
 }
 
+const SamplerState *Device::SamplerIndexer::find(uint32_t id)
+{
+	marl::lock lock(mutex);
+
+	auto it = std::find_if(std::begin(map), std::end(map),
+	                       [&id](auto &&p) { return p.second.id == id; });
+
+	return (it != std::end(map)) ? &(it->first) : nullptr;
+}
+
 Device::Device(const VkDeviceCreateInfo *pCreateInfo, void *mem, PhysicalDevice *physicalDevice, const VkPhysicalDeviceFeatures *enabledFeatures, const std::shared_ptr<marl::Scheduler> &scheduler)
     : physicalDevice(physicalDevice)
     , queues(reinterpret_cast<Queue *>(mem))
@@ -395,6 +405,11 @@
 	samplerIndexer->remove(samplerState);
 }
 
+const SamplerState *Device::findSampler(uint32_t samplerId) const
+{
+	return samplerIndexer->find(samplerId);
+}
+
 VkResult Device::setDebugUtilsObjectName(const VkDebugUtilsObjectNameInfoEXT *pNameInfo)
 {
 	// Optionally maps user-friendly name to an object
diff --git a/src/Vulkan/VkDevice.hpp b/src/Vulkan/VkDevice.hpp
index fdcce63..70330ef 100644
--- a/src/Vulkan/VkDevice.hpp
+++ b/src/Vulkan/VkDevice.hpp
@@ -141,6 +141,7 @@
 
 		uint32_t index(const SamplerState &samplerState);
 		void remove(const SamplerState &samplerState);
+		const SamplerState *find(uint32_t id);
 
 	private:
 		struct Identifier
@@ -157,6 +158,7 @@
 
 	uint32_t indexSampler(const SamplerState &samplerState);
 	void removeSampler(const SamplerState &samplerState);
+	const SamplerState *findSampler(uint32_t samplerId) const;
 
 	std::shared_ptr<vk::dbg::Context> getDebuggerContext() const
 	{