Move constant specialization to SpirvShader

SPIRV-Tools' pass to do this is incomplete, and implementing the
remaining pieces as a SPIRV->SPIRV transform is more work than
implementing them here.

Related SPIRV-Tools issues: 2585 2586

Bug: b/127454276
Test: dEQP-VK.*quantize*
Change-Id: I892de80dff366fb3bc417315b5db15d850577c47
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31348
Tested-by: Chris Forbes <chrisforbes@google.com>
Presubmit-Ready: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 8162f2f..e57cfc7 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -604,12 +604,15 @@
 			}
 
 			case spv::OpConstant:
+			case spv::OpSpecConstant:
 				CreateConstant(insn).constantValue[0] = insn.word(3);
 				break;
 			case spv::OpConstantFalse:
+			case spv::OpSpecConstantFalse:
 				CreateConstant(insn).constantValue[0] = 0;		// represent boolean false as zero
 				break;
 			case spv::OpConstantTrue:
+			case spv::OpSpecConstantTrue:
 				CreateConstant(insn).constantValue[0] = ~0u;	// represent boolean true as all bits set
 				break;
 			case spv::OpConstantNull:
@@ -626,6 +629,7 @@
 				break;
 			}
 			case spv::OpConstantComposite:
+			case spv::OpSpecConstantComposite:
 			{
 				auto &object = CreateConstant(insn);
 				auto offset = 0u;
@@ -657,6 +661,9 @@
 				}
 				break;
 			}
+			case spv::OpSpecConstantOp:
+				EvalSpecConstantOp(insn);
+				break;
 
 			case spv::OpCapability:
 				break; // Various capabilities will be declared, but none affect our code generation at this point.
@@ -723,11 +730,6 @@
 
 			case spv::OpFunctionParameter:
 			case spv::OpFunctionCall:
-			case spv::OpSpecConstant:
-			case spv::OpSpecConstantComposite:
-			case spv::OpSpecConstantFalse:
-			case spv::OpSpecConstantOp:
-			case spv::OpSpecConstantTrue:
 				// These should have all been removed by preprocessing passes. If we see them here,
 				// our assumptions are wrong and we will probably generate wrong code.
 				UNREACHABLE("%s should have already been lowered.", OpcodeName(opcode).c_str());
@@ -1165,7 +1167,7 @@
 		case spv::OpTypeArray:
 		{
 			// Element count * element size. Array sizes come from constant ids.
-			auto arraySize = GetConstantInt(insn.word(3));
+			auto arraySize = GetConstScalarInt(insn.word(3));
 			return getType(insn.word(2)).sizeInComponents * arraySize;
 		}
 
@@ -1262,7 +1264,7 @@
 		}
 		case spv::OpTypeArray:
 		{
-			auto arraySize = GetConstantInt(obj.definition.word(3));
+			auto arraySize = GetConstScalarInt(obj.definition.word(3));
 			for (auto i = 0u; i < arraySize; i++)
 			{
 				d.Location = VisitInterfaceInner<F>(obj.definition.word(2), d, f);
@@ -1344,7 +1346,7 @@
 			break;
 		case spv::OpTypeArray:
 		{
-			auto arraySize = GetConstantInt(type.definition.word(3));
+			auto arraySize = GetConstScalarInt(type.definition.word(3));
 			for (auto i = 0u; i < arraySize; i++)
 			{
 				ASSERT(d.HasArrayStride);
@@ -1440,7 +1442,7 @@
 			{
 			case spv::OpTypeStruct:
 			{
-				int memberIndex = GetConstantInt(indexIds[i]);
+				int memberIndex = GetConstScalarInt(indexIds[i]);
 				ApplyDecorationsForIdMember(d, typeId, memberIndex);
 				typeId = type.definition.word(2u + memberIndex);
 				break;
@@ -1449,7 +1451,7 @@
 			case spv::OpTypeRuntimeArray:
 				if (dd->InputAttachmentIndex >= 0)
 				{
-					dd->InputAttachmentIndex += GetConstantInt(indexIds[i]);
+					dd->InputAttachmentIndex += GetConstScalarInt(indexIds[i]);
 				}
 				typeId = type.element;
 				break;
@@ -1482,7 +1484,7 @@
 			if (type == spv::OpTypeArray || type == spv::OpTypeRuntimeArray)
 			{
 				ASSERT(getObject(indexIds[0]).kind == Object::Kind::Constant);
-				arrayIndex = GetConstantInt(indexIds[0]);
+				arrayIndex = GetConstScalarInt(indexIds[0]);
 
 				numIndexes--;
 				indexIds++;
@@ -1503,7 +1505,7 @@
 			{
 			case spv::OpTypeStruct:
 			{
-				int memberIndex = GetConstantInt(indexIds[i]);
+				int memberIndex = GetConstScalarInt(indexIds[i]);
 				ApplyDecorationsForIdMember(&d, typeId, memberIndex);
 				ASSERT(d.HasOffset);
 				constantOffset += d.Offset;
@@ -1518,7 +1520,7 @@
 				auto & obj = getObject(indexIds[i]);
 				if (obj.kind == Object::Kind::Constant)
 				{
-					constantOffset += d.ArrayStride * GetConstantInt(indexIds[i]);
+					constantOffset += d.ArrayStride * GetConstScalarInt(indexIds[i]);
 				}
 				else
 				{
@@ -1536,7 +1538,7 @@
 				auto & obj = getObject(indexIds[i]);
 				if (obj.kind == Object::Kind::Constant)
 				{
-					constantOffset += columnStride * GetConstantInt(indexIds[i]);
+					constantOffset += columnStride * GetConstScalarInt(indexIds[i]);
 				}
 				else
 				{
@@ -1551,7 +1553,7 @@
 				auto & obj = getObject(indexIds[i]);
 				if (obj.kind == Object::Kind::Constant)
 				{
-					constantOffset += elemStride * GetConstantInt(indexIds[i]);
+					constantOffset += elemStride * GetConstScalarInt(indexIds[i]);
 				}
 				else
 				{
@@ -1587,7 +1589,7 @@
 			{
 			case spv::OpTypeStruct:
 			{
-				int memberIndex = GetConstantInt(indexIds[i]);
+				int memberIndex = GetConstScalarInt(indexIds[i]);
 				int offsetIntoStruct = 0;
 				for (auto j = 0; j < memberIndex; j++) {
 					auto memberType = type.definition.word(2u + j);
@@ -1618,7 +1620,7 @@
 					ASSERT(d.Binding >= 0);
 					auto setLayout = routine->pipelineLayout->getDescriptorSetLayout(d.DescriptorSet);
 					auto stride = setLayout->getBindingStride(d.Binding);
-					ptr.base += stride * GetConstantInt(indexIds[i]);
+					ptr.base += stride * GetConstScalarInt(indexIds[i]);
 				}
 				else
 				{
@@ -1626,7 +1628,7 @@
 					auto & obj = getObject(indexIds[i]);
 					if (obj.kind == Object::Kind::Constant)
 					{
-						ptr += stride * GetConstantInt(indexIds[i]);
+						ptr += stride * GetConstScalarInt(indexIds[i]);
 					}
 					else
 					{
@@ -1860,20 +1862,6 @@
 		object.definition = insn;
 	}
 
-	uint32_t SpirvShader::GetConstantInt(Object::ID id) const
-	{
-		// Slightly hackish access to constants very early in translation.
-		// General consumption of constants by other instructions should
-		// probably be just lowered to Reactor.
-
-		// TODO: not encountered yet since we only use this for array sizes etc,
-		// but is possible to construct integer constant 0 via OpConstantNull.
-		auto insn = getObject(id).definition;
-		ASSERT(insn.opcode() == spv::OpConstant);
-		ASSERT(getType(insn.word(1)).opcode() == spv::OpTypeInt);
-		return insn.word(3);
-	}
-
 	// emit-time
 
 	void SpirvShader::emitProlog(SpirvRoutine *routine) const
@@ -2213,6 +2201,11 @@
 		case spv::OpConstantTrue:
 		case spv::OpConstantFalse:
 		case spv::OpConstantComposite:
+		case spv::OpSpecConstant:
+		case spv::OpSpecConstantTrue:
+		case spv::OpSpecConstantFalse:
+		case spv::OpSpecConstantComposite:
+		case spv::OpSpecConstantOp:
 		case spv::OpUndef:
 		case spv::OpExtension:
 		case spv::OpCapability:
@@ -3203,6 +3196,7 @@
 			}
 			case spv::OpQuantizeToF16:
 			{
+				// Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
 				auto abs = Abs(src.Float(i));
 				auto sign = src.Int(i) & SIMD::Int(0x80000000);
 				auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
@@ -5726,6 +5720,297 @@
 		return scopeObj.constantValue[0];
 	}
 
+	void SpirvShader::EvalSpecConstantOp(InsnIterator insn)
+	{
+		auto opcode = static_cast<spv::Op>(insn.word(3));
+
+		switch (opcode)
+		{
+		case spv::OpIAdd:
+		case spv::OpISub:
+		case spv::OpIMul:
+		case spv::OpUDiv:
+		case spv::OpSDiv:
+		case spv::OpUMod:
+		case spv::OpSMod:
+		case spv::OpSRem:
+		case spv::OpShiftRightLogical:
+		case spv::OpShiftRightArithmetic:
+		case spv::OpShiftLeftLogical:
+		case spv::OpBitwiseOr:
+		case spv::OpLogicalOr:
+		case spv::OpBitwiseAnd:
+		case spv::OpLogicalAnd:
+		case spv::OpBitwiseXor:
+		case spv::OpLogicalEqual:
+		case spv::OpIEqual:
+		case spv::OpLogicalNotEqual:
+		case spv::OpINotEqual:
+		case spv::OpULessThan:
+		case spv::OpSLessThan:
+		case spv::OpUGreaterThan:
+		case spv::OpSGreaterThan:
+		case spv::OpULessThanEqual:
+		case spv::OpSLessThanEqual:
+		case spv::OpUGreaterThanEqual:
+		case spv::OpSGreaterThanEqual:
+			EvalSpecConstantBinaryOp(insn);
+			break;
+
+		case spv::OpSConvert:
+		case spv::OpFConvert:
+		case spv::OpUConvert:
+		case spv::OpSNegate:
+		case spv::OpNot:
+		case spv::OpLogicalNot:
+		case spv::OpQuantizeToF16:
+			EvalSpecConstantUnaryOp(insn);
+			break;
+
+		case spv::OpSelect:
+		{
+			auto &result = CreateConstant(insn);
+			auto const &cond = getObject(insn.word(4));
+			auto const &left = getObject(insn.word(5));
+			auto const &right = getObject(insn.word(6));
+
+			for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+			{
+				result.constantValue[i] = cond.constantValue[i] ? left.constantValue[i] : right.constantValue[i];
+			}
+			break;
+		}
+
+		case spv::OpCompositeExtract:
+		{
+			auto &result = CreateConstant(insn);
+			auto const &compositeObject = getObject(insn.word(4));
+			auto firstComponent = WalkLiteralAccessChain(compositeObject.type, insn.wordCount() - 5, insn.wordPointer(5));
+
+			for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+			{
+				result.constantValue[i] = compositeObject.constantValue[firstComponent + i];
+			}
+			break;
+		}
+
+		case spv::OpCompositeInsert:
+		{
+			auto &result = CreateConstant(insn);
+			auto const &newPart = getObject(insn.word(4));
+			auto const &oldObject = getObject(insn.word(5));
+			auto firstNewComponent = WalkLiteralAccessChain(result.type, insn.wordCount() - 6, insn.wordPointer(6));
+
+			// old components before
+			for (auto i = 0u; i < firstNewComponent; i++)
+			{
+				result.constantValue[i] = oldObject.constantValue[i];
+			}
+			// new part
+			for (auto i = 0u; i < getType(newPart.type).sizeInComponents; i++)
+			{
+				result.constantValue[firstNewComponent + i] = newPart.constantValue[i];
+			}
+			// old components after
+			for (auto i = firstNewComponent + getType(newPart.type).sizeInComponents; i < getType(result.type).sizeInComponents; i++)
+			{
+				result.constantValue[i] = oldObject.constantValue[i];
+			}
+			break;
+		}
+
+		case spv::OpVectorShuffle:
+		{
+			auto &result = CreateConstant(insn);
+			auto const &firstHalf = getObject(insn.word(4));
+			auto const &secondHalf = getObject(insn.word(5));
+
+			for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+			{
+				auto selector = insn.word(6 + i);
+				if (selector == static_cast<uint32_t>(-1))
+				{
+					// Undefined value, we'll use zero
+					result.constantValue[i] = 0;
+				}
+				else if (selector < getType(firstHalf.type).sizeInComponents)
+				{
+					result.constantValue[i] = firstHalf.constantValue[selector];
+				}
+				else
+				{
+					result.constantValue[i] = secondHalf.constantValue[selector - getType(firstHalf.type).sizeInComponents];
+				}
+			}
+			break;
+		}
+
+		default:
+			// Other spec constant ops are possible, but require capabilities that are
+			// not exposed in our Vulkan implementation (eg Kernel), so we should never
+			// get here for correct shaders.
+			UNSUPPORTED("EvalSpecConstantOp op: %s", OpcodeName(opcode).c_str());
+		}
+	}
+
+	void SpirvShader::EvalSpecConstantUnaryOp(InsnIterator insn)
+	{
+		auto &result = CreateConstant(insn);
+
+		auto opcode = static_cast<spv::Op>(insn.word(3));
+		auto const &lhs = getObject(insn.word(4));
+		auto size = getType(lhs.type).sizeInComponents;
+
+		for (auto i = 0u; i < size; i++)
+		{
+			auto &v = result.constantValue[i];
+			auto l = lhs.constantValue[i];
+
+			switch (opcode)
+			{
+			case spv::OpSConvert:
+			case spv::OpFConvert:
+			case spv::OpUConvert:
+				UNREACHABLE("Not possible until we have multiple bit widths");
+				break;
+
+			case spv::OpSNegate:
+				v = -l;
+				break;
+			case spv::OpNot:
+			case spv::OpLogicalNot:
+				v = ~l;
+				break;
+
+			case spv::OpQuantizeToF16:
+			{
+				// Can do this nicer with host code, but want to perfectly mirror the reactor code we emit.
+				auto abs = bit_cast<float>(l & 0x7FFFFFFF);
+				auto sign = l & 0x80000000;
+				auto isZero = abs < 0.000061035f ? ~0u : 0u;
+				auto isInf = abs > 65504.0f ? ~0u : 0u;
+				auto isNaN = (abs != abs) ? ~0u : 0u;
+				auto isInfOrNan = isInf | isNaN;
+				v = l & 0xFFFFE000;
+				v &= ~isZero | 0x80000000;
+				v = sign | (isInfOrNan & 0x7F800000) | (~isInfOrNan & v);
+				v |= isNaN & 0x400000;
+				break;
+			}
+			default:
+				UNREACHABLE("EvalSpecConstantUnaryOp op: %s", OpcodeName(opcode).c_str());
+			}
+		}
+	}
+
+	void SpirvShader::EvalSpecConstantBinaryOp(InsnIterator insn)
+	{
+		auto &result = CreateConstant(insn);
+
+		auto opcode = static_cast<spv::Op>(insn.word(3));
+		auto const &lhs = getObject(insn.word(4));
+		auto const &rhs = getObject(insn.word(5));
+		auto size = getType(lhs.type).sizeInComponents;
+
+		for (auto i = 0u; i < size; i++)
+		{
+			auto &v = result.constantValue[i];
+			auto l = lhs.constantValue[i];
+			auto r = rhs.constantValue[i];
+
+			switch (opcode)
+			{
+			case spv::OpIAdd:
+				v = l + r;
+				break;
+			case spv::OpISub:
+				v = l - r;
+				break;
+			case spv::OpIMul:
+				v = l * r;
+				break;
+			case spv::OpUDiv:
+				v = (r == 0) ? 0 : l / r;
+				break;
+			case spv::OpUMod:
+				v = (r == 0) ? 0 : l % r;
+				break;
+			case spv::OpSDiv:
+				if (r == 0) r = UINT32_MAX;
+				if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+				v = static_cast<int32_t>(l) / static_cast<int32_t>(r);
+				break;
+			case spv::OpSRem:
+				if (r == 0) r = UINT32_MAX;
+				if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+				v = static_cast<int32_t>(l) % static_cast<int32_t>(r);
+				break;
+			case spv::OpSMod:
+				if (r == 0) r = UINT32_MAX;
+				if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+				if (l * r < 0)
+					v = static_cast<int32_t>(l) % static_cast<int32_t>(r) + r;
+				else
+					v = static_cast<int32_t>(l) % static_cast<int32_t>(r);
+				break;
+			case spv::OpShiftRightLogical:
+				v = l >> r;
+				break;
+			case spv::OpShiftRightArithmetic:
+				v = static_cast<int32_t>(l) >> r;
+				break;
+			case spv::OpShiftLeftLogical:
+				v = l << r;
+				break;
+			case spv::OpBitwiseOr:
+			case spv::OpLogicalOr:
+				v = l | r;
+				break;
+			case spv::OpBitwiseAnd:
+			case spv::OpLogicalAnd:
+				v = l & r;
+				break;
+			case spv::OpBitwiseXor:
+				v = l ^ r;
+				break;
+			case spv::OpLogicalEqual:
+			case spv::OpIEqual:
+				v = (l == r) ? ~0u : 0u;
+				break;
+			case spv::OpLogicalNotEqual:
+			case spv::OpINotEqual:
+				v = (l != r) ? ~0u : 0u;
+				break;
+			case spv::OpULessThan:
+				v = l < r ? ~0u : 0u;
+				break;
+			case spv::OpSLessThan:
+				v = static_cast<int32_t>(l) < static_cast<int32_t>(r) ? ~0u : 0u;
+				break;
+			case spv::OpUGreaterThan:
+				v = l > r ? ~0u : 0u;
+				break;
+			case spv::OpSGreaterThan:
+				v = static_cast<int32_t>(l) > static_cast<int32_t>(r) ? ~0u : 0u;
+				break;
+			case spv::OpULessThanEqual:
+				v = l <= r ? ~0u : 0u;
+				break;
+			case spv::OpSLessThanEqual:
+				v = static_cast<int32_t>(l) <= static_cast<int32_t>(r) ? ~0u : 0u;
+				break;
+			case spv::OpUGreaterThanEqual:
+				v = l >= r ? ~0u : 0u;
+				break;
+			case spv::OpSGreaterThanEqual:
+				v = static_cast<int32_t>(l) >= static_cast<int32_t>(r) ? ~0u : 0u;
+				break;
+			default:
+				UNREACHABLE("EvalSpecConstantBinaryOp op: %s", OpcodeName(opcode).c_str());
+			}
+		}
+	}
+
 	void SpirvShader::emitEpilog(SpirvRoutine *routine) const
 	{
 		for (auto insn : *this)
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 6addd10..c55c6fe 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -814,7 +814,6 @@
 		template<typename F>
 		void VisitMemoryObjectInner(Type::ID id, Decorations d, uint32_t &index, uint32_t offset, F f) const;
 
-		uint32_t GetConstantInt(Object::ID id) const;
 		Object& CreateConstant(InsnIterator it);
 
 		void ProcessInterfaceVariable(Object &object);
@@ -959,6 +958,9 @@
 		void GetImageDimensions(SpirvRoutine const *routine, Type const &resultTy, Object::ID imageId, Object::ID lodId, Intermediate &dst) const;
 		SIMD::Pointer GetTexelAddress(SpirvRoutine const *routine, SIMD::Pointer base, GenericValue const & coordinate, Type const & imageType, Pointer<Byte> descriptor, int texelSize, Object::ID sampleId, bool useStencilAspect) const;
 		uint32_t GetConstScalarInt(Object::ID id) const;
+		void EvalSpecConstantOp(InsnIterator insn);
+		void EvalSpecConstantUnaryOp(InsnIterator insn);
+		void EvalSpecConstantBinaryOp(InsnIterator insn);
 
 		// LoadPhi loads the phi values from the alloca storage and places the
 		// load values into the intermediate with the phi's result id.
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index 2c341d3..597002d 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -194,9 +194,6 @@
 		}
 		opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
 	}
-	// Freeze specialization constants into normal constants, and propagate through
-	opt.RegisterPass(spvtools::CreateFreezeSpecConstantValuePass());
-	opt.RegisterPass(spvtools::CreateFoldSpecConstantOpAndCompositePass());
 
 	// Basic optimization passes to primarily address glslang's love of loads &
 	// stores. Significantly reduces time spent in LLVM passes and codegen.