Unify template-based and classic descriptor writes

The template-based descriptor write mechanism had got left behind as
more descriptor types were added. Reimplement the classic mechanism in
terms of a template entry, and implement all the current logic there.

Bug: b/123244275
Test: dEQP-VK.binding_model.*
Test: dEQP-VK.spirv_assembly.*
Test: dEQP-VK.glsl.*
Change-Id: Ide5a6bf70978774170f79d42652c746c2b5b8abd
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/29608
Tested-by: Chris Forbes <chrisforbes@google.com>
Presubmit-Ready: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Vulkan/VkDescriptorSetLayout.cpp b/src/Vulkan/VkDescriptorSetLayout.cpp
index 3367b8d..0586a51 100644
--- a/src/Vulkan/VkDescriptorSetLayout.cpp
+++ b/src/Vulkan/VkDescriptorSetLayout.cpp
@@ -270,24 +270,24 @@
 	}
 }
 
-void DescriptorSetLayout::WriteDescriptorSet(const VkWriteDescriptorSet& writeDescriptorSet)
+void DescriptorSetLayout::WriteDescriptorSet(DescriptorSet *dstSet, VkDescriptorUpdateTemplateEntry const &entry, char const *src)
 {
-	DescriptorSet* dstSet = vk::Cast(writeDescriptorSet.dstSet);
 	DescriptorSetLayout* dstLayout = dstSet->layout;
 	ASSERT(dstLayout);
-	ASSERT(dstLayout->bindings[dstLayout->getBindingIndex(writeDescriptorSet.dstBinding)].descriptorType == writeDescriptorSet.descriptorType);
+	ASSERT(dstLayout->bindings[dstLayout->getBindingIndex(entry.dstBinding)].descriptorType == entry.descriptorType);
 
 	size_t typeSize = 0;
-	uint8_t* memToWrite = dstLayout->getOffsetPointer(dstSet, writeDescriptorSet.dstBinding, writeDescriptorSet.dstArrayElement, writeDescriptorSet.descriptorCount, &typeSize);
+	uint8_t* memToWrite = dstLayout->getOffsetPointer(dstSet, entry.dstBinding, entry.dstArrayElement, entry.descriptorCount, &typeSize);
 
-	if(writeDescriptorSet.descriptorType == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER)
+	if(entry.descriptorType == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER)
 	{
 		SampledImageDescriptor *imageSampler = reinterpret_cast<SampledImageDescriptor*>(memToWrite);
 
-		for(uint32_t i = 0; i < writeDescriptorSet.descriptorCount; i++)
+		for(uint32_t i = 0; i < entry.descriptorCount; i++)
 		{
-			vk::Sampler *sampler = vk::Cast(writeDescriptorSet.pImageInfo[i].sampler);
-			vk::ImageView *imageView = vk::Cast(writeDescriptorSet.pImageInfo[i].imageView);
+			auto update = reinterpret_cast<VkDescriptorImageInfo const *>(src + entry.offset + entry.stride * i);
+			vk::Sampler *sampler = vk::Cast(update->sampler);
+			vk::ImageView *imageView = vk::Cast(update->imageView);
 
 			imageSampler[i].sampler = sampler;
 			imageSampler[i].imageView = imageView;
@@ -439,27 +439,29 @@
 			}
 		}
 	}
-	else if (writeDescriptorSet.descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
+	else if (entry.descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
 	{
 		auto descriptor = reinterpret_cast<StorageImageDescriptor *>(memToWrite);
-		for(uint32_t i = 0; i < writeDescriptorSet.descriptorCount; i++)
+		for(uint32_t i = 0; i < entry.descriptorCount; i++)
 		{
-			auto imageView = vk::Cast(writeDescriptorSet.pImageInfo[i].imageView);
+			auto update = reinterpret_cast<VkDescriptorImageInfo const *>(src + entry.offset + entry.stride * i);
+			auto imageView = Cast(update->imageView);
 			descriptor[i].ptr = imageView->getOffsetPointer({0, 0, 0}, VK_IMAGE_ASPECT_COLOR_BIT);
 			descriptor[i].extent = imageView->getMipLevelExtent(0);
 			descriptor[i].rowPitchBytes = imageView->rowPitchBytes(VK_IMAGE_ASPECT_COLOR_BIT, 0);
 			descriptor[i].slicePitchBytes = imageView->getSubresourceRange().layerCount > 1
-					? imageView->layerPitchBytes(VK_IMAGE_ASPECT_COLOR_BIT)
-					: imageView->slicePitchBytes(VK_IMAGE_ASPECT_COLOR_BIT, 0);
+											? imageView->layerPitchBytes(VK_IMAGE_ASPECT_COLOR_BIT)
+											: imageView->slicePitchBytes(VK_IMAGE_ASPECT_COLOR_BIT, 0);
 			descriptor[i].arrayLayers = imageView->getSubresourceRange().layerCount;
 		}
 	}
-	else if (writeDescriptorSet.descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER)
+	else if (entry.descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER)
 	{
 		auto descriptor = reinterpret_cast<StorageImageDescriptor *>(memToWrite);
-		for (uint32_t i = 0; i < writeDescriptorSet.descriptorCount; i++)
+		for (uint32_t i = 0; i < entry.descriptorCount; i++)
 		{
-			auto bufferView = vk::Cast(writeDescriptorSet.pTexelBufferView[i]);
+			auto update = reinterpret_cast<VkBufferView const *>(src + entry.offset + entry.stride * i);
+			auto bufferView = Cast(*update);
 			descriptor[i].ptr = bufferView->getPointer();
 			descriptor[i].extent = {bufferView->getElementCount(), 1, 1};
 			descriptor[i].rowPitchBytes = 0;
@@ -475,11 +477,53 @@
 		// a binding has a descriptorCount of zero, it is skipped. This behavior
 		// applies recursively, with the update affecting consecutive bindings as
 		// needed to update all descriptorCount descriptors.
-		size_t writeSize = typeSize * writeDescriptorSet.descriptorCount;
-		memcpy(memToWrite, DescriptorSetLayout::GetInputData(writeDescriptorSet), writeSize);
+		for (auto i = 0u; i < entry.descriptorCount; i++)
+			memcpy(memToWrite + typeSize * i, src + entry.offset + entry.stride * i, typeSize);
 	}
 }
 
+void DescriptorSetLayout::WriteDescriptorSet(const VkWriteDescriptorSet& writeDescriptorSet)
+{
+	DescriptorSet* dstSet = vk::Cast(writeDescriptorSet.dstSet);
+	VkDescriptorUpdateTemplateEntry e;
+	e.descriptorType = writeDescriptorSet.descriptorType;
+	e.dstBinding = writeDescriptorSet.dstBinding;
+	e.dstArrayElement = writeDescriptorSet.dstArrayElement;
+	e.descriptorCount = writeDescriptorSet.descriptorCount;
+	e.offset = 0;
+	void const *ptr = nullptr;
+	switch (writeDescriptorSet.descriptorType)
+	{
+	case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
+	case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
+		ptr = writeDescriptorSet.pTexelBufferView;
+		e.stride = sizeof(*VkWriteDescriptorSet::pTexelBufferView);
+		break;
+
+	case VK_DESCRIPTOR_TYPE_SAMPLER:
+	case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+	case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
+	case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
+	case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
+		ptr = writeDescriptorSet.pImageInfo;
+		e.stride = sizeof(VkDescriptorImageInfo);
+		break;
+
+	case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
+	case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
+	case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
+	case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
+		ptr = writeDescriptorSet.pBufferInfo;
+		e.stride = sizeof(VkDescriptorBufferInfo);
+		break;
+
+	default:
+		UNIMPLEMENTED("descriptor type %u", writeDescriptorSet.descriptorType);
+	}
+
+	WriteDescriptorSet(dstSet, e, reinterpret_cast<char const *>(ptr));
+}
+
 void DescriptorSetLayout::CopyDescriptorSet(const VkCopyDescriptorSet& descriptorCopies)
 {
 	DescriptorSet* srcSet = vk::Cast(descriptorCopies.srcSet);
diff --git a/src/Vulkan/VkDescriptorSetLayout.hpp b/src/Vulkan/VkDescriptorSetLayout.hpp
index 05e074a..131cdd4 100644
--- a/src/Vulkan/VkDescriptorSetLayout.hpp
+++ b/src/Vulkan/VkDescriptorSetLayout.hpp
@@ -58,6 +58,8 @@
 	static void WriteDescriptorSet(const VkWriteDescriptorSet& descriptorWrites);
 	static void CopyDescriptorSet(const VkCopyDescriptorSet& descriptorCopies);
 
+	static void WriteDescriptorSet(DescriptorSet *dstSet, VkDescriptorUpdateTemplateEntry const &entry, char const *src);
+
 	void initialize(VkDescriptorSet descriptorSet);
 
 	// Returns the total size of the descriptor set in bytes.
diff --git a/src/Vulkan/VkDescriptorUpdateTemplate.cpp b/src/Vulkan/VkDescriptorUpdateTemplate.cpp
index 2a1400f..7ea84f8 100644
--- a/src/Vulkan/VkDescriptorUpdateTemplate.cpp
+++ b/src/Vulkan/VkDescriptorUpdateTemplate.cpp
@@ -37,38 +37,13 @@
 
 	void DescriptorUpdateTemplate::updateDescriptorSet(VkDescriptorSet vkDescriptorSet, const void* pData)
 	{
+
 		DescriptorSet* descriptorSet = vk::Cast(vkDescriptorSet);
 
 		for(uint32_t i = 0; i < descriptorUpdateEntryCount; i++)
 		{
-			auto const &entry = descriptorUpdateEntries[i];
-			auto binding = entry.dstBinding;
-			auto arrayElement = entry.dstArrayElement;
-			for (uint32_t descriptorIndex = 0; descriptorIndex < entry.descriptorCount; descriptorIndex++)
-			{
-				while (arrayElement == descriptorSetLayout->getBindingLayout(binding).descriptorCount)
-				{
-					// If descriptorCount is greater than the number of remaining
-					// array elements in the destination binding, those affect
-					// consecutive bindings in a manner similar to
-					// VkWriteDescriptorSet.
-					// If a binding has a descriptorCount of zero, it is skipped.
-					arrayElement = 0;
-					binding++;
-				}
-
-				uint8_t *memToRead = (uint8_t *)pData + entry.offset + descriptorIndex * entry.stride;
-				size_t typeSize = 0;
-				uint8_t* memToWrite = descriptorSetLayout->getOffsetPointer(
-					descriptorSet,
-					binding,
-					arrayElement,
-					1, // count
-					&typeSize);
-				memcpy(memToWrite, memToRead, typeSize);
-
-				arrayElement++;
-			}
+			DescriptorSetLayout::WriteDescriptorSet(descriptorSet, descriptorUpdateEntries[i],
+													reinterpret_cast<char const *>(pData));
 		}
 	}
 }
\ No newline at end of file