[flang] complete GetShape, compile
authorpeter klausler <pklausler@nvidia.com>
Thu, 4 Apr 2019 20:12:21 +0000 (13:12 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 5 Apr 2019 19:56:06 +0000 (12:56 -0700)
Original-commit: flang-compiler/f18@ff124f69a9c4a23f51004b332736fbb1f47c431a
Reviewed-on: https://github.com/flang-compiler/f18/pull/386
Tree-same-pre-rewrite: false

flang/lib/common/template.h
flang/lib/evaluate/constant.cc
flang/lib/evaluate/constant.h
flang/lib/evaluate/fold.cc
flang/lib/evaluate/formatting.cc
flang/lib/evaluate/shape.cc
flang/lib/evaluate/shape.h
flang/lib/evaluate/variable.h

index acb2b36..84a756f 100644 (file)
@@ -240,6 +240,10 @@ std::optional<R> MapOptional(
   }
   return std::nullopt;
 }
+template<typename R, typename... A>
+std::optional<R> MapOptional(R (*f)(A &&...), std::optional<A> &&... x) {
+  return MapOptional(std::function<R(A && ...)>{f}, std::move(x)...);
+}
 
 // Given a VISITOR class of the general form
 //   struct VISITOR {
index 2e2cc73..05c4f03 100644 (file)
@@ -42,6 +42,12 @@ auto ConstantBase<RESULT, VALUE>::At(
   return values_.at(SubscriptsToOffset(index, shape_));
 }
 
+template<typename RESULT, typename VALUE>
+auto ConstantBase<RESULT, VALUE>::At(std::vector<std::int64_t> &&index) const
+    -> ScalarValue {
+  return values_.at(SubscriptsToOffset(index, shape_));
+}
+
 static Constant<SubscriptInteger> ShapeAsConstant(
     const std::vector<std::int64_t> &shape) {
   using IntType = Scalar<SubscriptInteger>;
index 78cb69b..1858f1d 100644 (file)
@@ -66,6 +66,7 @@ public:
 
   // Apply 1-based subscripts
   ScalarValue At(const std::vector<std::int64_t> &) const;
+  ScalarValue At(std::vector<std::int64_t> &&) const;
 
   Constant<SubscriptInteger> SHAPE() const;
   std::ostream &AsFortran(std::ostream &) const;
index 1d194f1..920ba05 100644 (file)
@@ -478,6 +478,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
       // TODO pmk: get rank
     } else if (name == "shape") {
       // TODO pmk: call GetShape on argument, massage result
+    } else if (name == "size") {
+      // TODO pmk: call GetShape on argument, extract or compute result
     }
     // TODO:
     // ceiling, count, cshift, dot_product, eoshift,
@@ -485,7 +487,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
     // index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
     // minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape,
     // scan, selected_char_kind, selected_int_kind, selected_real_kind,
-    // sign, size, spread, sum, transfer, transpose, ubound, unpack, verify
+    // sign, spread, sum, transfer, transpose, ubound, unpack, verify
   }
   return Expr<T>{std::move(funcRef)};
 }
index 3542a12..2ee4651 100644 (file)
@@ -575,8 +575,9 @@ std::ostream &Designator<T>::AsFortran(std::ostream &o) const {
 std::ostream &DescriptorInquiry::AsFortran(std::ostream &o) const {
   switch (field_) {
   case Field::LowerBound: o << "lbound("; break;
-  case Field::Extent: o << "%EXTENT("; break;
+  case Field::Extent: o << "size("; break;
   case Field::Stride: o << "%STRIDE("; break;
+  case Field::Rank: o << "rank("; break;
   }
   std::visit(
       common::visitors{
@@ -588,8 +589,8 @@ std::ostream &DescriptorInquiry::AsFortran(std::ostream &o) const {
           [&](const Component &comp) { EmitVar(o, comp); },
       },
       base_);
-  if (dimension_ > 0) {
-    o << ",dim=" << dimension_;
+  if (dimension_ >= 0) {
+    o << ",dim=" << (dimension_ + 1);
   }
   return o << ')';
 }
index e17ff1a..6aff28a 100644 (file)
 #include "shape.h"
 #include "fold.h"
 #include "tools.h"
+#include "traversal.h"
 #include "../common/idioms.h"
+#include "../common/template.h"
 #include "../semantics/symbol.h"
 
 namespace Fortran::evaluate {
 
-static Extent GetLowerBound(const semantics::Symbol &symbol,
+Shape AsGeneralShape(const Constant<ExtentType> &constShape) {
+  CHECK(constShape.Rank() == 1);
+  Shape result;
+  std::size_t dimensions{constShape.size()};
+  for (std::size_t j{0}; j < dimensions; ++j) {
+    Scalar<ExtentType> extent{constShape.values().at(j)};
+    result.emplace_back(MaybeExtent{ExtentExpr{extent}});
+  }
+  return result;
+}
+
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
+  std::vector<Scalar<ExtentType>> extents;
+  for (const auto &dim : shape) {
+    if (dim.has_value()) {
+      if (const auto cdim{UnwrapExpr<Constant<ExtentType>>(*dim)}) {
+        extents.emplace_back(**cdim);
+        continue;
+      }
+    }
+    return std::nullopt;
+  }
+  std::vector<std::int64_t> rshape{static_cast<std::int64_t>(shape.size())};
+  return Constant<ExtentType>{std::move(extents), std::move(rshape)};
+}
+
+static ExtentExpr ComputeTripCount(
+    ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
+  ExtentExpr strideCopy{common::Clone(stride)};
+  ExtentExpr span{
+      (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
+      std::move(stride)};
+  ExtentExpr extent{
+      Extremum<ExtentType>{std::move(span), ExtentExpr{0}, Ordering::Greater}};
+  FoldingContext noFoldingContext;
+  return Fold(noFoldingContext, std::move(extent));
+}
+
+ExtentExpr CountTrips(
+    ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
+  return ComputeTripCount(
+      std::move(lower), std::move(upper), std::move(stride));
+}
+
+ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
+    const ExtentExpr &stride) {
+  return ComputeTripCount(
+      common::Clone(lower), common::Clone(upper), common::Clone(stride));
+}
+
+MaybeExtent CountTrips(
+    MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride) {
+  return common::MapOptional(
+      ComputeTripCount, std::move(lower), std::move(upper), std::move(stride));
+}
+
+MaybeExtent GetSize(Shape &&shape) {
+  ExtentExpr extent{1};
+  for (auto &&dim : std::move(shape)) {
+    if (dim.has_value()) {
+      extent = std::move(extent) * std::move(*dim);
+    } else {
+      return std::nullopt;
+    }
+  }
+  return extent;
+}
+
+static MaybeExtent GetLowerBound(const semantics::Symbol &symbol,
     const Component *component, int dimension) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     int j{0};
@@ -29,10 +99,10 @@ static Extent GetLowerBound(const semantics::Symbol &symbol,
         if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
           return *bound;
         } else if (component != nullptr) {
-          return Expr<SubscriptInteger>{DescriptorInquiry{
+          return ExtentExpr{DescriptorInquiry{
               *component, DescriptorInquiry::Field::LowerBound, dimension}};
         } else {
-          return Expr<SubscriptInteger>{DescriptorInquiry{
+          return ExtentExpr{DescriptorInquiry{
               symbol, DescriptorInquiry::Field::LowerBound, dimension}};
         }
       }
@@ -41,7 +111,7 @@ static Extent GetLowerBound(const semantics::Symbol &symbol,
   return std::nullopt;
 }
 
-static Extent GetExtent(const semantics::Symbol &symbol,
+static MaybeExtent GetExtent(const semantics::Symbol &symbol,
     const Component *component, int dimension) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     int j{0};
@@ -52,14 +122,14 @@ static Extent GetExtent(const semantics::Symbol &symbol,
             FoldingContext noFoldingContext;
             return Fold(noFoldingContext,
                 common::Clone(ubound.value()) - common::Clone(lbound.value()) +
-                    Expr<SubscriptInteger>{1});
+                    ExtentExpr{1});
           }
         }
         if (component != nullptr) {
-          return Expr<SubscriptInteger>{DescriptorInquiry{
+          return ExtentExpr{DescriptorInquiry{
               *component, DescriptorInquiry::Field::Extent, dimension}};
         } else {
-          return Expr<SubscriptInteger>{DescriptorInquiry{
+          return ExtentExpr{DescriptorInquiry{
               &symbol, DescriptorInquiry::Field::Extent, dimension}};
         }
       }
@@ -68,34 +138,23 @@ static Extent GetExtent(const semantics::Symbol &symbol,
   return std::nullopt;
 }
 
-static Extent GetExtent(const Subscript &subscript, const Symbol &symbol,
+static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
     const Component *component, int dimension) {
   return std::visit(
       common::visitors{
-          [&](const Triplet &triplet) -> Extent {
-            Extent upper{triplet.upper()};
+          [&](const Triplet &triplet) -> MaybeExtent {
+            MaybeExtent upper{triplet.upper()};
             if (!upper.has_value()) {
               upper = GetExtent(symbol, component, dimension);
             }
-            if (upper.has_value()) {
-              Extent lower{triplet.lower()};
-              if (!lower.has_value()) {
-                lower = GetLowerBound(symbol, component, dimension);
-              }
-              if (lower.has_value()) {
-                auto span{
-                    (std::move(*upper) - std::move(*lower) + triplet.stride()) /
-                    triplet.stride()};
-                Expr<SubscriptInteger> extent{
-                    Extremum<SubscriptInteger>{std::move(span),
-                        Expr<SubscriptInteger>{0}, Ordering::Greater}};
-                FoldingContext noFoldingContext;
-                return Fold(noFoldingContext, std::move(extent));
-              }
+            MaybeExtent lower{triplet.lower()};
+            if (!lower.has_value()) {
+              lower = GetLowerBound(symbol, component, dimension);
             }
-            return std::nullopt;
+            return CountTrips(std::move(lower), std::move(upper),
+                MaybeExtent{triplet.stride()});
           },
-          [](const IndirectSubscriptIntegerExpr &subs) -> Extent {
+          [](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent {
             if (auto shape{GetShape(subs.value())}) {
               if (shape->size() > 0) {
                 CHECK(shape->size() == 1);  // vector-valued subscript
@@ -108,6 +167,14 @@ static Extent GetExtent(const Subscript &subscript, const Symbol &symbol,
       subscript.u);
 }
 
+bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
+  struct MyVisitor : public virtual VisitorBase<bool> {
+    explicit MyVisitor(int) { result() = false; }
+    void Handle(const ImpliedDoIndex &) { Return(true); }
+  };
+  return Visitor<bool, MyVisitor>{0}.Traverse(expr);
+}
+
 std::optional<Shape> GetShape(
     const semantics::Symbol &symbol, const Component *component) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
@@ -211,8 +278,8 @@ std::optional<Shape> GetShape(const ProcedureRef &call) {
                  std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
         intrinsic->name == "ubound") {
-      return Shape{Extent{Expr<SubscriptInteger>{
-          call.arguments().front().value().value().Rank()}}};
+      return Shape{MaybeExtent{
+          ExtentExpr{call.arguments().front().value().value().Rank()}}};
     }
     // TODO: shapes of other non-elemental intrinsic results
     // esp. reshape, where shape is value of second argument
index 44829b6..3b87beb 100644 (file)
@@ -19,6 +19,7 @@
 #define FORTRAN_EVALUATE_SHAPE_H_
 
 #include "expression.h"
+#include "tools.h"
 #include "type.h"
 #include "../common/indirection.h"
 #include <optional>
 
 namespace Fortran::evaluate {
 
-using Extent = std::optional<Expr<SubscriptInteger>>;
-using Shape = std::vector<Extent>;
+using ExtentType = SubscriptInteger;
+using ExtentExpr = Expr<ExtentType>;
+using MaybeExtent = std::optional<ExtentExpr>;
+using Shape = std::vector<MaybeExtent>;
+
+// Convert a constant shape to the expression form, and vice versa.
+Shape AsGeneralShape(const Constant<ExtentType> &);
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
+
+// Compute a trip count for a triplet or implied DO.
+ExtentExpr CountTrips(
+    ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride);
+ExtentExpr CountTrips(
+    const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride);
+MaybeExtent CountTrips(
+    MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride);
+
+// Computes SIZE() == PRODUCT(shape)
+MaybeExtent GetSize(Shape &&);
 
 template<typename A> std::optional<Shape> GetShape(const A &) {
-  return std::nullopt;  // default case
+  return std::nullopt;  // default case  TODO pmk remove
 }
 
 // Forward declarations
@@ -59,6 +77,11 @@ std::optional<Shape> GetShape(const StructureConstructor &);
 std::optional<Shape> GetShape(const BOZLiteralConstant &);
 std::optional<Shape> GetShape(const NullPointer &);
 
+template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
+  Constant<ExtentType> shape{c.SHAPE()};
+  return AsGeneralShape(shape);
+}
+
 template<typename T>
 std::optional<Shape> GetShape(const Designator<T> &designator) {
   return GetShape(designator.u);
@@ -81,12 +104,58 @@ std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
 
 template<int KIND>
 std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
-  return Shape{};  // always scalar
+  return Shape{};  // always scalar, even when applied to an array
+}
+
+// Utility predicate: does an expression reference any implied DO index?
+bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
+
+template<typename T> MaybeExtent GetExtent(const ArrayConstructorValues<T> &);
+
+template<typename T>
+MaybeExtent GetExtent(const ArrayConstructorValue<T> &value) {
+  return std::visit(
+      common::visitors{
+          [](const common::CopyableIndirection<Expr<T>> &x) -> MaybeExtent {
+            if (std::optional<Shape> xShape{GetShape(x)}) {
+              // Array values in array constructors get linearized.
+              return GetSize(std::move(*xShape));
+            }
+            return std::nullopt;
+          },
+          [](const ImpliedDo<T> &ido) -> MaybeExtent {
+            // Don't be heroic and try to figure out triangular implied DO
+            // nests.
+            if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
+                !ContainsAnyImpliedDoIndex(ido.upper()) &&
+                !ContainsAnyImpliedDoIndex(ido.stride())) {
+              if (auto nValues{GetExtent(ido.values())}) {
+                return std::move(*nValues) *
+                    CountTrips(ido.lower(), ido.upper(), ido.stride());
+              }
+            }
+            return std::nullopt;
+          },
+      },
+      value.u);
+}
+
+template<typename T>
+MaybeExtent GetExtent(const ArrayConstructorValues<T> &values) {
+  ExtentExpr result{0};
+  for (const auto &value : values.values()) {
+    if (MaybeExtent n{GetExtent(value)}) {
+      result = std::move(result) + std::move(*n);
+    } else {
+      return std::nullopt;
+    }
+  }
+  return result;
 }
 
 template<typename T>
-std::optional<Shape> GetShape(const ArrayConstructorValues<T> &aconst) {
-  return std::nullopt;  // TODO pmk much more here!!
+std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
+  return Shape{GetExtent(aconst)};
 }
 
 template<typename... A>
index 468145d..415924b 100644 (file)
@@ -385,8 +385,7 @@ template<typename T> struct Variable {
 class DescriptorInquiry {
 public:
   using Result = SubscriptInteger;
-  ENUM_CLASS(Field, LowerBound, Extent, Stride)
-  // TODO: Also type parameters and/or CHARACTER length?
+  ENUM_CLASS(Field, LowerBound, Extent, Stride, Rank)
 
   CLASS_BOILERPLATE(DescriptorInquiry)
   DescriptorInquiry(const Symbol &, Field, int);