#ifndef FORTRAN_EVALUATE_TRAVERSAL_H_
#define FORTRAN_EVALUATE_TRAVERSAL_H_
-#include "expression.h"
+#include "descender.h"
#include <type_traits>
// Implements an expression traversal utility framework.
// See fold.cc to see how this framework is used to implement detection
// of constant expressions.
//
-// To use, define one or more client visitation classes of the form:
-// class MyVisitor : public virtual TraversalBase<RESULT> {
+// To use for non-mutating visitation, define one or more client visitation
+// classes of the form:
+// class MyVisitor : public virtual VisitorBase<RESULT> {
// explicit MyVisitor(ARGTYPE); // single-argument constructor
// void Handle(const T1 &); // callback for type T1 objects
// void Pre(const T2 &); // callback before visiting T2
// ...
// };
// RESULT should have some default-constructible type.
-// Then instantiate and construct a Traversal and its embedded MyVisitor via:
-// Traversal<RESULT, MyVisitor, ...> t{value}; // value is ARGTYPE &&
+// Then instantiate and construct a Visitor and its embedded MyVisitor via:
+// Visitor<RESULT, MyVisitor, ...> v{value}; // value is ARGTYPE &&
// and call:
-// RESULT result{t.Traverse(topLevelExpr)};
+// RESULT result{v.Traverse(topLevelExpr)};
// Within the callback routines (Handle, Pre, Post), one may call
// void Return(RESULT &&); // to define the result and end traversal
// void Return(); // to end traversal with current result
// For any given expression object type T for which a callback is defined
// in any visitor class, the callback must be distinct from all others.
// Further, if there is a Handle(const T &) callback, there cannot be a
-// Pre() or a Post().
+// Pre(const T &) or a Post(const T &).
+//
+// For rewriting traversals, the paradigm is similar; however, the
+// argument types are rvalues and the non-void result types match
+// the arguments:
+// class MyRewriter : public virtual RewriterBase<RESULT> {
+// explicit MyRewriter(ARGTYPE); // single-argument constructor
+// T1 Handle(T1 &&); // rewriting callback for type T1 objects
+// void Pre(T2 &); // in-place mutating callback before visiting T2
+// T2 Post(T2 &&); // rewriting callback after visiting T2
+// ...
+// };
+// Rewriter<MyRewriter, ...> rw{value};
+// topLevelExpr = rw.Traverse(std::move(topLevelExpr));
namespace Fortran::evaluate {
-template<typename RESULT> class TraversalBase {
+template<typename RESULT> class VisitorBase {
public:
using Result = RESULT;
};
template<typename RESULT, typename... A>
-class Traversal : public virtual TraversalBase<RESULT>, public A... {
+class Visitor : public virtual VisitorBase<RESULT>, public A... {
public:
using Result = RESULT;
- using Base = TraversalBase<Result>;
+ using Base = VisitorBase<Result>;
using Base::Handle, Base::Pre, Base::Post;
using A::Handle..., A::Pre..., A::Post...;
private:
- using TraversalBase<Result>::done_, TraversalBase<Result>::result_;
+ using VisitorBase<Result>::done_, VisitorBase<Result>::result_;
public:
- template<typename... B> Traversal(B... x) : A{x}... {}
+ template<typename... B> Visitor(B... x) : A{x}... {}
template<typename B> Result Traverse(const B &x) {
Visit(x);
return std::move(result_);
// No visitation class defines Handle(B), so try Pre()/Post().
Pre(x);
if (!done_) {
- Descend(x);
+ descender_.Descend(x);
if (!done_) {
Post(x);
}
}
}
- template<typename X> void Descend(const X &) {} // default case
+ friend class Descender<Visitor>;
+ Descender<Visitor> descender_{*this};
+};
- template<typename X> void Descend(const X *p) {
- if (p != nullptr) {
- Visit(*p);
- }
- }
- template<typename X> void Descend(const std::optional<X> &o) {
- if (o.has_value()) {
- Visit(*o);
- }
- }
- template<typename X> void Descend(const CopyableIndirection<X> &p) {
- Visit(p.value());
- }
- template<typename... X> void Descend(const std::variant<X...> &u) {
- std::visit([&](const auto &x) { Visit(x); }, u);
- }
- template<typename X> void Descend(const std::vector<X> &xs) {
- for (const auto &x : xs) {
- Visit(x);
- }
- }
- template<typename T> void Descend(const Expr<T> &expr) { Visit(expr.u); }
- template<typename D, typename R, typename... O>
- void Descend(const Operation<D, R, O...> &op) {
- Visit(op.left());
- if constexpr (op.operands > 1) {
- Visit(op.right());
- }
- }
- template<typename R> void Descend(const ImpliedDo<R> &ido) {
- Visit(ido.lower());
- Visit(ido.upper());
- Visit(ido.stride());
- Visit(ido.values());
- }
- template<typename R> void Descend(const ArrayConstructorValue<R> &av) {
- Visit(av.u);
- }
- template<typename R> void Descend(const ArrayConstructorValues<R> &avs) {
- Visit(avs.values());
- }
- template<int KIND>
- void Descend(
- const ArrayConstructor<Type<TypeCategory::Character, KIND>> &ac) {
- Visit(static_cast<
- ArrayConstructorValues<Type<TypeCategory::Character, KIND>>>(ac));
- Visit(ac.LEN());
- }
- void Descend(const semantics::ParamValue ¶m) {
- Visit(param.GetExplicit());
- }
- void Descend(const semantics::DerivedTypeSpec &derived) {
- for (const auto &pair : derived.parameters()) {
- Visit(pair.second);
- }
+class RewriterBase {
+public:
+ template<typename A> A Handle(A &&x) {
+ defaultHandleCalled_ = true;
+ return std::move(x);
}
- void Descend(const StructureConstructor &sc) {
- Visit(sc.derivedTypeSpec());
- for (const auto &pair : sc.values()) {
- Visit(pair.second);
+ template<typename A> void Pre(const A &) {}
+ template<typename A> A Post(A &&x) { return std::move(x); }
+
+ void Return() { done_ = true; }
+
+protected:
+ bool done_{false};
+ bool defaultHandleCalled_{false};
+};
+
+template<typename... A>
+class Rewriter : public virtual RewriterBase, public A... {
+public:
+ using RewriterBase::Handle, RewriterBase::Pre, RewriterBase::Post;
+ using A::Handle..., A::Pre..., A::Post...;
+
+ template<typename... B> Rewriter(B... x) : A{x}... {}
+
+private:
+ using RewriterBase::done_, RewriterBase::defaultHandleCalled_;
+
+public:
+ template<typename B> B Traverse(B &&x) {
+ if (!done_) {
+ defaultHandleCalled_ = false;
+ x = Handle(std::move(x));
+ if (defaultHandleCalled_) {
+ Pre(x);
+ if (!done_) {
+ descender_.Descend(x);
+ if (!done_) {
+ x = Post(std::move(x));
+ }
+ }
+ }
}
+ return x;
}
- void Descend(const BaseObject &object) { Visit(object.u); }
- void Descend(const Component &component) {
- Visit(component.base());
- Visit(component.GetLastSymbol());
- }
- template<int KIND> void Descend(const TypeParamInquiry<KIND> &inq) {
- Visit(inq.base());
- Visit(inq.parameter());
- }
- void Descend(const Triplet &triplet) {
- Visit(triplet.lower());
- Visit(triplet.upper());
- Visit(triplet.stride());
- }
- void Descend(const Subscript &sscript) { Visit(sscript.u); }
- void Descend(const ArrayRef &aref) {
- Visit(aref.base());
- Visit(aref.subscript());
- }
- void Descend(const CoarrayRef &caref) {
- Visit(caref.base());
- Visit(caref.subscript());
- Visit(caref.cosubscript());
- Visit(caref.stat());
- Visit(caref.team());
- }
- void Descend(const DataRef &data) { Visit(data.u); }
- void Descend(const ComplexPart &z) { Visit(z.complex()); }
- template<typename T> void Descend(const Designator<T> &designator) {
- Visit(designator.u);
- }
- template<typename T> void Descend(const Variable<T> &var) { Visit(var.u); }
- void Descend(const ActualArgument &arg) { Visit(arg.value()); }
- void Descend(const ProcedureDesignator &p) { Visit(p.u); }
- void Descend(const ProcedureRef &call) {
- Visit(call.proc());
- Visit(call.arguments());
- }
+
+ Descender<Rewriter> descender_{*this};
};
}
#endif // FORTRAN_EVALUATE_TRAVERSAL_H_