Fixed NxM matrix construction and minor issues

Fixed NxM matrix construction by
properly checking for these types in
TIntermOperator::isConstructor. Also
fixed a few areas of the code where
the secondary size wasn't properly
taken into account.

Change-Id: I646a41e37460255316f5712f1d744c3a06d8a64d
Reviewed-on: https://swiftshader-review.googlesource.com/3195
Tested-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <capn@google.com>
diff --git a/src/OpenGL/compiler/Intermediate.cpp b/src/OpenGL/compiler/Intermediate.cpp
index fad11c6..b1d2526 100644
--- a/src/OpenGL/compiler/Intermediate.cpp
+++ b/src/OpenGL/compiler/Intermediate.cpp
@@ -21,6 +21,38 @@
     return left > right ? left : right;
 }
 
+static bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
+{
+	switch(op)
+	{
+	case EOpMul:
+	case EOpMulAssign:
+		return left.getNominalSize() == right.getNominalSize() &&
+		       left.getSecondarySize() == right.getSecondarySize();
+	case EOpVectorTimesScalar:
+	case EOpVectorTimesScalarAssign:
+		return true;
+	case EOpVectorTimesMatrix:
+		return left.getNominalSize() == right.getSecondarySize();
+	case EOpVectorTimesMatrixAssign:
+		return left.getNominalSize() == right.getSecondarySize() &&
+		       left.getNominalSize() == right.getNominalSize();
+	case EOpMatrixTimesVector:
+		return left.getNominalSize() == right.getNominalSize();
+	case EOpMatrixTimesScalar:
+	case EOpMatrixTimesScalarAssign:
+		return true;
+	case EOpMatrixTimesMatrix:
+		return left.getNominalSize() == right.getSecondarySize();
+	case EOpMatrixTimesMatrixAssign:
+		return left.getNominalSize() == right.getNominalSize() &&
+		       left.getSecondarySize() == right.getSecondarySize();
+	default:
+		UNREACHABLE();
+		return false;
+	}
+}
+
 const char* getOperatorString(TOperator op) {
     switch (op) {
       case EOpInitialize: return "=";
@@ -645,7 +677,13 @@
         case EOpConstructVec3:
         case EOpConstructVec4:
         case EOpConstructMat2:
+        case EOpConstructMat2x3:
+        case EOpConstructMat2x4:
+        case EOpConstructMat3x2:
         case EOpConstructMat3:
+        case EOpConstructMat3x4:
+        case EOpConstructMat4x2:
+        case EOpConstructMat4x3:
         case EOpConstructMat4:
         case EOpConstructFloat:
         case EOpConstructIVec2:
@@ -752,7 +790,6 @@
     }
 
     int primarySize = std::max(left->getNominalSize(), right->getNominalSize());
-	int secondarySize = std::max(left->getSecondarySize(), right->getSecondarySize());
 
     //
     // All scalars. Code after this test assumes this case is removed!
@@ -791,25 +828,6 @@
 
     // If we reach here, at least one of the operands is vector or matrix.
     // The other operand could be a scalar, vector, or matrix.
-    // Are the sizes compatible?
-    //
-    if (left->getNominalSize() != right->getNominalSize()) {
-        // If the nominal size of operands do not match:
-        // One of them must be scalar.
-        if (left->getNominalSize() != 1 && right->getNominalSize() != 1)
-            return false;
-        // Operator cannot be of type pure assignment.
-        if (op == EOpAssign || op == EOpInitialize)
-            return false;
-    }
-
-	if (left->getSecondarySize() != right->getSecondarySize()) {
-        // Operator cannot be of type pure assignment.
-        if (op == EOpAssign || op == EOpInitialize)
-            return false;
-    }
-
-    //
     // Can these two operands be combined?
     //
     TBasicType basicType = left->getBasicType();
@@ -817,31 +835,45 @@
         case EOpMul:
             if (!left->isMatrix() && right->isMatrix()) {
                 if (left->isVector())
+                {
                     op = EOpVectorTimesMatrix;
+                    setType(TType(basicType, higherPrecision, EvqTemporary,
+                        static_cast<unsigned char>(right->getNominalSize()), 1));
+                }
                 else {
                     op = EOpMatrixTimesScalar;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, secondarySize));
+                    setType(TType(basicType, higherPrecision, EvqTemporary,
+                        static_cast<unsigned char>(right->getNominalSize()), static_cast<unsigned char>(right->getSecondarySize())));
                 }
             } else if (left->isMatrix() && !right->isMatrix()) {
                 if (right->isVector()) {
                     op = EOpMatrixTimesVector;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary,
+                        static_cast<unsigned char>(left->getSecondarySize()), 1));
                 } else {
                     op = EOpMatrixTimesScalar;
                 }
             } else if (left->isMatrix() && right->isMatrix()) {
                 op = EOpMatrixTimesMatrix;
+                setType(TType(basicType, higherPrecision, EvqTemporary,
+                    static_cast<unsigned char>(right->getNominalSize()), static_cast<unsigned char>(left->getSecondarySize())));
             } else if (!left->isMatrix() && !right->isMatrix()) {
                 if (left->isVector() && right->isVector()) {
                     // leave as component product
                 } else if (left->isVector() || right->isVector()) {
                     op = EOpVectorTimesScalar;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary,
+                        static_cast<unsigned char>(primarySize), 1));
                 }
             } else {
                 infoSink.info.message(EPrefixInternalError, "Missing elses", getLine());
                 return false;
             }
+
+            if(!ValidateMultiplication(op, left->getType(), right->getType()))
+            {
+                return false;
+            }
             break;
         case EOpMulAssign:
             if (!left->isMatrix() && right->isMatrix()) {
@@ -858,6 +890,8 @@
                 }
             } else if (left->isMatrix() && right->isMatrix()) {
                 op = EOpMatrixTimesMatrixAssign;
+                setType(TType(basicType, higherPrecision, EvqTemporary,
+                    static_cast<unsigned char>(right->getNominalSize()), static_cast<unsigned char>(left->getSecondarySize())));
             } else if (!left->isMatrix() && !right->isMatrix()) {
                 if (left->isVector() && right->isVector()) {
                     // leave as component product
@@ -865,16 +899,27 @@
                     if (! left->isVector())
                         return false;
                     op = EOpVectorTimesScalarAssign;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary,
+                        static_cast<unsigned char>(left->getNominalSize()), 1));
                 }
             } else {
                 infoSink.info.message(EPrefixInternalError, "Missing elses", getLine());
                 return false;
             }
+
+            if(!ValidateMultiplication(op, left->getType(), right->getType()))
+            {
+                return false;
+            }
             break;
 
         case EOpAssign:
         case EOpInitialize:
+            // No more additional checks are needed.
+            if ((left->getNominalSize() != right->getNominalSize()) ||
+                (left->getSecondarySize() != right->getSecondarySize()))
+                return false;
+            break;
         case EOpAdd:
         case EOpSub:
         case EOpDiv:
@@ -915,6 +960,8 @@
             }
 
             {
+                const int secondarySize = std::max(
+                    left->getSecondarySize(), right->getSecondarySize());
                 setType(TType(basicType, higherPrecision, EvqTemporary,
                     static_cast<unsigned char>(primarySize), static_cast<unsigned char>(secondarySize)));
                 if(left->isArray())
@@ -923,8 +970,6 @@
                     type.setArraySize(left->getArraySize());
                 }
             }
-
-            setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, secondarySize));
             break;
 
         case EOpEqual:
@@ -933,8 +978,8 @@
         case EOpGreaterThan:
         case EOpLessThanEqual:
         case EOpGreaterThanEqual:
-            if ((left->isMatrix() && right->isVector()) ||
-                (left->isVector() && right->isMatrix()))
+            if ((left->getNominalSize() != right->getNominalSize()) ||
+                (left->getSecondarySize() != right->getSecondarySize()))
                 return false;
             setType(TType(EbtBool, EbpUndefined));
             break;
@@ -1056,16 +1101,29 @@
                     return 0;
                 }
                 {// support MSVC++6.0
-                    int size = getNominalSize();
-                    tempConstArray = new ConstantUnion[size*size];
-                    for (int row = 0; row < size; row++) {
-                        for (int column = 0; column < size; column++) {
-                            tempConstArray[size * column + row].setFConst(0.0f);
-                            for (int i = 0; i < size; i++) {
-                                tempConstArray[size * column + row].setFConst(tempConstArray[size * column + row].getFConst() + unionArray[i * size + row].getFConst() * (rightUnionArray[column * size + i].getFConst()));
+                    int leftNumCols = getNominalSize();
+                    int leftNumRows = getSecondarySize();
+                    int rightNumCols = node->getNominalSize();
+                    int rightNumRows = node->getSecondarySize();
+                    if(leftNumCols != rightNumRows) {
+                        infoSink.info.message(EPrefixInternalError, "Constant Folding cannot be done for matrix multiply", getLine());
+                        return 0;
+                    }
+                    int tempNumCols = rightNumCols;
+                    int tempNumRows = leftNumRows;
+                    int tempNumAdds = leftNumCols;
+                    tempConstArray = new ConstantUnion[tempNumCols*tempNumRows];
+                    for (int row = 0; row < tempNumRows; row++) {
+                        for (int column = 0; column < tempNumCols; column++) {
+                            tempConstArray[tempNumRows * column + row].setFConst(0.0f);
+                            for (int i = 0; i < tempNumAdds; i++) {
+                                tempConstArray[tempNumRows * column + row].setFConst(tempConstArray[tempNumRows * column + row].getFConst() + unionArray[i * leftNumRows + row].getFConst() * (rightUnionArray[column * rightNumRows + i].getFConst()));
                             }
                         }
                     }
+                    // update return type for matrix product
+                    returnType.setNominalSize(static_cast<unsigned char>(tempNumCols));
+                    returnType.setSecondarySize(static_cast<unsigned char>(tempNumRows));
                 }
                 break;
             case EOpDiv:
diff --git a/src/OpenGL/compiler/OutputASM.cpp b/src/OpenGL/compiler/OutputASM.cpp
index d95d568..2af8826 100644
--- a/src/OpenGL/compiler/OutputASM.cpp
+++ b/src/OpenGL/compiler/OutputASM.cpp
@@ -1430,7 +1430,7 @@
 		}

 		else if(type.isMatrix())

 		{

-			return registers * type.getNominalSize();

+			return registers * type.getSecondarySize();

 		}

 		

 		UNREACHABLE();

@@ -1446,7 +1446,7 @@
 				return registerSize(*type.getStruct()->begin()->type, 0);

 			}

 

-			return type.getNominalSize();

+			return type.isMatrix() ? type.getSecondarySize() : type.getNominalSize();

 		}

 

 		if(type.isArray() && registers >= type.elementRegisterCount())

@@ -1590,7 +1590,7 @@
 	{

 		if(src &&

 			((src->isVector() && (!dst->isVector() || (dst->getNominalSize() != dst->getNominalSize()))) ||

-			 (src->isMatrix() && (!dst->isMatrix() || (src->getNominalSize() != dst->getNominalSize())))))

+			 (src->isMatrix() && (!dst->isMatrix() || (src->getNominalSize() != dst->getNominalSize()) || (src->getSecondarySize() != dst->getSecondarySize())))))

 		{

 			return mContext.error(src->getLine(), "Result type should match the l-value type in compound assignment", src->isVector() ? "vector" : "matrix");

 		}

diff --git a/src/OpenGL/compiler/ParseHelper.cpp b/src/OpenGL/compiler/ParseHelper.cpp
index f1287cc..2d9b3d5 100644
--- a/src/OpenGL/compiler/ParseHelper.cpp
+++ b/src/OpenGL/compiler/ParseHelper.cpp
@@ -114,7 +114,7 @@
 // Look at a '.' field selector string and change it into offsets
 // for a matrix.
 //
-bool TParseContext::parseMatrixFields(const TString& compString, int matSize, TMatrixFields& fields, int line)
+bool TParseContext::parseMatrixFields(const TString& compString, int matCols, int matRows, TMatrixFields& fields, int line)
 {
     fields.wholeRow = false;
     fields.wholeCol = false;
@@ -150,7 +150,7 @@
         fields.col = compString[1] - '0';
     }
 
-    if (fields.row >= matSize || fields.col >= matSize) {
+    if (fields.row >= matRows || fields.col >= matCols) {
         error(line, "matrix field selection out of range", compString.c_str());
         return false;
     }
diff --git a/src/OpenGL/compiler/ParseHelper.h b/src/OpenGL/compiler/ParseHelper.h
index 4f9fa96..c494572 100644
--- a/src/OpenGL/compiler/ParseHelper.h
+++ b/src/OpenGL/compiler/ParseHelper.h
@@ -76,7 +76,7 @@
     void recover();
 
     bool parseVectorFields(const TString&, int vecSize, TVectorFields&, int line);
-    bool parseMatrixFields(const TString&, int matSize, TMatrixFields&, int line);
+    bool parseMatrixFields(const TString&, int matCols, int matRows, TMatrixFields&, int line);
 
     bool reservedErrorCheck(int line, const TString& identifier);
     void assignError(int line, const char* op, TString left, TString right);
diff --git a/src/OpenGL/compiler/glslang.y b/src/OpenGL/compiler/glslang.y
index 0dff1c9..ce04736 100644
--- a/src/OpenGL/compiler/glslang.y
+++ b/src/OpenGL/compiler/glslang.y
@@ -323,9 +323,9 @@
             if ($1->getType().getQualifier() == EvqConstExpr)
                 $$->getTypePointer()->setQualifier(EvqConstExpr);
         } else if ($1->isMatrix() && $1->getType().getQualifier() == EvqConstExpr)
-            $$->setType(TType($1->getBasicType(), $1->getPrecision(), EvqConstExpr, $1->getNominalSize()));
+            $$->setType(TType($1->getBasicType(), $1->getPrecision(), EvqConstExpr, $1->getSecondarySize()));
         else if ($1->isMatrix())
-            $$->setType(TType($1->getBasicType(), $1->getPrecision(), EvqTemporary, $1->getNominalSize()));
+            $$->setType(TType($1->getBasicType(), $1->getPrecision(), EvqTemporary, $1->getSecondarySize()));
         else if ($1->isVector() && $1->getType().getQualifier() == EvqConstExpr)
             $$->setType(TType($1->getBasicType(), $1->getPrecision(), EvqConstExpr));
         else if ($1->isVector())
@@ -366,7 +366,7 @@
             }
         } else if ($1->isMatrix()) {
             TMatrixFields fields;
-            if (! context->parseMatrixFields(*$3.string, $1->getNominalSize(), fields, $3.line)) {
+            if (! context->parseMatrixFields(*$3.string, $1->getNominalSize(), $1->getSecondarySize(), fields, $3.line)) {
                 fields.wholeRow = false;
                 fields.wholeCol = false;
                 fields.row = 0;
diff --git a/src/OpenGL/compiler/glslang_tab.cpp b/src/OpenGL/compiler/glslang_tab.cpp
index a672362..0a5f2a5 100644
--- a/src/OpenGL/compiler/glslang_tab.cpp
+++ b/src/OpenGL/compiler/glslang_tab.cpp
@@ -2465,9 +2465,9 @@
             if ((yyvsp[(1) - (4)].interm.intermTypedNode)->getType().getQualifier() == EvqConstExpr)
                 (yyval.interm.intermTypedNode)->getTypePointer()->setQualifier(EvqConstExpr);
         } else if ((yyvsp[(1) - (4)].interm.intermTypedNode)->isMatrix() && (yyvsp[(1) - (4)].interm.intermTypedNode)->getType().getQualifier() == EvqConstExpr)
-            (yyval.interm.intermTypedNode)->setType(TType((yyvsp[(1) - (4)].interm.intermTypedNode)->getBasicType(), (yyvsp[(1) - (4)].interm.intermTypedNode)->getPrecision(), EvqConstExpr, (yyvsp[(1) - (4)].interm.intermTypedNode)->getNominalSize()));
+            (yyval.interm.intermTypedNode)->setType(TType((yyvsp[(1) - (4)].interm.intermTypedNode)->getBasicType(), (yyvsp[(1) - (4)].interm.intermTypedNode)->getPrecision(), EvqConstExpr, (yyvsp[(1) - (4)].interm.intermTypedNode)->getSecondarySize()));
         else if ((yyvsp[(1) - (4)].interm.intermTypedNode)->isMatrix())
-            (yyval.interm.intermTypedNode)->setType(TType((yyvsp[(1) - (4)].interm.intermTypedNode)->getBasicType(), (yyvsp[(1) - (4)].interm.intermTypedNode)->getPrecision(), EvqTemporary, (yyvsp[(1) - (4)].interm.intermTypedNode)->getNominalSize()));
+            (yyval.interm.intermTypedNode)->setType(TType((yyvsp[(1) - (4)].interm.intermTypedNode)->getBasicType(), (yyvsp[(1) - (4)].interm.intermTypedNode)->getPrecision(), EvqTemporary, (yyvsp[(1) - (4)].interm.intermTypedNode)->getSecondarySize()));
         else if ((yyvsp[(1) - (4)].interm.intermTypedNode)->isVector() && (yyvsp[(1) - (4)].interm.intermTypedNode)->getType().getQualifier() == EvqConstExpr)
             (yyval.interm.intermTypedNode)->setType(TType((yyvsp[(1) - (4)].interm.intermTypedNode)->getBasicType(), (yyvsp[(1) - (4)].interm.intermTypedNode)->getPrecision(), EvqConstExpr));
         else if ((yyvsp[(1) - (4)].interm.intermTypedNode)->isVector())
@@ -2516,7 +2516,7 @@
             }
         } else if ((yyvsp[(1) - (3)].interm.intermTypedNode)->isMatrix()) {
             TMatrixFields fields;
-            if (! context->parseMatrixFields(*(yyvsp[(3) - (3)].lex).string, (yyvsp[(1) - (3)].interm.intermTypedNode)->getNominalSize(), fields, (yyvsp[(3) - (3)].lex).line)) {
+            if (! context->parseMatrixFields(*(yyvsp[(3) - (3)].lex).string, (yyvsp[(1) - (3)].interm.intermTypedNode)->getNominalSize(), (yyvsp[(1) - (3)].interm.intermTypedNode)->getSecondarySize(), fields, (yyvsp[(3) - (3)].lex).line)) {
                 fields.wholeRow = false;
                 fields.wholeCol = false;
                 fields.row = 0;