[flang] Better folding.
authorpeter klausler <pklausler@nvidia.com>
Fri, 13 Jul 2018 18:35:30 +0000 (11:35 -0700)
committerpeter klausler <pklausler@nvidia.com>
Tue, 24 Jul 2018 21:33:47 +0000 (14:33 -0700)
Original-commit: flang-compiler/f18@4a3117968aabf5462dd25ff67a09073f3ef1c01a
Reviewed-on: https://github.com/flang-compiler/f18/pull/144
Tree-same-pre-rewrite: false

flang/lib/evaluate/expression.cc
flang/lib/evaluate/expression.h

index 6a83b32..84cb173 100644 (file)
@@ -55,13 +55,13 @@ std::ostream &GenericExpr::Dump(std::ostream &o) const {
   return DumpExpr(o, u);
 }
 
-template<typename A>
-std::ostream &Unary<A>::Dump(std::ostream &o, const char *opr) const {
+template<typename A, typename CONST>
+std::ostream &Unary<A, CONST>::Dump(std::ostream &o, const char *opr) const {
   return operand().Dump(o << opr) << ')';
 }
 
-template<typename A, typename B>
-std::ostream &Binary<A, B>::Dump(
+template<typename A, typename B, typename CONST>
+std::ostream &Binary<A, B, CONST>::Dump(
     std::ostream &o, const char *opr, const char *before) const {
   return right().Dump(left().Dump(o << before) << opr) << ')';
 }
@@ -191,9 +191,22 @@ SubscriptIntegerExpr Expr<Category::Character, KIND>::LEN() const {
       u_);
 }
 
+template<typename A, typename CONST>
+std::optional<CONST> Unary<A, CONST>::Fold(FoldingContext &context) {
+  operand_->Fold(context);
+  return {};
+}
+
+template<typename A, typename B, typename CONST>
+std::optional<CONST> Binary<A, B, CONST>::Fold(FoldingContext &context) {
+  left_->Fold(context);
+  right_->Fold(context);
+  return {};
+}
+
 template<int KIND>
-std::optional<typename Expr<Category::Integer, KIND>::Constant>
-Expr<Category::Integer, KIND>::ConstantValue() const {
+std::optional<typename IntegerExpr<KIND>::Constant>
+IntegerExpr<KIND>::ConstantValue() const {
   if (auto c{std::get_if<Constant>(&u_)}) {
     return {*c};
   }
@@ -201,59 +214,68 @@ Expr<Category::Integer, KIND>::ConstantValue() const {
 }
 
 template<int KIND>
-void Expr<Category::Integer, KIND>::Fold(FoldingContext &context) {
-  std::visit(common::visitors{[&](Parentheses &p) {
-                                p.operand().Fold(context);
-                                if (auto c{p.operand().ConstantValue()}) {
-                                  u_ = std::move(*c);
-                                }
-                              },
-                 [&](Negate &n) {
-                   n.operand().Fold(context);
-                   if (auto c{n.operand().ConstantValue()}) {
-                     auto negated{c->Negate()};
-                     if (negated.overflow && context.messages != nullptr) {
-                       context.messages->Say(
-                           context.at, "integer negation overflowed"_en_US);
-                     }
-                     u_ = std::move(negated.value);
-                   }
-                 },
-                 [&](Add &a) {
-                   a.left().Fold(context);
-                   a.right().Fold(context);
-                   if (auto xc{a.left().ConstantValue()}) {
-                     if (auto yc{a.right().ConstantValue()}) {
-                       auto sum{xc->AddSigned(*yc)};
-                       if (sum.overflow && context.messages != nullptr) {
-                         context.messages->Say(
-                             context.at, "integer addition overflowed"_en_US);
-                       }
-                       u_ = std::move(sum.value);
-                     }
-                   }
-                 },
-                 [&](Multiply &a) {
-                   a.left().Fold(context);
-                   a.right().Fold(context);
-                   if (auto xc{a.left().ConstantValue()}) {
-                     if (auto yc{a.right().ConstantValue()}) {
-                       auto product{xc->MultiplySigned(*yc)};
-                       if (product.SignedMultiplicationOverflowed() &&
-                           context.messages != nullptr) {
-                         context.messages->Say(context.at,
-                             "integer multiplication overflowed"_en_US);
-                       }
-                       u_ = std::move(product.lower);
-                     }
-                   }
-                 },
-                 [&](Bin &b) {
-                   b.left().Fold(context);
-                   b.right().Fold(context);
-                 },
-                 [&](const auto &) {  // TODO: more
-                 }},
+std::optional<typename IntegerExpr<KIND>::Constant>
+IntegerExpr<KIND>::Negate::Fold(FoldingContext &context) {
+  if (auto c{operand().Fold(context)}) {
+    auto negated{c->Negate()};
+    if (negated.overflow && context.messages != nullptr) {
+      context.messages->Say(context.at, "integer negation overflowed"_en_US);
+    }
+    return {std::move(negated.value)};
+  }
+  return {};
+}
+
+template<int KIND>
+std::optional<typename IntegerExpr<KIND>::Constant>
+IntegerExpr<KIND>::Add::Fold(FoldingContext &context) {
+  auto lc{left().Fold(context)};
+  auto rc{right().Fold(context)};
+  if (lc && rc) {
+    auto sum{lc->AddSigned(*rc)};
+    if (sum.overflow && context.messages != nullptr) {
+      context.messages->Say(context.at, "integer addition overflowed"_en_US);
+    }
+    return {std::move(sum.value)};
+  }
+  return {};
+}
+
+template<int KIND>
+std::optional<typename IntegerExpr<KIND>::Constant>
+IntegerExpr<KIND>::Multiply::Fold(FoldingContext &context) {
+  auto lc{left().Fold(context)};
+  auto rc{right().Fold(context)};
+  if (lc && rc) {
+    auto product{lc->MultiplySigned(*rc)};
+    if (product.SignedMultiplicationOverflowed() &&
+        context.messages != nullptr) {
+      context.messages->Say(
+          context.at, "integer multiplication overflowed"_en_US);
+    }
+    return {std::move(product.lower)};
+  }
+  return {};
+}
+
+template<int KIND>
+std::optional<typename IntegerExpr<KIND>::Constant> IntegerExpr<KIND>::Fold(
+    FoldingContext &context) {
+  return std::visit(
+      [&](auto &x) -> std::optional<Constant> {
+        using Ty = typename std::decay<decltype(x)>::type;
+        if constexpr (std::is_same_v<Ty, Constant>) {
+          return {x};
+        }
+        if constexpr (std::is_base_of_v<Un, Ty> || std::is_base_of_v<Bin, Ty>) {
+          auto c{x.Fold(context)};
+          if (c.has_value()) {
+            u_ = *c;
+            return c;
+          }
+        }
+        return {};
+      },
       u_);
 }
 
index 361225e..f3a7e8c 100644 (file)
@@ -40,9 +40,10 @@ struct FoldingContext {
 };
 
 // Helper base classes for packaging subexpressions.
-template<typename A> class Unary {
+template<typename A, typename CONST = typename A::Constant> class Unary {
 public:
   using Operand = A;
+  using Constant = CONST;
   CLASS_BOILERPLATE(Unary)
   Unary(const A &a) : operand_{a} {}
   Unary(A &&a) : operand_{std::move(a)} {}
@@ -50,15 +51,18 @@ public:
   const A &operand() const { return *operand_; }
   A &operand() { return *operand_; }
   std::ostream &Dump(std::ostream &, const char *opr) const;
+  std::optional<CONST> Fold(FoldingContext &);
 
 private:
   CopyableIndirection<A> operand_;
 };
 
-template<typename A, typename B = A> class Binary {
+template<typename A, typename B = A, typename CONST = typename A::Constant>
+class Binary {
 public:
   using Left = A;
   using Right = B;
+  using Constant = CONST;
   CLASS_BOILERPLATE(Binary)
   Binary(const A &a, const B &b) : left_{a}, right_{b} {}
   Binary(A &&a, B &&b) : left_{std::move(a)}, right_{std::move(b)} {}
@@ -70,6 +74,7 @@ public:
   B &right() { return *right_; }
   std::ostream &Dump(
       std::ostream &, const char *opr, const char *before = "(") const;
+  std::optional<CONST> Fold(FoldingContext &);
 
 private:
   CopyableIndirection<A> left_;
@@ -80,37 +85,43 @@ template<int KIND> class Expr<Category::Integer, KIND> {
 public:
   using Result = Type<Category::Integer, KIND>;
   using Constant = typename Result::Value;
-  template<typename A> struct Convert : public Unary<A> {
-    using Unary<A>::Unary;
+  template<typename A> struct Convert : public Unary<A, Constant> {
+    using Unary<A, Constant>::Unary;
   };
-  using Un = Unary<Expr>;
-  using Bin = Binary<Expr>;
+  using Un = Unary<Expr, Constant>;
+  using Bin = Binary<Expr, Expr, Constant>;
   struct Parentheses : public Un {
-    using Un::Un;
+    using Un::Un, Un::operand;
+    std::optional<Constant> Fold(FoldingContext &c) {
+      return operand().Fold(c);
+    }
   };
   struct Negate : public Un {
-    using Un::Un;
+    using Un::Un, Un::operand;
+    std::optional<Constant> Fold(FoldingContext &);
   };
   struct Add : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::left, Bin::right;
+    std::optional<Constant> Fold(FoldingContext &);
   };
   struct Subtract : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::Fold;
   };
   struct Multiply : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::left, Bin::right;
+    std::optional<Constant> Fold(FoldingContext &);
   };
   struct Divide : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::Fold;
   };
   struct Power : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::Fold;
   };
   struct Max : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::Fold;
   };
   struct Min : public Bin {
-    using Bin::Bin;
+    using Bin::Bin, Bin::Fold;
   };
   // TODO: R916 type-param-inquiry
 
@@ -134,7 +145,7 @@ public:
   template<typename A> Expr(CopyableIndirection<A> &&x) : u_{std::move(x)} {}
 
   std::optional<Constant> ConstantValue() const;
-  void Fold(FoldingContext &);
+  std::optional<Constant> Fold(FoldingContext &c);
 
 private:
   std::variant<Constant, CopyableIndirection<DataRef>,
@@ -151,11 +162,11 @@ public:
   // N.B. Real->Complex and Complex->Real conversions are done with CMPLX
   // and part access operations (resp.).  Conversions between kinds of
   // Complex are done via decomposition to Real and reconstruction.
-  template<typename A> struct Convert : public Unary<A> {
-    using Unary<A>::Unary;
+  template<typename A> struct Convert : public Unary<A, Constant> {
+    using Unary<A, Constant>::Unary;
   };
-  using Un = Unary<Expr>;
-  using Bin = Binary<Expr>;
+  using Un = Unary<Expr, Constant>;
+  using Bin = Binary<Expr, Expr, Constant>;
   struct Parentheses : public Un {
     using Un::Un;
   };
@@ -177,8 +188,8 @@ public:
   struct Power : public Bin {
     using Bin::Bin;
   };
-  struct IntPower : public Binary<Expr, GenericIntegerExpr> {
-    using Binary<Expr, GenericIntegerExpr>::Binary;
+  struct IntPower : public Binary<Expr, GenericIntegerExpr, Constant> {
+    using Binary<Expr, GenericIntegerExpr, Constant>::Binary;
   };
   struct Max : public Bin {
     using Bin::Bin;
@@ -186,7 +197,7 @@ public:
   struct Min : public Bin {
     using Bin::Bin;
   };
-  using CplxUn = Unary<ComplexExpr<KIND>>;
+  using CplxUn = Unary<ComplexExpr<KIND>, Constant>;
   struct RealPart : public CplxUn {
     using CplxUn::CplxUn;
   };
@@ -220,8 +231,8 @@ template<int KIND> class Expr<Category::Complex, KIND> {
 public:
   using Result = Type<Category::Complex, KIND>;
   using Constant = typename Result::Value;
-  using Un = Unary<Expr>;
-  using Bin = Binary<Expr>;
+  using Un = Unary<Expr, Constant>;
+  using Bin = Binary<Expr, Expr, Constant>;
   struct Parentheses : public Un {
     using Un::Un;
   };
@@ -243,11 +254,11 @@ public:
   struct Power : public Bin {
     using Bin::Bin;
   };
-  struct IntPower : public Binary<Expr, GenericIntegerExpr> {
-    using Binary<Expr, GenericIntegerExpr>::Binary;
+  struct IntPower : public Binary<Expr, GenericIntegerExpr, Constant> {
+    using Binary<Expr, GenericIntegerExpr, Constant>::Binary;
   };
-  struct CMPLX : public Binary<RealExpr<KIND>> {
-    using Binary<RealExpr<KIND>>::Binary;
+  struct CMPLX : public Binary<RealExpr<KIND>, RealExpr<KIND>, Constant> {
+    using Binary<RealExpr<KIND>, RealExpr<KIND>, Constant>::Binary;
   };
 
   CLASS_BOILERPLATE(Expr)
@@ -268,7 +279,7 @@ template<int KIND> class Expr<Category::Character, KIND> {
 public:
   using Result = Type<Category::Character, KIND>;
   using Constant = typename Result::Value;
-  using Bin = Binary<Expr>;
+  using Bin = Binary<Expr, Expr, Constant>;
   struct Concat : public Bin {
     using Bin::Bin;
   };
@@ -301,12 +312,12 @@ private:
 // categories and kinds of comparable operands.
 ENUM_CLASS(RelationalOperator, LT, LE, EQ, NE, GE, GT)
 
-template<typename EXPR> struct Comparison : Binary<EXPR> {
+template<typename EXPR> struct Comparison : Binary<EXPR, EXPR, bool> {
   CLASS_BOILERPLATE(Comparison)
   Comparison(RelationalOperator r, const EXPR &a, const EXPR &b)
-    : Binary<EXPR>{a, b}, opr{r} {}
+    : Binary<EXPR, EXPR, bool>{a, b}, opr{r} {}
   Comparison(RelationalOperator r, EXPR &&a, EXPR &&b)
-    : Binary<EXPR>{std::move(a), std::move(b)}, opr{r} {}
+    : Binary<EXPR, EXPR, bool>{std::move(a), std::move(b)}, opr{r} {}
   RelationalOperator opr;
 };
 
@@ -343,10 +354,10 @@ template<Category CAT> struct CategoryComparison {
 template<> class Expr<Category::Logical, 1> {
 public:
   using Constant = bool;
-  struct Not : Unary<Expr> {
-    using Unary<Expr>::Unary;
+  struct Not : Unary<Expr, bool> {
+    using Unary<Expr, Constant>::Unary;
   };
-  using Bin = Binary<Expr, Expr>;
+  using Bin = Binary<Expr, Expr, bool>;
   struct And : public Bin {
     using Bin::Bin;
   };