Fix constant folding of vector compares.
Change-Id: If17c2429d38158663c2436e374691a460e3d588c
Reviewed-on: https://swiftshader-review.googlesource.com/5064
Tested-by: Nicolas Capens <capn@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <capn@google.com>
diff --git a/src/OpenGL/compiler/ConstantUnion.h b/src/OpenGL/compiler/ConstantUnion.h
index e6bec81..b001ec4 100644
--- a/src/OpenGL/compiler/ConstantUnion.h
+++ b/src/OpenGL/compiler/ConstantUnion.h
@@ -174,7 +174,7 @@
}
bool operator>(const ConstantUnion& constant) const
- {
+ {
assert(type == constant.type);
switch (type) {
case EbtInt:
@@ -191,7 +191,7 @@
}
bool operator<(const ConstantUnion& constant) const
- {
+ {
assert(type == constant.type);
switch (type) {
case EbtInt:
@@ -207,8 +207,42 @@
return false;
}
+ bool operator<=(const ConstantUnion& constant) const
+ {
+ assert(type == constant.type);
+ switch (type) {
+ case EbtInt:
+ return iConst <= constant.iConst;
+ case EbtUInt:
+ return uConst <= constant.uConst;
+ case EbtFloat:
+ return fConst <= constant.fConst;
+ default:
+ return false; // Invalid operation, handled at semantic analysis
+ }
+
+ return false;
+ }
+
+ bool operator>=(const ConstantUnion& constant) const
+ {
+ assert(type == constant.type);
+ switch (type) {
+ case EbtInt:
+ return iConst >= constant.iConst;
+ case EbtUInt:
+ return uConst >= constant.uConst;
+ case EbtFloat:
+ return fConst >= constant.fConst;
+ default:
+ return false; // Invalid operation, handled at semantic analysis
+ }
+
+ return false;
+ }
+
ConstantUnion operator+(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -222,7 +256,7 @@
}
ConstantUnion operator-(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -236,13 +270,13 @@
}
ConstantUnion operator*(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst * constant.iConst); break;
case EbtUInt: returnValue.setUConst(uConst * constant.uConst); break;
- case EbtFloat: returnValue.setFConst(fConst * constant.fConst); break;
+ case EbtFloat: returnValue.setFConst(fConst * constant.fConst); break;
default: assert(false && "Default missing");
}
@@ -250,7 +284,7 @@
}
ConstantUnion operator%(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -263,7 +297,7 @@
}
ConstantUnion operator>>(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -276,7 +310,7 @@
}
ConstantUnion operator<<(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
// The signedness of the second parameter might be different, but we
// don't care, since the result is undefined if the second parameter is
@@ -292,7 +326,7 @@
}
ConstantUnion operator&(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(constant.type == EbtInt || constant.type == EbtUInt);
switch (type) {
@@ -305,7 +339,7 @@
}
ConstantUnion operator|(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -318,7 +352,7 @@
}
ConstantUnion operator^(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -331,7 +365,7 @@
}
ConstantUnion operator&&(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
@@ -343,7 +377,7 @@
}
ConstantUnion operator||(const ConstantUnion& constant) const
- {
+ {
ConstantUnion returnValue;
assert(type == constant.type);
switch (type) {
diff --git a/src/OpenGL/compiler/Intermediate.cpp b/src/OpenGL/compiler/Intermediate.cpp
index bc79a68..eb33972 100644
--- a/src/OpenGL/compiler/Intermediate.cpp
+++ b/src/OpenGL/compiler/Intermediate.cpp
@@ -1043,7 +1043,7 @@
default:
return false;
}
-
+
return true;
}
@@ -1539,38 +1539,29 @@
break;
case EOpLessThan:
- assert(objectSize == 1);
- tempConstArray = new ConstantUnion[1];
- tempConstArray->setBConst(*unionArray < *rightUnionArray);
- returnType = TType(EbtBool, EbpUndefined, EvqConstExpr);
+ tempConstArray = new ConstantUnion[objectSize];
+ for(int i = 0; i < objectSize; i++)
+ tempConstArray[i].setBConst(unionArray[i] < rightUnionArray[i]);
+ returnType = TType(EbtBool, EbpUndefined, EvqConstExpr, objectSize);
break;
case EOpGreaterThan:
- assert(objectSize == 1);
- tempConstArray = new ConstantUnion[1];
- tempConstArray->setBConst(*unionArray > *rightUnionArray);
- returnType = TType(EbtBool, EbpUndefined, EvqConstExpr);
+ tempConstArray = new ConstantUnion[objectSize];
+ for(int i = 0; i < objectSize; i++)
+ tempConstArray[i].setBConst(unionArray[i] > rightUnionArray[i]);
+ returnType = TType(EbtBool, EbpUndefined, EvqConstExpr, objectSize);
break;
case EOpLessThanEqual:
- {
- assert(objectSize == 1);
- ConstantUnion constant;
- constant.setBConst(*unionArray > *rightUnionArray);
- tempConstArray = new ConstantUnion[1];
- tempConstArray->setBConst(!constant.getBConst());
- returnType = TType(EbtBool, EbpUndefined, EvqConstExpr);
- break;
- }
+ tempConstArray = new ConstantUnion[objectSize];
+ for(int i = 0; i < objectSize; i++)
+ tempConstArray[i].setBConst(unionArray[i] <= rightUnionArray[i]);
+ returnType = TType(EbtBool, EbpUndefined, EvqConstExpr, objectSize);
+ break;
case EOpGreaterThanEqual:
- {
- assert(objectSize == 1);
- ConstantUnion constant;
- constant.setBConst(*unionArray < *rightUnionArray);
- tempConstArray = new ConstantUnion[1];
- tempConstArray->setBConst(!constant.getBConst());
- returnType = TType(EbtBool, EbpUndefined, EvqConstExpr);
- break;
- }
-
+ tempConstArray = new ConstantUnion[objectSize];
+ for(int i = 0; i < objectSize; i++)
+ tempConstArray[i].setBConst(unionArray[i] >= rightUnionArray[i]);
+ returnType = TType(EbtBool, EbpUndefined, EvqConstExpr, objectSize);
+ break;
case EOpEqual:
if (getType().getBasicType() == EbtStruct) {
if (!CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))