[flang] Fix LBOUND & UBOUND(function()), add tests
authorpeter klausler <pklausler@nvidia.com>
Thu, 1 Aug 2019 19:32:17 +0000 (12:32 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 2 Aug 2019 16:22:10 +0000 (09:22 -0700)
Original-commit: flang-compiler/f18@1e093e9927f94a06a19700a6de26b5bfefa659d9
Reviewed-on: https://github.com/flang-compiler/f18/pull/611
Tree-same-pre-rewrite: false

flang/lib/evaluate/expression.cc
flang/lib/evaluate/shape.cc
flang/test/evaluate/folding08.f90

index 4d07e9b..0d412c9 100644 (file)
@@ -85,11 +85,11 @@ std::optional<DynamicType> ExpressionBase<A>::GetType() const {
     return Result::GetType();
   } else {
     return std::visit(
-        [&](const auto &x) {
+        [&](const auto &x) -> std::optional<DynamicType> {
           if constexpr (!common::HasMember<decltype(x), TypelessExpression>) {
             return x.GetType();
           } else {
-            return std::optional<DynamicType>{};
+            return std::nullopt;
           }
         },
         derived().u);
index d7d2b1a..6f03378 100644 (file)
@@ -357,7 +357,31 @@ Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
 }
 
 void GetShapeVisitor::Handle(const Symbol &symbol) {
-  Handle(NamedEntity{symbol});
+  std::visit(
+      common::visitors{
+          [&](const semantics::ObjectEntityDetails &object) {
+            Handle(NamedEntity{symbol});
+          },
+          [&](const semantics::AssocEntityDetails &assoc) {
+            Nested(assoc.expr());
+          },
+          [&](const semantics::SubprogramDetails &subp) {
+            if (subp.isFunction()) {
+              Handle(subp.result());
+            } else {
+              Return();
+            }
+          },
+          [&](const semantics::ProcBindingDetails &binding) {
+            Handle(binding.symbol());
+          },
+          [&](const semantics::UseDetails &use) { Handle(use.symbol()); },
+          [&](const semantics::HostAssocDetails &assoc) {
+            Handle(assoc.symbol());
+          },
+          [&](const auto &) { Return(); },
+      },
+      symbol.details());
 }
 
 void GetShapeVisitor::Handle(const Component &component) {
@@ -369,23 +393,21 @@ void GetShapeVisitor::Handle(const Component &component) {
 }
 
 void GetShapeVisitor::Handle(const NamedEntity &base) {
-  const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
-  if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+  const Symbol &symbol{base.GetLastSymbol()};
+  if (const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     if (IsImpliedShape(symbol)) {
-      Nested(details->init());
+      Nested(object->init());
     } else {
       Shape result;
-      int n{static_cast<int>(details->shape().size())};
+      int n{static_cast<int>(object->shape().size())};
       for (int dimension{0}; dimension < n; ++dimension) {
         result.emplace_back(GetExtent(context_, base, dimension));
       }
       Return(std::move(result));
     }
-  } else if (const auto *details{
-                 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
-    Nested(details->expr());
+  } else {
+    Return();  // error recovery
   }
-  Return();
 }
 
 void GetShapeVisitor::Handle(const Substring &substring) {
index 7c9dfe1..770dfa8 100644 (file)
 
 ! Test folding of LBOUND and UBOUND
 
-subroutine test(n1,a1,a2)
-  integer, intent(in) :: n1
-  real, intent(in) :: a1(0:n1), a2(0:*)
-  type :: t
-    real :: a
-  end type
-  type(t) :: ta(0:2)
-  character(len=2) :: ca(-1:1)
-  integer, parameter :: lba1(:) = lbound(a1)
-  logical, parameter :: test_lba1 = all(lba1 == [0])
-  integer, parameter :: lba2(:) = lbound(a2)
-  logical, parameter :: test_lba2 = all(lba2 == [0])
-  integer, parameter :: uba2(:) = ubound(a2)
-  logical, parameter :: test_uba2 = all(uba2 == [-1])
-  integer, parameter :: lbta1(:) = lbound(ta)
-  logical, parameter :: test_lbta1 = all(lbta1 == [0])
-  integer, parameter :: ubta1(:) = ubound(ta)
-  logical, parameter :: test_ubta1 = all(ubta1 == [2])
-  integer, parameter :: lbta2(:) = lbound(ta(:))
-  logical, parameter :: test_lbta2 = all(lbta2 == [1])
-  integer, parameter :: ubta2(:) = ubound(ta(:))
-  logical, parameter :: test_ubta2 = all(ubta2 == [3])
-  integer, parameter :: lbta3(:) = lbound(ta%a)
-  logical, parameter :: test_lbta3 = all(lbta3 == [1])
-  integer, parameter :: ubta3(:) = ubound(ta%a)
-  logical, parameter :: test_ubta3 = all(ubta3 == [3])
-  integer, parameter :: lbca1(:) = lbound(ca)
-  logical, parameter :: test_lbca1 = all(lbca1 == [-1])
-  integer, parameter :: ubca1(:) = ubound(ca)
-  logical, parameter :: test_ubca1 = all(ubca1 == [1])
-  integer, parameter :: lbca2(:) = lbound(ca(:)(1:1))
-  logical, parameter :: test_lbca2 = all(lbca2 == [1])
-  integer, parameter :: ubca2(:) = ubound(ca(:)(1:1))
-  logical, parameter :: test_ubca2 = all(ubca2 == [3])
+module m
+ contains
+  function foo()
+    real :: foo(2:3,4:6)
+  end function
+  subroutine test(n1,a1,a2)
+    integer, intent(in) :: n1
+    real, intent(in) :: a1(0:n1), a2(0:*)
+    type :: t
+      real :: a
+    end type
+    type(t) :: ta(0:2)
+    character(len=2) :: ca(-1:1)
+    integer, parameter :: lba1(:) = lbound(a1)
+    logical, parameter :: test_lba1 = all(lba1 == [0])
+    integer, parameter :: lba2(:) = lbound(a2)
+    logical, parameter :: test_lba2 = all(lba2 == [0])
+    integer, parameter :: uba2(:) = ubound(a2)
+    logical, parameter :: test_uba2 = all(uba2 == [-1])
+    integer, parameter :: lbta1(:) = lbound(ta)
+    logical, parameter :: test_lbta1 = all(lbta1 == [0])
+    integer, parameter :: ubta1(:) = ubound(ta)
+    logical, parameter :: test_ubta1 = all(ubta1 == [2])
+    integer, parameter :: lbta2(:) = lbound(ta(:))
+    logical, parameter :: test_lbta2 = all(lbta2 == [1])
+    integer, parameter :: ubta2(:) = ubound(ta(:))
+    logical, parameter :: test_ubta2 = all(ubta2 == [3])
+    integer, parameter :: lbta3(:) = lbound(ta%a)
+    logical, parameter :: test_lbta3 = all(lbta3 == [1])
+    integer, parameter :: ubta3(:) = ubound(ta%a)
+    logical, parameter :: test_ubta3 = all(ubta3 == [3])
+    integer, parameter :: lbca1(:) = lbound(ca)
+    logical, parameter :: test_lbca1 = all(lbca1 == [-1])
+    integer, parameter :: ubca1(:) = ubound(ca)
+    logical, parameter :: test_ubca1 = all(ubca1 == [1])
+    integer, parameter :: lbca2(:) = lbound(ca(:)(1:1))
+    logical, parameter :: test_lbca2 = all(lbca2 == [1])
+    integer, parameter :: ubca2(:) = ubound(ca(:)(1:1))
+    logical, parameter :: test_ubca2 = all(ubca2 == [3])
+    integer, parameter :: lbfoo(:) = lbound(foo())
+    logical, parameter :: test_lbfoo = all(lbfoo == [1,1])
+    integer, parameter :: ubfoo(:) = ubound(foo())
+    logical, parameter :: test_ubfoo = all(ubfoo == [2,3])
+  end subroutine
 end