Change infrastructure to support constant folding across built-in functions, as requi...
authorJohn Kessenich <cepheus@frii.com>
Thu, 7 Mar 2013 19:22:07 +0000 (19:22 +0000)
committerJohn Kessenich <cepheus@frii.com>
Thu, 7 Mar 2013 19:22:07 +0000 (19:22 +0000)
git-svn-id: https://cvs.khronos.org/svn/repos/ogl/trunk/ecosystem/public/sdk/tools/glslang@20806 e7fa87d3-cd2b-0410-9028-fcbf551c1848

14 files changed:
Test/constErrors.frag
Test/constFold.frag [new file with mode: 0644]
Test/testlist
glslang.vcxproj
glslang.vcxproj.filters
glslang/Include/intermediate.h
glslang/MachineIndependent/Constant.cpp [new file with mode: 0644]
glslang/MachineIndependent/Intermediate.cpp
glslang/MachineIndependent/ParseHelper.cpp
glslang/MachineIndependent/ParseHelper.h
glslang/MachineIndependent/glslang.l
glslang/MachineIndependent/glslang.y
glslang/MachineIndependent/localintermediate.h
glslang/MachineIndependent/parseConst.cpp

index 282e5b9..f031d47 100644 (file)
@@ -18,5 +18,5 @@ void main()
     vec4 e[constInt + uniformInt]; // error
     vec4 f[uniformInt + constInt]; // error
 
-    vec4 g[sin(3.2)];              // okay
+    vec4 g[int(sin(0.3)) + 1];     // okay
 }
diff --git a/Test/constFold.frag b/Test/constFold.frag
new file mode 100644 (file)
index 0000000..3b0d360
--- /dev/null
@@ -0,0 +1,29 @@
+#version 430
+
+const int a = 1;
+const int b = 2;
+const int c = a + b; // 3
+const int d = c - a; // 2
+const float e = float(d); // 2.0
+const float f = e * float(c); // 6.0
+const float g = f / float(d); // 3.0
+
+in vec4 inv;
+out vec4 FragColor;
+
+void main()
+{
+    vec4 dx = dFdx(inv);
+    const ivec4 v = ivec4(a, b, c, d);
+    vec4 array2[v.y];                  // 2
+    const ivec4 u = ~v;
+
+    const float h = degrees(g); // 171.88
+
+    FragColor = vec4(e, f, g, h);  // 2, 6, 3, 171.88
+
+    vec4 array3[c];               // 3
+    vec4 arrayMax[int(max(float(array2.length()), float(array3.length())))];
+    vec4 arrayMin[int(min(float(array2.length()), float(array3.length())))];
+    FragColor = vec4(arrayMax.length(), arrayMin.length(), sin(3.14), cos(3.14));  // 3, 2, .00159, -.999
+}
index 723314c..6a65cae 100644 (file)
@@ -26,4 +26,5 @@ comment.frag
 330.frag
 330comp.frag
 constErrors.frag
+constFold.frag
 errors.frag
index 60d5aa6..85b12fb 100644 (file)
@@ -147,6 +147,7 @@ xcopy /y $(IntDir)$(TargetName)$(TargetExt) Test</Command>
     </ResourceCompile>\r
   </ItemDefinitionGroup>\r
   <ItemGroup>\r
+    <ClCompile Include="glslang\MachineIndependent\Constant.cpp" />\r
     <ClCompile Include="glslang\MachineIndependent\gen_glslang.cpp" />\r
     <ClCompile Include="glslang\MachineIndependent\glslang_tab.cpp" />\r
     <ClCompile Include="glslang\MachineIndependent\InfoSink.cpp" />\r
index 6561d78..4083edc 100644 (file)
     <ClCompile Include="glslang\MachineIndependent\Versions.cpp">\r
       <Filter>Machine Independent</Filter>\r
     </ClCompile>\r
+    <ClCompile Include="glslang\MachineIndependent\Constant.cpp">\r
+      <Filter>Machine Independent</Filter>\r
+    </ClCompile>\r
   </ItemGroup>\r
   <ItemGroup>\r
     <ClInclude Include="glslang\MachineIndependent\Initialize.h">\r
index f15a26c..d0a4d67 100644 (file)
@@ -406,6 +406,7 @@ public:
     virtual TIntermConstantUnion* getAsConstantUnion()  { return this; }
     virtual void traverse(TIntermTraverser* );
     virtual TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&);
+    virtual TIntermTyped* fold(TOperator, const TType&, TInfoSink&);
 protected:
     constUnion *unionArrayPointer;
 };
diff --git a/glslang/MachineIndependent/Constant.cpp b/glslang/MachineIndependent/Constant.cpp
new file mode 100644 (file)
index 0000000..df09a1c
--- /dev/null
@@ -0,0 +1,608 @@
+//
+//Copyright (C) 2002-2005  3Dlabs Inc. Ltd.
+//Copyright (C) 2012-2013 LunarG, Inc.
+//
+//All rights reserved.
+//
+//Redistribution and use in source and binary forms, with or without
+//modification, are permitted provided that the following conditions
+//are met:
+//
+//    Redistributions of source code must retain the above copyright
+//    notice, this list of conditions and the following disclaimer.
+//
+//    Redistributions in binary form must reproduce the above
+//    copyright notice, this list of conditions and the following
+//    disclaimer in the documentation and/or other materials provided
+//    with the distribution.
+//
+//    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
+//    contributors may be used to endorse or promote products derived
+//    from this software without specific prior written permission.
+//
+//THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+//"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+//LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+//FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+//COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+//BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+//LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+//CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+//LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+//ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+//POSSIBILITY OF SUCH DAMAGE.
+//
+
+#include "localintermediate.h"
+
+namespace {
+
+// Some helper functions
+
+const double pi = 3.1415926535897932384626433832795;
+
+bool CompareStruct(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray);
+
+bool CompareStructure(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
+{
+    if (leftNodeType.isArray()) {
+        TType typeWithoutArrayness = leftNodeType;
+        typeWithoutArrayness.dereference();
+
+        int arraySize = leftNodeType.getArraySize();
+
+        for (int i = 0; i < arraySize; ++i) {
+            int offset = typeWithoutArrayness.getObjectSize() * i;
+            if (! CompareStruct(typeWithoutArrayness, &rightUnionArray[offset], &leftUnionArray[offset]))
+                return false;
+        }
+    } else
+        return CompareStruct(leftNodeType, rightUnionArray, leftUnionArray);
+
+    return true;
+}
+
+bool CompareStruct(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
+{
+    TTypeList* fields = leftNodeType.getStruct();
+
+    size_t structSize = fields->size();
+    int index = 0;
+
+    for (size_t j = 0; j < structSize; j++) {
+        int size = (*fields)[j].type->getObjectSize();
+        for (int i = 0; i < size; i++) {
+            if ((*fields)[j].type->getBasicType() == EbtStruct) {
+                if (!CompareStructure(*(*fields)[j].type, &rightUnionArray[index], &leftUnionArray[index]))
+                    return false;
+            } else {
+                if (leftUnionArray[index] != rightUnionArray[index])
+                    return false;
+                index++;
+            }
+
+        }
+    }
+    return true;
+}
+
+}; // end anonymous namespace
+
+//
+// The fold functions see if an operation on a constant can be done in place,
+// without generating run-time code.
+//
+// Returns the node to keep using, which may or may not be the node passed in.
+//
+// Note: As of version 1.2, all constant operations must be folded.  It is
+// not opportunistic, but rather a semantic requirement.
+//
+
+//
+// Do folding between a pair of nodes
+//
+TIntermTyped* TIntermConstantUnion::fold(TOperator op, TIntermTyped* constantNode, TInfoSink& infoSink)
+{
+    constUnion *unionArray = getUnionArrayPointer();
+    int objectSize = getType().getObjectSize();
+    constUnion* newConstArray = 0;
+
+    // For most cases, the return type matches the argument type, so set that
+    // up and just code to exceptions below.
+    TType returnType = getType();
+
+    //
+    // A pair of nodes is to be folded together
+    //
+
+    TIntermConstantUnion *node = constantNode->getAsConstantUnion();
+    constUnion *rightUnionArray = node->getUnionArrayPointer();
+
+    if (getType().getBasicType() != node->getBasicType()) {
+        infoSink.info.message(EPrefixInternalError, "Constant folding basic types don't match", getLine());
+        return 0;
+    }
+
+    if (constantNode->getType().getObjectSize() == 1 && objectSize > 1) {
+        // for a case like float f = vec4(2,3,4,5) + 1.2;
+        rightUnionArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; ++i)
+            rightUnionArray[i] = *node->getUnionArrayPointer();
+    } else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1) {
+        // for a case like float f = 1.2 + vec4(2,3,4,5);
+        rightUnionArray = node->getUnionArrayPointer();
+        unionArray = new constUnion[constantNode->getType().getObjectSize()];
+        for (int i = 0; i < constantNode->getType().getObjectSize(); ++i)
+            unionArray[i] = *getUnionArrayPointer();
+        returnType = node->getType();
+        objectSize = constantNode->getType().getObjectSize();
+    }
+
+    int index = 0;
+    bool boolNodeFlag = false;
+    switch(op) {
+    case EOpAdd:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] + rightUnionArray[i];
+        break;
+    case EOpSub:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] - rightUnionArray[i];
+        break;
+
+    case EOpMul:
+    case EOpVectorTimesScalar:
+    case EOpMatrixTimesScalar:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] * rightUnionArray[i];
+        break;
+    case EOpMatrixTimesMatrix:
+        newConstArray = new constUnion[getMatrixRows() * node->getMatrixCols()];
+        for (int row = 0; row < getMatrixRows(); row++) {
+            for (int column = 0; column < node->getMatrixCols(); column++) {
+                float sum = 0.0f;
+                for (int i = 0; i < node->getMatrixRows(); i++)
+                    sum += unionArray[i * getMatrixRows() + row].getFConst() * rightUnionArray[column * node->getMatrixRows() + i].getFConst();
+                newConstArray[column * getMatrixRows() + row].setFConst(sum);
+            }
+        }
+        returnType = TType(getType().getBasicType(), EvqConst, 0, getMatrixRows(), node->getMatrixCols());
+        break;
+    case EOpDiv:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++) {
+            switch (getType().getBasicType()) {
+            case EbtFloat:
+                if (rightUnionArray[i] == 0.0f) {
+                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
+                    newConstArray[i].setFConst(FLT_MAX);
+                } else
+                    newConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
+            break;
+
+            case EbtInt:
+                if (rightUnionArray[i] == 0) {
+                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
+                    newConstArray[i].setIConst(0xEFFFFFFF);
+                } else
+                    newConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
+                break;
+            default:
+                infoSink.info.message(EPrefixInternalError, "Constant folding cannot be done for \"/\"", getLine());
+                return 0;
+            }
+        }
+        break;
+
+    case EOpMatrixTimesVector:
+        newConstArray = new constUnion[getMatrixRows()];
+        for (int i = 0; i < getMatrixRows(); i++) {
+            float sum = 0.0f;
+            for (int j = 0; j < node->getVectorSize(); j++) {
+                sum += unionArray[j*getMatrixRows() + i].getFConst() * rightUnionArray[j].getFConst();
+            }
+            newConstArray[i].setFConst(sum);
+        }
+
+        returnType = TType(getBasicType(), EvqConst, getMatrixRows());
+        break;
+
+    case EOpVectorTimesMatrix:
+        newConstArray = new constUnion[node->getMatrixCols()];
+        for (int i = 0; i < node->getMatrixCols(); i++) {
+            float sum = 0.0f;
+            for (int j = 0; j < getVectorSize(); j++)
+                sum += unionArray[j].getFConst() * rightUnionArray[i*node->getMatrixRows() + j].getFConst();
+            newConstArray[i].setFConst(sum);
+        }
+
+        returnType = TType(getBasicType(), EvqConst, node->getMatrixCols());
+        break;
+
+    case EOpMod:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] % rightUnionArray[i];
+        break;
+
+    case EOpRightShift:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] >> rightUnionArray[i];
+        break;
+
+    case EOpLeftShift:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] << rightUnionArray[i];
+        break;
+
+    case EOpAnd:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] & rightUnionArray[i];
+        break;
+    case EOpInclusiveOr:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] | rightUnionArray[i];
+        break;
+    case EOpExclusiveOr:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] ^ rightUnionArray[i];
+        break;
+
+    case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] && rightUnionArray[i];
+        break;
+
+    case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++)
+            newConstArray[i] = unionArray[i] || rightUnionArray[i];
+        break;
+
+    case EOpLogicalXor:
+        newConstArray = new constUnion[objectSize];
+        for (int i = 0; i < objectSize; i++) {
+            switch (getType().getBasicType()) {
+            case EbtBool: newConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
+            default: assert(false && "Default missing");
+            }
+        }
+        break;
+
+    case EOpLessThan:
+        assert(objectSize == 1);
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(*unionArray < *rightUnionArray);
+        returnType = TType(EbtBool, EvqConst);
+        break;
+    case EOpGreaterThan:
+        assert(objectSize == 1);
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(*unionArray > *rightUnionArray);
+        returnType = TType(EbtBool, EvqConst);
+        break;
+    case EOpLessThanEqual:
+    {
+        assert(objectSize == 1);
+        constUnion constant;
+        constant.setBConst(*unionArray > *rightUnionArray);
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(!constant.getBConst());
+        returnType = TType(EbtBool, EvqConst);
+        break;
+    }
+    case EOpGreaterThanEqual:
+    {
+        assert(objectSize == 1);
+        constUnion constant;
+        constant.setBConst(*unionArray < *rightUnionArray);
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(!constant.getBConst());
+        returnType = TType(EbtBool, EvqConst);
+        break;
+    }
+
+    case EOpEqual:
+        if (getType().getBasicType() == EbtStruct) {
+            if (! CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
+                boolNodeFlag = true;
+        } else {
+            for (int i = 0; i < objectSize; i++) {
+                if (unionArray[i] != rightUnionArray[i]) {
+                    boolNodeFlag = true;
+                    break;  // break out of for loop
+                }
+            }
+        }
+
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(! boolNodeFlag);
+        returnType = TType(EbtBool, EvqConst);
+        break;
+
+    case EOpNotEqual:
+        if (getType().getBasicType() == EbtStruct) {
+            if (CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
+                boolNodeFlag = true;
+        } else {
+            for (int i = 0; i < objectSize; i++) {
+                if (unionArray[i] == rightUnionArray[i]) {
+                    boolNodeFlag = true;
+                    break;  // break out of for loop
+                }
+            }
+        }
+
+        newConstArray = new constUnion[1];
+        newConstArray->setBConst(! boolNodeFlag);
+        returnType = TType(EbtBool, EvqConst);
+        break;
+
+    default:
+        infoSink.info.message(EPrefixInternalError, "Invalid operator for constant folding", getLine());
+
+        return 0;
+    }
+
+    TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
+    newNode->setLine(getLine());
+
+    return newNode;
+}
+
+//
+// Do single unary node folding
+//
+TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TType& returnType, TInfoSink& infoSink)
+{
+    constUnion *unionArray = getUnionArrayPointer();
+    int objectSize = getType().getObjectSize();
+
+    // First, size the result, which is mostly the same as the argument's size,
+    // but not always.
+    constUnion* newConstArray;
+    switch (op) {
+    // TODO: functionality: constant folding: finish listing exceptions to size here
+    case EOpDeterminant:
+    case EOpAny:
+    case EOpAll:
+        newConstArray = new constUnion[1];
+        break;
+    default:
+        newConstArray = new constUnion[objectSize];
+    }
+
+    // TODO: Functionality: constant folding: separate component-wise from non-component-wise
+    for (int i = 0; i < objectSize; i++) {
+        switch (op) {
+        case EOpNegative:
+            switch (getType().getBasicType()) {
+            case EbtFloat: newConstArray[i].setFConst(-unionArray[i].getFConst()); break;
+            case EbtInt:   newConstArray[i].setIConst(-unionArray[i].getIConst()); break;
+            default:
+                infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
+                return 0;
+            }
+            break;
+        case EOpLogicalNot:
+        case EOpVectorLogicalNot:
+            switch (getType().getBasicType()) {
+            case EbtBool:  newConstArray[i].setBConst(!unionArray[i].getBConst()); break;
+            default:
+                infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
+                return 0;
+            }
+            break;
+        case EOpBitwiseNot:
+            newConstArray[i] = ~unionArray[i];
+            break;
+        case EOpRadians:
+            newConstArray[i].setFConst(static_cast<float>(unionArray[i].getFConst() * pi / 180.0));
+            break;
+        case EOpDegrees:
+            newConstArray[i].setFConst(static_cast<float>(unionArray[i].getFConst() * 180.0 / pi));
+            break;
+        case EOpSin:
+            newConstArray[i].setFConst(sin(unionArray[i].getFConst()));
+            break;
+        case EOpCos:
+            newConstArray[i].setFConst(cos(unionArray[i].getFConst()));
+            break;
+        case EOpTan:
+            newConstArray[i].setFConst(tan(unionArray[i].getFConst()));
+            break;
+        case EOpAsin:
+            newConstArray[i].setFConst(asin(unionArray[i].getFConst()));
+            break;
+        case EOpAcos:
+            newConstArray[i].setFConst(acos(unionArray[i].getFConst()));
+            break;
+        case EOpAtan:
+            newConstArray[i].setFConst(atan(unionArray[i].getFConst()));
+            break;
+
+        // TODO: Functionality: constant folding: the rest of the ops have to be fleshed out
+
+        case EOpExp:
+        case EOpLog:
+        case EOpExp2:
+        case EOpLog2:
+        case EOpSqrt:
+        case EOpInverseSqrt:
+
+        case EOpAbs:
+        case EOpSign:
+        case EOpFloor:
+        case EOpCeil:
+        case EOpFract:
+
+        case EOpLength:
+
+        case EOpDPdx:
+        case EOpDPdy:
+        case EOpFwidth:
+            // The derivatives are all mandated to create a constant 0.
+
+        case EOpDeterminant:
+        case EOpMatrixInverse:
+        case EOpTranspose:
+
+        case EOpAny:
+        case EOpAll:
+
+        default:
+            infoSink.info.message(EPrefixInternalError, "Invalid operator for constant folding", getLine());
+            return 0;
+        }
+    }
+
+    TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
+    newNode->getTypePointer()->getQualifier().storage = EvqConst;
+    newNode->setLine(getLine());
+
+    return newNode;
+}
+
+//
+// Do constant folding for an aggregate node that has all its children
+// as constants and an operator that requires constant folding.
+// 
+TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
+{
+    if (! areAllChildConst(aggrNode))
+        return aggrNode;
+
+    if (aggrNode->isConstructor())
+        return foldConstructor(aggrNode);
+
+    TIntermSequence& children = aggrNode->getSequence();
+
+    // First, see if this is an operation to constant fold, kick out if not,
+    // see what size the result is if so.
+    int objectSize;
+    switch (aggrNode->getOp()) {
+    case EOpMin:
+    case EOpMax:
+    case EOpReflect:
+    case EOpRefract:
+    case EOpFaceForward:
+    case EOpAtan:
+    case EOpPow:
+    case EOpClamp:
+    case EOpMix:
+    case EOpDistance:
+    case EOpCross:
+    case EOpNormalize:
+        objectSize = children[0]->getAsConstantUnion()->getType().getObjectSize();
+        break;
+    case EOpDot:
+        objectSize = 1;
+        break;
+    case EOpOuterProduct:
+        objectSize = children[0]->getAsTyped()->getType().getVectorSize() *
+                     children[1]->getAsTyped()->getType().getVectorSize();
+        break;
+    case EOpStep:
+        objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
+                              children[1]->getAsTyped()->getType().getVectorSize());
+        break;
+    case EOpSmoothStep:
+        objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
+                              children[2]->getAsTyped()->getType().getVectorSize());
+        break;
+    default:
+        return aggrNode;
+    }
+    constUnion* newConstArray = new constUnion[objectSize];
+
+    TVector<constUnion*> childConstUnions;
+    for (unsigned int i = 0; i < children.size(); ++i)
+        childConstUnions.push_back(children[i]->getAsConstantUnion()->getUnionArrayPointer());
+
+    // Second, do the actual folding
+
+    // TODO: Functionality: constant folding: separate component-wise from non-component-wise
+    switch (aggrNode->getOp()) {
+    case EOpMin:
+    case EOpMax:
+        for (int i = 0; i < objectSize; i++) {
+            if (aggrNode->getOp() == EOpMax)
+                newConstArray[i].setFConst(std::max(childConstUnions[0]->getFConst(), childConstUnions[1]->getFConst()));
+            else
+                newConstArray[i].setFConst(std::min(childConstUnions[0]->getFConst(), childConstUnions[1]->getFConst()));
+        }
+        break;
+
+    // TODO: Functionality: constant folding: the rest of the ops have to be fleshed out
+
+    case EOpAtan:
+    case EOpPow:
+    case EOpClamp:
+    case EOpMix:
+    case EOpStep:
+    case EOpSmoothStep:
+    case EOpDistance:
+    case EOpDot:
+    case EOpCross:
+    case EOpNormalize:
+    case EOpFaceForward:
+    case EOpReflect:
+    case EOpRefract:
+    case EOpOuterProduct:
+        infoSink.info.message(EPrefixInternalError, "constant folding operation not implemented", aggrNode->getLine());
+        return aggrNode;
+
+    default:
+        return aggrNode;
+    }
+
+    TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, aggrNode->getType());
+    newNode->getTypePointer()->getQualifier().storage = EvqConst;
+    newNode->setLine(aggrNode->getLine());
+
+    return newNode;
+}
+
+bool TIntermediate::areAllChildConst(TIntermAggregate* aggrNode)
+{
+    bool allConstant = true;
+
+    // check if all the child nodes are constants so that they can be inserted into
+    // the parent node
+    if (aggrNode) {
+        TIntermSequence& childSequenceVector = aggrNode->getSequence();
+        for (TIntermSequence::iterator p  = childSequenceVector.begin();
+                                       p != childSequenceVector.end(); p++) {
+            if (!(*p)->getAsTyped()->getAsConstantUnion())
+                return false;
+        }
+    }
+
+    return allConstant;
+}
+
+TIntermTyped* TIntermediate::foldConstructor(TIntermAggregate* aggrNode)
+{
+    bool returnVal = false;
+
+    constUnion* unionArray = new constUnion[aggrNode->getType().getObjectSize()];
+    if (aggrNode->getSequence().size() == 1)
+        returnVal = parseConstTree(aggrNode->getLine(), aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType(), true);
+    else
+        returnVal = parseConstTree(aggrNode->getLine(), aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType());
+
+    if (returnVal)
+        return aggrNode;
+
+    return addConstantUnion(unionArray, aggrNode->getType(), aggrNode->getLine());
+}
index eac2701..7b50ee1 100644 (file)
@@ -43,8 +43,6 @@
 #include "RemoveTree.h"
 #include <float.h>
 
-bool CompareStructure(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray);
-
 ////////////////////////////////////////////////////////////////////////////
 //
 // First set of functions are to help build the intermediate representation.
@@ -221,7 +219,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
         if (child->getType().getBasicType() == EbtStruct || child->getType().isArray())
             return 0;
     }
-    
+
     //
     // Do we need to promote the operand?
     //
@@ -270,7 +268,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
         return 0;
 
     if (childTempConstant)  {
-        TIntermTyped* newChild = childTempConstant->fold(op, 0, infoSink);
+        TIntermTyped* newChild = childTempConstant->fold(op, node->getType(), infoSink);
         
         if (newChild)
             return newChild;
@@ -289,7 +287,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
 // Returns an aggregate node, which could be the one passed in if
 // it was already an aggregate.
 //
-TIntermAggregate* TIntermediate::setAggregateOperator(TIntermNode* node, TOperator op, TSourceLoc line)
+TIntermTyped* TIntermediate::setAggregateOperator(TIntermNode* node, TOperator op, const TType& type, TSourceLoc line)
 {
     TIntermAggregate* aggNode;
 
@@ -317,7 +315,9 @@ TIntermAggregate* TIntermediate::setAggregateOperator(TIntermNode* node, TOperat
     if (line != 0)
         aggNode->setLine(line);
 
-    return aggNode;
+    aggNode->setType(type);
+
+    return fold(aggNode);
 }
 
 //
@@ -431,7 +431,7 @@ TIntermTyped* TIntermediate::addConversion(TOperator op, const TType& type, TInt
     
     if (node->getAsConstantUnion()) {
 
-        return (promoteConstantUnion(promoteTo, node->getAsConstantUnion()));
+        return promoteConstantUnion(promoteTo, node->getAsConstantUnion());
     } else {    
         //
         // Add a new newNode for the conversion.
@@ -822,6 +822,7 @@ bool TIntermOperator::isConstructor() const
 {
     return op > EOpConstructGuardStart && op < EOpConstructGuardEnd;
 }
+
 //
 // Make sure the type of a unary operator is appropriate for its 
 // combination of operation and operand type.
@@ -833,10 +834,13 @@ bool TIntermUnary::promote(TInfoSink&)
     switch (op) {
     case EOpLogicalNot:
         if (operand->getBasicType() != EbtBool)
+
             return false;
         break;
     case EOpBitwiseNot:
-        if (operand->getBasicType() != EbtInt)
+        if (operand->getBasicType() != EbtInt &&
+            operand->getBasicType() != EbtUint)
+
             return false;
         break;
     case EOpNegative:
@@ -844,22 +848,53 @@ bool TIntermUnary::promote(TInfoSink&)
     case EOpPostDecrement:
     case EOpPreIncrement:
     case EOpPreDecrement:
-        if (operand->getBasicType() == EbtBool)
+        if (operand->getBasicType() != EbtInt && 
+            operand->getBasicType() != EbtUint && 
+            operand->getBasicType() != EbtFloat)
+
             return false;
         break;
 
-    // operators for built-ins are already type checked against their prototype
+    //
+    // Operators for built-ins are already type checked against their prototype.
+    // Special case the non-float ones, just so we don't give an error.
+    //
+
     case EOpAny:
     case EOpAll:
+        setType(TType(EbtBool));
+
+        return true;
+
     case EOpVectorLogicalNot:
+        break;
+
+    case EOpLength:
+        setType(TType(EbtFloat, EvqTemporary, operand->getQualifier().precision));
+
+        return true;
+
+    case EOpTranspose:
+        setType(TType(operand->getType().getBasicType(), EvqTemporary, operand->getQualifier().precision, 0, 
+                                                                       operand->getType().getMatrixRows(),
+                                                                       operand->getType().getMatrixCols()));
+        return true;
+        
+    case EOpDeterminant:
+        setType(TType(operand->getType().getBasicType(), EvqTemporary, operand->getQualifier().precision));
+
         return true;
 
     default:
+        // TODO: functionality: uint/int versions of built-ins
+        //       make sure all paths set the type
         if (operand->getBasicType() != EbtFloat)
+
             return false;
     }
-    
+
     setType(operand->getType());
+    getTypePointer()->getQualifier().storage = EvqTemporary;
 
     return true;
 }
@@ -1125,30 +1160,6 @@ bool TIntermBinary::promote(TInfoSink& infoSink)
     return true;
 }
 
-bool CompareStruct(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
-{
-    TTypeList* fields = leftNodeType.getStruct();
-
-    size_t structSize = fields->size();
-    int index = 0;
-
-    for (size_t j = 0; j < structSize; j++) {
-        int size = (*fields)[j].type->getObjectSize();
-        for (int i = 0; i < size; i++) {
-            if ((*fields)[j].type->getBasicType() == EbtStruct) {
-                if (!CompareStructure(*(*fields)[j].type, &rightUnionArray[index], &leftUnionArray[index]))
-                    return false;
-            } else {
-                if (leftUnionArray[index] != rightUnionArray[index])
-                    return false;
-                index++;
-            }    
-            
-        }
-    }
-    return true;
-}
-
 void TIntermTyped::propagatePrecision(TPrecisionQualifier newPrecision)
 {
     if (getQualifier().precision != EpqNone || (getBasicType() != EbtInt && getBasicType() != EbtFloat))
@@ -1196,350 +1207,6 @@ void TIntermTyped::propagatePrecision(TPrecisionQualifier newPrecision)
     //    indexing?
 }
 
-bool CompareStructure(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
-{
-    if (leftNodeType.isArray()) {
-        TType typeWithoutArrayness = leftNodeType;
-        typeWithoutArrayness.dereference();
-
-        int arraySize = leftNodeType.getArraySize();
-
-        for (int i = 0; i < arraySize; ++i) {
-            int offset = typeWithoutArrayness.getObjectSize() * i;
-            if (!CompareStruct(typeWithoutArrayness, &rightUnionArray[offset], &leftUnionArray[offset]))
-                return false;
-        }
-    } else
-        return CompareStruct(leftNodeType, rightUnionArray, leftUnionArray);    
-    
-    return true;
-} 
-
-//
-// The fold functions see if an operation on a constant can be done in place,
-// without generating run-time code.
-//
-// Returns the node to keep using, which may or may not be the node passed in.
-//
-
-TIntermTyped* TIntermConstantUnion::fold(TOperator op, TIntermTyped* constantNode, TInfoSink& infoSink)
-{   
-    constUnion *unionArray = getUnionArrayPointer();
-    int objectSize = getType().getObjectSize();
-
-    if (constantNode) {  // binary operations
-        TIntermConstantUnion *node = constantNode->getAsConstantUnion();
-        constUnion *rightUnionArray = node->getUnionArrayPointer();
-        TType returnType = getType();
-
-        if (getType().getBasicType() != node->getBasicType()) {
-            infoSink.info.message(EPrefixInternalError, "Constant folding basic types don't match", getLine());
-            return 0;
-        }
-
-        if (constantNode->getType().getObjectSize() == 1 && objectSize > 1) {
-            // for a case like float f = vec4(2,3,4,5) + 1.2;
-            rightUnionArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; ++i)
-                rightUnionArray[i] = *node->getUnionArrayPointer();
-        } else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1) {
-            // for a case like float f = 1.2 + vec4(2,3,4,5);
-            rightUnionArray = node->getUnionArrayPointer();
-            unionArray = new constUnion[constantNode->getType().getObjectSize()];
-            for (int i = 0; i < constantNode->getType().getObjectSize(); ++i)
-                unionArray[i] = *getUnionArrayPointer();
-            returnType = node->getType();
-            objectSize = constantNode->getType().getObjectSize();
-        }
-        
-        constUnion* tempConstArray = 0;
-        TIntermConstantUnion *tempNode;
-        int index = 0;
-        bool boolNodeFlag = false;
-        switch(op) {
-        case EOpAdd: 
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] + rightUnionArray[i];
-            break;
-        case EOpSub: 
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] - rightUnionArray[i];
-            break;
-
-        case EOpMul:
-        case EOpVectorTimesScalar:
-        case EOpMatrixTimesScalar: 
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] * rightUnionArray[i];
-            break;
-        case EOpMatrixTimesMatrix:
-            tempConstArray = new constUnion[getMatrixRows() * node->getMatrixCols()];
-            for (int row = 0; row < getMatrixRows(); row++) {
-                for (int column = 0; column < node->getMatrixCols(); column++) {
-                    float sum = 0.0f;                    
-                    for (int i = 0; i < node->getMatrixRows(); i++)
-                        sum += unionArray[i * getMatrixRows() + row].getFConst() * rightUnionArray[column * node->getMatrixRows() + i].getFConst();
-                    tempConstArray[column * getMatrixRows() + row].setFConst(sum);
-                }
-            }
-            returnType = TType(getType().getBasicType(), EvqConst, 0, getMatrixRows(), node->getMatrixCols());
-            break;
-        case EOpOuterProduct:
-            // TODO: functionality >= 120
-            break;
-        case EOpDeterminant:
-            // TODO: functionality >= 150
-            break;
-        case EOpMatrixInverse:
-            // TODO: functionality >= 150
-            break;
-        case EOpTranspose:
-            // TODO: functionality >= 120
-            break;
-        case EOpDiv: 
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++) {
-                switch (getType().getBasicType()) {
-                case EbtFloat: 
-                    if (rightUnionArray[i] == 0.0f) {
-                        infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
-                        tempConstArray[i].setFConst(FLT_MAX);
-                    } else
-                        tempConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
-                break;
-
-                case EbtInt:   
-                    if (rightUnionArray[i] == 0) {
-                        infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
-                        tempConstArray[i].setIConst(0xEFFFFFFF);
-                    } else
-                        tempConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
-                    break;            
-                default: 
-                    infoSink.info.message(EPrefixInternalError, "Constant folding cannot be done for \"/\"", getLine());
-                    return 0;
-                }
-            }
-            break;
-
-        case EOpMatrixTimesVector:
-            tempConstArray = new constUnion[getMatrixRows()];
-            for (int i = 0; i < getMatrixRows(); i++) {
-                float sum = 0.0f;
-                for (int j = 0; j < node->getVectorSize(); j++) {
-                    sum += unionArray[j*getMatrixRows() + i].getFConst() * rightUnionArray[j].getFConst();
-                }
-                tempConstArray[i].setFConst(sum);
-            }
-
-            tempNode = new TIntermConstantUnion(tempConstArray, TType(getBasicType(), EvqConst, getMatrixRows()));
-            tempNode->setLine(getLine());
-
-            return tempNode;                
-
-        case EOpVectorTimesMatrix:
-            tempConstArray = new constUnion[node->getMatrixCols()];
-            for (int i = 0; i < node->getMatrixCols(); i++) {
-                float sum = 0.0f;
-                for (int j = 0; j < getVectorSize(); j++)
-                    sum += unionArray[j].getFConst() * rightUnionArray[i*node->getMatrixRows() + j].getFConst();
-                tempConstArray[i].setFConst(sum);
-            }
-
-            tempNode = new TIntermConstantUnion(tempConstArray, TType(getBasicType(), EvqConst, node->getMatrixCols()));
-            tempNode->setLine(getLine());
-
-            return tempNode;                
-
-        case EOpMod:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] % rightUnionArray[i];
-            break;
-    
-        case EOpRightShift:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] >> rightUnionArray[i];
-            break;
-
-        case EOpLeftShift:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] << rightUnionArray[i];
-            break;
-    
-        case EOpAnd:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] & rightUnionArray[i];
-            break;
-        case EOpInclusiveOr:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] | rightUnionArray[i];
-            break;
-        case EOpExclusiveOr:
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] ^ rightUnionArray[i];
-            break;
-
-        case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] && rightUnionArray[i];
-            break;
-
-        case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++)
-                tempConstArray[i] = unionArray[i] || rightUnionArray[i];
-            break;
-
-        case EOpLogicalXor:  
-            tempConstArray = new constUnion[objectSize];
-            for (int i = 0; i < objectSize; i++) {
-                switch (getType().getBasicType()) {
-                case EbtBool: tempConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
-                default: assert(false && "Default missing");
-                }
-            }
-            break;
-
-        case EOpLessThan:         
-            assert(objectSize == 1);
-            tempConstArray = new constUnion[1];
-            tempConstArray->setBConst(*unionArray < *rightUnionArray);
-            returnType = TType(EbtBool, EvqConst);
-            break;
-        case EOpGreaterThan:      
-            assert(objectSize == 1);
-            tempConstArray = new constUnion[1];
-            tempConstArray->setBConst(*unionArray > *rightUnionArray);
-            returnType = TType(EbtBool, EvqConst);
-            break;
-        case EOpLessThanEqual:
-        {
-            assert(objectSize == 1);
-            constUnion constant;
-            constant.setBConst(*unionArray > *rightUnionArray);
-            tempConstArray = new constUnion[1];
-            tempConstArray->setBConst(!constant.getBConst());
-            returnType = TType(EbtBool, EvqConst);
-            break;
-        }
-        case EOpGreaterThanEqual: 
-        {
-            assert(objectSize == 1);
-            constUnion constant;
-            constant.setBConst(*unionArray < *rightUnionArray);
-            tempConstArray = new constUnion[1];
-            tempConstArray->setBConst(!constant.getBConst());
-            returnType = TType(EbtBool, EvqConst);
-            break;
-        }
-
-        case EOpEqual: 
-            if (getType().getBasicType() == EbtStruct) {
-                if (!CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
-                    boolNodeFlag = true;
-            } else {
-                for (int i = 0; i < objectSize; i++) {    
-                    if (unionArray[i] != rightUnionArray[i]) {
-                        boolNodeFlag = true;
-                        break;  // break out of for loop
-                    }
-                }
-            }
-
-            tempConstArray = new constUnion[1];
-            if (!boolNodeFlag) {
-                tempConstArray->setBConst(true);
-            }
-            else {
-                tempConstArray->setBConst(false);
-            }
-            
-            tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EvqConst));
-            tempNode->setLine(getLine());
-
-            return tempNode;         
-
-        case EOpNotEqual: 
-            if (getType().getBasicType() == EbtStruct) {
-                if (CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
-                    boolNodeFlag = true;
-            } else {
-                for (int i = 0; i < objectSize; i++) {    
-                    if (unionArray[i] == rightUnionArray[i]) {
-                        boolNodeFlag = true;
-                        break;  // break out of for loop
-                    }
-                }
-            }
-
-            tempConstArray = new constUnion[1];
-            if (!boolNodeFlag) {
-                tempConstArray->setBConst(true);
-            }
-            else {
-                tempConstArray->setBConst(false);
-            }
-            
-            tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EvqConst));
-            tempNode->setLine(getLine());
-
-            return tempNode;         
-        
-        default: 
-            infoSink.info.message(EPrefixInternalError, "Invalid operator for constant folding", getLine());
-            return 0;
-        }
-        tempNode = new TIntermConstantUnion(tempConstArray, returnType);
-        tempNode->setLine(getLine());
-
-        return tempNode;                
-    } else { 
-        //
-        // Do unary operations
-        //
-        TIntermConstantUnion *newNode = 0;
-        constUnion* tempConstArray = new constUnion[objectSize];
-        for (int i = 0; i < objectSize; i++) {
-            switch(op) {
-            case EOpNegative:                                       
-                switch (getType().getBasicType()) {
-                case EbtFloat: tempConstArray[i].setFConst(-unionArray[i].getFConst()); break;
-                case EbtInt:   tempConstArray[i].setIConst(-unionArray[i].getIConst()); break;
-                default: 
-                    infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
-                    return 0;
-                }
-                break;
-            case EOpLogicalNot: // this code is written for possible future use, will not get executed currently                                      
-                switch (getType().getBasicType()) {
-                case EbtBool:  tempConstArray[i].setBConst(!unionArray[i].getBConst()); break;
-                default: 
-                    infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
-                    return 0;
-                }
-                break;
-            default: 
-                return 0;
-            }
-        }
-        newNode = new TIntermConstantUnion(tempConstArray, getType());
-        newNode->setLine(getLine());
-        return newNode;     
-    }
-
-    return this;
-}
-
 TIntermTyped* TIntermediate::promoteConstantUnion(TBasicType promoteTo, TIntermConstantUnion* node) 
 {
     if (node->getType().isArray())
index f404b14..7b597e4 100644 (file)
@@ -1161,33 +1161,12 @@ bool TParseContext::executeInitializer(TSourceLoc line, TString& identifier, TPu
     return false;
 }
 
-bool TParseContext::areAllChildConst(TIntermAggregate* aggrNode)
-{
-    if (!aggrNode->isConstructor())
-        return false;
-
-    bool allConstant = true;
-
-    // check if all the child nodes are constants so that they can be inserted into 
-    // the parent node
-    if (aggrNode) {
-        TIntermSequence &childSequenceVector = aggrNode->getSequence() ;
-        for (TIntermSequence::iterator p = childSequenceVector.begin(); 
-                                    p != childSequenceVector.end(); p++) {
-            if (!(*p)->getAsTyped()->getAsConstantUnion())
-                return false;
-        }
-    }
-
-    return allConstant;
-}
-
 // This function is used to test for the correctness of the parameters passed to various constructor functions
 // and also convert them to the right datatype if it is allowed and required. 
 //
 // Returns 0 for an error or the constructed node (aggregate or typed) for no error.
 //
-TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType* type, TOperator op, TFunction* fnCall, TSourceLoc line)
+TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType& type, TOperator op, TFunction* fnCall, TSourceLoc line)
 {
     if (node == 0)
         return 0;
@@ -1196,10 +1175,10 @@ TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType* type
     
     TTypeList::iterator memberTypes;
     if (op == EOpConstructStruct)
-        memberTypes = type->getStruct()->begin();
+        memberTypes = type.getStruct()->begin();
     
-    TType elementType = *type;
-    if (type->isArray())
+    TType elementType = type;
+    if (type.isArray())
         elementType.dereference();
 
     bool singleArg;
@@ -1215,18 +1194,15 @@ TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType* type
     if (singleArg) {
         // If structure constructor or array constructor is being called 
         // for only one parameter inside the structure, we need to call constructStruct function once.
-        if (type->isArray())
-            newNode = constructStruct(node, &elementType, 1, node->getLine(), false);
+        if (type.isArray())
+            newNode = constructStruct(node, elementType, 1, node->getLine());
         else if (op == EOpConstructStruct)
-            newNode = constructStruct(node, (*memberTypes).type, 1, node->getLine(), false);
+            newNode = constructStruct(node, *(*memberTypes).type, 1, node->getLine());
         else
             newNode = constructBuiltIn(type, op, node, node->getLine(), false);
 
-        if (newNode && newNode->getAsAggregate()) {
-            TIntermTyped* constConstructor = foldConstConstructor(newNode->getAsAggregate(), *type);
-            if (constConstructor)
-                return constConstructor;
-        }
+        if (newNode && (type.isArray() || op == EOpConstructStruct))
+            newNode = intermediate.setAggregateOperator(newNode, EOpConstructStruct, type, line);
 
         return newNode;
     }
@@ -1246,10 +1222,10 @@ TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType* type
     
     for (TIntermSequence::iterator p = sequenceVector.begin(); 
                                    p != sequenceVector.end(); p++, paramCount++) {
-        if (type->isArray())
-            newNode = constructStruct(*p, &elementType, paramCount+1, node->getLine(), true);
+        if (type.isArray())
+            newNode = constructStruct(*p, elementType, paramCount+1, node->getLine());
         else if (op == EOpConstructStruct)
-            newNode = constructStruct(*p, (memberTypes[paramCount]).type, paramCount+1, node->getLine(), true);
+            newNode = constructStruct(*p, *(memberTypes[paramCount]).type, paramCount+1, node->getLine());
         else
             newNode = constructBuiltIn(type, op, *p, node->getLine(), true);
         
@@ -1259,36 +1235,11 @@ TIntermTyped* TParseContext::addConstructor(TIntermNode* node, const TType* type
         }
     }
 
-    TIntermTyped* constructor = intermediate.setAggregateOperator(aggrNode, op, line);
-    TIntermTyped* constConstructor = foldConstConstructor(constructor->getAsAggregate(), *type);
-    if (constConstructor)
-        return constConstructor;
+    TIntermTyped* constructor = intermediate.setAggregateOperator(aggrNode, op, type, line);
 
     return constructor;
 }
 
-TIntermTyped* TParseContext::foldConstConstructor(TIntermAggregate* aggrNode, const TType& type)
-{
-    bool canBeFolded = areAllChildConst(aggrNode);
-    aggrNode->setType(type);
-    if (canBeFolded) {
-        bool returnVal = false;
-        constUnion* unionArray = new constUnion[type.getObjectSize()];
-        if (aggrNode->getSequence().size() == 1)  {
-            returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray, aggrNode->getOp(), symbolTable,  type, true);
-        }
-        else {
-            returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray, aggrNode->getOp(), symbolTable,  type);
-        }
-        if (returnVal)
-            return 0;
-
-        return intermediate.addConstantUnion(unionArray, type, aggrNode->getLine());
-    }
-
-    return 0;
-}
-
 // Function for constructor implementation. Calls addUnaryMath with appropriate EOp value
 // for the parameter to the constructor (passed to this function). Essentially, it converts
 // the parameter types correctly. If a constructor expects an int (like ivec2) and is passed a 
@@ -1296,7 +1247,7 @@ TIntermTyped* TParseContext::foldConstConstructor(TIntermAggregate* aggrNode, co
 //
 // Returns 0 for an error or the constructed node.
 //
-TIntermTyped* TParseContext::constructBuiltIn(const TType* type, TOperator op, TIntermNode* node, TSourceLoc line, bool subset)
+TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, TIntermNode* node, TSourceLoc line, bool subset)
 {
     TIntermTyped* newNode;
     TOperator basicOp;
@@ -1368,11 +1319,11 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType* type, TOperator op, T
     //
     
     // Otherwise, skip out early.
-    if (subset || newNode != node && newNode->getType() == *type)
+    if (subset || newNode != node && newNode->getType() == type)
         return newNode;
 
     // setAggregateOperator will insert a new node for the constructor, as needed.
-    return intermediate.setAggregateOperator(newNode, op, line);
+    return intermediate.setAggregateOperator(newNode, op, type, line);
 }
 
 // This function tests for the type of the parameters to the structures constructors. Raises
@@ -1380,21 +1331,18 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType* type, TOperator op, T
 //
 // Returns 0 for an error or the input node itself if the expected and the given parameter types match.
 //
-TIntermTyped* TParseContext::constructStruct(TIntermNode* node, TType* type, int paramCount, TSourceLoc line, bool subset)
+TIntermTyped* TParseContext::constructStruct(TIntermNode* node, const TType& type, int paramCount, TSourceLoc line)
 {
-    TIntermNode* converted = intermediate.addConversion(EOpConstructStruct, *type, node->getAsTyped());
-    if (converted->getAsTyped()->getType() == *type) {
-        if (subset)
-            return converted->getAsTyped();
-        else
-            return intermediate.setAggregateOperator(converted->getAsTyped(), EOpConstructStruct, line);
-    } else {
+    TIntermTyped* converted = intermediate.addConversion(EOpConstructStruct, type, node->getAsTyped());
+    if (! converted || converted->getType() != type) {
         error(line, "", "constructor", "cannot convert parameter %d from '%s' to '%s'", paramCount,
-                node->getAsTyped()->getType().getCompleteTypeString().c_str(), type->getCompleteTypeString().c_str());
+                node->getAsTyped()->getType().getCompleteTypeString().c_str(), type.getCompleteTypeString().c_str());
         recover();
+
+        return 0;
     }
 
-    return 0;
+    return converted;
 }
 
 //
index 4044d1e..97bca38 100644 (file)
@@ -130,11 +130,9 @@ struct TParseContext {
     const TFunction* findFunction(int line, TFunction* pfnCall, bool *builtIn = 0);
     bool executeInitializer(TSourceLoc line, TString& identifier, TPublicType& pType,
                             TIntermTyped* initializer, TIntermNode*& intermNode, TVariable* variable = 0);
-    bool areAllChildConst(TIntermAggregate* aggrNode);
-    TIntermTyped* addConstructor(TIntermNode*, const TType*, TOperator, TFunction*, TSourceLoc);
-    TIntermTyped* foldConstConstructor(TIntermAggregate* aggrNode, const TType& type);
-    TIntermTyped* constructStruct(TIntermNode*, TType*, int, TSourceLoc, bool subset);
-    TIntermTyped* constructBuiltIn(const TType*, TOperator, TIntermNode*, TSourceLoc, bool subset);
+    TIntermTyped* addConstructor(TIntermNode*, const TType&, TOperator, TFunction*, TSourceLoc);
+    TIntermTyped* constructStruct(TIntermNode*, const TType&, int, TSourceLoc);
+    TIntermTyped* constructBuiltIn(const TType&, TOperator, TIntermNode*, TSourceLoc, bool subset);
     TIntermTyped* addConstVectorNode(TVectorFields&, TIntermTyped*, TSourceLoc);
     TIntermTyped* addConstMatrixNode(int , TIntermTyped*, TSourceLoc);
     TIntermTyped* addConstArrayNode(int index, TIntermTyped* node, TSourceLoc line);
index c80545f..caf143e 100644 (file)
@@ -105,7 +105,7 @@ int yy_input(char* buf, int max_size);
 \r
 \r
 %%\r
-<*>"//"[^\n]*"\n"     { /* ?? carriage and/or line-feed? */ };\r
+<*>"//"[^\n]*"\n"     { /* CPP should have taken care of this */ };\r
 \r
 "attribute"    {  pyylval->lex.line = yylineno; return(ATTRIBUTE); }  // TODO ES 30 reserved\r
 "const"        {  pyylval->lex.line = yylineno; return(CONST); }\r
index 711e26a..93dff17 100644 (file)
@@ -336,6 +336,7 @@ postfix_expression
                 $$ = parseContext.intermediate.addIndex(EOpIndexIndirect, $1, $3, $2.line);\r
             }\r
         }\r
+\r
         if ($$ == 0) {\r
             constUnion *unionArray = new constUnion[1];\r
             unionArray->setFConst(0.0f);\r
@@ -344,8 +345,7 @@ postfix_expression
             TType newType = $1->getType();\r
             newType.dereference();\r
             $$->setType(newType);\r
-            //?? why wouldn't the code above get the type right?\r
-            //?? write a dereference test\r
+            // TODO: testing: write a set of dereference tests\r
         }\r
     }\r
     | function_call {\r
@@ -511,14 +511,13 @@ function_call
                 //\r
                 // It's a constructor, of type 'type'.\r
                 //\r
-                $$ = parseContext.addConstructor($1.intermNode, &type, op, fnCall, $1.line);\r
+                $$ = parseContext.addConstructor($1.intermNode, type, op, fnCall, $1.line);\r
             }\r
 \r
             if ($$ == 0) {\r
                 parseContext.recover();\r
-                $$ = parseContext.intermediate.setAggregateOperator(0, op, $1.line);\r
+                $$ = parseContext.intermediate.setAggregateOperator(0, op, type, $1.line);\r
             }\r
-            $$->setType(type);\r
         } else {\r
             //\r
             // Not a constructor.  Find it in the symbol table.\r
@@ -538,7 +537,9 @@ function_call
                     //\r
                     if (fnCandidate->getParamCount() == 1) {\r
                         //\r
-                        // Treat it like a built-in unary operator.\r
+                        // Treat it like a built-in unary operator.  \r
+                        // addUnaryMath() should get the type correct on its own;\r
+                        // including constness (which would differ from the prototype).\r
                         //\r
                         $$ = parseContext.intermediate.addUnaryMath(op, $1.intermNode, 0, parseContext.symbolTable);\r
                         if ($$ == 0)  {\r
@@ -548,13 +549,12 @@ function_call
                             YYERROR;\r
                         }\r
                     } else {\r
-                        $$ = parseContext.intermediate.setAggregateOperator($1.intermAggregate, op, $1.line);\r
+                        $$ = parseContext.intermediate.setAggregateOperator($1.intermAggregate, op, fnCandidate->getReturnType(), $1.line);\r
                     }\r
                 } else {\r
                     // This is a real function call\r
 \r
-                    $$ = parseContext.intermediate.setAggregateOperator($1.intermAggregate, EOpFunctionCall, $1.line);\r
-                    $$->setType(fnCandidate->getReturnType());\r
+                    $$ = parseContext.intermediate.setAggregateOperator($1.intermAggregate, EOpFunctionCall, fnCandidate->getReturnType(), $1.line);\r
 \r
                     // this is how we know whether the given function is a builtIn function or a user defined function\r
                     // if builtIn == false, it's a userDefined -> could be an overloaded builtIn function also\r
@@ -576,7 +576,6 @@ function_call
                         qualifierList.push_back(qual);\r
                     }\r
                 }\r
-                $$->setType(fnCandidate->getReturnType());\r
             } else {\r
                 // error message was put out by PaFindFunction()\r
                 // Put on a dummy node for error recovery\r
@@ -2991,7 +2990,7 @@ function_definition
                 paramNodes = parseContext.intermediate.growAggregate(paramNodes, parseContext.intermediate.addSymbol(0, "", *param.type, $1.line), $1.line);\r
             }\r
         }\r
-        parseContext.intermediate.setAggregateOperator(paramNodes, EOpParameters, $1.line);\r
+        parseContext.intermediate.setAggregateOperator(paramNodes, EOpParameters, TType(EbtVoid), $1.line);\r
         $1.intermAggregate = paramNodes;\r
         parseContext.loopNestingLevel = 0;\r
     }\r
@@ -3003,9 +3002,8 @@ function_definition
         }\r
         parseContext.symbolTable.pop(&parseContext.defaultPrecision[0]);\r
         $$ = parseContext.intermediate.growAggregate($1.intermAggregate, $3, 0);\r
-        parseContext.intermediate.setAggregateOperator($$, EOpFunction, $1.line);\r
+        parseContext.intermediate.setAggregateOperator($$, EOpFunction, $1.function->getReturnType(), $1.line);\r
         $$->getAsAggregate()->setName($1.function->getMangledName().c_str());\r
-        $$->getAsAggregate()->setType($1.function->getReturnType());\r
 \r
         // store the pragma information for debug and optimize and other vendor specific\r
         // information. This information can be queried from the parse tree\r
index 0676bcc..f2e4512 100644 (file)
@@ -63,14 +63,17 @@ public:
     bool canImplicitlyPromote(TBasicType from, TBasicType to);
     TIntermAggregate* growAggregate(TIntermNode* left, TIntermNode* right, TSourceLoc);
     TIntermAggregate* makeAggregate(TIntermNode* node, TSourceLoc);
-    TIntermAggregate* setAggregateOperator(TIntermNode*, TOperator, TSourceLoc);
+    TIntermTyped* setAggregateOperator(TIntermNode*, TOperator, const TType& type, TSourceLoc);
+    bool areAllChildConst(TIntermAggregate* aggrNode);
+    TIntermTyped* fold(TIntermAggregate* aggrNode);
+    TIntermTyped* foldConstructor(TIntermAggregate* aggrNode);
     TIntermNode*  addSelection(TIntermTyped* cond, TIntermNodePair code, TSourceLoc);
     TIntermTyped* addSelection(TIntermTyped* cond, TIntermTyped* trueBlock, TIntermTyped* falseBlock, TSourceLoc);
     TIntermTyped* addComma(TIntermTyped* left, TIntermTyped* right, TSourceLoc);
     TIntermTyped* addMethod(TIntermTyped*, const TType&, const TString*, TSourceLoc);
     TIntermConstantUnion* addConstantUnion(constUnion*, const TType&, TSourceLoc);
     TIntermTyped* promoteConstantUnion(TBasicType, TIntermConstantUnion*) ;
-    bool parseConstTree(TSourceLoc, TIntermNode*, constUnion*, TOperator, TSymbolTable&, TType, bool singleConstantParam = false);        
+    bool parseConstTree(TSourceLoc, TIntermNode*, constUnion*, TOperator, TType, bool singleConstantParam = false);        
     TIntermNode* addLoop(TIntermNode*, TIntermTyped*, TIntermTyped*, bool testFirst, TSourceLoc);
     TIntermBranch* addBranch(TOperator, TSourceLoc);
     TIntermBranch* addBranch(TOperator, TIntermTyped*, TSourceLoc);
index abd9714..c797374 100644 (file)
@@ -40,8 +40,8 @@
 //
 class TConstTraverser : public TIntermTraverser {
 public:
-    TConstTraverser(constUnion* cUnion, bool singleConstParam, TOperator constructType, TInfoSink& sink, TSymbolTable& symTable, TType& t) : unionArray(cUnion), type(t),
-        constructorType(constructType), singleConstantParam(singleConstParam), infoSink(sink), symbolTable(symTable), error(false), isMatrix(false), 
+    TConstTraverser(constUnion* cUnion, bool singleConstParam, TOperator constructType, TInfoSink& sink, TType& t) : unionArray(cUnion), type(t),
+        constructorType(constructType), singleConstantParam(singleConstParam), infoSink(sink), error(false), isMatrix(false), 
         matrixCols(0), matrixRows(0) {  index = 0; tOp = EOpNull;}
     int index ;
     constUnion *unionArray;
@@ -50,7 +50,6 @@ public:
     TOperator constructorType;
     bool singleConstantParam;
     TInfoSink& infoSink;
-    TSymbolTable& symbolTable;
     bool error;
     int size; // size of the constructor ( 4 for vec4)
     bool isMatrix;
@@ -256,12 +255,12 @@ bool ParseBranch(bool /* previsit*/, TIntermBranch* node, TIntermTraverser* it)
 // Individual functions can be initialized to 0 to skip processing of that
 // type of node.  It's children will still be processed.
 //
-bool TIntermediate::parseConstTree(TSourceLoc line, TIntermNode* root, constUnion* unionArray, TOperator constructorType, TSymbolTable& symbolTable, TType t, bool singleConstantParam)
+bool TIntermediate::parseConstTree(TSourceLoc line, TIntermNode* root, constUnion* unionArray, TOperator constructorType, TType t, bool singleConstantParam)
 {
     if (root == 0)
         return false;
 
-    TConstTraverser it(unionArray, singleConstantParam, constructorType, infoSink, symbolTable, t);
+    TConstTraverser it(unionArray, singleConstantParam, constructorType, infoSink, t);
     
     it.visitAggregate = ParseAggregate;
     it.visitBinary = ParseBinary;