[flang] Take result length into account in ApplyElementwise folding
authorJean Perier <jperier@nvidia.com>
Thu, 26 Aug 2021 07:44:24 +0000 (09:44 +0200)
committerJean Perier <jperier@nvidia.com>
Thu, 26 Aug 2021 07:46:14 +0000 (09:46 +0200)
ApplyElementwise on character operation was always creating a result
ArrayConstructor with the length of the left operand. This is not
correct for concatenation and SetLength operations.

Compute and thread the length to the spot creating the ArrayConstructor
so that the length is correct for those character operations.

Differential Revision: https://reviews.llvm.org/D108711

flang/lib/Evaluate/fold-implementation.h
flang/test/Evaluate/folding22.f90 [new file with mode: 0644]

index aeb9553..5f975a6 100644 (file)
@@ -898,12 +898,24 @@ Expr<RESULT> MapOperation(FoldingContext &context,
       context, std::move(result), AsConstantExtents(context, shape));
 }
 
+template <typename RESULT, typename A>
+ArrayConstructor<RESULT> ArrayConstructorFromMold(
+    const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
+  if constexpr (RESULT::category == TypeCategory::Character) {
+    return ArrayConstructor<RESULT>{
+        std::move(length.value()), ArrayConstructorValues<RESULT>{}};
+  } else {
+    return ArrayConstructor<RESULT>{prototype};
+  }
+}
+
 // array * array case
 template <typename RESULT, typename LEFT, typename RIGHT>
 Expr<RESULT> MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
-    const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
-  ArrayConstructor<RESULT> result{leftValues};
+    const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
+    Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
+  auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
     std::visit(
@@ -942,9 +954,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
 template <typename RESULT, typename LEFT, typename RIGHT>
 Expr<RESULT> MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
-    const Shape &shape, Expr<LEFT> &&leftValues,
-    const Expr<RIGHT> &rightScalar) {
-  ArrayConstructor<RESULT> result{leftValues};
+    const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
+    Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
+  auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
   for (auto &leftValue : leftArrConst) {
     auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
@@ -959,9 +971,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
 template <typename RESULT, typename LEFT, typename RIGHT>
 Expr<RESULT> MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
-    const Shape &shape, const Expr<LEFT> &leftScalar,
-    Expr<RIGHT> &&rightValues) {
-  ArrayConstructor<RESULT> result{leftScalar};
+    const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
+    const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
+  auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
     std::visit(
         [&](auto &&kindExpr) {
@@ -987,6 +999,15 @@ Expr<RESULT> MapOperation(FoldingContext &context,
       context, std::move(result), AsConstantExtents(context, shape));
 }
 
+template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
+std::optional<Expr<SubscriptInteger>> ComputeResultLength(
+    Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
+  if constexpr (RESULT::category == TypeCategory::Character) {
+    return Expr<RESULT>{operation.derived()}.LEN();
+  }
+  return std::nullopt;
+}
+
 // ApplyElementwise() recursively folds the operand expression(s) of an
 // operation, then attempts to apply the operation to the (corresponding)
 // scalar element(s) of those operands.  Returns std::nullopt for scalars
@@ -1024,6 +1045,7 @@ auto ApplyElementwise(FoldingContext &context,
     Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
     -> std::optional<Expr<RESULT>> {
+  auto resultLength{ComputeResultLength(operation)};
   auto &leftExpr{operation.left()};
   leftExpr = Fold(context, std::move(leftExpr));
   auto &rightExpr{operation.right()};
@@ -1038,25 +1060,26 @@ auto ApplyElementwise(FoldingContext &context,
                       CheckConformanceFlags::EitherScalarExpandable)
                       .value_or(false /*fail if not known now to conform*/)) {
                 return MapOperation(context, std::move(f), *leftShape,
-                    std::move(*left), std::move(*right));
+                    std::move(resultLength), std::move(*left),
+                    std::move(*right));
               } else {
                 return std::nullopt;
               }
               return MapOperation(context, std::move(f), *leftShape,
-                  std::move(*left), std::move(*right));
+                  std::move(resultLength), std::move(*left), std::move(*right));
             }
           }
         } else if (IsExpandableScalar(rightExpr)) {
-          return MapOperation(
-              context, std::move(f), *leftShape, std::move(*left), rightExpr);
+          return MapOperation(context, std::move(f), *leftShape,
+              std::move(resultLength), std::move(*left), rightExpr);
         }
       }
     }
   } else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
     if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
       if (auto right{AsFlatArrayConstructor(rightExpr)}) {
-        return MapOperation(
-            context, std::move(f), *shape, leftExpr, std::move(*right));
+        return MapOperation(context, std::move(f), *shape,
+            std::move(resultLength), leftExpr, std::move(*right));
       }
     }
   }
diff --git a/flang/test/Evaluate/folding22.f90 b/flang/test/Evaluate/folding22.f90
new file mode 100644 (file)
index 0000000..6e7ccc7
--- /dev/null
@@ -0,0 +1,23 @@
+! RUN: %S/test_folding.sh %s %t %flang_fc1
+! REQUIRES: shell
+
+! Test character concatenation folding
+
+logical, parameter :: test_scalar_scalar =  ('ab' // 'cde').eq.('abcde')
+
+character(2), parameter :: scalar_array(2) =  ['1','2'] // 'a'
+logical, parameter :: test_scalar_array = all(scalar_array.eq.(['1a', '2a']))
+
+character(2), parameter :: array_scalar(2) =  '1' // ['a', 'b']
+logical, parameter :: test_array_scalar = all(array_scalar.eq.(['1a', '1b']))
+
+character(2), parameter :: array_array(2) =  ['1','2'] // ['a', 'b']
+logical, parameter :: test_array_array = all(array_array.eq.(['1a', '2b']))
+
+
+character(1), parameter :: input(2) = ['x', 'y']
+character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
+logical, parameter :: test_zero_sized = len(zero_sized).eq.6
+
+end