SpirvShader: Rework pointer types

Previously Objects of the Variable kind would hold an Object::ID to another pointerBase object, which was combined with a per-lane offset in SpirvRoutine::intermediates.
This is almost exactly the same as PhysicalPointer - except the base address was taken from SpirvRoutine::physicalPointers.

With descriptor indices, we need a dynamic base pointer (only known at emit time) with per lane offsets.

This change transforms the Kind::Variable and Kind::PhysicalPointer kinds into Kind::DivergentPointer and Kind::NonDivergentPointer. This reduces complexity in loads and stores, and better represents the various forms of pointer we care about.

Bug: b/126330097
Change-Id: I514af5962b9cad4109197893066eda6f996be107
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/28390
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 7a06ec2..4db352a 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -188,10 +188,9 @@
 					UNIMPLEMENTED("Variable initializers not yet supported");
 
 				auto &object = defs[resultId];
-				object.kind = Object::Kind::Variable;
+				object.kind = Object::Kind::NonDivergentPointer;
 				object.definition = insn;
 				object.type = typeId;
-				object.pointerBase = insn.word(2);	// base is itself
 
 				ASSERT(getType(typeId).storageClass == storageClass);
 
@@ -201,12 +200,10 @@
 				case spv::StorageClassOutput:
 					ProcessInterfaceVariable(object);
 					break;
+
 				case spv::StorageClassUniform:
 				case spv::StorageClassStorageBuffer:
 				case spv::StorageClassPushConstant:
-					object.kind = Object::Kind::PhysicalPointer;
-					break;
-
 				case spv::StorageClassPrivate:
 				case spv::StorageClassFunction:
 					break; // Correctly handled.
@@ -438,23 +435,16 @@
 			case spv::OpFwidthFine:
 			case spv::OpAtomicLoad:
 			case spv::OpPhi:
-				// Instructions that yield an intermediate value
+				// Instructions that yield an intermediate value or divergent
+				// pointer
 			{
 				Type::ID typeId = insn.word(1);
 				Object::ID resultId = insn.word(2);
 				auto &object = defs[resultId];
 				object.type = typeId;
-				object.kind = Object::Kind::Intermediate;
+				object.kind = (getType(typeId).opcode() == spv::OpTypePointer)
+					? Object::Kind::DivergentPointer : Object::Kind::Intermediate;
 				object.definition = insn;
-
-				if (insn.opcode() == spv::OpAccessChain || insn.opcode() == spv::OpInBoundsAccessChain)
-				{
-					// interior ptr has two parts:
-					// - logical base ptr, common across all lanes and known at compile time
-					// - per-lane offset
-					Object::ID baseId = insn.word(3);
-					object.pointerBase = getObject(baseId).pointerBase;
-				}
 				break;
 			}
 
@@ -816,23 +806,38 @@
 		VisitInterfaceInner<F>(def.word(1), d, f);
 	}
 
-	SIMD::Int SpirvShader::WalkExplicitLayoutAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
+	std::pair<Pointer<Byte>, SIMD::Int> SpirvShader::GetPointerToData(Object::ID id, int arrayIndex, SpirvRoutine *routine) const
+	{
+		auto &object = getObject(id);
+		switch (object.kind)
+		{
+			case Object::Kind::NonDivergentPointer:
+			case Object::Kind::InterfaceVariable:
+				return std::make_pair(routine->getPointer(id), SIMD::Int(0));
+
+			case Object::Kind::DivergentPointer:
+				return std::make_pair(routine->getPointer(id), routine->getIntermediate(id).Int(0));
+
+			default:
+				UNREACHABLE("Invalid pointer kind %d", int(object.kind));
+				return std::make_pair(Pointer<Byte>(), SIMD::Int(0));
+		}
+	}
+
+	std::pair<Pointer<Byte>, SIMD::Int> SpirvShader::WalkExplicitLayoutAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
 	{
 		// Produce a offset into external memory in sizeof(float) units
 
-		int constantOffset = 0;
-		SIMD::Int dynamicOffset = SIMD::Int(0);
 		auto &baseObject = getObject(id);
 		Type::ID typeId = getType(baseObject.type).element;
-		Decorations d{};
+		Decorations d = {};
 		ApplyDecorationsForId(&d, baseObject.type);
 
-		// The <base> operand is an intermediate value itself, ie produced by a previous OpAccessChain.
-		// Start with its offset and build from there.
-		if (baseObject.kind == Object::Kind::Intermediate)
-		{
-			dynamicOffset += routine->getIntermediate(id).Int(0);
-		}
+		SIMD::Int dynamicOffset;
+		Pointer<Byte> pointerBase;
+		std::tie(pointerBase, dynamicOffset) = GetPointerToData(id, 0, routine);
+
+		int constantOffset = 0;
 
 		for (auto i = 0u; i < numIndexes; i++)
 		{
@@ -890,7 +895,7 @@
 			}
 		}
 
-		return dynamicOffset + SIMD::Int(constantOffset);
+		return std::make_pair(pointerBase, dynamicOffset + SIMD::Int(constantOffset));
 	}
 
 	SIMD::Int SpirvShader::WalkAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
@@ -903,9 +908,9 @@
 		auto &baseObject = getObject(id);
 		Type::ID typeId = getType(baseObject.type).element;
 
-		// The <base> operand is an intermediate value itself, ie produced by a previous OpAccessChain.
+		// The <base> operand is a divergent pointer itself.
 		// Start with its offset and build from there.
-		if (baseObject.kind == Object::Kind::Intermediate)
+		if (baseObject.kind == Object::Kind::DivergentPointer)
 		{
 			dynamicOffset += routine->getIntermediate(id).Int(0);
 		}
@@ -1699,8 +1704,16 @@
 		Object::ID resultId = insn.word(2);
 		auto &object = getObject(resultId);
 		auto &objectTy = getType(object.type);
+
 		switch (objectTy.storageClass)
 		{
+		case spv::StorageClassOutput:
+		case spv::StorageClassPrivate:
+		case spv::StorageClassFunction:
+		{
+			routine->createPointer(resultId, &routine->getVariable(resultId)[0]);
+			break;
+		}
 		case spv::StorageClassInput:
 		{
 			if (object.kind == Object::Kind::InterfaceVariable)
@@ -1713,6 +1726,7 @@
 									dst[offset++] = routine->inputs[scalarSlot];
 								});
 			}
+			routine->createPointer(resultId, &routine->getVariable(resultId)[0]);
 			break;
 		}
 		case spv::StorageClassUniform:
@@ -1741,12 +1755,12 @@
 				offset += routine->descriptorDynamicOffsets[dynamicBindingIndex];
 			}
 
-			routine->physicalPointers[resultId] = data + offset;
+			routine->createPointer(resultId, data + offset);
 			break;
 		}
 		case spv::StorageClassPushConstant:
 		{
-			routine->physicalPointers[resultId] = routine->pushConstants;
+			routine->createPointer(resultId, routine->pushConstants);
 			break;
 		}
 		default:
@@ -1765,8 +1779,7 @@
 		auto &result = getObject(resultId);
 		auto &resultTy = getType(result.type);
 		auto &pointer = getObject(pointerId);
-		auto &pointerBase = getObject(pointer.pointerBase);
-		auto &pointerBaseTy = getType(pointerBase.type);
+		auto &pointerTy = getType(pointer.type);
 		std::memory_order memoryOrder = std::memory_order_relaxed;
 
 		if(atomic)
@@ -1780,32 +1793,23 @@
 		ASSERT(Type::ID(insn.word(1)) == result.type);
 		ASSERT(!atomic || getType(getType(pointer.type).element).opcode() == spv::OpTypeInt);  // Vulkan 1.1: "Atomic instructions must declare a scalar 32-bit integer type, for the value pointed to by Pointer."
 
-		if (pointerBaseTy.storageClass == spv::StorageClassImage)
+		if (pointerTy.storageClass == spv::StorageClassImage)
 		{
 			UNIMPLEMENTED("StorageClassImage load not yet implemented");
 		}
 
+		SIMD::Int offsets;
 		Pointer<Float> ptrBase;
-		if (pointerBase.kind == Object::Kind::PhysicalPointer)
-		{
-			ptrBase = routine->getPhysicalPointer(pointer.pointerBase);
-		}
-		else
-		{
-			ptrBase = &routine->getVariable(pointer.pointerBase)[0];
-		}
+		std::tie(ptrBase, offsets) = GetPointerToData(pointerId, 0, routine);
 
-		bool interleavedByLane = IsStorageInterleavedByLane(pointerBaseTy.storageClass);
+		bool interleavedByLane = IsStorageInterleavedByLane(pointerTy.storageClass);
 		auto anyInactiveLanes = AnyFalse(state->activeLaneMask());
 
 		auto load = std::unique_ptr<SIMD::Float[]>(new SIMD::Float[resultTy.sizeInComponents]);
 
-		If(pointer.kind == Object::Kind::Intermediate || anyInactiveLanes)
+		If(pointer.kind == Object::Kind::DivergentPointer || anyInactiveLanes)
 		{
 			// Divergent offsets or masked lanes.
-			auto offsets = pointer.kind == Object::Kind::Intermediate ?
-					As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
-					RValue<SIMD::Int>(SIMD::Int(0));
 			for (auto i = 0u; i < resultTy.sizeInComponents; i++)
 			{
 				// i wish i had a Float,Float,Float,Float constructor here..
@@ -1861,8 +1865,6 @@
 		auto &pointer = getObject(pointerId);
 		auto &pointerTy = getType(pointer.type);
 		auto &elementTy = getType(pointerTy.element);
-		auto &pointerBase = getObject(pointer.pointerBase);
-		auto &pointerBaseTy = getType(pointerBase.type);
 		std::memory_order memoryOrder = std::memory_order_relaxed;
 
 		if(atomic)
@@ -1874,34 +1876,26 @@
 
 		ASSERT(!atomic || elementTy.opcode() == spv::OpTypeInt);  // Vulkan 1.1: "Atomic instructions must declare a scalar 32-bit integer type, for the value pointed to by Pointer."
 
-		if (pointerBaseTy.storageClass == spv::StorageClassImage)
+		if (pointerTy.storageClass == spv::StorageClassImage)
 		{
 			UNIMPLEMENTED("StorageClassImage store not yet implemented");
 		}
 
+		SIMD::Int offsets;
 		Pointer<Float> ptrBase;
-		if (pointerBase.kind == Object::Kind::PhysicalPointer)
-		{
-			ptrBase = routine->getPhysicalPointer(pointer.pointerBase);
-		}
-		else
-		{
-			ptrBase = &routine->getVariable(pointer.pointerBase)[0];
-		}
+		std::tie(ptrBase, offsets) = GetPointerToData(pointerId, 0, routine);
 
-		bool interleavedByLane = IsStorageInterleavedByLane(pointerBaseTy.storageClass);
+		bool interleavedByLane = IsStorageInterleavedByLane(pointerTy.storageClass);
 		auto anyInactiveLanes = AnyFalse(state->activeLaneMask());
 
 		if (object.kind == Object::Kind::Constant)
 		{
 			// Constant source data.
 			auto src = reinterpret_cast<float *>(object.constantValue.get());
-			If(pointer.kind == Object::Kind::Intermediate || anyInactiveLanes)
+			If(pointer.kind == Object::Kind::DivergentPointer || anyInactiveLanes)
 			{
 				// Divergent offsets or masked lanes.
-				auto offsets = pointer.kind == Object::Kind::Intermediate ?
-						As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
-						RValue<SIMD::Int>(SIMD::Int(0));
+
 				for (auto i = 0u; i < elementTy.sizeInComponents; i++)
 				{
 					for (int j = 0; j < SIMD::Width; j++)
@@ -1930,12 +1924,9 @@
 		{
 			// Intermediate source data.
 			auto &src = routine->getIntermediate(objectId);
-			If(pointer.kind == Object::Kind::Intermediate || anyInactiveLanes)
+			If(pointer.kind == Object::Kind::DivergentPointer || anyInactiveLanes)
 			{
 				// Divergent offsets or masked lanes.
-				auto offsets = pointer.kind == Object::Kind::Intermediate ?
-						As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
-						RValue<SIMD::Int>(SIMD::Int(0));
 				for (auto i = 0u; i < elementTy.sizeInComponents; i++)
 				{
 					for (int j = 0; j < SIMD::Width; j++)
@@ -1986,19 +1977,21 @@
 		const uint32_t *indexes = insn.wordPointer(4);
 		auto &type = getType(typeId);
 		ASSERT(type.sizeInComponents == 1);
-		ASSERT(getObject(baseId).pointerBase == getObject(resultId).pointerBase);
-
-		auto &dst = routine->createIntermediate(resultId, type.sizeInComponents);
+		ASSERT(getObject(resultId).kind == Object::Kind::DivergentPointer);
 
 		if(type.storageClass == spv::StorageClassPushConstant ||
 		   type.storageClass == spv::StorageClassUniform ||
 		   type.storageClass == spv::StorageClassStorageBuffer)
 		{
-			dst.move(0, WalkExplicitLayoutAccessChain(baseId, numIndexes, indexes, routine));
+			auto baseAndOffset = WalkExplicitLayoutAccessChain(baseId, numIndexes, indexes, routine);
+			routine->createPointer(resultId, baseAndOffset.first);
+			routine->createIntermediate(resultId, type.sizeInComponents).move(0, baseAndOffset.second);
 		}
 		else
 		{
-			dst.move(0, WalkAccessChain(baseId, numIndexes, indexes, routine));
+			auto offset = WalkAccessChain(baseId, numIndexes, indexes, routine);
+			routine->createPointer(resultId, routine->getPointer(baseId));
+			routine->createIntermediate(resultId, type.sizeInComponents).move(0, offset);
 		}
 
 		return EmitResult::Continue;
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 0dda3d9..6b91160 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -233,17 +233,36 @@
 
 			InsnIterator definition;
 			Type::ID type;
-			ID pointerBase;
 			std::unique_ptr<uint32_t[]> constantValue = nullptr;
 
 			enum class Kind
 			{
-				Unknown,        /* for paranoia -- if we get left with an object in this state, the module was broken */
-				Variable,          // TODO: Document
-				InterfaceVariable, // TODO: Document
-				Constant,          // Values held by Object::constantValue
-				Intermediate,      // Values held by SpirvRoutine::intermediates
-				PhysicalPointer,   // Pointer held by SpirvRoutine::physicalPointers
+				// Invalid default kind.
+				// If we get left with an object in this state, the module was
+				// broken.
+				Unknown,
+
+				// TODO: Better document this kind.
+				// A shader interface variable pointer.
+				// Pointer with uniform address across all lanes.
+				// Pointer held by SpirvRoutine::pointers
+				InterfaceVariable,
+
+				// Constant value held by Object::constantValue.
+				Constant,
+
+				// Value held by SpirvRoutine::intermediates.
+				Intermediate,
+
+				// DivergentPointer formed from a base pointer and per-lane offset.
+				// Base pointer held by SpirvRoutine::pointers
+				// Per-lane offset held by SpirvRoutine::intermediates.
+				DivergentPointer,
+
+				// Pointer with uniform address across all lanes.
+				// Pointer held by SpirvRoutine::pointers
+				NonDivergentPointer,
+
 			} kind = Kind::Unknown;
 		};
 
@@ -539,7 +558,15 @@
 
 		void ProcessInterfaceVariable(Object &object);
 
-		SIMD::Int WalkExplicitLayoutAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const;
+		// Returns a base pointer and per-lane offset to the underlying data for
+		// the given pointer object. Handles objects of the following kinds:
+		//  • DivergentPointer
+		//  • InterfaceVariable
+		//  • NonDivergentPointer
+		// Calling GetPointerToData with objects of any other kind will assert.
+		std::pair<Pointer<Byte>, SIMD::Int> GetPointerToData(Object::ID id, int arrayIndex, SpirvRoutine *routine) const;
+
+		std::pair<Pointer<Byte>, SIMD::Int> WalkExplicitLayoutAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const;
 		SIMD::Int WalkAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const;
 		uint32_t WalkLiteralAccessChain(Type::ID id, uint32_t numIndexes, uint32_t const *indexes) const;
 
@@ -655,7 +682,7 @@
 
 		std::unordered_map<SpirvShader::Object::ID, Intermediate> intermediates;
 
-		std::unordered_map<SpirvShader::Object::ID, Pointer<Byte> > physicalPointers;
+		std::unordered_map<SpirvShader::Object::ID, Pointer<Byte> > pointers;
 
 		Variable inputs = Variable{MAX_INTERFACE_COMPONENTS};
 		Variable outputs = Variable{MAX_INTERFACE_COMPONENTS};
@@ -671,6 +698,25 @@
 			ASSERT_MSG(added, "Variable %d created twice", id.value());
 		}
 
+		template <typename T>
+		void createPointer(SpirvShader::Object::ID id, Pointer<T> ptrBase)
+		{
+			bool added = pointers.emplace(id, ptrBase).second;
+			ASSERT_MSG(added, "Pointer %d created twice", id.value());
+		}
+
+		template <typename T>
+		void createPointer(SpirvShader::Object::ID id, RValue<Pointer<T>> ptrBase)
+		{
+			createPointer(id, Pointer<T>(ptrBase));
+		}
+
+		template <typename T>
+		void createPointer(SpirvShader::Object::ID id, Reference<Pointer<T>> ptrBase)
+		{
+			createPointer(id, Pointer<T>(ptrBase));
+		}
+
 		Intermediate& createIntermediate(SpirvShader::Object::ID id, uint32_t size)
 		{
 			auto it = intermediates.emplace(std::piecewise_construct,
@@ -694,10 +740,10 @@
 			return it->second;
 		}
 
-		Pointer<Byte>& getPhysicalPointer(SpirvShader::Object::ID id)
+		Pointer<Byte>& getPointer(SpirvShader::Object::ID id)
 		{
-			auto it = physicalPointers.find(id);
-			ASSERT_MSG(it != physicalPointers.end(), "Unknown physical pointer %d", id.value());
+			auto it = pointers.find(id);
+			ASSERT_MSG(it != pointers.end(), "Unknown pointer %d", id.value());
 			return it->second;
 		}
 	};