[flang] Fold LBOUND and UBOUND; do not insert empty triplets into whole array expressions
authorpeter klausler <pklausler@nvidia.com>
Tue, 30 Jul 2019 23:51:25 +0000 (16:51 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 2 Aug 2019 16:22:00 +0000 (09:22 -0700)
Original-commit: flang-compiler/f18@82fba68a665802c990f35d14222f8df4ac4e1dee
Reviewed-on: https://github.com/flang-compiler/f18/pull/611
Tree-same-pre-rewrite: false

flang/lib/evaluate/fold.cc
flang/lib/evaluate/shape.cc
flang/lib/evaluate/shape.h
flang/lib/evaluate/tools.h
flang/lib/evaluate/variable.h
flang/lib/semantics/expression.cc
flang/lib/semantics/tools.cc
flang/lib/semantics/tools.h
flang/test/evaluate/CMakeLists.txt
flang/test/evaluate/folding08.f90 [new file with mode: 0644]

index ec56acf..c368967 100644 (file)
@@ -31,6 +31,7 @@
 #include "../parser/message.h"
 #include "../semantics/scope.h"
 #include "../semantics/symbol.h"
+#include "../semantics/tools.h"
 #include <cmath>
 #include <complex>
 #include <cstdio>
@@ -206,15 +207,13 @@ using ScalarFuncWithContext =
 template<typename T>
 static inline Constant<T> *FoldConvertedArg(
     FoldingContext &context, std::optional<ActualArgument> &arg) {
-  if (arg.has_value()) {
-    if (auto *expr{arg->UnwrapExpr()}) {
-      if (UnwrapExpr<Expr<T>>(*expr) == nullptr) {
-        if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
-          *expr = Fold(context, std::move(*converted));
-        }
+  if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
+    if (UnwrapExpr<Expr<T>>(*expr) == nullptr) {
+      if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
+        *expr = Fold(context, std::move(*converted));
       }
-      return UnwrapConstantValue<T>(*expr);
     }
+    return UnwrapConstantValue<T>(*expr);
   }
   return nullptr;
 }
@@ -533,7 +532,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
               return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
             }));
   } else if (name == "int") {
-    if (auto *expr{args[0].value().UnwrapExpr()}) {
+    if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
       return std::visit(
           [&](auto &&x) -> Expr<T> {
             using From = std::decay_t<decltype(x)>;
@@ -551,6 +550,52 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     } else {
       common::die("kind() result not integral");
     }
+  } else if (name == "lbound") {
+    if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
+      if (int rank{array->Rank()}) {
+        std::optional<std::int64_t> dim;
+        if (args[1].has_value()) {
+          dim = GetInt64Arg(args[1]);
+          if (!dim.has_value()) {
+            // DIM= is present but not constant
+            return Expr<T>{std::move(funcRef)};
+          } else if (*dim < 1 || *dim > rank) {
+            context.messages().Say(
+                "LBOUND(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
+                static_cast<std::intmax_t>(*dim), rank);
+            return Expr<T>(std::move(funcRef));
+          }
+        }
+        bool lowerBoundsAreOne{true};
+        if (auto named{ExtractNamedEntity(*array)}) {
+          const Symbol &symbol{named->GetLastSymbol()};
+          if (symbol.Rank() == rank) {
+            lowerBoundsAreOne = false;
+            if (dim.has_value()) {
+              if (auto lb{
+                      GetLowerBound(context, *named, static_cast<int>(*dim))}) {
+                return Fold(context, ConvertToType<T>(std::move(*lb)));
+              }
+            } else if (auto lbounds{
+                           AsConstantShape(GetLowerBounds(context, *named))}) {
+              return Fold(context,
+                  ConvertToType<T>(Expr<ExtentType>{std::move(*lbounds)}));
+            }
+          } else {
+            lowerBoundsAreOne = symbol.Rank() == 0;  // component
+          }
+        }
+        if (lowerBoundsAreOne) {
+          if (dim.has_value()) {
+            return Expr<T>{1};
+          } else {
+            std::vector<Scalar<T>> ones(rank, Scalar<T>{1});
+            return Expr<T>{
+                Constant<T>{std::move(ones), ConstantSubscripts{rank}}};
+          }
+        }
+      }
+    }
   } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
       name == "popcnt") {
     if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
@@ -688,16 +733,16 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
       }
     }
   } else if (name == "shape") {
-    if (auto shape{GetShape(context, args[0].value())}) {
+    if (auto shape{GetShape(context, args[0])}) {
       if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
         return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
       }
     }
   } else if (name == "size") {
-    if (auto shape{GetShape(context, args[0].value())}) {
+    if (auto shape{GetShape(context, args[0])}) {
       if (auto &dimArg{args[1]}) {  // DIM= is present, get one extent
         if (auto dim{GetInt64Arg(args[1])}) {
-          int rank = GetRank(*shape);
+          int rank{GetRank(*shape)};
           if (*dim >= 1 && *dim <= rank) {
             if (auto &extent{shape->at(*dim - 1)}) {
               return Fold(context, ConvertToType<T>(std::move(*extent)));
@@ -717,13 +762,70 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
         return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
       }
     }
+  } else if (name == "ubound") {
+    if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
+      if (int rank{array->Rank()}; rank > 0) {
+        std::optional<std::int64_t> dim;
+        if (args[1].has_value()) {
+          dim = GetInt64Arg(args[1]);
+          if (!dim.has_value()) {
+            // DIM= is present but not constant
+            return Expr<T>{std::move(funcRef)};
+          } else if (*dim < 1 || *dim > rank) {
+            context.messages().Say(
+                "UBOUND(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
+                static_cast<std::intmax_t>(*dim), rank);
+            return Expr<T>(std::move(funcRef));
+          }
+        }
+        bool takeBoundsFromShape{true};
+        if (auto named{ExtractNamedEntity(*array)}) {
+          const Symbol &symbol{named->GetLastSymbol()};
+          if (symbol.Rank() == rank) {
+            takeBoundsFromShape = false;
+            if (dim.has_value()) {
+              if (semantics::IsAssumedSizeArray(symbol) && *dim == rank) {
+                return Expr<T>{-1};
+              } else if (auto ub{GetUpperBound(
+                             context, *named, static_cast<int>(*dim))}) {
+                return Fold(context, ConvertToType<T>(std::move(*ub)));
+              }
+            } else {
+              Shape ubounds{GetUpperBounds(context, *named)};
+              if (semantics::IsAssumedSizeArray(symbol)) {
+                CHECK(!ubounds.back().has_value());
+                ubounds.back() = ExtentExpr{-1};
+              }
+              if (auto constant{AsConstantShape(ubounds)}) {
+                return Fold(context,
+                    ConvertToType<T>(Expr<ExtentType>{std::move(*constant)}));
+              }
+            }
+          } else {
+            takeBoundsFromShape = symbol.Rank() == 0;  // component
+          }
+        }
+        if (takeBoundsFromShape) {
+          if (auto shape{GetShape(context, *array)}) {
+            if (dim.has_value()) {
+              if (auto &dimSize{shape->at(*dim)}) {
+                return Fold(context,
+                    ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
+              }
+            } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
+              return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
+            }
+          }
+        }
+      }
+    }
   }
   // TODO:
   // ceiling, count, cshift, dot_product, eoshift,
   // findloc, floor, iall, iany, iparity, ibits, image_status, index, ishftc,
-  // lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
+  // len_trim, matmul, max, maxloc, maxval, merge, min,
   // minloc, minval, mod, modulo, nint, not, pack, product, reduce,
-  // scan, sign, spread, sum, transfer, transpose, ubound, unpack, verify
+  // scan, sign, spread, sum, transfer, transpose, unpack, verify
   return Expr<T>{std::move(funcRef)};
 }
 
index f3dc9c3..4bf3d30 100644 (file)
@@ -266,7 +266,7 @@ MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
           [&](const Triplet &triplet) -> MaybeExtentExpr {
             MaybeExtentExpr upper{triplet.upper()};
             if (!upper.has_value()) {
-              upper = GetExtent(context, base, dimension);
+              upper = GetUpperBound(context, base, dimension);
             }
             MaybeExtentExpr lower{triplet.lower()};
             if (!lower.has_value()) {
@@ -298,12 +298,46 @@ MaybeExtentExpr GetUpperBound(FoldingContext &context, MaybeExtentExpr &&lower,
   }
 }
 
+MaybeExtentExpr GetUpperBound(
+    FoldingContext &context, const NamedEntity &base, int dimension) {
+  const Symbol &symbol{base.GetLastSymbol()};
+  if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+    int j{0};
+    for (const auto &shapeSpec : details->shape()) {
+      if (j++ == dimension) {
+        if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
+          return Fold(context, common::Clone(*bound));
+        } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
+          break;
+        } else {
+          return GetUpperBound(context, GetLowerBound(context, base, dimension),
+              GetExtent(context, base, dimension));
+        }
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
+  int rank{base.GetLastSymbol().Rank()};
+  Shape result;
+  for (int dim{0}; dim < rank; ++dim) {
+    result.emplace_back(GetUpperBound(context, base, dim));
+  }
+  return result;
+}
+
 void GetShapeVisitor::Handle(const Symbol &symbol) {
   Handle(NamedEntity{symbol});
 }
 
 void GetShapeVisitor::Handle(const Component &component) {
-  Handle(NamedEntity{Component{component}});
+  if (component.GetLastSymbol().Rank() > 0) {
+    Handle(NamedEntity{Component{component}});
+  } else {
+    Nested(component.base());
+  }
 }
 
 void GetShapeVisitor::Handle(const NamedEntity &base) {
@@ -326,6 +360,10 @@ void GetShapeVisitor::Handle(const NamedEntity &base) {
   Return();
 }
 
+void GetShapeVisitor::Handle(const Substring &substring) {
+  Nested(substring.parent());
+}
+
 void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
   Shape shape;
   int dimension{0};
index 069e66b..3614d02 100644 (file)
@@ -69,6 +69,9 @@ MaybeExtentExpr GetExtent(
     FoldingContext &, const Subscript &, const NamedEntity &, int dimension);
 MaybeExtentExpr GetUpperBound(
     FoldingContext &, MaybeExtentExpr &&lower, MaybeExtentExpr &&extent);
+MaybeExtentExpr GetUpperBound(
+    FoldingContext &, const NamedEntity &, int dimension);
+Shape GetUpperBounds(FoldingContext &, const NamedEntity &);
 
 // Compute an element count for a triplet or trip count for a DO.
 ExtentExpr CountTrips(
@@ -104,6 +107,7 @@ public:
   void Handle(const StaticDataObject::Pointer &) { Scalar(); }
   void Handle(const ArrayRef &);
   void Handle(const CoarrayRef &);
+  void Handle(const Substring &);
   void Handle(const ProcedureRef &);
   void Handle(const StructureConstructor &) { Scalar(); }
   template<typename T> void Handle(const ArrayConstructor<T> &aconst) {
index 02e9b23..390f863 100644 (file)
@@ -209,7 +209,7 @@ template<typename A, typename B> A *UnwrapExpr(std::optional<B> &x) {
 // If an expression simply wraps a DataRef, extract and return it.
 template<typename A>
 common::IfNoLvalue<std::optional<DataRef>, A> ExtractDataRef(const A &) {
-  return std::nullopt;  // default base casec
+  return std::nullopt;  // default base case
 }
 template<typename T>
 std::optional<DataRef> ExtractDataRef(const Designator<T> &d) {
@@ -235,6 +235,24 @@ std::optional<DataRef> ExtractDataRef(const std::optional<A> &x) {
   }
 }
 
+template<typename A> std::optional<NamedEntity> ExtractNamedEntity(const A &x) {
+  if (auto dataRef{ExtractDataRef(x)}) {
+    return std::visit(
+        common::visitors{
+            [](const Symbol *symbol) -> std::optional<NamedEntity> {
+              return NamedEntity{*symbol};
+            },
+            [](Component &&component) -> std::optional<NamedEntity> {
+              return NamedEntity{std::move(component)};
+            },
+            [](auto &&) -> std::optional<NamedEntity> { return std::nullopt; },
+        },
+        std::move(dataRef->u));
+  } else {
+    return std::nullopt;
+  }
+}
+
 // If an expression is simply a whole symbol data designator,
 // extract and return that symbol, else null.
 template<typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
index e8feb9a..3bd95b9 100644 (file)
@@ -340,7 +340,7 @@ public:
 private:
   void SetBounds(std::optional<Expr<SubscriptInteger>> &,
       std::optional<Expr<SubscriptInteger>> &);
-  std::variant<DataRef, StaticDataObject::Pointer> parent_;
+  Parent parent_;
   std::optional<IndirectSubscriptIntegerExpr> lower_, upper_;
 };
 
index c537fe0..e306759 100644 (file)
@@ -153,25 +153,15 @@ MaybeExpr ExpressionAnalyzer::Designate(DataRef &&ref) {
 // subscripts are in hand.
 MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) {
   const Symbol &symbol{ref.GetLastSymbol().GetUltimate()};
+  const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()};
   int symbolRank{symbol.Rank()};
   int subscripts{static_cast<int>(ref.size())};
   if (subscripts == 0) {
-    if (semantics::IsAssumedSizeArray(symbol)) {
-      // Don't introduce a triplet that would later be caught
-      // as being invalid.
-      return Designate(DataRef{std::move(ref)});
-    }
-    // A -> A(:,:)
-    for (; subscripts < symbolRank; ++subscripts) {
-      ref.emplace_back(Triplet{});
-    }
-  }
-  if (subscripts != symbolRank) {
+    // nothing to check
+  } else if (subscripts != symbolRank) {
     Say("Reference to rank-%d object '%s' has %d subscripts"_err_en_US,
         symbolRank, symbol.name(), subscripts);
     return std::nullopt;
-  } else if (subscripts == 0) {
-    // nothing to check
   } else if (Component * component{ref.base().UnwrapComponent()}) {
     int baseRank{component->base().Rank()};
     if (baseRank > 0) {
@@ -186,11 +176,10 @@ MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) {
         return std::nullopt;
       }
     }
-  } else if (const auto *details{
-                 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+  } else if (object != nullptr) {
     // C928 & C1002
     if (Triplet * last{std::get_if<Triplet>(&ref.subscript().back().u)}) {
-      if (!last->upper().has_value() && details->IsAssumedSize()) {
+      if (!last->upper().has_value() && object->IsAssumedSize()) {
         Say("Assumed-size array '%s' must have explicit final "
             "subscript upper bound value"_err_en_US,
             symbol.name());
@@ -221,10 +210,8 @@ MaybeExpr ExpressionAnalyzer::ApplySubscripts(
       std::move(dataRef.u));
 }
 
-// Top-level checks for data references.  Unsubscripted whole array references
-// get expanded -- e.g., MATRIX becomes MATRIX(:,:).
+// Top-level checks for data references.
 MaybeExpr ExpressionAnalyzer::TopLevelChecks(DataRef &&dataRef) {
-  bool addSubscripts{false};
   if (Component * component{std::get_if<Component>(&dataRef.u)}) {
     const Symbol &symbol{component->GetLastSymbol()};
     int componentRank{symbol.Rank()};
@@ -234,18 +221,8 @@ MaybeExpr ExpressionAnalyzer::TopLevelChecks(DataRef &&dataRef) {
         Say("Reference to whole rank-%d component '%%%s' of "
             "rank-%d array of derived type is not allowed"_err_en_US,
             componentRank, symbol.name(), baseRank);
-      } else {
-        addSubscripts = true;
       }
     }
-  } else if (const Symbol **symbol{std::get_if<const Symbol *>(&dataRef.u)}) {
-    addSubscripts = (*symbol)->Rank() > 0;
-  }
-  if (addSubscripts) {
-    if (MaybeExpr subscripted{
-            ApplySubscripts(std::move(dataRef), std::vector<Subscript>{})}) {
-      return subscripted;
-    }
   }
   return Designate(std::move(dataRef));
 }
index f3fabd3..7d017b3 100644 (file)
@@ -478,11 +478,6 @@ bool IsFinalizable(const Symbol &symbol) {
 
 bool IsCoarray(const Symbol &symbol) { return symbol.Corank() > 0; }
 
-bool IsAssumedSizeArray(const Symbol &symbol) {
-  const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
-  return details && details->IsAssumedSize();
-}
-
 bool IsExternalInPureContext(const Symbol &symbol, const Scope &scope) {
   if (const auto *pureProc{semantics::FindPureProcedureContaining(&scope)}) {
     if (const Symbol * root{GetAssociationRoot(symbol)}) {
index 6cb8f4a..aeee4c7 100644 (file)
@@ -110,7 +110,10 @@ inline bool IsProtected(const Symbol &symbol) {
 }
 bool IsFinalizable(const Symbol &symbol);
 bool IsCoarray(const Symbol &symbol);
-bool IsAssumedSizeArray(const Symbol &symbol);
+inline bool IsAssumedSizeArray(const Symbol &symbol) {
+  const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
+  return details && details->IsAssumedSize();
+}
 std::optional<parser::MessageFixedText> WhyNotModifiable(
     const Symbol &symbol, const Scope &scope);
 // Is the symbol modifiable in this scope
index f745042..62033d0 100644 (file)
@@ -129,6 +129,7 @@ set(FOLDING_TESTS
   folding05.f90
   folding06.f90
   folding07.f90
+  folding08.f90
 )
 
 
diff --git a/flang/test/evaluate/folding08.f90 b/flang/test/evaluate/folding08.f90
new file mode 100644 (file)
index 0000000..7c2dce9
--- /dev/null
@@ -0,0 +1,52 @@
+! Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+!
+! Licensed under the Apache License, Version 2.0 (the "License");
+! you may not use this file except in compliance with the License.
+! You may obtain a copy of the License at
+!
+!     http://www.apache.org/licenses/LICENSE-2.0
+!
+! Unless required by applicable law or agreed to in writing, software
+! distributed under the License is distributed on an "AS IS" BASIS,
+! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+! implied.
+! See the License for the specific language governing permissions and
+! limitations under the License.
+
+! Test folding of LBOUND and UBOUND
+
+subroutine testlbound(n1,a1,a2)
+  integer, intent(in) :: n1
+  real, intent(in) :: a1(0:n1), a2(0:*)
+  type :: t
+    real :: a
+  end type
+  type(t) :: ta(0:2)
+  character(len=2) :: ca(-1:1)
+  integer, parameter :: lba1(:) = lbound(a1)
+  logical, parameter :: test_lba1 = all(lba1 == [0])
+  integer, parameter :: lba2(:) = lbound(a2)
+  logical, parameter :: test_lba2 = all(lba2 == [0])
+  integer, parameter :: uba2(:) = ubound(a2)
+  logical, parameter :: test_uba2 = all(uba2 == [-1])
+  integer, parameter :: lbta1(:) = lbound(ta)
+  logical, parameter :: test_lbta1 = all(lbta1 == [0])
+  integer, parameter :: ubta1(:) = ubound(ta)
+  logical, parameter :: test_ubta1 = all(ubta1 == [2])
+  integer, parameter :: lbta2(:) = lbound(ta(:))
+  logical, parameter :: test_lbta2 = all(lbta2 == [1])
+  integer, parameter :: ubta2(:) = ubound(ta(:))
+  logical, parameter :: test_ubta2 = all(ubta2 == [3])
+  integer, parameter :: lbta3(:) = lbound(ta%a)
+  logical, parameter :: test_lbta3 = all(lbta3 == [1])
+  integer, parameter :: ubta3(:) = ubound(ta%a)
+  logical, parameter :: test_ubta3 = all(ubta3 == [3])
+  integer, parameter :: lbca1(:) = lbound(ca)
+  logical, parameter :: test_lbca1 = all(lbca1 == [-1])
+  integer, parameter :: ubca1(:) = ubound(ca)
+  logical, parameter :: test_ubca1 = all(ubca1 == [1])
+  integer, parameter :: lbca2(:) = lbound(ca(:)(1:1))
+  logical, parameter :: test_lbca2 = all(lbca2 == [1])
+  integer, parameter :: ubca2(:) = ubound(ca(:)(1:1))
+  logical, parameter :: test_ubca2 = all(ubca2 == [3])
+end