b33af84d4f8e149779c1b2fed0e35cbe0b3e4754
[platform/upstream/glslang.git] / glslang / MachineIndependent / Constant.cpp
1 //
2 // Copyright (C) 2002-2005  3Dlabs Inc. Ltd.
3 // Copyright (C) 2012-2013 LunarG, Inc.
4 // Copyright (C) 2017 ARM Limited.
5 //
6 // All rights reserved.
7 //
8 // Redistribution and use in source and binary forms, with or without
9 // modification, are permitted provided that the following conditions
10 // are met:
11 //
12 //    Redistributions of source code must retain the above copyright
13 //    notice, this list of conditions and the following disclaimer.
14 //
15 //    Redistributions in binary form must reproduce the above
16 //    copyright notice, this list of conditions and the following
17 //    disclaimer in the documentation and/or other materials provided
18 //    with the distribution.
19 //
20 //    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
21 //    contributors may be used to endorse or promote products derived
22 //    from this software without specific prior written permission.
23 //
24 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
25 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
26 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
27 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
28 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
29 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
30 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
31 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
33 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
34 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35 // POSSIBILITY OF SUCH DAMAGE.
36 //
37
38 #include "localintermediate.h"
39 #include <cmath>
40 #include <cfloat>
41 #include <cstdlib>
42 #include <climits>
43
44 namespace {
45
46 using namespace glslang;
47
48 typedef union {
49     double d;
50     int i[2];
51 } DoubleIntUnion;
52
53 // Some helper functions
54
55 bool isNan(double x)
56 {
57     DoubleIntUnion u;
58     // tough to find a platform independent library function, do it directly
59     u.d = x;
60     int bitPatternL = u.i[0];
61     int bitPatternH = u.i[1];
62     return (bitPatternH & 0x7ff80000) == 0x7ff80000 &&
63            ((bitPatternH & 0xFFFFF) != 0 || bitPatternL != 0);
64 }
65
66 bool isInf(double x)
67 {
68     DoubleIntUnion u;
69     // tough to find a platform independent library function, do it directly
70     u.d = x;
71     int bitPatternL = u.i[0];
72     int bitPatternH = u.i[1];
73     return (bitPatternH & 0x7ff00000) == 0x7ff00000 &&
74            (bitPatternH & 0xFFFFF) == 0 && bitPatternL == 0;
75 }
76
77 const double pi = 3.1415926535897932384626433832795;
78
79 } // end anonymous namespace
80
81
82 namespace glslang {
83
84 //
85 // The fold functions see if an operation on a constant can be done in place,
86 // without generating run-time code.
87 //
88 // Returns the node to keep using, which may or may not be the node passed in.
89 //
90 // Note: As of version 1.2, all constant operations must be folded.  It is
91 // not opportunistic, but rather a semantic requirement.
92 //
93
94 //
95 // Do folding between a pair of nodes.
96 // 'this' is the left-hand operand and 'rightConstantNode' is the right-hand operand.
97 //
98 // Returns a new node representing the result.
99 //
100 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* rightConstantNode) const
101 {
102     // For most cases, the return type matches the argument type, so set that
103     // up and just code to exceptions below.
104     TType returnType;
105     returnType.shallowCopy(getType());
106
107     //
108     // A pair of nodes is to be folded together
109     //
110
111     const TIntermConstantUnion *rightNode = rightConstantNode->getAsConstantUnion();
112     TConstUnionArray leftUnionArray = getConstArray();
113     TConstUnionArray rightUnionArray = rightNode->getConstArray();
114
115     // Figure out the size of the result
116     int newComps;
117     int constComps;
118     switch(op) {
119     case EOpMatrixTimesMatrix:
120         newComps = rightNode->getMatrixCols() * getMatrixRows();
121         break;
122     case EOpMatrixTimesVector:
123         newComps = getMatrixRows();
124         break;
125     case EOpVectorTimesMatrix:
126         newComps = rightNode->getMatrixCols();
127         break;
128     default:
129         newComps = getType().computeNumComponents();
130         constComps = rightConstantNode->getType().computeNumComponents();
131         if (constComps == 1 && newComps > 1) {
132             // for a case like vec4 f = vec4(2,3,4,5) + 1.2;
133             TConstUnionArray smearedArray(newComps, rightNode->getConstArray()[0]);
134             rightUnionArray = smearedArray;
135         } else if (constComps > 1 && newComps == 1) {
136             // for a case like vec4 f = 1.2 + vec4(2,3,4,5);
137             newComps = constComps;
138             rightUnionArray = rightNode->getConstArray();
139             TConstUnionArray smearedArray(newComps, getConstArray()[0]);
140             leftUnionArray = smearedArray;
141             returnType.shallowCopy(rightNode->getType());
142         }
143         break;
144     }
145
146     TConstUnionArray newConstArray(newComps);
147     TType constBool(EbtBool, EvqConst);
148
149     switch(op) {
150     case EOpAdd:
151         for (int i = 0; i < newComps; i++)
152             newConstArray[i] = leftUnionArray[i] + rightUnionArray[i];
153         break;
154     case EOpSub:
155         for (int i = 0; i < newComps; i++)
156             newConstArray[i] = leftUnionArray[i] - rightUnionArray[i];
157         break;
158
159     case EOpMul:
160     case EOpVectorTimesScalar:
161     case EOpMatrixTimesScalar:
162         for (int i = 0; i < newComps; i++)
163             newConstArray[i] = leftUnionArray[i] * rightUnionArray[i];
164         break;
165     case EOpMatrixTimesMatrix:
166         for (int row = 0; row < getMatrixRows(); row++) {
167             for (int column = 0; column < rightNode->getMatrixCols(); column++) {
168                 double sum = 0.0f;
169                 for (int i = 0; i < rightNode->getMatrixRows(); i++)
170                     sum += leftUnionArray[i * getMatrixRows() + row].getDConst() * rightUnionArray[column * rightNode->getMatrixRows() + i].getDConst();
171                 newConstArray[column * getMatrixRows() + row].setDConst(sum);
172             }
173         }
174         returnType.shallowCopy(TType(getType().getBasicType(), EvqConst, 0, rightNode->getMatrixCols(), getMatrixRows()));
175         break;
176     case EOpDiv:
177         for (int i = 0; i < newComps; i++) {
178             switch (getType().getBasicType()) {
179             case EbtDouble:
180             case EbtFloat:
181             case EbtFloat16:
182                 if (rightUnionArray[i].getDConst() != 0.0)
183                     newConstArray[i].setDConst(leftUnionArray[i].getDConst() / rightUnionArray[i].getDConst());
184                 else if (leftUnionArray[i].getDConst() > 0.0)
185                     newConstArray[i].setDConst((double)INFINITY);
186                 else if (leftUnionArray[i].getDConst() < 0.0)
187                     newConstArray[i].setDConst(-(double)INFINITY);
188                 else
189                     newConstArray[i].setDConst((double)NAN);
190                 break;
191             case EbtInt8:
192                 if (rightUnionArray[i] == (signed char)0)
193                     newConstArray[i].setI8Const((signed char)0x7F);
194                 else if (rightUnionArray[i].getI8Const() == (signed char)-1 && leftUnionArray[i].getI8Const() == (signed char)-0x80)
195                     newConstArray[i].setI8Const((signed char)-0x80);
196                 else
197                     newConstArray[i].setI8Const(leftUnionArray[i].getI8Const() / rightUnionArray[i].getI8Const());
198                 break;
199
200             case EbtUint8:
201                 if (rightUnionArray[i] == (unsigned char)0u)
202                     newConstArray[i].setU8Const((unsigned char)0xFFu);
203                 else
204                     newConstArray[i].setU8Const(leftUnionArray[i].getU8Const() / rightUnionArray[i].getU8Const());
205                 break;
206
207            case EbtInt16:
208                 if (rightUnionArray[i] == (signed short)0)
209                     newConstArray[i].setI16Const((signed short)0x7FFF);
210                 else if (rightUnionArray[i].getI16Const() == (signed short)-1 && leftUnionArray[i].getI16Const() == (signed short)-0x8000)
211                     newConstArray[i].setI16Const((signed short)-0x8000);
212                 else
213                     newConstArray[i].setI16Const(leftUnionArray[i].getI16Const() / rightUnionArray[i].getI16Const());
214                 break;
215
216             case EbtUint16:
217                 if (rightUnionArray[i] == (unsigned short)0u)
218                     newConstArray[i].setU16Const((unsigned short)0xFFFFu);
219                 else
220                     newConstArray[i].setU16Const(leftUnionArray[i].getU16Const() / rightUnionArray[i].getU16Const());
221                 break;
222
223             case EbtInt:
224                 if (rightUnionArray[i] == 0)
225                     newConstArray[i].setIConst(0x7FFFFFFF);
226                 else if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == (int)-0x80000000ll)
227                     newConstArray[i].setIConst((int)-0x80000000ll);
228                 else
229                     newConstArray[i].setIConst(leftUnionArray[i].getIConst() / rightUnionArray[i].getIConst());
230                 break;
231
232             case EbtUint:
233                 if (rightUnionArray[i] == 0u)
234                     newConstArray[i].setUConst(0xFFFFFFFFu);
235                 else
236                     newConstArray[i].setUConst(leftUnionArray[i].getUConst() / rightUnionArray[i].getUConst());
237                 break;
238
239             case EbtInt64:
240                 if (rightUnionArray[i] == 0ll)
241                     newConstArray[i].setI64Const(0x7FFFFFFFFFFFFFFFll);
242                 else if (rightUnionArray[i].getI64Const() == -1 && leftUnionArray[i].getI64Const() == (long long)-0x8000000000000000ll)
243                     newConstArray[i].setI64Const((long long)-0x8000000000000000ll);
244                 else
245                     newConstArray[i].setI64Const(leftUnionArray[i].getI64Const() / rightUnionArray[i].getI64Const());
246                 break;
247
248             case EbtUint64:
249                 if (rightUnionArray[i] == 0ull)
250                     newConstArray[i].setU64Const(0xFFFFFFFFFFFFFFFFull);
251                 else
252                     newConstArray[i].setU64Const(leftUnionArray[i].getU64Const() / rightUnionArray[i].getU64Const());
253                 break;
254             default:
255                 return 0;
256             }
257         }
258         break;
259
260     case EOpMatrixTimesVector:
261         for (int i = 0; i < getMatrixRows(); i++) {
262             double sum = 0.0f;
263             for (int j = 0; j < rightNode->getVectorSize(); j++) {
264                 sum += leftUnionArray[j*getMatrixRows() + i].getDConst() * rightUnionArray[j].getDConst();
265             }
266             newConstArray[i].setDConst(sum);
267         }
268
269         returnType.shallowCopy(TType(getBasicType(), EvqConst, getMatrixRows()));
270         break;
271
272     case EOpVectorTimesMatrix:
273         for (int i = 0; i < rightNode->getMatrixCols(); i++) {
274             double sum = 0.0f;
275             for (int j = 0; j < getVectorSize(); j++)
276                 sum += leftUnionArray[j].getDConst() * rightUnionArray[i*rightNode->getMatrixRows() + j].getDConst();
277             newConstArray[i].setDConst(sum);
278         }
279
280         returnType.shallowCopy(TType(getBasicType(), EvqConst, rightNode->getMatrixCols()));
281         break;
282
283     case EOpMod:
284         for (int i = 0; i < newComps; i++) {
285             if (rightUnionArray[i] == 0)
286                 newConstArray[i] = leftUnionArray[i];
287             else {
288                 switch (getType().getBasicType()) {
289                 case EbtInt:
290                     if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == INT_MIN) {
291                         newConstArray[i].setIConst(0);
292                         break;
293                     } else goto modulo_default;
294
295                 case EbtInt64:
296                     if (rightUnionArray[i].getI64Const() == -1 && leftUnionArray[i].getI64Const() == LLONG_MIN) {
297                         newConstArray[i].setI64Const(0);
298                         break;
299                     } else goto modulo_default;
300 #ifdef AMD_EXTENSIONS
301                 case EbtInt16:
302                     if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == SHRT_MIN) {
303                         newConstArray[i].setIConst(0);
304                         break;
305                     } else goto modulo_default;
306 #endif
307                 default:
308                 modulo_default:
309                     newConstArray[i] = leftUnionArray[i] % rightUnionArray[i];
310                 }
311             }
312         }
313         break;
314
315     case EOpRightShift:
316         for (int i = 0; i < newComps; i++)
317             newConstArray[i] = leftUnionArray[i] >> rightUnionArray[i];
318         break;
319
320     case EOpLeftShift:
321         for (int i = 0; i < newComps; i++)
322             newConstArray[i] = leftUnionArray[i] << rightUnionArray[i];
323         break;
324
325     case EOpAnd:
326         for (int i = 0; i < newComps; i++)
327             newConstArray[i] = leftUnionArray[i] & rightUnionArray[i];
328         break;
329     case EOpInclusiveOr:
330         for (int i = 0; i < newComps; i++)
331             newConstArray[i] = leftUnionArray[i] | rightUnionArray[i];
332         break;
333     case EOpExclusiveOr:
334         for (int i = 0; i < newComps; i++)
335             newConstArray[i] = leftUnionArray[i] ^ rightUnionArray[i];
336         break;
337
338     case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
339         for (int i = 0; i < newComps; i++)
340             newConstArray[i] = leftUnionArray[i] && rightUnionArray[i];
341         break;
342
343     case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
344         for (int i = 0; i < newComps; i++)
345             newConstArray[i] = leftUnionArray[i] || rightUnionArray[i];
346         break;
347
348     case EOpLogicalXor:
349         for (int i = 0; i < newComps; i++) {
350             switch (getType().getBasicType()) {
351             case EbtBool: newConstArray[i].setBConst((leftUnionArray[i] == rightUnionArray[i]) ? false : true); break;
352             default: assert(false && "Default missing");
353             }
354         }
355         break;
356
357     case EOpLessThan:
358         newConstArray[0].setBConst(leftUnionArray[0] < rightUnionArray[0]);
359         returnType.shallowCopy(constBool);
360         break;
361     case EOpGreaterThan:
362         newConstArray[0].setBConst(leftUnionArray[0] > rightUnionArray[0]);
363         returnType.shallowCopy(constBool);
364         break;
365     case EOpLessThanEqual:
366         newConstArray[0].setBConst(! (leftUnionArray[0] > rightUnionArray[0]));
367         returnType.shallowCopy(constBool);
368         break;
369     case EOpGreaterThanEqual:
370         newConstArray[0].setBConst(! (leftUnionArray[0] < rightUnionArray[0]));
371         returnType.shallowCopy(constBool);
372         break;
373     case EOpEqual:
374         newConstArray[0].setBConst(rightNode->getConstArray() == leftUnionArray);
375         returnType.shallowCopy(constBool);
376         break;
377     case EOpNotEqual:
378         newConstArray[0].setBConst(rightNode->getConstArray() != leftUnionArray);
379         returnType.shallowCopy(constBool);
380         break;
381
382     default:
383         return 0;
384     }
385
386     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
387     newNode->setLoc(getLoc());
388
389     return newNode;
390 }
391
392 //
393 // Do single unary node folding
394 //
395 // Returns a new node representing the result.
396 //
397 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TType& returnType) const
398 {
399     // First, size the result, which is mostly the same as the argument's size,
400     // but not always, and classify what is componentwise.
401     // Also, eliminate cases that can't be compile-time constant.
402     int resultSize;
403     bool componentWise = true;
404
405     int objectSize = getType().computeNumComponents();
406     switch (op) {
407     case EOpDeterminant:
408     case EOpAny:
409     case EOpAll:
410     case EOpLength:
411         componentWise = false;
412         resultSize = 1;
413         break;
414
415     case EOpEmitStreamVertex:
416     case EOpEndStreamPrimitive:
417         // These don't actually fold
418         return 0;
419
420     case EOpPackSnorm2x16:
421     case EOpPackUnorm2x16:
422     case EOpPackHalf2x16:
423         componentWise = false;
424         resultSize = 1;
425         break;
426
427     case EOpUnpackSnorm2x16:
428     case EOpUnpackUnorm2x16:
429     case EOpUnpackHalf2x16:
430         componentWise = false;
431         resultSize = 2;
432         break;
433
434     case EOpPack16:
435     case EOpPack32:
436     case EOpPack64:
437     case EOpUnpack32:
438     case EOpUnpack16:
439     case EOpUnpack8:
440     case EOpNormalize:
441         componentWise = false;
442         resultSize = objectSize;
443         break;
444
445     default:
446         resultSize = objectSize;
447         break;
448     }
449
450     // Set up for processing
451     TConstUnionArray newConstArray(resultSize);
452     const TConstUnionArray& unionArray = getConstArray();
453
454     // Process non-component-wise operations
455     switch (op) {
456     case EOpLength:
457     case EOpNormalize:
458     {
459         double sum = 0;
460         for (int i = 0; i < objectSize; i++)
461             sum += unionArray[i].getDConst() * unionArray[i].getDConst();
462         double length = sqrt(sum);
463         if (op == EOpLength)
464             newConstArray[0].setDConst(length);
465         else {
466             for (int i = 0; i < objectSize; i++)
467                 newConstArray[i].setDConst(unionArray[i].getDConst() / length);
468         }
469         break;
470     }
471
472     case EOpAny:
473     {
474         bool result = false;
475         for (int i = 0; i < objectSize; i++) {
476             if (unionArray[i].getBConst())
477                 result = true;
478         }
479         newConstArray[0].setBConst(result);
480         break;
481     }
482     case EOpAll:
483     {
484         bool result = true;
485         for (int i = 0; i < objectSize; i++) {
486             if (! unionArray[i].getBConst())
487                 result = false;
488         }
489         newConstArray[0].setBConst(result);
490         break;
491     }
492
493     // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out
494
495     case EOpPackSnorm2x16:
496     case EOpPackUnorm2x16:
497     case EOpPackHalf2x16:
498     case EOpPack16:
499     case EOpPack32:
500     case EOpPack64:
501     case EOpUnpack32:
502     case EOpUnpack16:
503     case EOpUnpack8:
504
505     case EOpUnpackSnorm2x16:
506     case EOpUnpackUnorm2x16:
507     case EOpUnpackHalf2x16:
508
509     case EOpDeterminant:
510     case EOpMatrixInverse:
511     case EOpTranspose:
512         return 0;
513
514     default:
515         assert(componentWise);
516         break;
517     }
518
519     // Turn off the componentwise loop
520     if (! componentWise)
521         objectSize = 0;
522
523     // Process component-wise operations
524     for (int i = 0; i < objectSize; i++) {
525         switch (op) {
526         case EOpNegative:
527             switch (getType().getBasicType()) {
528             case EbtDouble:
529             case EbtFloat16:
530             case EbtFloat: newConstArray[i].setDConst(-unionArray[i].getDConst()); break;
531             case EbtInt8:  newConstArray[i].setI8Const(-unionArray[i].getI8Const()); break;
532             case EbtUint8: newConstArray[i].setU8Const(static_cast<unsigned int>(-static_cast<signed int>(unionArray[i].getU8Const())));  break;
533             case EbtInt16: newConstArray[i].setI16Const(-unionArray[i].getI16Const()); break;
534             case EbtUint16:newConstArray[i].setU16Const(static_cast<unsigned int>(-static_cast<signed int>(unionArray[i].getU16Const())));  break;
535             case EbtInt:   newConstArray[i].setIConst(-unionArray[i].getIConst()); break;
536             case EbtUint:  newConstArray[i].setUConst(static_cast<unsigned int>(-static_cast<int>(unionArray[i].getUConst())));  break;
537             case EbtInt64: newConstArray[i].setI64Const(-unionArray[i].getI64Const()); break;
538             case EbtUint64: newConstArray[i].setU64Const(static_cast<unsigned long long>(-static_cast<long long>(unionArray[i].getU64Const())));  break;
539             default:
540                 return 0;
541             }
542             break;
543         case EOpLogicalNot:
544         case EOpVectorLogicalNot:
545             switch (getType().getBasicType()) {
546             case EbtBool:  newConstArray[i].setBConst(!unionArray[i].getBConst()); break;
547             default:
548                 return 0;
549             }
550             break;
551         case EOpBitwiseNot:
552             newConstArray[i] = ~unionArray[i];
553             break;
554         case EOpRadians:
555             newConstArray[i].setDConst(unionArray[i].getDConst() * pi / 180.0);
556             break;
557         case EOpDegrees:
558             newConstArray[i].setDConst(unionArray[i].getDConst() * 180.0 / pi);
559             break;
560         case EOpSin:
561             newConstArray[i].setDConst(sin(unionArray[i].getDConst()));
562             break;
563         case EOpCos:
564             newConstArray[i].setDConst(cos(unionArray[i].getDConst()));
565             break;
566         case EOpTan:
567             newConstArray[i].setDConst(tan(unionArray[i].getDConst()));
568             break;
569         case EOpAsin:
570             newConstArray[i].setDConst(asin(unionArray[i].getDConst()));
571             break;
572         case EOpAcos:
573             newConstArray[i].setDConst(acos(unionArray[i].getDConst()));
574             break;
575         case EOpAtan:
576             newConstArray[i].setDConst(atan(unionArray[i].getDConst()));
577             break;
578
579         case EOpDPdx:
580         case EOpDPdy:
581         case EOpFwidth:
582         case EOpDPdxFine:
583         case EOpDPdyFine:
584         case EOpFwidthFine:
585         case EOpDPdxCoarse:
586         case EOpDPdyCoarse:
587         case EOpFwidthCoarse:
588             // The derivatives are all mandated to create a constant 0.
589             newConstArray[i].setDConst(0.0);
590             break;
591
592         case EOpExp:
593             newConstArray[i].setDConst(exp(unionArray[i].getDConst()));
594             break;
595         case EOpLog:
596             newConstArray[i].setDConst(log(unionArray[i].getDConst()));
597             break;
598         case EOpExp2:
599             {
600                 const double inv_log2_e = 0.69314718055994530941723212145818;
601                 newConstArray[i].setDConst(exp(unionArray[i].getDConst() * inv_log2_e));
602                 break;
603             }
604         case EOpLog2:
605             {
606                 const double log2_e = 1.4426950408889634073599246810019;
607                 newConstArray[i].setDConst(log2_e * log(unionArray[i].getDConst()));
608                 break;
609             }
610         case EOpSqrt:
611             newConstArray[i].setDConst(sqrt(unionArray[i].getDConst()));
612             break;
613         case EOpInverseSqrt:
614             newConstArray[i].setDConst(1.0 / sqrt(unionArray[i].getDConst()));
615             break;
616
617         case EOpAbs:
618             if (unionArray[i].getType() == EbtDouble)
619                 newConstArray[i].setDConst(fabs(unionArray[i].getDConst()));
620             else if (unionArray[i].getType() == EbtInt)
621                 newConstArray[i].setIConst(abs(unionArray[i].getIConst()));
622             else
623                 newConstArray[i] = unionArray[i];
624             break;
625         case EOpSign:
626             #define SIGN(X) (X == 0 ? 0 : (X < 0 ? -1 : 1))
627             if (unionArray[i].getType() == EbtDouble)
628                 newConstArray[i].setDConst(SIGN(unionArray[i].getDConst()));
629             else
630                 newConstArray[i].setIConst(SIGN(unionArray[i].getIConst()));
631             break;
632         case EOpFloor:
633             newConstArray[i].setDConst(floor(unionArray[i].getDConst()));
634             break;
635         case EOpTrunc:
636             if (unionArray[i].getDConst() > 0)
637                 newConstArray[i].setDConst(floor(unionArray[i].getDConst()));
638             else
639                 newConstArray[i].setDConst(ceil(unionArray[i].getDConst()));
640             break;
641         case EOpRound:
642             newConstArray[i].setDConst(floor(0.5 + unionArray[i].getDConst()));
643             break;
644         case EOpRoundEven:
645         {
646             double flr = floor(unionArray[i].getDConst());
647             bool even = flr / 2.0 == floor(flr / 2.0);
648             double rounded = even ? ceil(unionArray[i].getDConst() - 0.5) : floor(unionArray[i].getDConst() + 0.5);
649             newConstArray[i].setDConst(rounded);
650             break;
651         }
652         case EOpCeil:
653             newConstArray[i].setDConst(ceil(unionArray[i].getDConst()));
654             break;
655         case EOpFract:
656         {
657             double x = unionArray[i].getDConst();
658             newConstArray[i].setDConst(x - floor(x));
659             break;
660         }
661
662         case EOpIsNan:
663         {
664             newConstArray[i].setBConst(isNan(unionArray[i].getDConst()));
665             break;
666         }
667         case EOpIsInf:
668         {
669             newConstArray[i].setBConst(isInf(unionArray[i].getDConst()));
670             break;
671         }
672
673         // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out
674
675         case EOpSinh:
676         case EOpCosh:
677         case EOpTanh:
678         case EOpAsinh:
679         case EOpAcosh:
680         case EOpAtanh:
681
682         case EOpFloatBitsToInt:
683         case EOpFloatBitsToUint:
684         case EOpIntBitsToFloat:
685         case EOpUintBitsToFloat:
686         case EOpDoubleBitsToInt64:
687         case EOpDoubleBitsToUint64:
688         case EOpInt64BitsToDouble:
689         case EOpUint64BitsToDouble:
690         case EOpFloat16BitsToInt16:
691         case EOpFloat16BitsToUint16:
692         case EOpInt16BitsToFloat16:
693         case EOpUint16BitsToFloat16:
694         default:
695             return 0;
696         }
697     }
698
699     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
700     newNode->getWritableType().getQualifier().storage = EvqConst;
701     newNode->setLoc(getLoc());
702
703     return newNode;
704 }
705
706 //
707 // Do constant folding for an aggregate node that has all its children
708 // as constants and an operator that requires constant folding.
709 //
710 TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
711 {
712     if (aggrNode == nullptr)
713         return aggrNode;
714
715     if (! areAllChildConst(aggrNode))
716         return aggrNode;
717
718     if (aggrNode->isConstructor())
719         return foldConstructor(aggrNode);
720
721     TIntermSequence& children = aggrNode->getSequence();
722
723     // First, see if this is an operation to constant fold, kick out if not,
724     // see what size the result is if so.
725
726     bool componentwise = false;  // will also say componentwise if a scalar argument gets repeated to make per-component results
727     int objectSize;
728     switch (aggrNode->getOp()) {
729     case EOpAtan:
730     case EOpPow:
731     case EOpMin:
732     case EOpMax:
733     case EOpMix:
734     case EOpClamp:
735     case EOpLessThan:
736     case EOpGreaterThan:
737     case EOpLessThanEqual:
738     case EOpGreaterThanEqual:
739     case EOpVectorEqual:
740     case EOpVectorNotEqual:
741         componentwise = true;
742         objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents();
743         break;
744     case EOpCross:
745     case EOpReflect:
746     case EOpRefract:
747     case EOpFaceForward:
748         objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents();
749         break;
750     case EOpDistance:
751     case EOpDot:
752         objectSize = 1;
753         break;
754     case EOpOuterProduct:
755         objectSize = children[0]->getAsTyped()->getType().getVectorSize() *
756                      children[1]->getAsTyped()->getType().getVectorSize();
757         break;
758     case EOpStep:
759         componentwise = true;
760         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
761                               children[1]->getAsTyped()->getType().getVectorSize());
762         break;
763     case EOpSmoothStep:
764         componentwise = true;
765         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
766                               children[2]->getAsTyped()->getType().getVectorSize());
767         break;
768     default:
769         return aggrNode;
770     }
771     TConstUnionArray newConstArray(objectSize);
772
773     TVector<TConstUnionArray> childConstUnions;
774     for (unsigned int arg = 0; arg < children.size(); ++arg)
775         childConstUnions.push_back(children[arg]->getAsConstantUnion()->getConstArray());
776
777     if (componentwise) {
778         for (int comp = 0; comp < objectSize; comp++) {
779
780             // some arguments are scalars instead of matching vectors; simulate a smear
781             int arg0comp = std::min(comp, children[0]->getAsTyped()->getType().getVectorSize() - 1);
782             int arg1comp = 0;
783             if (children.size() > 1)
784                 arg1comp = std::min(comp, children[1]->getAsTyped()->getType().getVectorSize() - 1);
785             int arg2comp = 0;
786             if (children.size() > 2)
787                 arg2comp = std::min(comp, children[2]->getAsTyped()->getType().getVectorSize() - 1);
788
789             switch (aggrNode->getOp()) {
790             case EOpAtan:
791                 newConstArray[comp].setDConst(atan2(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
792                 break;
793             case EOpPow:
794                 newConstArray[comp].setDConst(pow(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
795                 break;
796             case EOpMin:
797                 switch(children[0]->getAsTyped()->getBasicType()) {
798                 case EbtFloat16:
799                 case EbtFloat:
800                 case EbtDouble:
801                     newConstArray[comp].setDConst(std::min(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
802                     break;
803                 case EbtInt8:
804                     newConstArray[comp].setI8Const(std::min(childConstUnions[0][arg0comp].getI8Const(), childConstUnions[1][arg1comp].getI8Const()));
805                     break;
806                 case EbtUint8:
807                     newConstArray[comp].setU8Const(std::min(childConstUnions[0][arg0comp].getU8Const(), childConstUnions[1][arg1comp].getU8Const()));
808                     break;
809                 case EbtInt16:
810                     newConstArray[comp].setI16Const(std::min(childConstUnions[0][arg0comp].getI16Const(), childConstUnions[1][arg1comp].getI16Const()));
811                     break;
812                 case EbtUint16:
813                     newConstArray[comp].setU16Const(std::min(childConstUnions[0][arg0comp].getU16Const(), childConstUnions[1][arg1comp].getU16Const()));
814                     break;
815                 case EbtInt:
816                     newConstArray[comp].setIConst(std::min(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
817                     break;
818                 case EbtUint:
819                     newConstArray[comp].setUConst(std::min(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
820                     break;
821                 case EbtInt64:
822                     newConstArray[comp].setI64Const(std::min(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()));
823                     break;
824                 case EbtUint64:
825                     newConstArray[comp].setU64Const(std::min(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()));
826                     break;
827                 default: assert(false && "Default missing");
828                 }
829                 break;
830             case EOpMax:
831                 switch(children[0]->getAsTyped()->getBasicType()) {
832                 case EbtFloat16:
833                 case EbtFloat:
834                 case EbtDouble:
835                     newConstArray[comp].setDConst(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
836                     break;
837                 case EbtInt8:
838                     newConstArray[comp].setI8Const(std::max(childConstUnions[0][arg0comp].getI8Const(), childConstUnions[1][arg1comp].getI8Const()));
839                     break;
840                 case EbtUint8:
841                     newConstArray[comp].setU8Const(std::max(childConstUnions[0][arg0comp].getU8Const(), childConstUnions[1][arg1comp].getU8Const()));
842                     break;
843                 case EbtInt16:
844                     newConstArray[comp].setI16Const(std::max(childConstUnions[0][arg0comp].getI16Const(), childConstUnions[1][arg1comp].getI16Const()));
845                     break;
846                 case EbtUint16:
847                     newConstArray[comp].setU16Const(std::max(childConstUnions[0][arg0comp].getU16Const(), childConstUnions[1][arg1comp].getU16Const()));
848                     break;
849                 case EbtInt:
850                     newConstArray[comp].setIConst(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
851                     break;
852                 case EbtUint:
853                     newConstArray[comp].setUConst(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
854                     break;
855                 case EbtInt64:
856                     newConstArray[comp].setI64Const(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()));
857                     break;
858                 case EbtUint64:
859                     newConstArray[comp].setU64Const(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()));
860                     break;
861                 default: assert(false && "Default missing");
862                 }
863                 break;
864             case EOpClamp:
865                 switch(children[0]->getAsTyped()->getBasicType()) {
866                 case EbtFloat16:
867                 case EbtFloat:
868                 case EbtDouble:
869                     newConstArray[comp].setDConst(std::min(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()),
870                                                                                                                childConstUnions[2][arg2comp].getDConst()));
871                     break;
872                 case EbtInt8:
873                     newConstArray[comp].setI8Const(std::min(std::max(childConstUnions[0][arg0comp].getI8Const(), childConstUnions[1][arg1comp].getI8Const()),
874                                                                                                                    childConstUnions[2][arg2comp].getI8Const()));
875                     break;
876                 case EbtUint8:
877                      newConstArray[comp].setU8Const(std::min(std::max(childConstUnions[0][arg0comp].getU8Const(), childConstUnions[1][arg1comp].getU8Const()),
878                                                                                                                    childConstUnions[2][arg2comp].getU8Const()));
879                     break;
880                 case EbtInt16:
881                     newConstArray[comp].setI16Const(std::min(std::max(childConstUnions[0][arg0comp].getI16Const(), childConstUnions[1][arg1comp].getI16Const()),
882                                                                                                                    childConstUnions[2][arg2comp].getI16Const()));
883                     break;
884                 case EbtUint16:
885                     newConstArray[comp].setU16Const(std::min(std::max(childConstUnions[0][arg0comp].getU16Const(), childConstUnions[1][arg1comp].getU16Const()),
886                                                                                                                    childConstUnions[2][arg2comp].getU16Const()));
887                     break;
888                 case EbtInt:
889                     newConstArray[comp].setIConst(std::min(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()),
890                                                                                                                    childConstUnions[2][arg2comp].getIConst()));
891                     break;
892                 case EbtUint:
893                     newConstArray[comp].setUConst(std::min(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()),
894                                                                                                                    childConstUnions[2][arg2comp].getUConst()));
895                     break;
896                 case EbtInt64:
897                     newConstArray[comp].setI64Const(std::min(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()),
898                                                                                                                        childConstUnions[2][arg2comp].getI64Const()));
899                     break;
900                 case EbtUint64:
901                     newConstArray[comp].setU64Const(std::min(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()),
902                                                                                                                        childConstUnions[2][arg2comp].getU64Const()));
903                     break;
904                 default: assert(false && "Default missing");
905                 }
906                 break;
907             case EOpLessThan:
908                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp]);
909                 break;
910             case EOpGreaterThan:
911                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp]);
912                 break;
913             case EOpLessThanEqual:
914                 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp]));
915                 break;
916             case EOpGreaterThanEqual:
917                 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp]));
918                 break;
919             case EOpVectorEqual:
920                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] == childConstUnions[1][arg1comp]);
921                 break;
922             case EOpVectorNotEqual:
923                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] != childConstUnions[1][arg1comp]);
924                 break;
925             case EOpMix:
926                 if (children[2]->getAsTyped()->getBasicType() == EbtBool)
927                     newConstArray[comp].setDConst(childConstUnions[2][arg2comp].getBConst() ? childConstUnions[1][arg1comp].getDConst() :
928                                                                                               childConstUnions[0][arg0comp].getDConst());
929                 else
930                     newConstArray[comp].setDConst(childConstUnions[0][arg0comp].getDConst() * (1.0 - childConstUnions[2][arg2comp].getDConst()) +
931                                                   childConstUnions[1][arg1comp].getDConst() *        childConstUnions[2][arg2comp].getDConst());
932                 break;
933             case EOpStep:
934                 newConstArray[comp].setDConst(childConstUnions[1][arg1comp].getDConst() < childConstUnions[0][arg0comp].getDConst() ? 0.0 : 1.0);
935                 break;
936             case EOpSmoothStep:
937             {
938                 double t = (childConstUnions[2][arg2comp].getDConst() - childConstUnions[0][arg0comp].getDConst()) /
939                            (childConstUnions[1][arg1comp].getDConst() - childConstUnions[0][arg0comp].getDConst());
940                 if (t < 0.0)
941                     t = 0.0;
942                 if (t > 1.0)
943                     t = 1.0;
944                 newConstArray[comp].setDConst(t * t * (3.0 - 2.0 * t));
945                 break;
946             }
947             default:
948                 return aggrNode;
949             }
950         }
951     } else {
952         // Non-componentwise...
953
954         int numComps = children[0]->getAsConstantUnion()->getType().computeNumComponents();
955         double dot;
956
957         switch (aggrNode->getOp()) {
958         case EOpDistance:
959         {
960             double sum = 0.0;
961             for (int comp = 0; comp < numComps; ++comp) {
962                 double diff = childConstUnions[1][comp].getDConst() - childConstUnions[0][comp].getDConst();
963                 sum += diff * diff;
964             }
965             newConstArray[0].setDConst(sqrt(sum));
966             break;
967         }
968         case EOpDot:
969             newConstArray[0].setDConst(childConstUnions[0].dot(childConstUnions[1]));
970             break;
971         case EOpCross:
972             newConstArray[0] = childConstUnions[0][1] * childConstUnions[1][2] - childConstUnions[0][2] * childConstUnions[1][1];
973             newConstArray[1] = childConstUnions[0][2] * childConstUnions[1][0] - childConstUnions[0][0] * childConstUnions[1][2];
974             newConstArray[2] = childConstUnions[0][0] * childConstUnions[1][1] - childConstUnions[0][1] * childConstUnions[1][0];
975             break;
976         case EOpFaceForward:
977             // If dot(Nref, I) < 0 return N, otherwise return -N:  Arguments are (N, I, Nref).
978             dot = childConstUnions[1].dot(childConstUnions[2]);
979             for (int comp = 0; comp < numComps; ++comp) {
980                 if (dot < 0.0)
981                     newConstArray[comp] = childConstUnions[0][comp];
982                 else
983                     newConstArray[comp].setDConst(-childConstUnions[0][comp].getDConst());
984             }
985             break;
986         case EOpReflect:
987             // I - 2 * dot(N, I) * N:  Arguments are (I, N).
988             dot = childConstUnions[0].dot(childConstUnions[1]);
989             dot *= 2.0;
990             for (int comp = 0; comp < numComps; ++comp)
991                 newConstArray[comp].setDConst(childConstUnions[0][comp].getDConst() - dot * childConstUnions[1][comp].getDConst());
992             break;
993         case EOpRefract:
994         {
995             // Arguments are (I, N, eta).
996             // k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
997             // if (k < 0.0)
998             //     return dvec(0.0)
999             // else
1000             //     return eta * I - (eta * dot(N, I) + sqrt(k)) * N
1001             dot = childConstUnions[0].dot(childConstUnions[1]);
1002             double eta = childConstUnions[2][0].getDConst();
1003             double k = 1.0 - eta * eta * (1.0 - dot * dot);
1004             if (k < 0.0) {
1005                 for (int comp = 0; comp < numComps; ++comp)
1006                     newConstArray[comp].setDConst(0.0);
1007             } else {
1008                 for (int comp = 0; comp < numComps; ++comp)
1009                     newConstArray[comp].setDConst(eta * childConstUnions[0][comp].getDConst() - (eta * dot + sqrt(k)) * childConstUnions[1][comp].getDConst());
1010             }
1011             break;
1012         }
1013         case EOpOuterProduct:
1014         {
1015             int numRows = numComps;
1016             int numCols = children[1]->getAsConstantUnion()->getType().computeNumComponents();
1017             for (int row = 0; row < numRows; ++row)
1018                 for (int col = 0; col < numCols; ++col)
1019                     newConstArray[col * numRows + row] = childConstUnions[0][row] * childConstUnions[1][col];
1020             break;
1021         }
1022         default:
1023             return aggrNode;
1024         }
1025     }
1026
1027     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, aggrNode->getType());
1028     newNode->getWritableType().getQualifier().storage = EvqConst;
1029     newNode->setLoc(aggrNode->getLoc());
1030
1031     return newNode;
1032 }
1033
1034 bool TIntermediate::areAllChildConst(TIntermAggregate* aggrNode)
1035 {
1036     bool allConstant = true;
1037
1038     // check if all the child nodes are constants so that they can be inserted into
1039     // the parent node
1040     if (aggrNode) {
1041         TIntermSequence& childSequenceVector = aggrNode->getSequence();
1042         for (TIntermSequence::iterator p  = childSequenceVector.begin();
1043                                        p != childSequenceVector.end(); p++) {
1044             if (!(*p)->getAsTyped()->getAsConstantUnion())
1045                 return false;
1046         }
1047     }
1048
1049     return allConstant;
1050 }
1051
1052 TIntermTyped* TIntermediate::foldConstructor(TIntermAggregate* aggrNode)
1053 {
1054     bool error = false;
1055
1056     TConstUnionArray unionArray(aggrNode->getType().computeNumComponents());
1057     if (aggrNode->getSequence().size() == 1)
1058         error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType(), true);
1059     else
1060         error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType());
1061
1062     if (error)
1063         return aggrNode;
1064
1065     return addConstantUnion(unionArray, aggrNode->getType(), aggrNode->getLoc());
1066 }
1067
1068 //
1069 // Constant folding of a bracket (array-style) dereference or struct-like dot
1070 // dereference.  Can handle anything except a multi-character swizzle, though
1071 // all swizzles may go to foldSwizzle().
1072 //
1073 TIntermTyped* TIntermediate::foldDereference(TIntermTyped* node, int index, const TSourceLoc& loc)
1074 {
1075     TType dereferencedType(node->getType(), index);
1076     dereferencedType.getQualifier().storage = EvqConst;
1077     TIntermTyped* result = 0;
1078     int size = dereferencedType.computeNumComponents();
1079
1080     // arrays, vectors, matrices, all use simple multiplicative math
1081     // while structures need to add up heterogeneous members
1082     int start;
1083     if (node->isArray() || ! node->isStruct())
1084         start = size * index;
1085     else {
1086         // it is a structure
1087         assert(node->isStruct());
1088         start = 0;
1089         for (int i = 0; i < index; ++i)
1090             start += (*node->getType().getStruct())[i].type->computeNumComponents();
1091     }
1092
1093     result = addConstantUnion(TConstUnionArray(node->getAsConstantUnion()->getConstArray(), start, size), node->getType(), loc);
1094
1095     if (result == 0)
1096         result = node;
1097     else
1098         result->setType(dereferencedType);
1099
1100     return result;
1101 }
1102
1103 //
1104 // Make a constant vector node or constant scalar node, representing a given
1105 // constant vector and constant swizzle into it.
1106 //
1107 TIntermTyped* TIntermediate::foldSwizzle(TIntermTyped* node, TSwizzleSelectors<TVectorSelector>& selectors, const TSourceLoc& loc)
1108 {
1109     const TConstUnionArray& unionArray = node->getAsConstantUnion()->getConstArray();
1110     TConstUnionArray constArray(selectors.size());
1111
1112     for (int i = 0; i < selectors.size(); i++)
1113         constArray[i] = unionArray[selectors[i]];
1114
1115     TIntermTyped* result = addConstantUnion(constArray, node->getType(), loc);
1116
1117     if (result == 0)
1118         result = node;
1119     else
1120         result->setType(TType(node->getBasicType(), EvqConst, selectors.size()));
1121
1122     return result;
1123 }
1124
1125 } // end namespace glslang