Strongly type object / type identifiers.

Prevents you from mixing them up by mistake, and provides better self-documentation on function signatures.

Bug: b/126126820
Change-Id: I21ce20ded434ca3d5d03ebf3f9027cf6f6b5386f
Reviewed-on: https://swiftshader-review.googlesource.com/c/25068
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
diff --git a/src/Pipeline/SpirvID.hpp b/src/Pipeline/SpirvID.hpp
new file mode 100644
index 0000000..079d518
--- /dev/null
+++ b/src/Pipeline/SpirvID.hpp
@@ -0,0 +1,61 @@
+// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef sw_ID_hpp
+#define sw_ID_hpp
+
+#include <unordered_map>
+#include <cstdint>
+
+namespace sw
+{
+	// SpirvID is a strongly-typed identifier backed by a uint32_t.
+	// The template parameter T is not actually used by the implementation of
+	// ID; instead it is used to prevent implicit casts between idenfitifers of
+	// different T types.
+	// IDs are typically used as a map key to value of type T.
+	template <typename T>
+	class SpirvID
+	{
+	public:
+		SpirvID() : id(0) {}
+		SpirvID(uint32_t id) : id(id) {}
+		bool operator == (const SpirvID<T>& rhs) const { return id == rhs.id; }
+		bool operator < (const SpirvID<T>& rhs) const { return id < rhs.id; }
+
+		// value returns the numerical value of the identifier.
+		uint32_t value() const { return id; }
+	private:
+		uint32_t id;
+	};
+
+	// HandleMap<T> is an unordered map of SpirvID<T> to T.
+	template <typename T>
+	using HandleMap = std::unordered_map<SpirvID<T>, T>;
+}
+
+namespace std
+{
+	// std::hash implementation for sw::SpirvID<T>
+	template<typename T>
+	struct hash< sw::SpirvID<T> >
+	{
+		std::size_t operator()(const sw::SpirvID<T>& id) const noexcept
+		{
+			return std::hash<uint32_t>()(id.value());
+		}
+	};
+}
+
+#endif  // sw_ID_hpp
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 4bae36d..77fd452 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -41,7 +41,7 @@
 
 			case spv::OpDecorate:
 			{
-				auto targetId = insn.word(1);
+				TypeOrObjectID targetId = insn.word(1);
 				auto decoration = static_cast<spv::Decoration>(insn.word(2));
 				decorations[targetId].Apply(
 						decoration,
@@ -54,7 +54,7 @@
 
 			case spv::OpMemberDecorate:
 			{
-				auto targetId = insn.word(1);
+				TypeID targetId = insn.word(1);
 				auto memberIndex = insn.word(2);
 				auto &d = memberDecorations[targetId];
 				if (memberIndex >= d.size())
@@ -116,7 +116,7 @@
 			case spv::OpTypePointer:
 			case spv::OpTypeFunction:
 			{
-				auto resultId = insn.word(1);
+				TypeID resultId = insn.word(1);
 				auto &type = types[resultId];
 				type.definition = insn;
 				type.sizeInComponents = ComputeTypeSize(insn);
@@ -140,7 +140,7 @@
 				}
 				else if (insn.opcode() == spv::OpTypePointer)
 				{
-					auto pointeeType = insn.word(3);
+					TypeID pointeeType = insn.word(3);
 					type.isBuiltInBlock = getType(pointeeType).isBuiltInBlock;
 				}
 				break;
@@ -148,8 +148,8 @@
 
 			case spv::OpVariable:
 			{
-				auto typeId = insn.word(1);
-				auto resultId = insn.word(2);
+				TypeID typeId = insn.word(1);
+				ObjectID resultId = insn.word(2);
 				auto storageClass = static_cast<spv::StorageClass>(insn.word(3));
 				if (insn.wordCount() > 4)
 					UNIMPLEMENTED("Variable initializers not yet supported");
@@ -244,8 +244,8 @@
 			case spv::OpAccessChain:
 				// Instructions that yield an ssavalue.
 			{
-				auto typeId = insn.word(1);
-				auto resultId = insn.word(2);
+				TypeID typeId = insn.word(1);
+				ObjectID resultId = insn.word(2);
 				auto &object = defs[resultId];
 				object.kind = Object::Kind::Value;
 				object.definition = insn;
@@ -279,8 +279,8 @@
 
 	SpirvShader::Object& SpirvShader::CreateConstant(InsnIterator insn)
 	{
-		auto typeId = insn.word(1);
-		auto resultId = insn.word(2);
+		TypeID typeId = insn.word(1);
+		ObjectID resultId = insn.word(2);
 		auto &object = defs[resultId];
 		object.kind = Object::Kind::Constant;
 		object.definition = insn;
@@ -296,13 +296,13 @@
 		auto &builtinInterface = (object.storageClass == spv::StorageClassInput) ? inputBuiltins : outputBuiltins;
 		auto &userDefinedInterface = (object.storageClass == spv::StorageClassInput) ? inputs : outputs;
 
-		auto resultId = object.definition.word(2);
+		ObjectID resultId = object.definition.word(2);
 		if (object.isBuiltInBlock)
 		{
 			// walk the builtin block, registering each of its members separately.
 			auto ptrType = getType(object.definition.word(1)).definition;
 			assert(ptrType.opcode() == spv::OpTypePointer);
-			auto pointeeType = ptrType.word(3);
+			TypeID pointeeType = ptrType.word(3);
 			auto m = memberDecorations.find(pointeeType);
 			assert(m != memberDecorations.end());        // otherwise we wouldn't have marked the type chain
 			auto &structType = getType(pointeeType).definition;
@@ -439,7 +439,7 @@
 	}
 
 	template<typename F>
-	int SpirvShader::VisitInterfaceInner(uint32_t id, Decorations d, F f) const
+	int SpirvShader::VisitInterfaceInner(TypeID id, Decorations d, F f) const
 	{
 		// Recursively walks variable definition and its type tree, taking into account
 		// any explicit Location or Component decorations encountered; where explicit
@@ -510,7 +510,7 @@
 	}
 
 	template<typename F>
-	void SpirvShader::VisitInterface(uint32_t id, F f) const
+	void SpirvShader::VisitInterface(ObjectID id, F f) const
 	{
 		// Walk a variable definition and call f for each component in it.
 		Decorations d{};
@@ -521,7 +521,7 @@
 		VisitInterfaceInner<F>(def.word(1), d, f);
 	}
 
-	Int4 SpirvShader::WalkAccessChain(uint32_t id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
+	Int4 SpirvShader::WalkAccessChain(ObjectID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
 	{
 		// TODO: think about explicit layout (UBO/SSBO) storage classes
 		// TODO: avoid doing per-lane work in some cases if we can?
@@ -529,7 +529,7 @@
 		int constantOffset = 0;
 		Int4 dynamicOffset = Int4(0);
 		auto & baseObject = getObject(id);
-		auto typeId = baseObject.definition.word(1);
+		TypeID typeId = baseObject.definition.word(1);
 
 		// The <base> operand is an intermediate value itself, ie produced by a previous OpAccessChain.
 		// Start with its offset and build from there.
@@ -646,14 +646,14 @@
 		BufferBlock |= src.BufferBlock;
 	}
 
-	void SpirvShader::ApplyDecorationsForId(Decorations *d, uint32_t id) const
+	void SpirvShader::ApplyDecorationsForId(Decorations *d, TypeOrObjectID id) const
 	{
 		auto it = decorations.find(id);
 		if (it != decorations.end())
 			d->Apply(it->second);
 	}
 
-	void SpirvShader::ApplyDecorationsForIdMember(Decorations *d, uint32_t id, uint32_t member) const
+	void SpirvShader::ApplyDecorationsForIdMember(Decorations *d, TypeID id, uint32_t member) const
 	{
 		auto it = memberDecorations.find(id);
 		if (it != memberDecorations.end() && member < it->second.size())
@@ -662,7 +662,7 @@
 		}
 	}
 
-	uint32_t SpirvShader::GetConstantInt(uint32_t id) const
+	uint32_t SpirvShader::GetConstantInt(ObjectID id) const
 	{
 		// Slightly hackish access to constants very early in translation.
 		// General consumption of constants by other instructions should
@@ -686,7 +686,7 @@
 			{
 			case spv::OpVariable:
 			{
-				auto resultId = insn.word(2);
+				ObjectID resultId = insn.word(2);
 				auto &object = getObject(resultId);
 				// TODO: what to do about zero-slot objects?
 				if (object.sizeInComponents > 0)
@@ -748,7 +748,7 @@
 
 			case spv::OpVariable:
 			{
-				auto resultId = insn.word(2);
+				ObjectID resultId = insn.word(2);
 				auto &object = getObject(resultId);
 				if (object.kind == Object::Kind::InterfaceVariable && object.storageClass == spv::StorageClassInput)
 				{
@@ -907,7 +907,7 @@
 			{
 			case spv::OpVariable:
 			{
-				auto resultId = insn.word(2);
+				ObjectID resultId = insn.word(2);
 				auto &object = getObject(resultId);
 				if (object.kind == Object::Kind::InterfaceVariable && object.storageClass == spv::StorageClassOutput)
 				{
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 1256a47..0a33861 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -18,6 +18,7 @@
 #include "System/Types.hpp"
 #include "Vulkan/VkDebug.hpp"
 #include "ShaderCore.hpp"
+#include "SpirvID.hpp"
 
 #include <string>
 #include <vector>
@@ -74,43 +75,7 @@
 		uint32_t size;
 	};
 
-	class SpirvRoutine
-	{
-	public:
-		using Value = Array<Float4>;
-		std::unordered_map<uint32_t, Value> lvalues;
-
-		std::unordered_map<uint32_t, Intermediate> intermediates;
-
-		Value inputs = Value{MAX_INTERFACE_COMPONENTS};
-		Value outputs = Value{MAX_INTERFACE_COMPONENTS};
-
-		void createLvalue(uint32_t id, uint32_t size)
-		{
-			lvalues.emplace(id, Value(size));
-		}
-
-		void createIntermediate(uint32_t id, uint32_t size)
-		{
-			intermediates.emplace(std::piecewise_construct,
-					std::forward_as_tuple(id),
-					std::forward_as_tuple(size));
-		}
-
-		Value& getValue(uint32_t id)
-		{
-			auto it = lvalues.find(id);
-			assert(it != lvalues.end());
-			return it->second;
-		}
-
-		Intermediate& getIntermediate(uint32_t id)
-		{
-			auto it = intermediates.find(id);
-			assert(it != intermediates.end());
-			return it->second;
-		}
-	};
+	class SpirvRoutine;
 
 	class SpirvShader
 	{
@@ -218,6 +183,25 @@
 			} kind = Kind::Unknown;
 		};
 
+		using TypeID = SpirvID<Type>;
+		using ObjectID = SpirvID<Object>;
+
+		struct TypeOrObject {}; // Dummy struct to represent a Type or Object.
+
+		// TypeOrObjectID is an identifier that represents a Type or an Object,
+		// and supports implicit casting to and from TypeID or ObjectID.
+		class TypeOrObjectID : public SpirvID<TypeOrObject>
+		{
+		public:
+			using Hash = std::hash<SpirvID<TypeOrObject>>;
+
+			inline TypeOrObjectID(uint32_t id) : SpirvID(id) {}
+			inline TypeOrObjectID(TypeID id) : SpirvID(id.value()) {}
+			inline TypeOrObjectID(ObjectID id) : SpirvID(id.value()) {}
+			inline operator TypeID() const { return TypeID(value()); }
+			inline operator ObjectID() const { return ObjectID(value()); }
+		};
+
 		int getSerialID() const
 		{
 			return serialID;
@@ -288,8 +272,8 @@
 			void Apply(spv::Decoration decoration, uint32_t arg);
 		};
 
-		std::unordered_map<uint32_t, Decorations> decorations;
-		std::unordered_map<uint32_t, std::vector<Decorations>> memberDecorations;
+		std::unordered_map<TypeOrObjectID, Decorations, TypeOrObjectID::Hash> decorations;
+		std::unordered_map<TypeID, std::vector<Decorations>> memberDecorations;
 
 		struct InterfaceComponent
 		{
@@ -306,7 +290,7 @@
 
 		struct BuiltinMapping
 		{
-			uint32_t Id;
+			ObjectID Id;
 			uint32_t FirstComponent;
 			uint32_t SizeInComponents;
 		};
@@ -322,14 +306,14 @@
 		std::unordered_map<spv::BuiltIn, BuiltinMapping, BuiltInHash> inputBuiltins;
 		std::unordered_map<spv::BuiltIn, BuiltinMapping, BuiltInHash> outputBuiltins;
 
-		Type const &getType(uint32_t id) const
+		Type const &getType(TypeID id) const
 		{
 			auto it = types.find(id);
 			assert(it != types.end());
 			return it->second;
 		}
 
-		Object const &getObject(uint32_t id) const
+		Object const &getObject(ObjectID id) const
 		{
 			auto it = defs.find(id);
 			assert(it != defs.end());
@@ -340,28 +324,67 @@
 		const int serialID;
 		static volatile int serialCounter;
 		Modes modes;
-		std::unordered_map<uint32_t, Type> types;
-		std::unordered_map<uint32_t, Object> defs;
+		HandleMap<Type> types;
+		HandleMap<Object> defs;
 
 		void ProcessExecutionMode(InsnIterator it);
 
 		uint32_t ComputeTypeSize(InsnIterator insn);
-		void ApplyDecorationsForId(Decorations *d, uint32_t id) const;
-		void ApplyDecorationsForIdMember(Decorations *d, uint32_t id, uint32_t member) const;
+		void ApplyDecorationsForId(Decorations *d, TypeOrObjectID id) const;
+		void ApplyDecorationsForIdMember(Decorations *d, TypeID id, uint32_t member) const;
 
 		template<typename F>
-		int VisitInterfaceInner(uint32_t id, Decorations d, F f) const;
+		int VisitInterfaceInner(TypeID id, Decorations d, F f) const;
 
 		template<typename F>
-		void VisitInterface(uint32_t id, F f) const;
+		void VisitInterface(ObjectID id, F f) const;
 
-		uint32_t GetConstantInt(uint32_t id) const;
+		uint32_t GetConstantInt(ObjectID id) const;
 		Object& CreateConstant(InsnIterator it);
 
 		void ProcessInterfaceVariable(Object &object);
 
-		Int4 WalkAccessChain(uint32_t id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const;
+		Int4 WalkAccessChain(ObjectID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const;
 	};
+
+	class SpirvRoutine
+	{
+	public:
+		using Value = Array<Float4>;
+		std::unordered_map<SpirvShader::ObjectID, Value> lvalues;
+
+		std::unordered_map<SpirvShader::ObjectID, Intermediate> intermediates;
+
+		Value inputs = Value{MAX_INTERFACE_COMPONENTS};
+		Value outputs = Value{MAX_INTERFACE_COMPONENTS};
+
+		void createLvalue(SpirvShader::ObjectID id, uint32_t size)
+		{
+			lvalues.emplace(id, Value(size));
+		}
+
+		void createIntermediate(SpirvShader::ObjectID id, uint32_t size)
+		{
+			intermediates.emplace(std::piecewise_construct,
+					std::forward_as_tuple(id),
+					std::forward_as_tuple(size));
+		}
+
+		Value& getValue(SpirvShader::ObjectID id)
+		{
+			auto it = lvalues.find(id);
+			assert(it != lvalues.end());
+			return it->second;
+		}
+
+		Intermediate& getIntermediate(SpirvShader::ObjectID id)
+		{
+			auto it = intermediates.find(id);
+			assert(it != intermediates.end());
+			return it->second;
+		}
+	};
+
 }
 
 #endif  // sw_SpirvShader_hpp