[flang] SELECT CASE constructs with character selectors that require a temp
authorValentin Clement <clementval@gmail.com>
Thu, 30 Jun 2022 07:03:49 +0000 (09:03 +0200)
committerValentin Clement <clementval@gmail.com>
Thu, 30 Jun 2022 07:04:27 +0000 (09:04 +0200)
Here is a character SELECT CASE construct that requires a temp to hold the
result of the TRIM intrinsic call:

```
module m
      character(len=6) :: s
    contains
      subroutine sc
        n = 0
        if (lge(s,'00')) then
          select case(trim(s))
          case('11')
             n = 1
          case default
             continue
          case('22')
             n = 2
          case('33')
             n = 3
          case('44':'55','66':'77','88':)
             n = 4
          end select
        end if
        print*, n
      end subroutine
    end module m
```

This SELECT CASE construct is implemented as an IF/ELSE-IF/ELSE comparison
sequence.  The temp must be retained until some comparison is successful.
At that point the temp may be freed.  Generalize statement context processing
to allow multiple finalize calls to do this, such that the program always
executes exactly one freemem call.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: klausler, vdonaldson

Differential Revision: https://reviews.llvm.org/D128852

Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
flang/include/flang/Lower/StatementContext.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/ConvertExpr.cpp
flang/lib/Lower/IO.cpp
flang/test/Lower/select-case-statement.f90

index 58cb9e9..69ceeae 100644 (file)
@@ -35,7 +35,7 @@ public:
 
   ~StatementContext() {
     if (!cufs.empty())
-      finalize(/*popScope=*/true);
+      finalizeAndPop();
     assert(cufs.empty() && "invalid StatementContext destructor call");
   }
 
@@ -61,15 +61,29 @@ public:
     }
   }
 
-  /// Make cleanup calls.  Pop or reset the stack top list.
-  void finalize(bool popScope = false) {
+  /// Make cleanup calls.  Retain the stack top list for a repeat call.
+  void finalizeAndKeep() {
     assert(!cufs.empty() && "invalid finalize statement context");
     if (cufs.back())
       (*cufs.back())();
-    if (popScope)
-      cufs.pop_back();
-    else
-      cufs.back().reset();
+  }
+
+  /// Make cleanup calls.  Pop the stack top list.
+  void finalizeAndPop() {
+    finalizeAndKeep();
+    cufs.pop_back();
+  }
+
+  /// Make cleanup calls.  Clear the stack top list.
+  void finalize() {
+    finalizeAndKeep();
+    cufs.back().reset();
+  }
+
+  bool workListIsEmpty() const {
+    return cufs.empty() || llvm::all_of(cufs, [](auto &opt) -> bool {
+             return !opt.hasValue();
+           });
   }
 
 private:
index b7d180e..d3bc95a 100644 (file)
@@ -1749,8 +1749,11 @@ private:
     // Generate a sequence of case value comparisons and branches.
     auto caseValue = valueList.begin();
     auto caseBlock = blockList.begin();
-    for (mlir::Attribute attr : attrList) {
-      if (attr.isa<mlir::UnitAttr>()) {
+    bool skipFinalization = false;
+    for (const auto attr : llvm::enumerate(attrList)) {
+      if (attr.value().isa<mlir::UnitAttr>()) {
+        if (attrList.size() == 1)
+          stmtCtx.finalize();
         genFIRBranch(*caseBlock++);
         break;
       }
@@ -1767,16 +1770,30 @@ private:
             charHelper.createUnboxChar(rhs);
         mlir::Value &rhsAddr = rhsVal.first;
         mlir::Value &rhsLen = rhsVal.second;
-        return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr,
-                                            lhsLen, rhsAddr, rhsLen);
+        mlir::Value result = fir::runtime::genCharCompare(
+            *builder, loc, pred, lhsAddr, lhsLen, rhsAddr, rhsLen);
+        if (stmtCtx.workListIsEmpty() || skipFinalization)
+          return result;
+        if (attr.index() == attrList.size() - 2) {
+          stmtCtx.finalize();
+          return result;
+        }
+        fir::IfOp ifOp = builder->create<fir::IfOp>(loc, result,
+                                                    /*withElseRegion=*/false);
+        builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+        stmtCtx.finalizeAndKeep();
+        builder->setInsertionPointAfter(ifOp);
+        return result;
       };
       mlir::Block *newBlock = insertBlock(*caseBlock);
-      if (attr.isa<fir::ClosedIntervalAttr>()) {
+      if (attr.value().isa<fir::ClosedIntervalAttr>()) {
         mlir::Block *newBlock2 = insertBlock(*caseBlock);
+        skipFinalization = true;
         mlir::Value cond =
             genCond(*caseValue++, mlir::arith::CmpIPredicate::sge);
         genFIRConditionalBranch(cond, newBlock, newBlock2);
         builder->setInsertionPointToEnd(newBlock);
+        skipFinalization = false;
         mlir::Value cond2 =
             genCond(*caseValue++, mlir::arith::CmpIPredicate::sle);
         genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
@@ -1784,12 +1801,13 @@ private:
         continue;
       }
       mlir::arith::CmpIPredicate pred;
-      if (attr.isa<fir::PointIntervalAttr>()) {
+      if (attr.value().isa<fir::PointIntervalAttr>()) {
         pred = mlir::arith::CmpIPredicate::eq;
-      } else if (attr.isa<fir::LowerBoundAttr>()) {
+      } else if (attr.value().isa<fir::LowerBoundAttr>()) {
         pred = mlir::arith::CmpIPredicate::sge;
       } else {
-        assert(attr.isa<fir::UpperBoundAttr>() && "unexpected predicate");
+        assert(attr.value().isa<fir::UpperBoundAttr>() &&
+               "unexpected predicate");
         pred = mlir::arith::CmpIPredicate::sle;
       }
       mlir::Value cond = genCond(*caseValue++, pred);
@@ -1798,12 +1816,7 @@ private:
     }
     assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
            "select case list mismatch");
-    // Clean-up the selector at the end of the construct if it is a temporary
-    // (which is possible with characters).
-    mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
-    builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block);
-    stmtCtx.finalize();
-    builder->restoreInsertionPoint(insertPt);
+    assert(stmtCtx.workListIsEmpty() && "statement context must be empty");
   }
 
   fir::ExtendedValue
index 5771a89..e0f7773 100644 (file)
@@ -3813,7 +3813,7 @@ public:
     // be needed afterwards.
     stmtCtx.pushScope();
     [[maybe_unused]] ExtValue loopRes = lowerArrayExpression(expr);
-    stmtCtx.finalize(/*popScope=*/true);
+    stmtCtx.finalizeAndPop();
     assert(fir::getBase(loopRes));
   }
 
@@ -4719,7 +4719,7 @@ private:
   /// fir::ResultOp at the end of the innermost loop.
   void finalizeElementCtx() {
     if (elementCtx) {
-      stmtCtx.finalize(/*popScope=*/true);
+      stmtCtx.finalizeAndPop();
       elementCtx = false;
     }
   }
@@ -6433,7 +6433,7 @@ private:
         builder.create<fir::StoreOp>(loc, castLen, charLen.value());
       }
     }
-    stmtCtx.finalize(/*popScope=*/true);
+    stmtCtx.finalizeAndPop();
 
     builder.create<fir::ResultOp>(loc, mem);
     builder.restoreInsertionPoint(insPt);
index 849288b..b3bc3a9 100644 (file)
@@ -196,7 +196,7 @@ static mlir::Value genEndIO(Fortran::lower::AbstractConverter &converter,
                                           mlir::ValueRange{cookie});
   mlir::Value iostat = call.getResult(0);
   if (csi.bigUnitIfOp) {
-    stmtCtx.finalize(/*popScope=*/true);
+    stmtCtx.finalizeAndPop();
     builder.create<fir::ResultOp>(loc, iostat);
     builder.setInsertionPointAfter(csi.bigUnitIfOp);
     iostat = csi.bigUnitIfOp.getResult(0);
index 2efbcc0..e188f88 100644 (file)
     print*, nn
   end
 
-  ! CHECK-LABEL: func @_QPtest_char_temp_selector
-  subroutine test_char_temp_selector()
-    ! Test that character selector that are temps are deallocated
-    ! only after they have been used in the select case comparisons.
-    interface
-      function gen_char_temp_selector()
-        character(:), allocatable :: gen_char_temp_selector
-      end function
-    end interface
-    select case (gen_char_temp_selector())
-    case ('case1')
-      call foo1()
-    case ('case2')
-      call foo2()
-    case ('case3')
-      call foo3()
+  ! CHECK-LABEL: func @_QPscharacter1
+  subroutine scharacter1(s)
+    ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+    character(len=3) :: s
+    ! CHECK-DAG: %[[V_1:[0-9]+]] = fir.alloca i32 {bindc_name = "n", uniq_name = "_QFscharacter1En"}
+    ! CHECK:     fir.store %c0{{.*}} to %[[V_1]] : !fir.ref<i32>
+    n = 0
+
+    ! CHECK:     %[[V_8:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+    ! CHECK:     %[[V_9:[0-9]+]] = arith.cmpi sge, %[[V_8]], %c0{{.*}} : i32
+    ! CHECK:     cond_br %[[V_9]], ^bb1, ^bb15
+    ! CHECK:   ^bb1:  // pred: ^bb0
+    if (lge(s,'00')) then
+
+      ! CHECK:   %[[V_18:[0-9]+]] = fir.load %[[V_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+      ! CHECK:   %[[V_20:[0-9]+]] = fir.box_addr %[[V_18]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+      ! CHECK:   %[[V_42:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_43:[0-9]+]] = arith.cmpi eq, %[[V_42]], %c0{{.*}} : i32
+      ! CHECK:   fir.if %[[V_43]] {
+      ! CHECK:     fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   }
+      ! CHECK:   cond_br %[[V_43]], ^bb3, ^bb2
+      ! CHECK: ^bb2:  // pred: ^bb1
+      select case(trim(s))
+      case('11')
+        n = 1
+
+      case default
+        continue
+
+      ! CHECK:   %[[V_48:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_49:[0-9]+]] = arith.cmpi eq, %[[V_48]], %c0{{.*}} : i32
+      ! CHECK:   fir.if %[[V_49]] {
+      ! CHECK:     fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   }
+      ! CHECK:   cond_br %[[V_49]], ^bb6, ^bb5
+      ! CHECK: ^bb3:  // pred: ^bb1
+      ! CHECK:   fir.store %c1{{.*}} to %[[V_1]] : !fir.ref<i32>
+      ! CHECK: ^bb4:  // pred: ^bb13
+      ! CHECK: ^bb5:  // pred: ^bb2
+      case('22')
+        n = 2
+
+      ! CHECK:   %[[V_54:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_55:[0-9]+]] = arith.cmpi eq, %[[V_54]], %c0{{.*}} : i32
+      ! CHECK:   fir.if %[[V_55]] {
+      ! CHECK:     fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   }
+      ! CHECK:   cond_br %[[V_55]], ^bb8, ^bb7
+      ! CHECK: ^bb6:  // pred: ^bb2
+      ! CHECK:   fir.store %c2{{.*}} to %[[V_1]] : !fir.ref<i32>
+      ! CHECK: ^bb7:  // pred: ^bb5
+      case('33')
+        n = 3
+
+      case('44':'55','66':'77','88':)
+        n = 4
+      ! CHECK:   %[[V_60:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_61:[0-9]+]] = arith.cmpi sge, %[[V_60]], %c0{{.*}} : i32
+      ! CHECK:   cond_br %[[V_61]], ^bb9, ^bb10
+      ! CHECK: ^bb8:  // pred: ^bb5
+      ! CHECK:   fir.store %c3{{.*}} to %[[V_1]] : !fir.ref<i32>
+      ! CHECK: ^bb9:  // pred: ^bb7
+      ! CHECK:   %[[V_66:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_67:[0-9]+]] = arith.cmpi sle, %[[V_66]], %c0{{.*}} : i32
+      ! CHECK:   fir.if %[[V_67]] {
+      ! CHECK:     fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   }
+      ! CHECK:   cond_br %[[V_67]], ^bb14, ^bb10
+      ! CHECK: ^bb10:  // 2 preds: ^bb7, ^bb9
+      ! CHECK:   %[[V_72:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_73:[0-9]+]] = arith.cmpi sge, %[[V_72]], %c0{{.*}} : i32
+      ! CHECK:   cond_br %[[V_73]], ^bb11, ^bb12
+      ! CHECK: ^bb11:  // pred: ^bb10
+      ! CHECK:   %[[V_78:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_79:[0-9]+]] = arith.cmpi sle, %[[V_78]], %c0{{.*}} : i32
+      ! CHECK:   fir.if %[[V_79]] {
+      ! CHECK:     fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   }
+      ! CHECK: ^bb12:  // 2 preds: ^bb10, ^bb11
+      ! CHECK:   %[[V_84:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+      ! CHECK:   %[[V_85:[0-9]+]] = arith.cmpi sge, %[[V_84]], %c0{{.*}} : i32
+      ! CHECK:   fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+      ! CHECK:   cond_br %[[V_85]], ^bb14, ^bb13
+      ! CHECK: ^bb13:  // pred: ^bb12
+      ! CHECK: ^bb14:  // 3 preds: ^bb9, ^bb11, ^bb12
+      ! CHECK:   fir.store %c4{{.*}} to %[[V_1]] : !fir.ref<i32>
+      ! CHECK: ^bb15:  // 6 preds: ^bb0, ^bb3, ^bb4, ^bb6, ^bb8, ^bb14
+      end select
+    end if
+    ! CHECK:     %[[V_89:[0-9]+]] = fir.load %[[V_1]] : !fir.ref<i32>
+    print*, n
+  end subroutine
+
+
+  ! CHECK-LABEL: func @_QPscharacter2
+  subroutine scharacter2(s)
+    ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+    ! CHECK:   %[[V_1:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+    character(len=3) :: s
+    n = 0
+
+    ! CHECK:   %[[V_12:[0-9]+]] = fir.load %[[V_1]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+    ! CHECK:   %[[V_13:[0-9]+]] = fir.box_addr %[[V_12]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+    ! CHECK:   fir.freemem %[[V_13]] : !fir.heap<!fir.char<1,?>>
+    ! CHECK:   br ^bb1
+    ! CHECK: ^bb1:  // pred: ^bb0
+    ! CHECK:   br ^bb2
+    n = -10
+    select case(trim(s))
     case default
-      call foo_default()
+      n = 9
+    end select
+    print*, n
+
+    ! CHECK: ^bb2:  // pred: ^bb1
+    ! CHECK:   %[[V_28:[0-9]+]] = fir.load %[[V_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+    ! CHECK:   %[[V_29:[0-9]+]] = fir.box_addr %[[V_28]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+    ! CHECK:   fir.freemem %[[V_29]] : !fir.heap<!fir.char<1,?>>
+    ! CHECK:   br ^bb3
+    ! CHECK: ^bb3:  // pred: ^bb2
+    n = -2
+    select case(trim(s))
     end select
-    ! CHECK:   %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".result"}
-    ! CHECK:   %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box<!fir.heap<!fir.char<1,?>>>
-    ! CHECK:   fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.char<1,?>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
-    ! CHECK:   cond_br %{{.*}}, ^bb2, ^bb1
-    ! CHECK: ^bb1:
-    ! CHECK:   cond_br %{{.*}}, ^bb4, ^bb3
-    ! CHECK: ^bb2:
-    ! CHECK:   fir.call @_QPfoo1() : () -> ()
-    ! CHECK:   br ^bb8
-    ! CHECK: ^bb3:
-    ! CHECK:   cond_br %{{.*}}, ^bb6, ^bb5
-    ! CHECK: ^bb4:
-    ! CHECK:   fir.call @_QPfoo2() : () -> ()
-    ! CHECK:   br ^bb8
-    ! CHECK: ^bb5:
-    ! CHECK:   br ^bb7
-    ! CHECK: ^bb6:
-    ! CHECK:   fir.call @_QPfoo3() : () -> ()
-    ! CHECK:   br ^bb8
-    ! CHECK: ^bb7:
-    ! CHECK:   fir.call @_QPfoo_default() : () -> ()
-    ! CHECK:   br ^bb8
-    ! CHECK: ^bb8:
-    ! CHECK:   %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
-    ! CHECK:   %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
-    ! CHECK:   %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap<!fir.char<1,?>>) -> i64
-    ! CHECK:   %[[VAL_39:.*]] = arith.constant 0 : i64
-    ! CHECK:   %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64
-    ! CHECK:   fir.if %[[VAL_40]] {
-    ! CHECK:     fir.freemem %[[VAL_37]]
-    ! CHECK:   }
+    print*, n
   end subroutine
+
+  ! CHECK-LABEL: main
+  program p
+    integer sinteger, v(10)
+
+    n = -10
+    do j = 1, 4
+      do k = 1, 10
+        n = n + 1
+        v(k) = sinteger(n)
+      enddo
+      ! expected output:  1 1 1 1 1 1 1 1 1 1
+      !                   1 2 3 4 4 6 7 7 7 7
+      !                   7 7 7 7 7 0 0 0 0 0
+      !                   7 7 7 7 7 7 7 7 7 7
+      print*, v
+    enddo
+
+    print*
+    call slogical(.false.)    ! expected output:  0 1 0 3 1 1 3 1
+    call slogical(.true.)     ! expected output:  0 0 2 3 2 3 2 2
+
+    print*
+    call scharacter('aa')     ! expected output: 10
+    call scharacter('d')      ! expected output: 10
+    call scharacter('f')      ! expected output: -1
+    call scharacter('ff')     ! expected output: 20
+    call scharacter('fff')    ! expected output: 20
+    call scharacter('ffff')   ! expected output: 20
+    call scharacter('fffff')  ! expected output: -1
+    call scharacter('jj')     ! expected output: -1
+    call scharacter('m')      ! expected output: 30
+    call scharacter('q')      ! expected output: -1
+    call scharacter('qq')     ! expected output: 40
+    call scharacter('qqq')    ! expected output: -1
+    call scharacter('vv')     ! expected output: -1
+    call scharacter('xx')     ! expected output: 50
+    call scharacter('zz')     ! expected output: 50
+
+    print*
+    call scharacter1('99 ')   ! expected output:  4
+    call scharacter1('88 ')   ! expected output:  4
+    call scharacter1('77 ')   ! expected output:  4
+    call scharacter1('66 ')   ! expected output:  4
+    call scharacter1('55 ')   ! expected output:  4
+    call scharacter1('44 ')   ! expected output:  4
+    call scharacter1('33 ')   ! expected output:  3
+    call scharacter1('22 ')   ! expected output:  2
+    call scharacter1('11 ')   ! expected output:  1
+    call scharacter1('00 ')   ! expected output:  0
+    call scharacter1('.  ')   ! expected output:  0
+    call scharacter1('   ')   ! expected output:  0
+    print*
+    call scharacter2('99 ')   ! expected output:  9 -2
+    call scharacter2('22 ')   ! expected output:  9 -2
+    call scharacter2('.  ')   ! expected output:  9 -2
+    call scharacter2('   ')   ! expected output:  9 -2
+  end