Uniquely identify sampler state

To avoid re-generating sampling routines for sampler with identical
state, keep a map of sampler state to 32-bit integer identifiers.

Bug: b/151235334
Change-Id: I105151675afbf29bd29585e866b8cd976f66fb49
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/42468
Presubmit-Ready: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Vulkan/VkDevice.cpp b/src/Vulkan/VkDevice.cpp
index 7f9bcc0..33c9731 100644
--- a/src/Vulkan/VkDevice.cpp
+++ b/src/Vulkan/VkDevice.cpp
@@ -59,6 +59,44 @@
 	cache.updateConstCache();
 }
 
+Device::SamplerIndexer::~SamplerIndexer()
+{
+	ASSERT(map.empty());
+}
+
+uint32_t Device::SamplerIndexer::index(const SamplerState &samplerState)
+{
+	std::lock_guard<std::mutex> lock(mutex);
+
+	auto it = map.find(samplerState);
+
+	if(it != map.end())
+	{
+		it->second.count++;
+		return it->second.id;
+	}
+
+	nextID++;
+
+	map.emplace(samplerState, Identifier{ nextID, 1 });
+
+	return nextID;
+}
+
+void Device::SamplerIndexer::remove(const SamplerState &samplerState)
+{
+	std::lock_guard<std::mutex> lock(mutex);
+
+	auto it = map.find(samplerState);
+	ASSERT(it != map.end());
+
+	auto count = --it->second.count;
+	if(count == 0)
+	{
+		map.erase(it);
+	}
+}
+
 Device::Device(const VkDeviceCreateInfo *pCreateInfo, void *mem, PhysicalDevice *physicalDevice, const VkPhysicalDeviceFeatures *enabledFeatures, const std::shared_ptr<marl::Scheduler> &scheduler)
     : physicalDevice(physicalDevice)
     , queues(reinterpret_cast<Queue *>(mem))
@@ -99,6 +137,7 @@
 	// FIXME (b/119409619): use an allocator here so we can control all memory allocations
 	blitter.reset(new sw::Blitter());
 	samplingRoutineCache.reset(new SamplingRoutineCache());
+	samplerIndexer.reset(new SamplerIndexer());
 
 #ifdef ENABLE_VK_DEBUGGER
 	static auto port = getenv("VK_DEBUGGER_PORT");
@@ -279,4 +318,14 @@
 	return samplingRoutineCacheMutex;
 }
 
+uint32_t Device::indexSampler(const SamplerState &samplerState)
+{
+	return samplerIndexer->index(samplerState);
+}
+
+void Device::removeSampler(const SamplerState &samplerState)
+{
+	samplerIndexer->remove(samplerState);
+}
+
 }  // namespace vk
diff --git a/src/Vulkan/VkDevice.hpp b/src/Vulkan/VkDevice.hpp
index 0a19f50..c2ba927 100644
--- a/src/Vulkan/VkDevice.hpp
+++ b/src/Vulkan/VkDevice.hpp
@@ -16,8 +16,11 @@
 #define VK_DEVICE_HPP_
 
 #include "VkObject.hpp"
+#include "VkSampler.hpp"
 #include "Device/LRUCache.hpp"
 #include "Reactor/Routine.hpp"
+
+#include <map>
 #include <memory>
 #include <mutex>
 
@@ -98,6 +101,30 @@
 	rr::Routine *findInConstCache(const SamplingRoutineCache::Key &key) const;
 	void updateSamplingRoutineConstCache();
 
+	class SamplerIndexer
+	{
+	public:
+		~SamplerIndexer();
+
+		uint32_t index(const SamplerState &samplerState);
+		void remove(const SamplerState &samplerState);
+
+	private:
+		struct Identifier
+		{
+			uint32_t id;
+			uint32_t count;  // Number of samplers sharing this state identifier.
+		};
+
+		std::map<SamplerState, Identifier> map;  // guarded by mutex
+		std::mutex mutex;
+
+		uint32_t nextID = 0;
+	};
+
+	uint32_t indexSampler(const SamplerState &samplerState);
+	void removeSampler(const SamplerState &samplerState);
+
 	std::shared_ptr<vk::dbg::Context> getDebuggerContext() const
 	{
 #ifdef ENABLE_VK_DEBUGGER
@@ -118,7 +145,9 @@
 	typedef char ExtensionName[VK_MAX_EXTENSION_NAME_SIZE];
 	ExtensionName *extensions = nullptr;
 	const VkPhysicalDeviceFeatures enabledFeatures = {};
+
 	std::shared_ptr<marl::Scheduler> scheduler;
+	std::unique_ptr<SamplerIndexer> samplerIndexer;
 
 #ifdef ENABLE_VK_DEBUGGER
 	struct
diff --git a/src/Vulkan/VkSampler.cpp b/src/Vulkan/VkSampler.cpp
index e5adb19..11cb00e 100644
--- a/src/Vulkan/VkSampler.cpp
+++ b/src/Vulkan/VkSampler.cpp
@@ -16,8 +16,6 @@
 
 namespace vk {
 
-std::atomic<uint32_t> Sampler::nextID(1);
-
 SamplerState::SamplerState(const VkSamplerCreateInfo *pCreateInfo, const vk::SamplerYcbcrConversion *ycbcrConversion)
     : Memset(this, 0)
     , magFilter(pCreateInfo->magFilter)
@@ -44,8 +42,9 @@
 	}
 }
 
-Sampler::Sampler(const VkSamplerCreateInfo *pCreateInfo, void *mem, const vk::SamplerYcbcrConversion *ycbcrConversion)
-    : SamplerState(pCreateInfo, ycbcrConversion)
+Sampler::Sampler(const VkSamplerCreateInfo *pCreateInfo, void *mem, const SamplerState &samplerState, uint32_t samplerID)
+    : SamplerState(samplerState)
+    , id(samplerID)
 {
 }
 
diff --git a/src/Vulkan/VkSampler.hpp b/src/Vulkan/VkSampler.hpp
index 84d3373..d2f51ac 100644
--- a/src/Vulkan/VkSampler.hpp
+++ b/src/Vulkan/VkSampler.hpp
@@ -58,17 +58,14 @@
 class Sampler : public Object<Sampler, VkSampler>, public SamplerState
 {
 public:
-	Sampler(const VkSamplerCreateInfo *pCreateInfo, void *mem, const vk::SamplerYcbcrConversion *ycbcrConversion);
+	Sampler(const VkSamplerCreateInfo *pCreateInfo, void *mem, const SamplerState &samplerState, uint32_t samplerID);
 
 	static size_t ComputeRequiredAllocationSize(const VkSamplerCreateInfo *pCreateInfo)
 	{
 		return 0;
 	}
 
-	const uint32_t id = nextID++;
-
-private:
-	static std::atomic<uint32_t> nextID;
+	const uint32_t id = 0;
 };
 
 class SamplerYcbcrConversion : public Object<SamplerYcbcrConversion, VkSamplerYcbcrConversion>
diff --git a/src/Vulkan/libVulkan.cpp b/src/Vulkan/libVulkan.cpp
index c583748..3e3e648 100644
--- a/src/Vulkan/libVulkan.cpp
+++ b/src/Vulkan/libVulkan.cpp
@@ -1857,7 +1857,18 @@
 		extensionCreateInfo = extensionCreateInfo->pNext;
 	}
 
-	return vk::Sampler::Create(pAllocator, pCreateInfo, pSampler, ycbcrConversion);
+	vk::SamplerState samplerState(pCreateInfo, ycbcrConversion);
+	uint32_t samplerID = vk::Cast(device)->indexSampler(samplerState);
+
+	VkResult result = vk::Sampler::Create(pAllocator, pCreateInfo, pSampler, samplerState, samplerID);
+
+	if(*pSampler == VK_NULL_HANDLE)
+	{
+		ASSERT(result != VK_SUCCESS);
+		vk::Cast(device)->removeSampler(samplerState);
+	}
+
+	return result;
 }
 
 VKAPI_ATTR void VKAPI_CALL vkDestroySampler(VkDevice device, VkSampler sampler, const VkAllocationCallbacks *pAllocator)
@@ -1865,7 +1876,12 @@
 	TRACE("(VkDevice device = %p, VkSampler sampler = %p, const VkAllocationCallbacks* pAllocator = %p)",
 	      device, static_cast<void *>(sampler), pAllocator);
 
-	vk::destroy(sampler, pAllocator);
+	if(sampler != VK_NULL_HANDLE)
+	{
+		vk::Cast(device)->removeSampler(*vk::Cast(sampler));
+
+		vk::destroy(sampler, pAllocator);
+	}
 }
 
 VKAPI_ATTR VkResult VKAPI_CALL vkCreateDescriptorSetLayout(VkDevice device, const VkDescriptorSetLayoutCreateInfo *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkDescriptorSetLayout *pSetLayout)