[flang] More folding of SIZE()
authorpeter klausler <pklausler@nvidia.com>
Fri, 5 Apr 2019 00:10:07 +0000 (17:10 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 5 Apr 2019 19:56:12 +0000 (12:56 -0700)
Original-commit: flang-compiler/f18@23f62fea1d076311bc4fd63d7b749fbb8423763c
Reviewed-on: https://github.com/flang-compiler/f18/pull/386
Tree-same-pre-rewrite: false

flang/lib/common/template.h
flang/lib/evaluate/fold.cc
flang/lib/evaluate/shape.cc
flang/lib/evaluate/shape.h

index 135c8ba..be4027a 100644 (file)
@@ -20,6 +20,7 @@
 #include <tuple>
 #include <type_traits>
 #include <variant>
+#include <vector>
 
 // Utility templates for metaprogramming and for composing the
 // std::optional<>, std::tuple<>, and std::variant<> containers.
@@ -234,6 +235,24 @@ std::optional<std::tuple<A...>> AllElementsPresent(
       std::move(t), std::index_sequence_for<A...>{});
 }
 
+// std::vector<std::optional<A>> -> std::optional<std::vector<A>>
+// i.e., inverts a vector of optional values into an optional vector that
+// has a value of if all of the original elements were present.
+template<typename A>
+std::optional<std::vector<A>> AllElementsPresent(
+    std::vector<std::optional<A>> &&v) {
+  for (const auto &maybeA : v) {
+    if (!maybeA.has_value()) {
+      return std::nullopt;
+    }
+  }
+  std::vector<A> result;
+  for (auto &&maybeA : std::move(v)) {
+    result.emplace_back(std::move(*maybeA));
+  }
+  return result;
+}
+
 // (std::optional<>...) -> std::optional<std::tuple<...>>
 // i.e., given some number of optional values, return a optional tuple of
 // those values that is present only of all of the values were so.
index 7ffbb58..6f980e8 100644 (file)
@@ -485,7 +485,28 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
       }
     } else if (name == "size") {
       if (auto shape{GetShape(args[0].value())}) {
-        // TODO pmk: extract or compute result
+        if (auto &dimArg{args[1]}) {  // DIM= is present, get one extent
+          if (auto dim{ToInt64(dimArg->value())}) {
+            std::int64_t rank = shape->size();
+            if (*dim >= 1 && *dim <= rank) {
+              if (auto &extent{shape->at(*dim - 1)}) {
+                return Fold(context, ConvertToType<T>(std::move(*extent)));
+              }
+            } else {
+              context.messages().Say(
+                  "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
+                  static_cast<std::intmax_t>(*dim), static_cast<int>(rank));
+            }
+          }
+        } else if (auto extents{
+                       common::AllElementsPresent(std::move(*shape))}) {
+          // DIM= is absent; compute PRODUCT(SHAPE())
+          ExtentExpr product{1};
+          for (auto &&extent : std::move(*extents)) {
+            product = std::move(product) * std::move(extent);
+          }
+          return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
+        }
       }
     }
     // TODO:
index 90e42ec..62c8972 100644 (file)
 
 namespace Fortran::evaluate {
 
-Shape AsGeneralShape(const Constant<ExtentType> &constShape) {
-  CHECK(constShape.Rank() == 1);
+Shape AsShape(const Constant<ExtentType> &arrayConstant) {
+  CHECK(arrayConstant.Rank() == 1);
   Shape result;
-  std::size_t dimensions{constShape.size()};
+  std::size_t dimensions{arrayConstant.size()};
   for (std::size_t j{0}; j < dimensions; ++j) {
-    Scalar<ExtentType> extent{constShape.values().at(j)};
+    Scalar<ExtentType> extent{arrayConstant.values().at(j)};
     result.emplace_back(MaybeExtent{ExtentExpr{extent}});
   }
   return result;
 }
 
+std::optional<Shape> AsShape(ExtentExpr &&arrayExpr) {
+  if (auto *constArray{UnwrapExpr<Constant<ExtentType>>(arrayExpr)}) {
+    return AsShape(*constArray);
+  }
+  if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
+    Shape result;
+    for (const auto &value : constructor->values()) {
+      if (const auto *expr{
+              std::get_if<common::CopyableIndirection<ExtentExpr>>(&value.u)}) {
+        if (expr->value().Rank() == 0) {
+          result.emplace_back(std::move(expr->value()));
+          continue;
+        }
+      }
+      return std::nullopt;
+    }
+    return result;
+  }
+  // TODO: linearize other array-valued expressions of known shape, e.g. A+B
+  // as well as conversions of arrays; this will be easier given a
+  // general-purpose array expression flattener (pmk)
+  return std::nullopt;
+}
+
 std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &shape) {
   ArrayConstructorValues<ExtentType> values;
   for (const auto &dim : shape) {
@@ -297,9 +321,12 @@ std::optional<Shape> GetShape(const ProcedureRef &call) {
         intrinsic->name == "ubound") {
       return Shape{MaybeExtent{
           ExtentExpr{call.arguments().front().value().value().Rank()}}};
+    } else if (intrinsic->name == "reshape") {
+      if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) {
+      }
+    } else {
+      // TODO: shapes of other non-elemental intrinsic results
     }
-    // TODO: shapes of other non-elemental intrinsic results
-    // esp. reshape, where shape is value of second argument
   }
   return std::nullopt;
 }
index f882fa2..df28c32 100644 (file)
@@ -32,8 +32,9 @@ using ExtentExpr = Expr<ExtentType>;
 using MaybeExtent = std::optional<ExtentExpr>;
 using Shape = std::vector<MaybeExtent>;
 
-// Convert a constant shape to the expression form, and vice versa.
-Shape AsGeneralShape(const Constant<ExtentType> &);
+// Convert between various representations of shapes
+Shape AsShape(const Constant<ExtentType> &arrayConstant);
+std::optional<Shape> AsShape(ExtentExpr &&arrayExpr);
 std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &);  // array constructor
 std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
 
@@ -79,7 +80,7 @@ std::optional<Shape> GetShape(const NullPointer &);
 
 template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
   Constant<ExtentType> shape{c.SHAPE()};
-  return AsGeneralShape(shape);
+  return AsShape(shape);
 }
 
 template<typename T>