Move constant specialization to SpirvShader
SPIRV-Tools' pass to do this is incomplete, and implementing the
remaining pieces as a SPIRV->SPIRV transform is more work than
implementing them here.
Related SPIRV-Tools issues: 2585 2586
Bug: b/127454276
Test: dEQP-VK.*quantize*
Change-Id: I892de80dff366fb3bc417315b5db15d850577c47
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31348
Tested-by: Chris Forbes <chrisforbes@google.com>
Presubmit-Ready: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 8162f2f..e57cfc7 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -604,12 +604,15 @@
}
case spv::OpConstant:
+ case spv::OpSpecConstant:
CreateConstant(insn).constantValue[0] = insn.word(3);
break;
case spv::OpConstantFalse:
+ case spv::OpSpecConstantFalse:
CreateConstant(insn).constantValue[0] = 0; // represent boolean false as zero
break;
case spv::OpConstantTrue:
+ case spv::OpSpecConstantTrue:
CreateConstant(insn).constantValue[0] = ~0u; // represent boolean true as all bits set
break;
case spv::OpConstantNull:
@@ -626,6 +629,7 @@
break;
}
case spv::OpConstantComposite:
+ case spv::OpSpecConstantComposite:
{
auto &object = CreateConstant(insn);
auto offset = 0u;
@@ -657,6 +661,9 @@
}
break;
}
+ case spv::OpSpecConstantOp:
+ EvalSpecConstantOp(insn);
+ break;
case spv::OpCapability:
break; // Various capabilities will be declared, but none affect our code generation at this point.
@@ -723,11 +730,6 @@
case spv::OpFunctionParameter:
case spv::OpFunctionCall:
- case spv::OpSpecConstant:
- case spv::OpSpecConstantComposite:
- case spv::OpSpecConstantFalse:
- case spv::OpSpecConstantOp:
- case spv::OpSpecConstantTrue:
// These should have all been removed by preprocessing passes. If we see them here,
// our assumptions are wrong and we will probably generate wrong code.
UNREACHABLE("%s should have already been lowered.", OpcodeName(opcode).c_str());
@@ -1165,7 +1167,7 @@
case spv::OpTypeArray:
{
// Element count * element size. Array sizes come from constant ids.
- auto arraySize = GetConstantInt(insn.word(3));
+ auto arraySize = GetConstScalarInt(insn.word(3));
return getType(insn.word(2)).sizeInComponents * arraySize;
}
@@ -1262,7 +1264,7 @@
}
case spv::OpTypeArray:
{
- auto arraySize = GetConstantInt(obj.definition.word(3));
+ auto arraySize = GetConstScalarInt(obj.definition.word(3));
for (auto i = 0u; i < arraySize; i++)
{
d.Location = VisitInterfaceInner<F>(obj.definition.word(2), d, f);
@@ -1344,7 +1346,7 @@
break;
case spv::OpTypeArray:
{
- auto arraySize = GetConstantInt(type.definition.word(3));
+ auto arraySize = GetConstScalarInt(type.definition.word(3));
for (auto i = 0u; i < arraySize; i++)
{
ASSERT(d.HasArrayStride);
@@ -1440,7 +1442,7 @@
{
case spv::OpTypeStruct:
{
- int memberIndex = GetConstantInt(indexIds[i]);
+ int memberIndex = GetConstScalarInt(indexIds[i]);
ApplyDecorationsForIdMember(d, typeId, memberIndex);
typeId = type.definition.word(2u + memberIndex);
break;
@@ -1449,7 +1451,7 @@
case spv::OpTypeRuntimeArray:
if (dd->InputAttachmentIndex >= 0)
{
- dd->InputAttachmentIndex += GetConstantInt(indexIds[i]);
+ dd->InputAttachmentIndex += GetConstScalarInt(indexIds[i]);
}
typeId = type.element;
break;
@@ -1482,7 +1484,7 @@
if (type == spv::OpTypeArray || type == spv::OpTypeRuntimeArray)
{
ASSERT(getObject(indexIds[0]).kind == Object::Kind::Constant);
- arrayIndex = GetConstantInt(indexIds[0]);
+ arrayIndex = GetConstScalarInt(indexIds[0]);
numIndexes--;
indexIds++;
@@ -1503,7 +1505,7 @@
{
case spv::OpTypeStruct:
{
- int memberIndex = GetConstantInt(indexIds[i]);
+ int memberIndex = GetConstScalarInt(indexIds[i]);
ApplyDecorationsForIdMember(&d, typeId, memberIndex);
ASSERT(d.HasOffset);
constantOffset += d.Offset;
@@ -1518,7 +1520,7 @@
auto & obj = getObject(indexIds[i]);
if (obj.kind == Object::Kind::Constant)
{
- constantOffset += d.ArrayStride * GetConstantInt(indexIds[i]);
+ constantOffset += d.ArrayStride * GetConstScalarInt(indexIds[i]);
}
else
{
@@ -1536,7 +1538,7 @@
auto & obj = getObject(indexIds[i]);
if (obj.kind == Object::Kind::Constant)
{
- constantOffset += columnStride * GetConstantInt(indexIds[i]);
+ constantOffset += columnStride * GetConstScalarInt(indexIds[i]);
}
else
{
@@ -1551,7 +1553,7 @@
auto & obj = getObject(indexIds[i]);
if (obj.kind == Object::Kind::Constant)
{
- constantOffset += elemStride * GetConstantInt(indexIds[i]);
+ constantOffset += elemStride * GetConstScalarInt(indexIds[i]);
}
else
{
@@ -1587,7 +1589,7 @@
{
case spv::OpTypeStruct:
{
- int memberIndex = GetConstantInt(indexIds[i]);
+ int memberIndex = GetConstScalarInt(indexIds[i]);
int offsetIntoStruct = 0;
for (auto j = 0; j < memberIndex; j++) {
auto memberType = type.definition.word(2u + j);
@@ -1618,7 +1620,7 @@
ASSERT(d.Binding >= 0);
auto setLayout = routine->pipelineLayout->getDescriptorSetLayout(d.DescriptorSet);
auto stride = setLayout->getBindingStride(d.Binding);
- ptr.base += stride * GetConstantInt(indexIds[i]);
+ ptr.base += stride * GetConstScalarInt(indexIds[i]);
}
else
{
@@ -1626,7 +1628,7 @@
auto & obj = getObject(indexIds[i]);
if (obj.kind == Object::Kind::Constant)
{
- ptr += stride * GetConstantInt(indexIds[i]);
+ ptr += stride * GetConstScalarInt(indexIds[i]);
}
else
{
@@ -1860,20 +1862,6 @@
object.definition = insn;
}
- uint32_t SpirvShader::GetConstantInt(Object::ID id) const
- {
- // Slightly hackish access to constants very early in translation.
- // General consumption of constants by other instructions should
- // probably be just lowered to Reactor.
-
- // TODO: not encountered yet since we only use this for array sizes etc,
- // but is possible to construct integer constant 0 via OpConstantNull.
- auto insn = getObject(id).definition;
- ASSERT(insn.opcode() == spv::OpConstant);
- ASSERT(getType(insn.word(1)).opcode() == spv::OpTypeInt);
- return insn.word(3);
- }
-
// emit-time
void SpirvShader::emitProlog(SpirvRoutine *routine) const
@@ -2213,6 +2201,11 @@
case spv::OpConstantTrue:
case spv::OpConstantFalse:
case spv::OpConstantComposite:
+ case spv::OpSpecConstant:
+ case spv::OpSpecConstantTrue:
+ case spv::OpSpecConstantFalse:
+ case spv::OpSpecConstantComposite:
+ case spv::OpSpecConstantOp:
case spv::OpUndef:
case spv::OpExtension:
case spv::OpCapability:
@@ -3203,6 +3196,7 @@
}
case spv::OpQuantizeToF16:
{
+ // Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
auto abs = Abs(src.Float(i));
auto sign = src.Int(i) & SIMD::Int(0x80000000);
auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
@@ -5726,6 +5720,297 @@
return scopeObj.constantValue[0];
}
+ void SpirvShader::EvalSpecConstantOp(InsnIterator insn)
+ {
+ auto opcode = static_cast<spv::Op>(insn.word(3));
+
+ switch (opcode)
+ {
+ case spv::OpIAdd:
+ case spv::OpISub:
+ case spv::OpIMul:
+ case spv::OpUDiv:
+ case spv::OpSDiv:
+ case spv::OpUMod:
+ case spv::OpSMod:
+ case spv::OpSRem:
+ case spv::OpShiftRightLogical:
+ case spv::OpShiftRightArithmetic:
+ case spv::OpShiftLeftLogical:
+ case spv::OpBitwiseOr:
+ case spv::OpLogicalOr:
+ case spv::OpBitwiseAnd:
+ case spv::OpLogicalAnd:
+ case spv::OpBitwiseXor:
+ case spv::OpLogicalEqual:
+ case spv::OpIEqual:
+ case spv::OpLogicalNotEqual:
+ case spv::OpINotEqual:
+ case spv::OpULessThan:
+ case spv::OpSLessThan:
+ case spv::OpUGreaterThan:
+ case spv::OpSGreaterThan:
+ case spv::OpULessThanEqual:
+ case spv::OpSLessThanEqual:
+ case spv::OpUGreaterThanEqual:
+ case spv::OpSGreaterThanEqual:
+ EvalSpecConstantBinaryOp(insn);
+ break;
+
+ case spv::OpSConvert:
+ case spv::OpFConvert:
+ case spv::OpUConvert:
+ case spv::OpSNegate:
+ case spv::OpNot:
+ case spv::OpLogicalNot:
+ case spv::OpQuantizeToF16:
+ EvalSpecConstantUnaryOp(insn);
+ break;
+
+ case spv::OpSelect:
+ {
+ auto &result = CreateConstant(insn);
+ auto const &cond = getObject(insn.word(4));
+ auto const &left = getObject(insn.word(5));
+ auto const &right = getObject(insn.word(6));
+
+ for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+ {
+ result.constantValue[i] = cond.constantValue[i] ? left.constantValue[i] : right.constantValue[i];
+ }
+ break;
+ }
+
+ case spv::OpCompositeExtract:
+ {
+ auto &result = CreateConstant(insn);
+ auto const &compositeObject = getObject(insn.word(4));
+ auto firstComponent = WalkLiteralAccessChain(compositeObject.type, insn.wordCount() - 5, insn.wordPointer(5));
+
+ for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+ {
+ result.constantValue[i] = compositeObject.constantValue[firstComponent + i];
+ }
+ break;
+ }
+
+ case spv::OpCompositeInsert:
+ {
+ auto &result = CreateConstant(insn);
+ auto const &newPart = getObject(insn.word(4));
+ auto const &oldObject = getObject(insn.word(5));
+ auto firstNewComponent = WalkLiteralAccessChain(result.type, insn.wordCount() - 6, insn.wordPointer(6));
+
+ // old components before
+ for (auto i = 0u; i < firstNewComponent; i++)
+ {
+ result.constantValue[i] = oldObject.constantValue[i];
+ }
+ // new part
+ for (auto i = 0u; i < getType(newPart.type).sizeInComponents; i++)
+ {
+ result.constantValue[firstNewComponent + i] = newPart.constantValue[i];
+ }
+ // old components after
+ for (auto i = firstNewComponent + getType(newPart.type).sizeInComponents; i < getType(result.type).sizeInComponents; i++)
+ {
+ result.constantValue[i] = oldObject.constantValue[i];
+ }
+ break;
+ }
+
+ case spv::OpVectorShuffle:
+ {
+ auto &result = CreateConstant(insn);
+ auto const &firstHalf = getObject(insn.word(4));
+ auto const &secondHalf = getObject(insn.word(5));
+
+ for (auto i = 0u; i < getType(result.type).sizeInComponents; i++)
+ {
+ auto selector = insn.word(6 + i);
+ if (selector == static_cast<uint32_t>(-1))
+ {
+ // Undefined value, we'll use zero
+ result.constantValue[i] = 0;
+ }
+ else if (selector < getType(firstHalf.type).sizeInComponents)
+ {
+ result.constantValue[i] = firstHalf.constantValue[selector];
+ }
+ else
+ {
+ result.constantValue[i] = secondHalf.constantValue[selector - getType(firstHalf.type).sizeInComponents];
+ }
+ }
+ break;
+ }
+
+ default:
+ // Other spec constant ops are possible, but require capabilities that are
+ // not exposed in our Vulkan implementation (eg Kernel), so we should never
+ // get here for correct shaders.
+ UNSUPPORTED("EvalSpecConstantOp op: %s", OpcodeName(opcode).c_str());
+ }
+ }
+
+ void SpirvShader::EvalSpecConstantUnaryOp(InsnIterator insn)
+ {
+ auto &result = CreateConstant(insn);
+
+ auto opcode = static_cast<spv::Op>(insn.word(3));
+ auto const &lhs = getObject(insn.word(4));
+ auto size = getType(lhs.type).sizeInComponents;
+
+ for (auto i = 0u; i < size; i++)
+ {
+ auto &v = result.constantValue[i];
+ auto l = lhs.constantValue[i];
+
+ switch (opcode)
+ {
+ case spv::OpSConvert:
+ case spv::OpFConvert:
+ case spv::OpUConvert:
+ UNREACHABLE("Not possible until we have multiple bit widths");
+ break;
+
+ case spv::OpSNegate:
+ v = -l;
+ break;
+ case spv::OpNot:
+ case spv::OpLogicalNot:
+ v = ~l;
+ break;
+
+ case spv::OpQuantizeToF16:
+ {
+ // Can do this nicer with host code, but want to perfectly mirror the reactor code we emit.
+ auto abs = bit_cast<float>(l & 0x7FFFFFFF);
+ auto sign = l & 0x80000000;
+ auto isZero = abs < 0.000061035f ? ~0u : 0u;
+ auto isInf = abs > 65504.0f ? ~0u : 0u;
+ auto isNaN = (abs != abs) ? ~0u : 0u;
+ auto isInfOrNan = isInf | isNaN;
+ v = l & 0xFFFFE000;
+ v &= ~isZero | 0x80000000;
+ v = sign | (isInfOrNan & 0x7F800000) | (~isInfOrNan & v);
+ v |= isNaN & 0x400000;
+ break;
+ }
+ default:
+ UNREACHABLE("EvalSpecConstantUnaryOp op: %s", OpcodeName(opcode).c_str());
+ }
+ }
+ }
+
+ void SpirvShader::EvalSpecConstantBinaryOp(InsnIterator insn)
+ {
+ auto &result = CreateConstant(insn);
+
+ auto opcode = static_cast<spv::Op>(insn.word(3));
+ auto const &lhs = getObject(insn.word(4));
+ auto const &rhs = getObject(insn.word(5));
+ auto size = getType(lhs.type).sizeInComponents;
+
+ for (auto i = 0u; i < size; i++)
+ {
+ auto &v = result.constantValue[i];
+ auto l = lhs.constantValue[i];
+ auto r = rhs.constantValue[i];
+
+ switch (opcode)
+ {
+ case spv::OpIAdd:
+ v = l + r;
+ break;
+ case spv::OpISub:
+ v = l - r;
+ break;
+ case spv::OpIMul:
+ v = l * r;
+ break;
+ case spv::OpUDiv:
+ v = (r == 0) ? 0 : l / r;
+ break;
+ case spv::OpUMod:
+ v = (r == 0) ? 0 : l % r;
+ break;
+ case spv::OpSDiv:
+ if (r == 0) r = UINT32_MAX;
+ if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+ v = static_cast<int32_t>(l) / static_cast<int32_t>(r);
+ break;
+ case spv::OpSRem:
+ if (r == 0) r = UINT32_MAX;
+ if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+ v = static_cast<int32_t>(l) % static_cast<int32_t>(r);
+ break;
+ case spv::OpSMod:
+ if (r == 0) r = UINT32_MAX;
+ if (l == static_cast<uint32_t>(INT32_MIN)) l = UINT32_MAX;
+ if (l * r < 0)
+ v = static_cast<int32_t>(l) % static_cast<int32_t>(r) + r;
+ else
+ v = static_cast<int32_t>(l) % static_cast<int32_t>(r);
+ break;
+ case spv::OpShiftRightLogical:
+ v = l >> r;
+ break;
+ case spv::OpShiftRightArithmetic:
+ v = static_cast<int32_t>(l) >> r;
+ break;
+ case spv::OpShiftLeftLogical:
+ v = l << r;
+ break;
+ case spv::OpBitwiseOr:
+ case spv::OpLogicalOr:
+ v = l | r;
+ break;
+ case spv::OpBitwiseAnd:
+ case spv::OpLogicalAnd:
+ v = l & r;
+ break;
+ case spv::OpBitwiseXor:
+ v = l ^ r;
+ break;
+ case spv::OpLogicalEqual:
+ case spv::OpIEqual:
+ v = (l == r) ? ~0u : 0u;
+ break;
+ case spv::OpLogicalNotEqual:
+ case spv::OpINotEqual:
+ v = (l != r) ? ~0u : 0u;
+ break;
+ case spv::OpULessThan:
+ v = l < r ? ~0u : 0u;
+ break;
+ case spv::OpSLessThan:
+ v = static_cast<int32_t>(l) < static_cast<int32_t>(r) ? ~0u : 0u;
+ break;
+ case spv::OpUGreaterThan:
+ v = l > r ? ~0u : 0u;
+ break;
+ case spv::OpSGreaterThan:
+ v = static_cast<int32_t>(l) > static_cast<int32_t>(r) ? ~0u : 0u;
+ break;
+ case spv::OpULessThanEqual:
+ v = l <= r ? ~0u : 0u;
+ break;
+ case spv::OpSLessThanEqual:
+ v = static_cast<int32_t>(l) <= static_cast<int32_t>(r) ? ~0u : 0u;
+ break;
+ case spv::OpUGreaterThanEqual:
+ v = l >= r ? ~0u : 0u;
+ break;
+ case spv::OpSGreaterThanEqual:
+ v = static_cast<int32_t>(l) >= static_cast<int32_t>(r) ? ~0u : 0u;
+ break;
+ default:
+ UNREACHABLE("EvalSpecConstantBinaryOp op: %s", OpcodeName(opcode).c_str());
+ }
+ }
+ }
+
void SpirvShader::emitEpilog(SpirvRoutine *routine) const
{
for (auto insn : *this)
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 6addd10..c55c6fe 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -814,7 +814,6 @@
template<typename F>
void VisitMemoryObjectInner(Type::ID id, Decorations d, uint32_t &index, uint32_t offset, F f) const;
- uint32_t GetConstantInt(Object::ID id) const;
Object& CreateConstant(InsnIterator it);
void ProcessInterfaceVariable(Object &object);
@@ -959,6 +958,9 @@
void GetImageDimensions(SpirvRoutine const *routine, Type const &resultTy, Object::ID imageId, Object::ID lodId, Intermediate &dst) const;
SIMD::Pointer GetTexelAddress(SpirvRoutine const *routine, SIMD::Pointer base, GenericValue const & coordinate, Type const & imageType, Pointer<Byte> descriptor, int texelSize, Object::ID sampleId, bool useStencilAspect) const;
uint32_t GetConstScalarInt(Object::ID id) const;
+ void EvalSpecConstantOp(InsnIterator insn);
+ void EvalSpecConstantUnaryOp(InsnIterator insn);
+ void EvalSpecConstantBinaryOp(InsnIterator insn);
// LoadPhi loads the phi values from the alloca storage and places the
// load values into the intermediate with the phi's result id.
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index 2c341d3..597002d 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -194,9 +194,6 @@
}
opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
}
- // Freeze specialization constants into normal constants, and propagate through
- opt.RegisterPass(spvtools::CreateFreezeSpecConstantValuePass());
- opt.RegisterPass(spvtools::CreateFoldSpecConstantOpAndCompositePass());
// Basic optimization passes to primarily address glslang's love of loads &
// stores. Significantly reduces time spent in LLVM passes and codegen.