[flang] Improve detection of default Handle() callback
authorpeter klausler <pklausler@nvidia.com>
Wed, 6 Mar 2019 17:11:17 +0000 (09:11 -0800)
committerpeter klausler <pklausler@nvidia.com>
Thu, 7 Mar 2019 00:15:51 +0000 (16:15 -0800)
Original-commit: flang-compiler/f18@c09c4c9e00d15a5dd391216662665d48a908c4fd
Reviewed-on: https://github.com/flang-compiler/f18/pull/316
Tree-same-pre-rewrite: false

flang/lib/evaluate/traversal.h

index 0c708d2..9504391 100644 (file)
 #define FORTRAN_EVALUATE_TRAVERSAL_H_
 
 #include "expression.h"
+#include <type_traits>
 
 // Implements an expression traversal utility framework.
 namespace Fortran::evaluate {
 
-template<typename RESULT>
-class TraversalBase {
+template<typename RESULT> class TraversalBase {
 public:
   using Result = RESULT;
-  template<typename A> void Handle(const A &) { defaultHandle_ = true; }
+  // Note the weird return type; it distinguishes this default Handle
+  // from any void-valued override.
+  template<typename A> std::nullptr_t Handle(const A &) { return nullptr; }
   template<typename A> void Pre(const A &) {}
   template<typename A> void Post(const A &) {}
-  template<typename... A> void Return(A &&...x) {
+  template<typename... A> void Return(A &&... x) {
     result_.emplace(std::move(x)...);
   }
+
 protected:
   std::optional<Result> result_;
-  bool defaultHandle_{false};
 };
 
 // Descend() is a helper function template for Traversal::Visit().
 // Do not use directly.
 namespace descend {
-template<typename VISITOR, typename EXPR> void Descend(VISITOR &, const EXPR &) {}
+template<typename VISITOR, typename EXPR>
+void Descend(VISITOR &, const EXPR &) {}
 template<typename V, typename A> void Descend(V &visitor, const A *p) {
   if (p != nullptr) {
     visitor.Visit(*p);
   }
 }
-template<typename V, typename A> void Descend(V &visitor, const std::optional<A> *o) {
+template<typename V, typename A>
+void Descend(V &visitor, const std::optional<A> *o) {
   if (o.has_value()) {
     visitor.Visit(*o);
   }
 }
-template<typename V, typename A> void Descend(V &visitor, const CopyableIndirection<A> &p) {
+template<typename V, typename A>
+void Descend(V &visitor, const CopyableIndirection<A> &p) {
   visitor.Visit(p.value());
 }
-template<typename V, typename... A> void Descend(V &visitor, const std::variant<A...> &u) {
-  std::visit([&](const auto &x){ visitor.Visit(x); }, u);
+template<typename V, typename... A>
+void Descend(V &visitor, const std::variant<A...> &u) {
+  std::visit([&](const auto &x) { visitor.Visit(x); }, u);
 }
-template<typename V, typename A> void Descend(V &visitor, const std::vector<A> &xs) {
+template<typename V, typename A>
+void Descend(V &visitor, const std::vector<A> &xs) {
   for (const auto &x : xs) {
     visitor.Visit(x);
   }
@@ -64,32 +71,41 @@ template<typename V, typename T> void Descend(V &visitor, const Expr<T> &expr) {
   visitor.Visit(expr.u);
 }
 template<typename V, typename D, typename R, typename... O>
-void Descend(V &visitor, const Operation<D,R,O...> &op) {
+void Descend(V &visitor, const Operation<D, R, O...> &op) {
   visitor.Visit(op.left());
   if constexpr (op.operands > 1) {
     visitor.Visit(op.right());
   }
 }
-template<typename V, typename R> void Descend(V &visitor, const ImpliedDo<R> &ido) {
+template<typename V, typename R>
+void Descend(V &visitor, const ImpliedDo<R> &ido) {
   visitor.Visit(ido.lower());
   visitor.Visit(ido.upper());
   visitor.Visit(ido.stride());
   visitor.Visit(ido.values());
 }
-template<typename V, typename R> void Descend(V &visitor, const ArrayConstructorValue<R> &av) {
+template<typename V, typename R>
+void Descend(V &visitor, const ArrayConstructorValue<R> &av) {
   visitor.Visit(av.u);
 }
-template<typename V, typename R> void Descend(V &visitor, const ArrayConstructorValues<R> &avs) {
+template<typename V, typename R>
+void Descend(V &visitor, const ArrayConstructorValues<R> &avs) {
   visitor.Visit(avs.values());
 }
-template<typename V, int KIND> void Descend(V &visitor, const ArrayConstructor<Type<TypeCategory::Character, KIND>> &ac) {
-  visitor.Visit(static_cast<ArrayConstructorValues<Type<TypeCategory::Character, KIND>>>(ac));
+template<typename V, int KIND>
+void Descend(V &visitor,
+    const ArrayConstructor<Type<TypeCategory::Character, KIND>> &ac) {
+  visitor.Visit(
+      static_cast<ArrayConstructorValues<Type<TypeCategory::Character, KIND>>>(
+          ac));
   visitor.Visit(ac.LEN());
 }
-template<typename V> void Descend(V &visitor, const semantics::ParamValue &param) {
+template<typename V>
+void Descend(V &visitor, const semantics::ParamValue &param) {
   visitor.Visit(param.GetExplicit());
 }
-template<typename V> void Descend(V &visitor, const semantics::DerivedTypeSpec &derived) {
+template<typename V>
+void Descend(V &visitor, const semantics::DerivedTypeSpec &derived) {
   for (const auto &pair : derived.parameters()) {
     visitor.Visit(pair.second);
   }
@@ -107,7 +123,8 @@ template<typename V> void Descend(V &visitor, const Component &component) {
   visitor.Visit(component.base());
   visitor.Visit(component.GetLastSymbol());
 }
-template<typename V, int KIND> void Descend(V &visitor, const TypeParamInquiry<KIND> &inq) {
+template<typename V, int KIND>
+void Descend(V &visitor, const TypeParamInquiry<KIND> &inq) {
   visitor.Visit(inq.base());
   visitor.Visit(inq.parameter());
 }
@@ -136,10 +153,12 @@ template<typename V> void Descend(V &visitor, const DataRef &data) {
 template<typename V> void Descend(V &visitor, const ComplexPart &z) {
   visitor.Visit(z.complex());
 }
-template<typename V, typename T> void Descend(V &visitor, const Designator<T> &designator) {
+template<typename V, typename T>
+void Descend(V &visitor, const Designator<T> &designator) {
   visitor.Visit(designator.u);
 }
-template<typename V, typename T> void Descend(V &visitor, const Variable<T> &var) {
+template<typename V, typename T>
+void Descend(V &visitor, const Variable<T> &var) {
   visitor.Visit(var.u);
 }
 template<typename V> void Descend(V &visitor, const ActualArgument &arg) {
@@ -161,8 +180,10 @@ public:
   using Base = TraversalBase<Result>;
   using Base::Handle, Base::Pre, Base::Post;
   using A::Handle..., A::Pre..., A::Post...;
+
 private:
-  using TraversalBase<Result>::result_, TraversalBase<Result>::defaultHandle_;
+  using TraversalBase<Result>::result_;
+
 public:
   template<typename... B> Traversal(B... x) : A{x}... {}
   template<typename B> std::optional<Result> Traverse(const B &x) {
@@ -173,9 +194,9 @@ public:
   // TODO: make private, make Descend instances friends
   template<typename B> void Visit(const B &x) {
     if (!result_.has_value()) {
-      defaultHandle_ = false;
-      Handle(x);
-      if (defaultHandle_) {
+      if constexpr (std::is_same_v<std::decay_t<decltype(Handle(x))>,
+                        std::nullptr_t>) {
+        // No visitation class defines Handle(B), so try Pre()/Post().
         Pre(x);
         if (!result_.has_value()) {
           descend::Descend(*this, x);
@@ -183,6 +204,8 @@ public:
             Post(x);
           }
         }
+      } else {
+        Handle(x);
       }
     }
   }