[flang] Fix "EQ" comparison of arrays
authorPeter Steinfeld <psteinfeld@nvidia.com>
Fri, 13 Nov 2020 17:31:41 +0000 (09:31 -0800)
committerPeter Steinfeld <psteinfeld@nvidia.com>
Fri, 13 Nov 2020 23:18:13 +0000 (15:18 -0800)
When comparing arrays whose shapes do not conform, the contant folding
code ran into problems trying to get the value of an extent that did not
exist.  There were actually two problems.  First, the routine
"CheckConformance()" was returning "true" when the compiler was unable
to get the extent of an array.  Second, the function
"ApplyElementwise()" was calling "CheckConformance()" prior to folding
the elements of two arrays, but it was ignoring the return value.

Differential Revision: https://reviews.llvm.org/D91440

flang/lib/Evaluate/fold-implementation.h
flang/lib/Evaluate/shape.cpp
flang/test/Semantics/shape.f90 [new file with mode: 0644]

index 78df7e7..ee3aaa3 100644 (file)
@@ -1113,7 +1113,13 @@ auto ApplyElementwise(FoldingContext &context,
         if (rightExpr.Rank() > 0) {
           if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
             if (auto right{AsFlatArrayConstructor(rightExpr)}) {
-              CheckConformance(context.messages(), *leftShape, *rightShape);
+              if (CheckConformance(
+                      context.messages(), *leftShape, *rightShape)) {
+                return MapOperation(context, std::move(f), *leftShape,
+                    std::move(*left), std::move(*right));
+              } else {
+                return std::nullopt;
+              }
               return MapOperation(context, std::move(f), *leftShape,
                   std::move(*left), std::move(*right));
             }
index bfc2447..c672cc1 100644 (file)
@@ -682,6 +682,8 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
   return std::nullopt;
 }
 
+// Check conformance of the passed shapes.  Only return true if we can verify
+// that they conform
 bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
     const Shape &right, const char *leftIs, const char *rightIs) {
   int n{GetRank(left)};
@@ -693,15 +695,16 @@ bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
       return false;
     } else {
       for (int j{0}; j < n; ++j) {
-        if (auto leftDim{ToInt64(left[j])}) {
-          if (auto rightDim{ToInt64(right[j])}) {
-            if (*leftDim != *rightDim) {
-              messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
-                           "but %4$s has extent %5$jd"_err_en_US,
-                  j + 1, leftIs, *leftDim, rightIs, *rightDim);
-              return false;
-            }
-          }
+        auto leftDim{ToInt64(left[j])};
+        auto rightDim{ToInt64(right[j])};
+        if (!leftDim || !rightDim) {
+          return false;
+        }
+        if (*leftDim != *rightDim) {
+          messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
+                       "but %4$s has extent %5$jd"_err_en_US,
+              j + 1, leftIs, *leftDim, rightIs, *rightDim);
+          return false;
         }
       }
     }
diff --git a/flang/test/Semantics/shape.f90 b/flang/test/Semantics/shape.f90
new file mode 100644 (file)
index 0000000..ef0771b
--- /dev/null
@@ -0,0 +1,41 @@
+! RUN: %S/test_errors.sh %s %t %f18
+! Test comparisons that use the intrinsic SHAPE() as an operand
+program testShape
+contains
+  subroutine sub1(arrayDummy)
+    integer :: arrayDummy(:)
+    integer, allocatable :: arrayDeferred(:)
+    integer :: arrayLocal(2) = [88, 99]
+    if (all(shape(arrayDummy)==shape(8))) then
+      print *, "hello"
+    end if
+    if (all(shape(27)==shape(arrayDummy))) then
+      print *, "hello"
+    end if
+    if (all(64==shape(arrayDummy))) then
+      print *, "hello"
+    end if
+    if (all(shape(arrayDeferred)==shape(8))) then
+      print *, "hello"
+    end if
+    if (all(shape(27)==shape(arrayDeferred))) then
+      print *, "hello"
+    end if
+    if (all(64==shape(arrayDeferred))) then
+      print *, "hello"
+    end if
+    !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0
+    !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0
+    if (all(shape(arrayLocal)==shape(8))) then
+      print *, "hello"
+    end if
+    !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1
+    !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1
+    if (all(shape(27)==shape(arrayLocal))) then
+      print *, "hello"
+    end if
+    if (all(64==shape(arrayLocal))) then
+      print *, "hello"
+    end if
+  end subroutine sub1
+end program testShape