[flang] fix original failure (reshape intrinsic argument check)
authorpeter klausler <pklausler@nvidia.com>
Wed, 3 Apr 2019 23:04:13 +0000 (16:04 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 5 Apr 2019 19:56:04 +0000 (12:56 -0700)
Original-commit: flang-compiler/f18@8bba330b32d928a5cf5d581d139c0cec02294b58
Reviewed-on: https://github.com/flang-compiler/f18/pull/386
Tree-same-pre-rewrite: false

flang/lib/evaluate/call.cc
flang/lib/evaluate/call.h
flang/lib/evaluate/common.h
flang/lib/evaluate/fold.cc
flang/lib/evaluate/intrinsics.cc
flang/lib/evaluate/shape.cc
flang/lib/evaluate/shape.h
flang/lib/evaluate/variable.cc
flang/lib/evaluate/variable.h
flang/lib/parser/message.h

index 102b9fe..5313069 100644 (file)
@@ -29,14 +29,6 @@ bool ActualArgument::operator==(const ActualArgument &that) const {
       isAlternateReturn == that.isAlternateReturn && value() == that.value();
 }
 
-std::optional<int> ActualArgument::VectorSize() const {
-  if (Rank() != 1) {
-    return std::nullopt;
-  }
-  // TODO: get shape vector of value, return its length
-  return std::nullopt;
-}
-
 bool SpecificIntrinsic::operator==(const SpecificIntrinsic &that) const {
   return name == that.name && type == that.type && rank == that.rank &&
       attrs == that.attrs;
index 5b11df1..4dab138 100644 (file)
@@ -45,7 +45,6 @@ public:
   int Rank() const;
   bool operator==(const ActualArgument &) const;
   std::ostream &AsFortran(std::ostream &) const;
-  std::optional<int> VectorSize() const;
 
   std::optional<parser::CharBlock> keyword;
   bool isAlternateReturn{false};  // when true, "value" is a label number
index d91da13..63f9950 100644 (file)
@@ -201,6 +201,7 @@ template<typename A> class Expr;
 
 class FoldingContext {
 public:
+  FoldingContext() = default;
   explicit FoldingContext(const parser::ContextualMessages &m,
       Rounding round = defaultRounding, bool flush = false)
     : messages_{m}, rounding_{round}, flushSubnormalsToZero_{flush} {}
index dd4e13e..1d194f1 100644 (file)
@@ -19,6 +19,7 @@
 #include "host.h"
 #include "int-power.h"
 #include "intrinsics-library-templates.h"
+#include "shape.h"
 #include "tools.h"
 #include "traversal.h"
 #include "type.h"
@@ -473,13 +474,17 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
       }
       return FoldElementalIntrinsic<T, T, T, T>(
           context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
+    } else if (name == "rank") {
+      // TODO pmk: get rank
+    } else if (name == "shape") {
+      // TODO pmk: call GetShape on argument, massage result
     }
     // TODO:
     // ceiling, count, cshift, dot_product, eoshift,
     // findloc, floor, iachar, iall, iany, iparity, ibits, ichar, image_status,
     // index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
     // minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape,
-    // scan, selected_char_kind, selected_int_kind, selected_real_kind, shape,
+    // scan, selected_char_kind, selected_int_kind, selected_real_kind,
     // sign, size, spread, sum, transfer, transpose, ubound, unpack, verify
   }
   return Expr<T>{std::move(funcRef)};
index b645385..f42052b 100644 (file)
@@ -15,6 +15,7 @@
 #include "intrinsics.h"
 #include "expression.h"
 #include "fold.h"
+#include "shape.h"
 #include "tools.h"
 #include "type.h"
 #include "../common/Fortran.h"
@@ -502,6 +503,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
     {"product",
         {{"array", SameNumeric, Rank::array}, OptionalDIM, OptionalMASK},
         SameNumeric, Rank::dimReduced},
+    // TODO pmk: "rank"
     {"real", {{"a", AnyNumeric, Rank::elementalOrBOZ}, DefaultingKIND},
         KINDReal},
     {"reduce",
@@ -607,7 +609,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
 //   COSHAPE
 // TODO: Object characteristic inquiry functions
 //   ALLOCATED, ASSOCIATED, EXTENDS_TYPE_OF, IS_CONTIGUOUS,
-//   PRESENT, RANK, SAME_TYPE, STORAGE_SIZE
+//   PRESENT, SAME_TYPE, STORAGE_SIZE
 // TODO: Type inquiry intrinsic functions - these return constants
 //  BIT_SIZE, DIGITS, EPSILON, HUGE, KIND, MAXEXPONENT, MINEXPONENT,
 //  NEW_LINE, PRECISION, RADIX, RANGE, TINY
@@ -939,7 +941,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
   // Check the ranks of the arguments against the intrinsic's interface.
   const ActualArgument *arrayArg{nullptr};
   const ActualArgument *knownArg{nullptr};
-  const ActualArgument *shapeArg{nullptr};
+  std::optional<int> shapeArgSize;
   int elementalRank{0};
   for (std::size_t j{0}; j < dummies; ++j) {
     const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]};
@@ -963,9 +965,21 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
       case Rank::scalar: argOk = rank == 0; break;
       case Rank::vector: argOk = rank == 1; break;
       case Rank::shape:
-        CHECK(shapeArg == nullptr);
-        shapeArg = arg;
-        argOk = rank == 1 && arg->VectorSize().has_value();
+        CHECK(!shapeArgSize.has_value());
+        if (rank == 1) {
+          if (auto shape{GetShape(*arg)}) {
+            CHECK(shape->size() == 1);
+            if (auto value{ToInt64(shape->at(0))}) {
+              shapeArgSize = *value;
+              argOk = *value >= 0;
+            }
+          }
+        }
+        if (!argOk) {
+          messages.Say(
+              "'shape=' argument must be a vector of known size"_err_en_US);
+          return std::nullopt;
+        }
         break;
       case Rank::matrix: argOk = rank == 2; break;
       case Rank::array:
@@ -1134,8 +1148,8 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
     resultRank = knownArg->Rank() + 1;
     break;
   case Rank::shaped:
-    CHECK(shapeArg != nullptr);
-    resultRank = shapeArg->VectorSize().value();
+    CHECK(shapeArgSize.has_value());
+    resultRank = *shapeArgSize;
     break;
   case Rank::elementalOrBOZ:
   case Rank::shape:
index 63b1663..e17ff1a 100644 (file)
 #include "../semantics/symbol.h"
 
 namespace Fortran::evaluate {
-std::optional<Shape> GetShape(
-    const semantics::Symbol &symbol, const Component *component) {
+
+static Extent GetLowerBound(const semantics::Symbol &symbol,
+    const Component *component, int dimension) {
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
-    Shape result;
-    int dimension{1};
+    int j{0};
     for (const auto &shapeSpec : details->shape()) {
-      if (shapeSpec.isExplicit()) {
-        result.emplace_back(
-            common::Clone(shapeSpec.ubound().GetExplicit().value()) -
-            common::Clone(shapeSpec.lbound().GetExplicit().value()) +
-            Expr<SubscriptInteger>{1});
-      } else if (component != nullptr) {
-        result.emplace_back(Expr<SubscriptInteger>{DescriptorInquiry{
-            *component, DescriptorInquiry::Field::Extent, dimension}});
-      } else {
-        result.emplace_back(Expr<SubscriptInteger>{DescriptorInquiry{
-            symbol, DescriptorInquiry::Field::Extent, dimension}});
+      if (j++ == dimension) {
+        if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
+          return *bound;
+        } else if (component != nullptr) {
+          return Expr<SubscriptInteger>{DescriptorInquiry{
+              *component, DescriptorInquiry::Field::LowerBound, dimension}};
+        } else {
+          return Expr<SubscriptInteger>{DescriptorInquiry{
+              symbol, DescriptorInquiry::Field::LowerBound, dimension}};
+        }
       }
-      ++dimension;
     }
-    return result;
-  } else {
-    return std::nullopt;
   }
+  return std::nullopt;
 }
 
-std::optional<Shape> GetShape(const Component &component) {
-  const Symbol &symbol{component.GetLastSymbol()};
-  if (symbol.Rank() > 0) {
-    return GetShape(symbol, &component);
-  } else {
-    return GetShape(component.base());
+static Extent GetExtent(const semantics::Symbol &symbol,
+    const Component *component, int dimension) {
+  if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+    int j{0};
+    for (const auto &shapeSpec : details->shape()) {
+      if (j++ == dimension) {
+        if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
+          if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
+            FoldingContext noFoldingContext;
+            return Fold(noFoldingContext,
+                common::Clone(ubound.value()) - common::Clone(lbound.value()) +
+                    Expr<SubscriptInteger>{1});
+          }
+        }
+        if (component != nullptr) {
+          return Expr<SubscriptInteger>{DescriptorInquiry{
+              *component, DescriptorInquiry::Field::Extent, dimension}};
+        } else {
+          return Expr<SubscriptInteger>{DescriptorInquiry{
+              &symbol, DescriptorInquiry::Field::Extent, dimension}};
+        }
+      }
+    }
   }
+  return std::nullopt;
 }
-static Extent GetExtent(const Subscript &subscript) {
+
+static Extent GetExtent(const Subscript &subscript, const Symbol &symbol,
+    const Component *component, int dimension) {
   return std::visit(
       common::visitors{
-          [](const Triplet &triplet) -> Extent {
-            if (auto lower{triplet.lower()}) {
-              if (auto lowerValue{ToInt64(*lower)}) {
-                if (auto upper{triplet.upper()}) {
-                  if (auto upperValue{ToInt64(*upper)}) {
-                    if (auto strideValue{ToInt64(triplet.stride())}) {
-                      if (*strideValue != 0) {
-                        std::int64_t extent{
-                            (*upperValue - *lowerValue + *strideValue) /
-                            *strideValue};
-                        return Expr<SubscriptInteger>{extent > 0 ? extent : 0};
-                      }
-                    }
-                  }
-                }
+          [&](const Triplet &triplet) -> Extent {
+            Extent upper{triplet.upper()};
+            if (!upper.has_value()) {
+              upper = GetExtent(symbol, component, dimension);
+            }
+            if (upper.has_value()) {
+              Extent lower{triplet.lower()};
+              if (!lower.has_value()) {
+                lower = GetLowerBound(symbol, component, dimension);
+              }
+              if (lower.has_value()) {
+                auto span{
+                    (std::move(*upper) - std::move(*lower) + triplet.stride()) /
+                    triplet.stride()};
+                Expr<SubscriptInteger> extent{
+                    Extremum<SubscriptInteger>{std::move(span),
+                        Expr<SubscriptInteger>{0}, Ordering::Greater}};
+                FoldingContext noFoldingContext;
+                return Fold(noFoldingContext, std::move(extent));
               }
             }
             return std::nullopt;
           },
           [](const IndirectSubscriptIntegerExpr &subs) -> Extent {
             if (auto shape{GetShape(subs.value())}) {
-              if (shape->size() == 1) {
+              if (shape->size() > 0) {
+                CHECK(shape->size() == 1);  // vector-valued subscript
                 return std::move(shape->at(0));
               }
             }
@@ -86,12 +107,48 @@ static Extent GetExtent(const Subscript &subscript) {
       },
       subscript.u);
 }
+
+std::optional<Shape> GetShape(
+    const semantics::Symbol &symbol, const Component *component) {
+  if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+    Shape result;
+    int n = details->shape().size();
+    for (int dimension{0}; dimension < n; ++dimension) {
+      result.emplace_back(GetExtent(symbol, component, dimension++));
+    }
+    return result;
+  } else {
+    return std::nullopt;
+  }
+}
+
+std::optional<Shape> GetShape(const BaseObject &object) {
+  if (const Symbol * symbol{object.symbol()}) {
+    return GetShape(*symbol);
+  } else {
+    return Shape{};
+  }
+}
+
+std::optional<Shape> GetShape(const Component &component) {
+  const Symbol &symbol{component.GetLastSymbol()};
+  if (symbol.Rank() > 0) {
+    return GetShape(symbol, &component);
+  } else {
+    return GetShape(component.base());
+  }
+}
+
 std::optional<Shape> GetShape(const ArrayRef &arrayRef) {
   Shape shape;
+  const Symbol &symbol{arrayRef.GetLastSymbol()};
+  const Component *component{std::get_if<Component>(&arrayRef.base())};
+  int dimension{0};
   for (const Subscript &ss : arrayRef.subscript()) {
     if (ss.Rank() > 0) {
-      shape.emplace_back(GetExtent(ss));
+      shape.emplace_back(GetExtent(ss, symbol, component, dimension));
     }
+    ++dimension;
   }
   if (shape.empty()) {
     return GetShape(arrayRef.base());
@@ -99,12 +156,18 @@ std::optional<Shape> GetShape(const ArrayRef &arrayRef) {
     return shape;
   }
 }
+
 std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
   Shape shape;
+  SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()};
+  const Symbol &symbol{coarrayRef.GetLastSymbol()};
+  const Component *component{std::get_if<Component>(&base)};
+  int dimension{0};
   for (const Subscript &ss : coarrayRef.subscript()) {
     if (ss.Rank() > 0) {
-      shape.emplace_back(GetExtent(ss));
+      shape.emplace_back(GetExtent(ss, symbol, component, dimension));
     }
+    ++dimension;
   }
   if (shape.empty()) {
     return GetShape(coarrayRef.GetLastSymbol());
@@ -112,9 +175,11 @@ std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
     return shape;
   }
 }
+
 std::optional<Shape> GetShape(const DataRef &dataRef) {
   return std::visit([](const auto &x) { return GetShape(x); }, dataRef.u);
 }
+
 std::optional<Shape> GetShape(const Substring &substring) {
   if (const auto *dataRef{substring.GetParentIf<DataRef>()}) {
     return GetShape(*dataRef);
@@ -122,7 +187,49 @@ std::optional<Shape> GetShape(const Substring &substring) {
     return std::nullopt;
   }
 }
+
 std::optional<Shape> GetShape(const ComplexPart &part) {
   return GetShape(part.complex());
 }
+
+std::optional<Shape> GetShape(const ActualArgument &arg) {
+  return GetShape(arg.value());
+}
+
+std::optional<Shape> GetShape(const ProcedureRef &call) {
+  if (call.Rank() == 0) {
+    return Shape{};
+  } else if (call.IsElemental()) {
+    for (const auto &arg : call.arguments()) {
+      if (arg.has_value() && arg->Rank() > 0) {
+        return GetShape(*arg);
+      }
+    }
+  } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
+    return GetShape(*symbol);
+  } else if (const auto *intrinsic{
+                 std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
+    if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
+        intrinsic->name == "ubound") {
+      return Shape{Extent{Expr<SubscriptInteger>{
+          call.arguments().front().value().value().Rank()}}};
+    }
+    // TODO: shapes of other non-elemental intrinsic results
+    // esp. reshape, where shape is value of second argument
+  }
+  return std::nullopt;
+}
+
+std::optional<Shape> GetShape(const StructureConstructor &) {
+  return Shape{};  // always scalar
+}
+
+std::optional<Shape> GetShape(const BOZLiteralConstant &) {
+  return Shape{};  // always scalar
+}
+
+std::optional<Shape> GetShape(const NullPointer &) {
+  return {};  // not an object
+}
+
 }
index 30e5e92..44829b6 100644 (file)
@@ -30,51 +30,81 @@ using Extent = std::optional<Expr<SubscriptInteger>>;
 using Shape = std::vector<Extent>;
 
 template<typename A> std::optional<Shape> GetShape(const A &) {
-  return std::nullopt;
+  return std::nullopt;  // default case
 }
 
-template<typename T> std::optional<Shape> GetShape(const Expr<T> &);
-
+// Forward declarations
+template<typename... A>
+std::optional<Shape> GetShape(const std::variant<A...> &);
 template<typename A, bool COPY>
-std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
-  return GetShape(p.value());
-}
-
-template<typename A> std::optional<Shape> GetShape(const std::optional<A> &x) {
-  if (x.has_value()) {
-    return GetShape(*x);
-  } else {
-    return std::nullopt;
-  }
-}
+std::optional<Shape> GetShape(const common::Indirection<A, COPY> &);
+template<typename A> std::optional<Shape> GetShape(const std::optional<A> &);
 
-template<typename... A>
-std::optional<Shape> GetShape(const std::variant<A...> &u) {
-  return std::visit([](const auto &x) { return GetShape(x); }, u);
+template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
+  return GetShape(expr.u);
 }
 
 std::optional<Shape> GetShape(
     const semantics::Symbol &, const Component * = nullptr);
-std::optional<Shape> GetShape(const DataRef &);
-std::optional<Shape> GetShape(const ComplexPart &);
-std::optional<Shape> GetShape(const Substring &);
+std::optional<Shape> GetShape(const BaseObject &);
 std::optional<Shape> GetShape(const Component &);
 std::optional<Shape> GetShape(const ArrayRef &);
 std::optional<Shape> GetShape(const CoarrayRef &);
+std::optional<Shape> GetShape(const DataRef &);
+std::optional<Shape> GetShape(const Substring &);
+std::optional<Shape> GetShape(const ComplexPart &);
+std::optional<Shape> GetShape(const ActualArgument &);
+std::optional<Shape> GetShape(const ProcedureRef &);
+std::optional<Shape> GetShape(const StructureConstructor &);
+std::optional<Shape> GetShape(const BOZLiteralConstant &);
+std::optional<Shape> GetShape(const NullPointer &);
 
 template<typename T>
 std::optional<Shape> GetShape(const Designator<T> &designator) {
-  return std::visit([](const auto &x) { return GetShape(x); }, designator.u);
+  return GetShape(designator.u);
 }
 
-template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
-  return std::visit(
-      common::visitors{
-          [](const BOZLiteralConstant &) { return Shape{}; },
-          [](const NullPointer &) { return std::nullopt; },
-          [](const auto &x) { return GetShape(x); },
-      },
-      expr.u);
+template<typename T>
+std::optional<Shape> GetShape(const Variable<T> &variable) {
+  return GetShape(variable.u);
+}
+
+template<typename D, typename R, typename... O>
+std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
+  if constexpr (operation.operands > 1) {
+    if (operation.right().Rank() > 0) {
+      return GetShape(operation.right());
+    }
+  }
+  return GetShape(operation.left());
+}
+
+template<int KIND>
+std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
+  return Shape{};  // always scalar
+}
+
+template<typename T>
+std::optional<Shape> GetShape(const ArrayConstructorValues<T> &aconst) {
+  return std::nullopt;  // TODO pmk much more here!!
+}
+
+template<typename... A>
+std::optional<Shape> GetShape(const std::variant<A...> &u) {
+  return std::visit([](const auto &x) { return GetShape(x); }, u);
+}
+
+template<typename A, bool COPY>
+std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
+  return GetShape(p.value());
+}
+
+template<typename A> std::optional<Shape> GetShape(const std::optional<A> &x) {
+  if (x.has_value()) {
+    return GetShape(*x);
+  } else {
+    return std::nullopt;
+  }
 }
 }
 #endif  // FORTRAN_EVALUATE_SHAPE_H_
index 8f1a7e8..abcb8b2 100644 (file)
@@ -57,9 +57,7 @@ std::optional<Expr<SubscriptInteger>> Triplet::upper() const {
   return std::nullopt;
 }
 
-const Expr<SubscriptInteger> &Triplet::stride() const {
-  return stride_.value();
-}
+Expr<SubscriptInteger> Triplet::stride() const { return stride_.value(); }
 
 bool Triplet::IsStrideOne() const {
   if (auto stride{ToInt64(stride_.value())}) {
@@ -359,18 +357,18 @@ int ArrayRef::Rank() const {
   }
   return std::visit(
       common::visitors{
-          [=](const Symbol *s) { return s->Rank(); },
+          [=](const Symbol *s) { return 0; },
           [=](const Component &c) { return c.Rank(); },
       },
       base_);
 }
 
 int CoarrayRef::Rank() const {
-  int rank{0};
-  for (const auto &expr : subscript_) {
-    rank += expr.Rank();
-  }
-  if (rank > 0) {
+  if (!subscript_.empty()) {
+    int rank{0};
+    for (const auto &expr : subscript_) {
+      rank += expr.Rank();
+    }
     return rank;
   } else {
     return base_.back()->Rank();
@@ -519,6 +517,21 @@ template<typename T> std::optional<DynamicType> Designator<T>::GetType() const {
   }
 }
 
+SymbolOrComponent CoarrayRef::GetBaseSymbolOrComponent() const {
+  SymbolOrComponent base{base_.front()};
+  int j{0};
+  for (const Symbol *symbol : base_) {
+    if (j == 0) {  // X - already captured the symbol above
+    } else if (j == 1) {  // X%Y
+      base = Component{DataRef{std::get<const Symbol *>(base)}, *symbol};
+    } else {  // X%Y%Z or more
+      base = Component{DataRef{std::move(std::get<Component>(base))}, *symbol};
+    }
+    ++j;
+  }
+  return base;
+}
+
 // Equality testing
 
 bool BaseObject::operator==(const BaseObject &that) const {
index 14ddf5f..468145d 100644 (file)
@@ -142,7 +142,7 @@ public:
       std::optional<Expr<SubscriptInteger>> &&);
   std::optional<Expr<SubscriptInteger>> lower() const;
   std::optional<Expr<SubscriptInteger>> upper() const;
-  const Expr<SubscriptInteger> &stride() const;
+  Expr<SubscriptInteger> stride() const;
   bool operator==(const Triplet &) const;
   bool IsStrideOne() const;
   std::ostream &AsFortran(std::ostream &) const;
@@ -237,6 +237,7 @@ public:
   int Rank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
+  SymbolOrComponent GetBaseSymbolOrComponent() const;
   Expr<SubscriptInteger> LEN() const;
   bool operator==(const CoarrayRef &) const;
   std::ostream &AsFortran(std::ostream &) const;
@@ -404,7 +405,7 @@ public:
 private:
   SymbolOrComponent base_{nullptr};
   Field field_;
-  int dimension_{0};
+  int dimension_{0};  // zero-based
 };
 
 #define INSTANTIATE_VARIABLE_TEMPLATES \
index 771bc8b..b26b8c3 100644 (file)
@@ -241,6 +241,7 @@ private:
 
 class ContextualMessages {
 public:
+  ContextualMessages() = default;
   ContextualMessages(CharBlock at, Messages *m) : at_{at}, messages_{m} {}
   ContextualMessages(const ContextualMessages &that)
     : at_{that.at_}, messages_{that.messages_} {}