[flang] add rewriting as well as const visitation
authorpeter klausler <pklausler@nvidia.com>
Thu, 7 Mar 2019 00:12:59 +0000 (16:12 -0800)
committerpeter klausler <pklausler@nvidia.com>
Thu, 7 Mar 2019 00:15:54 +0000 (16:15 -0800)
Original-commit: flang-compiler/f18@1224eaee85c65d70724385d222ba72541dd80213
Reviewed-on: https://github.com/flang-compiler/f18/pull/316
Tree-same-pre-rewrite: false

flang/lib/evaluate/fold.cc
flang/lib/evaluate/traversal.h
flang/lib/semantics/check-do-concurrent.cc

index 7f40d96..9a41c9a 100644 (file)
@@ -800,7 +800,7 @@ FOR_EACH_TYPE_AND_KIND(template class ExpressionBase)
 // the expression may reference derived type kind parameters whose values
 // are not yet known.
 
-class IsConstantExprVisitor : public virtual TraversalBase<bool> {
+class IsConstantExprVisitor : public virtual VisitorBase<bool> {
 public:
   explicit IsConstantExprVisitor(int) { result() = true; }
 
@@ -831,7 +831,7 @@ private:
 };
 
 bool IsConstantExpr(const Expr<SomeType> &expr) {
-  return Traversal<bool, IsConstantExprVisitor>{0}.Traverse(expr);
+  return Visitor<bool, IsConstantExprVisitor>{0}.Traverse(expr);
 }
 
 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
index 9ce7c01..ca79ba3 100644 (file)
 #ifndef FORTRAN_EVALUATE_TRAVERSAL_H_
 #define FORTRAN_EVALUATE_TRAVERSAL_H_
 
-#include "expression.h"
+#include "descender.h"
 #include <type_traits>
 
 // Implements an expression traversal utility framework.
 // See fold.cc to see how this framework is used to implement detection
 // of constant expressions.
 //
-// To use, define one or more client visitation classes of the form:
-//   class MyVisitor : public virtual TraversalBase<RESULT> {
+// To use for non-mutating visitation, define one or more client visitation
+// classes of the form:
+//   class MyVisitor : public virtual VisitorBase<RESULT> {
 //     explicit MyVisitor(ARGTYPE);  // single-argument constructor
 //     void Handle(const T1 &);  // callback for type T1 objects
 //     void Pre(const T2 &);  // callback before visiting T2
 //     ...
 //   };
 // RESULT should have some default-constructible type.
-// Then instantiate and construct a Traversal and its embedded MyVisitor via:
-//   Traversal<RESULT, MyVisitor, ...> t{value};  // value is ARGTYPE &&
+// Then instantiate and construct a Visitor and its embedded MyVisitor via:
+//   Visitor<RESULT, MyVisitor, ...> v{value};  // value is ARGTYPE &&
 // and call:
-//   RESULT result{t.Traverse(topLevelExpr)};
+//   RESULT result{v.Traverse(topLevelExpr)};
 // Within the callback routines (Handle, Pre, Post), one may call
 //   void Return(RESULT &&);  // to define the result and end traversal
 //   void Return();  // to end traversal with current result
 // For any given expression object type T for which a callback is defined
 // in any visitor class, the callback must be distinct from all others.
 // Further, if there is a Handle(const T &) callback, there cannot be a
-// Pre() or a Post().
+// Pre(const T &) or a Post(const T &).
+//
+// For rewriting traversals, the paradigm is similar; however, the
+// argument types are rvalues and the non-void result types match
+// the arguments:
+//   class MyRewriter : public virtual RewriterBase<RESULT> {
+//     explicit MyRewriter(ARGTYPE);  // single-argument constructor
+//     T1 Handle(T1 &&);  // rewriting callback for type T1 objects
+//     void Pre(T2 &);  // in-place mutating callback before visiting T2
+//     T2 Post(T2 &&);  // rewriting callback after visiting T2
+//     ...
+//   };
+//   Rewriter<MyRewriter, ...> rw{value};
+//   topLevelExpr = rw.Traverse(std::move(topLevelExpr));
 
 namespace Fortran::evaluate {
 
-template<typename RESULT> class TraversalBase {
+template<typename RESULT> class VisitorBase {
 public:
   using Result = RESULT;
 
@@ -70,18 +84,18 @@ protected:
 };
 
 template<typename RESULT, typename... A>
-class Traversal : public virtual TraversalBase<RESULT>, public A... {
+class Visitor : public virtual VisitorBase<RESULT>, public A... {
 public:
   using Result = RESULT;
-  using Base = TraversalBase<Result>;
+  using Base = VisitorBase<Result>;
   using Base::Handle, Base::Pre, Base::Post;
   using A::Handle..., A::Pre..., A::Post...;
 
 private:
-  using TraversalBase<Result>::done_, TraversalBase<Result>::result_;
+  using VisitorBase<Result>::done_, VisitorBase<Result>::result_;
 
 public:
-  template<typename... B> Traversal(B... x) : A{x}... {}
+  template<typename... B> Visitor(B... x) : A{x}... {}
   template<typename B> Result Traverse(const B &x) {
     Visit(x);
     return std::move(result_);
@@ -95,7 +109,7 @@ private:
         // No visitation class defines Handle(B), so try Pre()/Post().
         Pre(x);
         if (!done_) {
-          Descend(x);
+          descender_.Descend(x);
           if (!done_) {
             Post(x);
           }
@@ -110,108 +124,56 @@ private:
     }
   }
 
-  template<typename X> void Descend(const X &) {}  // default case
+  friend class Descender<Visitor>;
+  Descender<Visitor> descender_{*this};
+};
 
-  template<typename X> void Descend(const X *p) {
-    if (p != nullptr) {
-      Visit(*p);
-    }
-  }
-  template<typename X> void Descend(const std::optional<X> &o) {
-    if (o.has_value()) {
-      Visit(*o);
-    }
-  }
-  template<typename X> void Descend(const CopyableIndirection<X> &p) {
-    Visit(p.value());
-  }
-  template<typename... X> void Descend(const std::variant<X...> &u) {
-    std::visit([&](const auto &x) { Visit(x); }, u);
-  }
-  template<typename X> void Descend(const std::vector<X> &xs) {
-    for (const auto &x : xs) {
-      Visit(x);
-    }
-  }
-  template<typename T> void Descend(const Expr<T> &expr) { Visit(expr.u); }
-  template<typename D, typename R, typename... O>
-  void Descend(const Operation<D, R, O...> &op) {
-    Visit(op.left());
-    if constexpr (op.operands > 1) {
-      Visit(op.right());
-    }
-  }
-  template<typename R> void Descend(const ImpliedDo<R> &ido) {
-    Visit(ido.lower());
-    Visit(ido.upper());
-    Visit(ido.stride());
-    Visit(ido.values());
-  }
-  template<typename R> void Descend(const ArrayConstructorValue<R> &av) {
-    Visit(av.u);
-  }
-  template<typename R> void Descend(const ArrayConstructorValues<R> &avs) {
-    Visit(avs.values());
-  }
-  template<int KIND>
-  void Descend(
-      const ArrayConstructor<Type<TypeCategory::Character, KIND>> &ac) {
-    Visit(static_cast<
-        ArrayConstructorValues<Type<TypeCategory::Character, KIND>>>(ac));
-    Visit(ac.LEN());
-  }
-  void Descend(const semantics::ParamValue &param) {
-    Visit(param.GetExplicit());
-  }
-  void Descend(const semantics::DerivedTypeSpec &derived) {
-    for (const auto &pair : derived.parameters()) {
-      Visit(pair.second);
-    }
+class RewriterBase {
+public:
+  template<typename A> A Handle(A &&x) {
+    defaultHandleCalled_ = true;
+    return std::move(x);
   }
-  void Descend(const StructureConstructor &sc) {
-    Visit(sc.derivedTypeSpec());
-    for (const auto &pair : sc.values()) {
-      Visit(pair.second);
+  template<typename A> void Pre(const A &) {}
+  template<typename A> A Post(A &&x) { return std::move(x); }
+
+  void Return() { done_ = true; }
+
+protected:
+  bool done_{false};
+  bool defaultHandleCalled_{false};
+};
+
+template<typename... A>
+class Rewriter : public virtual RewriterBase, public A... {
+public:
+  using RewriterBase::Handle, RewriterBase::Pre, RewriterBase::Post;
+  using A::Handle..., A::Pre..., A::Post...;
+
+  template<typename... B> Rewriter(B... x) : A{x}... {}
+
+private:
+  using RewriterBase::done_, RewriterBase::defaultHandleCalled_;
+
+public:
+  template<typename B> B Traverse(B &&x) {
+    if (!done_) {
+      defaultHandleCalled_ = false;
+      x = Handle(std::move(x));
+      if (defaultHandleCalled_) {
+        Pre(x);
+        if (!done_) {
+          descender_.Descend(x);
+          if (!done_) {
+            x = Post(std::move(x));
+          }
+        }
+      }
     }
+    return x;
   }
-  void Descend(const BaseObject &object) { Visit(object.u); }
-  void Descend(const Component &component) {
-    Visit(component.base());
-    Visit(component.GetLastSymbol());
-  }
-  template<int KIND> void Descend(const TypeParamInquiry<KIND> &inq) {
-    Visit(inq.base());
-    Visit(inq.parameter());
-  }
-  void Descend(const Triplet &triplet) {
-    Visit(triplet.lower());
-    Visit(triplet.upper());
-    Visit(triplet.stride());
-  }
-  void Descend(const Subscript &sscript) { Visit(sscript.u); }
-  void Descend(const ArrayRef &aref) {
-    Visit(aref.base());
-    Visit(aref.subscript());
-  }
-  void Descend(const CoarrayRef &caref) {
-    Visit(caref.base());
-    Visit(caref.subscript());
-    Visit(caref.cosubscript());
-    Visit(caref.stat());
-    Visit(caref.team());
-  }
-  void Descend(const DataRef &data) { Visit(data.u); }
-  void Descend(const ComplexPart &z) { Visit(z.complex()); }
-  template<typename T> void Descend(const Designator<T> &designator) {
-    Visit(designator.u);
-  }
-  template<typename T> void Descend(const Variable<T> &var) { Visit(var.u); }
-  void Descend(const ActualArgument &arg) { Visit(arg.value()); }
-  void Descend(const ProcedureDesignator &p) { Visit(p.u); }
-  void Descend(const ProcedureRef &call) {
-    Visit(call.proc());
-    Visit(call.arguments());
-  }
+
+  Descender<Rewriter> descender_{*this};
 };
 }
 #endif  // FORTRAN_EVALUATE_TRAVERSAL_H_
index f7f65c2..cb3aa54 100644 (file)
@@ -305,11 +305,11 @@ static CS GatherLocalVariableNames(
 static CS GatherReferencesFromExpression(const parser::Expr &expression) {
   // Use the new expression traversal framework if possible, for testing.
   if (expression.typedExpr.has_value()) {
-    struct CollectSymbols : public virtual evaluate::TraversalBase<CS> {
+    struct CollectSymbols : public virtual evaluate::VisitorBase<CS> {
       explicit CollectSymbols(int) {}
       void Handle(const Symbol *symbol) { result().push_back(symbol); }
     };
-    return evaluate::Traversal<CS, CollectSymbols>{0}.Traverse(
+    return evaluate::Visitor<CS, CollectSymbols>{0}.Traverse(
         expression.typedExpr.value());
   } else {
     GatherSymbols gatherSymbols;