added constant folding & branch elimination to skslc
authorethannicholas <ethannicholas@google.com>
Wed, 9 Nov 2016 21:26:45 +0000 (13:26 -0800)
committerCommit bot <commit-bot@chromium.org>
Wed, 9 Nov 2016 21:26:45 +0000 (13:26 -0800)
BUG=skia:
GOLD_TRYBOT_URL= https://gold.skia.org/search?issue=2489673002

Committed: https://skia.googlesource.com/skia/+/6136310ee8f43247548bcefcaeca6d43023c10aa
Review-Url: https://codereview.chromium.org/2489673002

src/sksl/SkSLIRGenerator.cpp
src/sksl/SkSLIRGenerator.h
tests/SkSLErrorTest.cpp
tests/SkSLGLSLTest.cpp

index ec64fa9..1a4c775 100644 (file)
@@ -246,6 +246,19 @@ std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) {
             return nullptr;
         }
     }
+    if (test->fKind == Expression::kBoolLiteral_Kind) {
+        // static boolean value, fold down to a single branch
+        if (((BoolLiteral&) *test).fValue) {
+            return ifTrue;
+        } else if (s.fIfFalse) {
+            return ifFalse;
+        } else {
+            // False & no else clause. Not an error, so don't return null!
+            std::vector<std::unique_ptr<Statement>> empty;
+            return std::unique_ptr<Statement>(new Block(s.fPosition, std::move(empty), 
+                                                        fSymbolTable));
+        }
+    }
     return std::unique_ptr<Statement>(new IfStatement(s.fPosition, std::move(test), 
                                                       std::move(ifTrue), std::move(ifFalse)));
 }
@@ -794,6 +807,78 @@ static bool determine_binary_type(const Context& context,
     return false;
 }
 
+/**
+ * If both operands are compile-time constants and can be folded, returns an expression representing
+ * the folded value. Otherwise, returns null. Note that unlike most other functions here, null does
+ * not represent a compilation error.
+ */
+std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
+                                                      Token::Kind op,
+                                                      const Expression& right) {
+    // Note that we expressly do not worry about precision and overflow here -- we use the maximum
+    // precision to calculate the results and hope the result makes sense. The plan is to move the
+    // Skia caps into SkSL, so we have access to all of them including the precisions of the various
+    // types, which will let us be more intelligent about this.
+    if (left.fKind == Expression::kBoolLiteral_Kind && 
+        right.fKind == Expression::kBoolLiteral_Kind) {
+        bool leftVal  = ((BoolLiteral&) left).fValue;
+        bool rightVal = ((BoolLiteral&) right).fValue;
+        bool result;
+        switch (op) {
+            case Token::LOGICALAND: result = leftVal && rightVal; break;
+            case Token::LOGICALOR:  result = leftVal || rightVal; break;
+            case Token::LOGICALXOR: result = leftVal ^  rightVal; break;
+            default: return nullptr;
+        }
+        return std::unique_ptr<Expression>(new BoolLiteral(fContext, left.fPosition, result));
+    }
+    #define RESULT(t, op) std::unique_ptr<Expression>(new t ## Literal(fContext, left.fPosition, \
+                                                                       leftVal op rightVal))
+    if (left.fKind == Expression::kIntLiteral_Kind && right.fKind == Expression::kIntLiteral_Kind) {
+        int64_t leftVal  = ((IntLiteral&) left).fValue;
+        int64_t rightVal = ((IntLiteral&) right).fValue;
+        switch (op) {
+            case Token::PLUS:       return RESULT(Int,  +);
+            case Token::MINUS:      return RESULT(Int,  -);
+            case Token::STAR:       return RESULT(Int,  *);
+            case Token::SLASH:      return RESULT(Int,  /);
+            case Token::PERCENT:    return RESULT(Int,  %);
+            case Token::BITWISEAND: return RESULT(Int,  &);
+            case Token::BITWISEOR:  return RESULT(Int,  |);
+            case Token::BITWISEXOR: return RESULT(Int,  ^);
+            case Token::SHL:        return RESULT(Int,  <<);
+            case Token::SHR:        return RESULT(Int,  >>);
+            case Token::EQEQ:       return RESULT(Bool, ==);
+            case Token::NEQ:        return RESULT(Bool, !=);
+            case Token::GT:         return RESULT(Bool, >);
+            case Token::GTEQ:       return RESULT(Bool, >=);
+            case Token::LT:         return RESULT(Bool, <);
+            case Token::LTEQ:       return RESULT(Bool, <=);
+            default:                return nullptr;
+        }
+    }
+    if (left.fKind == Expression::kFloatLiteral_Kind && 
+        right.fKind == Expression::kFloatLiteral_Kind) {
+        double leftVal  = ((FloatLiteral&) left).fValue;
+        double rightVal = ((FloatLiteral&) right).fValue;
+        switch (op) {
+            case Token::PLUS:       return RESULT(Float, +);
+            case Token::MINUS:      return RESULT(Float, -);
+            case Token::STAR:       return RESULT(Float, *);
+            case Token::SLASH:      return RESULT(Float, /);
+            case Token::EQEQ:       return RESULT(Bool,  ==);
+            case Token::NEQ:        return RESULT(Bool,  !=);
+            case Token::GT:         return RESULT(Bool,  >);
+            case Token::GTEQ:       return RESULT(Bool,  >=);
+            case Token::LT:         return RESULT(Bool,  <);
+            case Token::LTEQ:       return RESULT(Bool,  <=);
+            default:                return nullptr;
+        }
+    }
+    #undef RESULT
+    return nullptr;
+}
+
 std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
                                                             const ASTBinaryExpression& expression) {
     std::unique_ptr<Expression> left = this->convertExpression(*expression.fLeft);
@@ -823,11 +908,16 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
     if (!left || !right) {
         return nullptr;
     }
-    return std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition, 
-                                                            std::move(left), 
-                                                            expression.fOperator, 
-                                                            std::move(right), 
-                                                            *resultType));
+    std::unique_ptr<Expression> result = this->constantFold(*left.get(), expression.fOperator, 
+                                                            *right.get());
+    if (!result) {
+        result = std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition,
+                                                                  std::move(left),
+                                                                  expression.fOperator,
+                                                                  std::move(right),
+                                                                  *resultType));
+    }
+    return result;
 }
 
 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(  
@@ -858,6 +948,14 @@ std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
     ASSERT(trueType == falseType);
     ifTrue = this->coerce(std::move(ifTrue), *trueType);
     ifFalse = this->coerce(std::move(ifFalse), *falseType);
+    if (test->fKind == Expression::kBoolLiteral_Kind) {
+        // static boolean test, just return one of the branches
+        if (((BoolLiteral&) *test).fValue) {
+            return ifTrue;
+        } else {
+            return ifFalse;
+        }
+    }
     return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition, 
                                                              std::move(test),
                                                              std::move(ifTrue), 
@@ -1126,6 +1224,10 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
                               "' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
+            if (base->fKind == Expression::kBoolLiteral_Kind) {
+                return std::unique_ptr<Expression>(new BoolLiteral(fContext, base->fPosition,
+                                                                   !((BoolLiteral&) *base).fValue));
+            }
             break;
         case Token::BITWISENOT:
             if (base->fType != *fContext.fInt_Type) {
index a1b86f4..036f242 100644 (file)
@@ -89,6 +89,11 @@ private:
     std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d);
     std::unique_ptr<Statement> convertDo(const ASTDoStatement& d);
     std::unique_ptr<Expression> convertBinaryExpression(const ASTBinaryExpression& expression);
+    // Returns null if it cannot fold the expression. Note that unlike most other functions here, a 
+    // null return does not represent a compilation error.
+    std::unique_ptr<Expression> constantFold(const Expression& left,
+                                             Token::Kind op,
+                                             const Expression& right);
     std::unique_ptr<Extension> convertExtension(const ASTExtension& e);
     std::unique_ptr<Statement> convertExpressionStatement(const ASTExpressionStatement& s);
     std::unique_ptr<Statement> convertFor(const ASTForStatement& f);
index 9e89b71..fc20fa2 100644 (file)
@@ -352,3 +352,19 @@ DEF_TEST(SkSLContinueOutsideLoop, r) {
                  "void foo() { for(;;); continue; }",
                  "error: 1: continue statement must be inside a loop\n1 error\n");
 }
+
+DEF_TEST(SkSLStaticIfError, r) {
+    // ensure eliminated branch of static if / ternary is still checked for errors
+    test_failure(r,
+                 "void foo() { if (true); else x = 5; }",
+                 "error: 1: unknown identifier 'x'\n1 error\n");
+    test_failure(r,
+                 "void foo() { if (false) x = 5; }",
+                 "error: 1: unknown identifier 'x'\n1 error\n");
+    test_failure(r,
+                 "void foo() { true ? 5 : x; }",
+                 "error: 1: unknown identifier 'x'\n1 error\n");
+    test_failure(r,
+                 "void foo() { false ? x : 5; }",
+                 "error: 1: unknown identifier 'x'\n1 error\n");
+}
index 610ff2b..38fce87 100644 (file)
@@ -43,7 +43,7 @@ DEF_TEST(SkSLHelloWorld, r) {
 DEF_TEST(SkSLControl, r) {
     test(r,
          "void main() {"
-         "if (1 + 2 + 3 > 5) { sk_FragColor = vec4(0.75); } else { discard; }"
+         "if (sqrt(2) > 5) { sk_FragColor = vec4(0.75); } else { discard; }"
          "int i = 0;"
          "while (i < 10) sk_FragColor *= 0.5;"
          "do { sk_FragColor += 0.01; } while (sk_FragColor.x < 0.7);"
@@ -55,7 +55,7 @@ DEF_TEST(SkSLControl, r) {
          default_caps(),
          "#version 400\n"
          "void main() {\n"
-         "    if ((1 + 2) + 3 > 5) {\n"
+         "    if (sqrt(2.0) > 5.0) {\n"
          "        gl_FragColor = vec4(0.75);\n"
          "    } else {\n"
          "        discard;\n"
@@ -104,7 +104,7 @@ DEF_TEST(SkSLOperators, r) {
          "x = x + y * z * x * (y - z);"
          "y = x / y / z;"
          "z = (z / 2 % 3 << 4) >> 2 << 1;"
-         "bool b = (x > 4) == x < 2 || 2 >= 5 && y <= z && 12 != 11;"
+         "bool b = (x > 4) == x < 2 || 2 >= sqrt(2) && y <= z;"
          "x += 12;"
          "x -= 12;"
          "x *= y /= z = 10;"
@@ -126,7 +126,7 @@ DEF_TEST(SkSLOperators, r) {
          "    x = x + ((y * float(z)) * x) * (y - float(z));\n"
          "    y = (x / y) / float(z);\n"
          "    z = (((z / 2) % 3 << 4) >> 2) << 1;\n"
-         "    bool b = x > 4.0 == x < 2.0 || (2 >= 5 && y <= float(z)) && 12 != 11;\n"
+         "    bool b = x > 4.0 == x < 2.0 || 2.0 >= sqrt(2.0) && y <= float(z);\n"
          "    x += 12.0;\n"
          "    x -= 12.0;\n"
          "    x *= (y /= float(z = 10));\n"
@@ -430,3 +430,100 @@ DEF_TEST(SkSLDerivatives, r) {
          "    float x = dFdx(1.0);\n"
          "}\n");
 }
+
+DEF_TEST(SkSLConstantFolding, r) {
+    test(r,
+         "void main() {"
+         "float f_add = 32 + 2;"
+         "float f_sub = 32 - 2;"
+         "float f_mul = 32 * 2;"
+         "float f_div = 32 / 2;"
+         "float mixed = (12 > 2.0) ? (10 * 2 / 5 + 18 - 3) : 0;"
+         "int i_add = 32 + 2;"
+         "int i_sub = 32 - 2;"
+         "int i_mul = 32 * 2;"
+         "int i_div = 32 / 2;"
+         "int i_or = 12 | 6;"
+         "int i_and = 254 & 7;"
+         "int i_xor = 2 ^ 7;"
+         "int i_shl = 1 << 4;"
+         "int i_shr = 128 >> 2;"
+         "bool gt_it = 6 > 5;"
+         "bool gt_if = 6 > 6;"
+         "bool gt_ft = 6.0 > 5.0;"
+         "bool gt_ff = 6.0 > 6.0;"
+         "bool gte_it = 6 >= 6;"
+         "bool gte_if = 6 >= 7;"
+         "bool gte_ft = 6.0 >= 6.0;"
+         "bool gte_ff = 6.0 >= 7.0;"
+         "bool lte_it = 6 <= 6;"
+         "bool lte_if = 6 <= 5;"
+         "bool lte_ft = 6.0 <= 6.0;"
+         "bool lte_ff = 6.0 <= 5.0;"
+         "bool or_t = 1 == 1 || 2 == 8;"
+         "bool or_f = 1 > 1 || 2 == 8;"
+         "bool and_t = 1 == 1 && 2 <= 8;"
+         "bool and_f = 1 == 2 && 2 == 8;"
+         "bool xor_t = 1 == 1 ^^ 1 != 1;"
+         "bool xor_f = 1 == 1 ^^ 1 == 1;"
+         "int ternary = 10 > 5 ? 10 : 5;"
+         "}",
+         default_caps(),
+         "#version 400\n"
+         "void main() {\n"
+         "    float f_add = 34.0;\n"
+         "    float f_sub = 30.0;\n"
+         "    float f_mul = 64.0;\n"
+         "    float f_div = 16.0;\n"
+         "    float mixed = 19.0;\n"
+         "    int i_add = 34;\n"
+         "    int i_sub = 30;\n"
+         "    int i_mul = 64;\n"
+         "    int i_div = 16;\n"
+         "    int i_or = 14;\n"
+         "    int i_and = 6;\n"
+         "    int i_xor = 5;\n"
+         "    int i_shl = 16;\n"
+         "    int i_shr = 32;\n"
+         "    bool gt_it = true;\n"
+         "    bool gt_if = false;\n"
+         "    bool gt_ft = true;\n"
+         "    bool gt_ff = false;\n"
+         "    bool gte_it = true;\n"
+         "    bool gte_if = false;\n"
+         "    bool gte_ft = true;\n"
+         "    bool gte_ff = false;\n"
+         "    bool lte_it = true;\n"
+         "    bool lte_if = false;\n"
+         "    bool lte_ft = true;\n"
+         "    bool lte_ff = false;\n"
+         "    bool or_t = true;\n"
+         "    bool or_f = false;\n"
+         "    bool and_t = true;\n"
+         "    bool and_f = false;\n"
+         "    bool xor_t = true;\n"
+         "    bool xor_f = false;\n"
+         "    int ternary = 10;\n"
+         "}\n");
+}
+
+DEF_TEST(SkSLStaticIf, r) {
+    test(r,
+         "void main() {"
+         "int x;"
+         "if (true) x = 1;"
+         "if (2 > 1) x = 2; else x = 3;"
+         "if (1 > 2) x = 4; else x = 5;"
+         "if (false) x = 6;"
+         "}",
+         default_caps(),
+         "#version 400\n"
+         "void main() {\n"
+         "    int x;\n"
+         "    x = 1;\n"
+         "    x = 2;\n"
+         "    x = 5;\n"
+         "    {\n"
+         "    }\n"
+         "}\n");
+}