[flang] Simplify expression visitor usage
authorpeter klausler <pklausler@nvidia.com>
Thu, 4 Apr 2019 20:58:46 +0000 (13:58 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 5 Apr 2019 19:56:09 +0000 (12:56 -0700)
Original-commit: flang-compiler/f18@9ab121d6a6dc96b12042a242e5b4eec903455990
Reviewed-on: https://github.com/flang-compiler/f18/pull/386
Tree-same-pre-rewrite: false

flang/lib/common/template.h
flang/lib/evaluate/fold.cc
flang/lib/evaluate/intrinsics.cc
flang/lib/evaluate/shape.cc
flang/lib/evaluate/shape.h
flang/lib/evaluate/traversal.h
flang/lib/semantics/check-do-concurrent.cc
flang/lib/semantics/mod-file.cc

index 84a756fdc2525c1c5da5682a380b32e709fa491f..135c8baa5d3bdbcf8656889c12024e26f1ed296f 100644 (file)
@@ -157,6 +157,21 @@ struct AreTypesDistinctHelper {
 template<typename... Ts>
 constexpr bool AreTypesDistinct{AreTypesDistinctHelper<Ts...>::value()};
 
+template<typename A, typename... Ts> struct AreSameTypeHelper {
+  using type = A;
+  static constexpr bool value() {
+    if constexpr (sizeof...(Ts) == 0) {
+      return true;
+    } else {
+      using Rest = AreSameTypeHelper<Ts...>;
+      return std::is_same_v<type, typename Rest::type> && Rest::value();
+    }
+  }
+};
+
+template<typename... Ts>
+constexpr bool AreSameType{AreSameTypeHelper<Ts...>::value()};
+
 template<typename> struct TupleToVariantHelper;
 template<typename... Ts> struct TupleToVariantHelper<std::tuple<Ts...>> {
   static_assert(AreTypesDistinct<Ts...> ||
index 920ba055d4d66c0049cf4205d48dd5311d8e9c0f..7ffbb58663cc07d371b6c93de60ed6a88d064cd0 100644 (file)
@@ -475,11 +475,18 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
       return FoldElementalIntrinsic<T, T, T, T>(
           context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
     } else if (name == "rank") {
-      // TODO pmk: get rank
+      // TODO assumed-rank dummy argument
+      return Expr<T>{args[0].value().Rank()};
     } else if (name == "shape") {
-      // TODO pmk: call GetShape on argument, massage result
+      if (auto shape{GetShape(args[0].value())}) {
+        if (auto shapeExpr{AsShapeArrayExpr(*shape)}) {
+          return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
+        }
+      }
     } else if (name == "size") {
-      // TODO pmk: call GetShape on argument, extract or compute result
+      if (auto shape{GetShape(args[0].value())}) {
+        // TODO pmk: extract or compute result
+      }
     }
     // TODO:
     // ceiling, count, cshift, dot_product, eoshift,
@@ -1373,6 +1380,7 @@ FOR_EACH_TYPE_AND_KIND(template class ExpressionBase, )
 
 class IsConstantExprVisitor : public virtual VisitorBase<bool> {
 public:
+  using Result = bool;
   explicit IsConstantExprVisitor(int) { result() = true; }
 
   template<int KIND> void Handle(const TypeParamInquiry<KIND> &inq) {
@@ -1402,7 +1410,7 @@ private:
 };
 
 bool IsConstantExpr(const Expr<SomeType> &expr) {
-  return Visitor<bool, IsConstantExprVisitor>{0}.Traverse(expr);
+  return Visitor<IsConstantExprVisitor>{0}.Traverse(expr);
 }
 
 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
index f42052b273f1aa4fa5f521d3347f1988a42714de..5c65f75c94a390e3f3fa4fef4057d9fd13f25d80 100644 (file)
@@ -503,7 +503,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
     {"product",
         {{"array", SameNumeric, Rank::array}, OptionalDIM, OptionalMASK},
         SameNumeric, Rank::dimReduced},
-    // TODO pmk: "rank"
+    {"rank", {{"a", Anything, Rank::anyOrAssumedRank}}, DefaultInt},
     {"real", {{"a", AnyNumeric, Rank::elementalOrBOZ}, DefaultingKIND},
         KINDReal},
     {"reduce",
@@ -968,10 +968,10 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
         CHECK(!shapeArgSize.has_value());
         if (rank == 1) {
           if (auto shape{GetShape(*arg)}) {
-            CHECK(shape->size() == 1);
-            if (auto value{ToInt64(shape->at(0))}) {
-              shapeArgSize = *value;
-              argOk = *value >= 0;
+            if (auto constShape{AsConstantShape(*shape)}) {
+              shapeArgSize = (**constShape).ToInt64();
+              CHECK(shapeArgSize >= 0);
+              argOk = true;
             }
           }
         }
index 6aff28ad430f8b1f678510829d2d5447621f4b84..90e42ecddc3c1d43f81750e468b5cd0e801d5897 100644 (file)
@@ -33,19 +33,27 @@ Shape AsGeneralShape(const Constant<ExtentType> &constShape) {
   return result;
 }
 
-std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
-  std::vector<Scalar<ExtentType>> extents;
+std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &shape) {
+  ArrayConstructorValues<ExtentType> values;
   for (const auto &dim : shape) {
     if (dim.has_value()) {
-      if (const auto cdim{UnwrapExpr<Constant<ExtentType>>(*dim)}) {
-        extents.emplace_back(**cdim);
-        continue;
-      }
+      values.Push(common::Clone(*dim));
+    } else {
+      return std::nullopt;
+    }
+  }
+  return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
+}
+
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
+  if (auto shapeArray{AsShapeArrayExpr(shape)}) {
+    FoldingContext noFoldingContext;
+    auto folded{Fold(noFoldingContext, std::move(*shapeArray))};
+    if (auto *p{UnwrapExpr<Constant<ExtentType>>(folded)}) {
+      return std::move(*p);
     }
-    return std::nullopt;
   }
-  std::vector<std::int64_t> rshape{static_cast<std::int64_t>(shape.size())};
-  return Constant<ExtentType>{std::move(extents), std::move(rshape)};
+  return std::nullopt;
 }
 
 static ExtentExpr ComputeTripCount(
@@ -90,8 +98,8 @@ MaybeExtent GetSize(Shape &&shape) {
   return extent;
 }
 
-static MaybeExtent GetLowerBound(const semantics::Symbol &symbol,
-    const Component *component, int dimension) {
+static MaybeExtent GetLowerBound(
+    const Symbol &symbol, const Component *component, int dimension) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     int j{0};
     for (const auto &shapeSpec : details->shape()) {
@@ -111,8 +119,8 @@ static MaybeExtent GetLowerBound(const semantics::Symbol &symbol,
   return std::nullopt;
 }
 
-static MaybeExtent GetExtent(const semantics::Symbol &symbol,
-    const Component *component, int dimension) {
+static MaybeExtent GetExtent(
+    const Symbol &symbol, const Component *component, int dimension) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     int j{0};
     for (const auto &shapeSpec : details->shape()) {
@@ -169,14 +177,15 @@ static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
 
 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
   struct MyVisitor : public virtual VisitorBase<bool> {
+    using Result = bool;
     explicit MyVisitor(int) { result() = false; }
     void Handle(const ImpliedDoIndex &) { Return(true); }
   };
-  return Visitor<bool, MyVisitor>{0}.Traverse(expr);
+  return Visitor<MyVisitor>{0}.Traverse(expr);
 }
 
 std::optional<Shape> GetShape(
-    const semantics::Symbol &symbol, const Component *component) {
+    const Symbol &symbol, const Component *component) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     Shape result;
     int n = details->shape().size();
@@ -189,6 +198,14 @@ std::optional<Shape> GetShape(
   }
 }
 
+std::optional<Shape> GetShape(const Symbol *symbol) {
+  if (symbol != nullptr) {
+    return GetShape(*symbol);
+  } else {
+    return std::nullopt;
+  }
+}
+
 std::optional<Shape> GetShape(const BaseObject &object) {
   if (const Symbol * symbol{object.symbol()}) {
     return GetShape(*symbol);
@@ -244,7 +261,7 @@ std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
 }
 
 std::optional<Shape> GetShape(const DataRef &dataRef) {
-  return std::visit([](const auto &x) { return GetShape(x); }, dataRef.u);
+  return GetShape(dataRef.u);
 }
 
 std::optional<Shape> GetShape(const Substring &substring) {
@@ -287,10 +304,22 @@ std::optional<Shape> GetShape(const ProcedureRef &call) {
   return std::nullopt;
 }
 
+std::optional<Shape> GetShape(const Relational<SomeType> &relation) {
+  return GetShape(relation.u);
+}
+
 std::optional<Shape> GetShape(const StructureConstructor &) {
   return Shape{};  // always scalar
 }
 
+std::optional<Shape> GetShape(const ImpliedDoIndex &) {
+  return Shape{};  // always scalar
+}
+
+std::optional<Shape> GetShape(const DescriptorInquiry &) {
+  return Shape{};  // always scalar
+}
+
 std::optional<Shape> GetShape(const BOZLiteralConstant &) {
   return Shape{};  // always scalar
 }
index 3b87beb58e18f1b8ff02041902a63d57dc17ce8d..f882fa25feaa33577a3b46e70d086935757f88e9 100644 (file)
@@ -34,9 +34,10 @@ using Shape = std::vector<MaybeExtent>;
 
 // Convert a constant shape to the expression form, and vice versa.
 Shape AsGeneralShape(const Constant<ExtentType> &);
+std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &);  // array constructor
 std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
 
-// Compute a trip count for a triplet or implied DO.
+// Compute an element count for a triplet or trip count for a DO.
 ExtentExpr CountTrips(
     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride);
 ExtentExpr CountTrips(
@@ -47,10 +48,6 @@ MaybeExtent CountTrips(
 // Computes SIZE() == PRODUCT(shape)
 MaybeExtent GetSize(Shape &&);
 
-template<typename A> std::optional<Shape> GetShape(const A &) {
-  return std::nullopt;  // default case  TODO pmk remove
-}
-
 // Forward declarations
 template<typename... A>
 std::optional<Shape> GetShape(const std::variant<A...> &);
@@ -62,8 +59,8 @@ template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
   return GetShape(expr.u);
 }
 
-std::optional<Shape> GetShape(
-    const semantics::Symbol &, const Component * = nullptr);
+std::optional<Shape> GetShape(const Symbol &, const Component * = nullptr);
+std::optional<Shape> GetShape(const Symbol *);
 std::optional<Shape> GetShape(const BaseObject &);
 std::optional<Shape> GetShape(const Component &);
 std::optional<Shape> GetShape(const ArrayRef &);
@@ -73,7 +70,10 @@ std::optional<Shape> GetShape(const Substring &);
 std::optional<Shape> GetShape(const ComplexPart &);
 std::optional<Shape> GetShape(const ActualArgument &);
 std::optional<Shape> GetShape(const ProcedureRef &);
+std::optional<Shape> GetShape(const ImpliedDoIndex &);
+std::optional<Shape> GetShape(const Relational<SomeType> &);
 std::optional<Shape> GetShape(const StructureConstructor &);
+std::optional<Shape> GetShape(const DescriptorInquiry &);
 std::optional<Shape> GetShape(const BOZLiteralConstant &);
 std::optional<Shape> GetShape(const NullPointer &);
 
@@ -94,7 +94,7 @@ std::optional<Shape> GetShape(const Variable<T> &variable) {
 
 template<typename D, typename R, typename... O>
 std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
-  if constexpr (operation.operands > 1) {
+  if constexpr (sizeof...(O) > 1) {
     if (operation.right().Rank() > 0) {
       return GetShape(operation.right());
     }
index ca79ba335877ce5ce89b672e7a75db4c45c17107..45c334fbf24113afc03141e7dfe3e4f3e229d5c6 100644 (file)
 // To use for non-mutating visitation, define one or more client visitation
 // classes of the form:
 //   class MyVisitor : public virtual VisitorBase<RESULT> {
+//     using Result = RESULT;
 //     explicit MyVisitor(ARGTYPE);  // single-argument constructor
 //     void Handle(const T1 &);  // callback for type T1 objects
 //     void Pre(const T2 &);  // callback before visiting T2
 //     void Post(const T2 &);  // callback after visiting T2
 //     ...
 //   };
-// RESULT should have some default-constructible type.
-// Then instantiate and construct a Visitor and its embedded MyVisitor via:
-//   Visitor<RESULT, MyVisitor, ...> v{value};  // value is ARGTYPE &&
+// RESULT should have some default-constructible type, and it must be
+// the same type in all of the visitors that you combine in the next step.
+//
+// Then instantiate and construct a Visitor and its embedded visitors via:
+//   Visitor<MyVisitor, ...> v{value...};  // value is/are ARGTYPE &&
 // and call:
 //   RESULT result{v.Traverse(topLevelExpr)};
 // Within the callback routines (Handle, Pre, Post), one may call
@@ -49,6 +52,7 @@
 // argument types are rvalues and the non-void result types match
 // the arguments:
 //   class MyRewriter : public virtual RewriterBase<RESULT> {
+//     using Result = RESULT;
 //     explicit MyRewriter(ARGTYPE);  // single-argument constructor
 //     T1 Handle(T1 &&);  // rewriting callback for type T1 objects
 //     void Pre(T2 &);  // in-place mutating callback before visiting T2
@@ -83,10 +87,18 @@ protected:
   Result result_;
 };
 
-template<typename RESULT, typename... A>
-class Visitor : public virtual VisitorBase<RESULT>, public A... {
+template<typename A, typename... B> struct VisitorResultTypeHelper {
+  using type = typename A::Result;
+  static_assert(common::AreSameType<type, typename B::Result...>);
+};
+template<typename... A>
+using VisitorResultType = typename VisitorResultTypeHelper<A...>::type;
+
+template<typename... A>
+class Visitor : public virtual VisitorBase<VisitorResultType<A...>>,
+                public A... {
 public:
-  using Result = RESULT;
+  using Result = VisitorResultType<A...>;
   using Base = VisitorBase<Result>;
   using Base::Handle, Base::Pre, Base::Post;
   using A::Handle..., A::Pre..., A::Post...;
index a4fd2e180183f11c27730d7f44211c037bde1a26..163cc90e868ce7f024a9f3635becd0aa4fa4effc 100644 (file)
@@ -382,11 +382,11 @@ static CS GatherReferencesFromExpression(const parser::Expr &expression) {
   // Use the new expression traversal framework if possible, for testing.
   if (expression.typedExpr) {
     struct CollectSymbols : public virtual evaluate::VisitorBase<CS> {
+      using Result = CS;
       explicit CollectSymbols(int) {}
       void Handle(const Symbol *symbol) { result().push_back(symbol); }
     };
-    return evaluate::Visitor<CS, CollectSymbols>{0}.Traverse(
-        *expression.typedExpr);
+    return evaluate::Visitor<CollectSymbols>{0}.Traverse(*expression.typedExpr);
   } else {
     GatherSymbols gatherSymbols;
     parser::Walk(expression, gatherSymbols);
index 2fc07a8b99d0eaa9562ac2b485a934f84a910d5e..9de5086f09d0c18aad6616c21985d85ae256e6cb 100644 (file)
@@ -87,12 +87,13 @@ private:
 
   using SymbolVector = std::vector<const Symbol *>;
   struct SymbolVisitor : public virtual evaluate::VisitorBase<SymbolVector> {
+    using Result = SymbolVector;
     explicit SymbolVisitor(int) {}
     void Handle(const Symbol *symbol) { result().push_back(symbol); }
   };
 
   template<typename T> void DoExpr(evaluate::Expr<T> expr) {
-    evaluate::Visitor<SymbolVector, SymbolVisitor> visitor{0};
+    evaluate::Visitor<SymbolVisitor> visitor{0};
     for (const Symbol *symbol : visitor.Traverse(expr)) {
       CHECK(symbol && "bad symbol from Traverse");
       DoSymbol(*symbol);