[flang] Upstream lowering of real control loops
authorDiana Picus <diana.picus@linaro.org>
Tue, 31 May 2022 10:55:56 +0000 (10:55 +0000)
committerDiana Picus <diana.picus@linaro.org>
Wed, 1 Jun 2022 08:00:45 +0000 (08:00 +0000)
Upstream the code for handling loops with real control variables from
the fir-dev branch at
https://github.com/flang-compiler/f18-llvm-project/tree/fir-dev/

Also add a test.

Loops with real-valued control variables are always lowered to
unstructured loops. The real-valued control variables are handled the
same as integer ones, the only difference is that they need to use
floating point instructions instead of the integer equivalents.

Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
flang/lib/Lower/Bridge.cpp
flang/test/Lower/do_loop.f90

index a33ddee..bbce38e 100644 (file)
@@ -95,6 +95,7 @@ struct IncrementLoopInfo {
   fir::DoLoopOp doLoop = nullptr;
 
   // Data members for unstructured loops.
+  bool hasRealControl = false;
   mlir::Value tripVariable = nullptr;
   mlir::Block *headerBlock = nullptr; // loop entry and test block
   mlir::Block *bodyBlock = nullptr;   // first loop body block
@@ -997,6 +998,8 @@ private:
           bounds->step);
       if (unstructuredContext) {
         maybeStartBlock(preheaderBlock);
+        info.hasRealControl = info.loopVariableSym.GetType()->IsNumeric(
+            Fortran::common::TypeCategory::Real);
         info.headerBlock = headerBlock;
         info.bodyBlock = bodyBlock;
         info.exitBlock = exitBlock;
@@ -1034,6 +1037,9 @@ private:
       if (expr)
         return builder->createConvert(loc, controlType,
                                       createFIRExpr(loc, expr, stmtCtx));
+
+      if (info.hasRealControl)
+        return builder->createRealConstant(loc, controlType, 1u);
       return builder->createIntegerConstant(loc, controlType, 1); // step
     };
     for (IncrementLoopInfo &info : incrementLoopNestInfo) {
@@ -1059,12 +1065,24 @@ private:
 
       // Unstructured loop preheader - initialize tripVariable and loopVariable.
       mlir::Value tripCount;
-      auto diff1 =
-          builder->create<mlir::arith::SubIOp>(loc, upperValue, lowerValue);
-      auto diff2 =
-          builder->create<mlir::arith::AddIOp>(loc, diff1, info.stepValue);
-      tripCount =
-          builder->create<mlir::arith::DivSIOp>(loc, diff2, info.stepValue);
+      if (info.hasRealControl) {
+        auto diff1 =
+            builder->create<mlir::arith::SubFOp>(loc, upperValue, lowerValue);
+        auto diff2 =
+            builder->create<mlir::arith::AddFOp>(loc, diff1, info.stepValue);
+        tripCount =
+            builder->create<mlir::arith::DivFOp>(loc, diff2, info.stepValue);
+        tripCount =
+            builder->createConvert(loc, builder->getIndexType(), tripCount);
+
+      } else {
+        auto diff1 =
+            builder->create<mlir::arith::SubIOp>(loc, upperValue, lowerValue);
+        auto diff2 =
+            builder->create<mlir::arith::AddIOp>(loc, diff1, info.stepValue);
+        tripCount =
+            builder->create<mlir::arith::DivSIOp>(loc, diff2, info.stepValue);
+      }
       info.tripVariable = builder->createTemporary(loc, tripCount.getType());
       builder->create<fir::StoreOp>(loc, tripCount, info.tripVariable);
       builder->create<fir::StoreOp>(loc, lowerValue, info.loopVariable);
@@ -1117,7 +1135,12 @@ private:
       tripCount = builder->create<mlir::arith::SubIOp>(loc, tripCount, one);
       builder->create<fir::StoreOp>(loc, tripCount, info.tripVariable);
       mlir::Value value = builder->create<fir::LoadOp>(loc, info.loopVariable);
-      value = builder->create<mlir::arith::AddIOp>(loc, value, info.stepValue);
+      if (info.hasRealControl)
+        value =
+            builder->create<mlir::arith::AddFOp>(loc, value, info.stepValue);
+      else
+        value =
+            builder->create<mlir::arith::AddIOp>(loc, value, info.stepValue);
       builder->create<fir::StoreOp>(loc, value, info.loopVariable);
 
       genFIRBranch(info.headerBlock);
index 6ef6de2..61190ec 100644 (file)
@@ -207,3 +207,40 @@ subroutine loop_with_non_default_integer(s,e,st)
   ! CHECK: %[[I_RES_CVT:.*]] = fir.convert %[[I_RES]] : (index) -> i64
   ! CHECK: fir.store %[[I_RES_CVT]] to %[[I_REF]] : !fir.ref<i64>
 end subroutine
+
+! Test real loop control.
+! CHECK-LABEL: loop_with_real_control
+! CHECK-SAME: (%[[S_REF:.*]]: !fir.ref<f32> {fir.bindc_name = "s"}, %[[E_REF:.*]]: !fir.ref<f32> {fir.bindc_name = "e"}, %[[ST_REF:.*]]: !fir.ref<f32> {fir.bindc_name = "st"}) {
+subroutine loop_with_real_control(s,e,st)
+  ! CHECK-DAG: %[[INDEX_REF:.*]] = fir.alloca index
+  ! CHECK-DAG: %[[X_REF:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFloop_with_real_controlEx"}
+  ! CHECK-DAG: %[[S:.*]] = fir.load %[[S_REF]] : !fir.ref<f32>
+  ! CHECK-DAG: %[[E:.*]] = fir.load %[[E_REF]] : !fir.ref<f32>
+  ! CHECK-DAG: %[[ST:.*]] = fir.load %[[ST_REF]] : !fir.ref<f32>
+  real :: x, s, e, st
+
+  ! CHECK: %[[DIFF:.*]] = arith.subf %[[E]], %[[S]] : f32
+  ! CHECK: %[[RANGE:.*]] = arith.addf %[[DIFF]], %[[ST]] : f32
+  ! CHECK: %[[HIGH:.*]] = arith.divf %[[RANGE]], %[[ST]] : f32
+  ! CHECK: %[[HIGH_INDEX:.*]] = fir.convert %[[HIGH]] : (f32) -> index
+  ! CHECK: fir.store %[[HIGH_INDEX]] to %[[INDEX_REF]] : !fir.ref<index>
+  ! CHECK: fir.store %[[S]] to %[[X_REF]] : !fir.ref<f32>
+
+  ! CHECK: br ^[[HDR:.*]]
+  ! CHECK: ^[[HDR]]:  // 2 preds: ^{{.*}}, ^[[EXIT:.*]]
+  ! CHECK-DAG: %[[INDEX:.*]] = fir.load %[[INDEX_REF]] : !fir.ref<index>
+  ! CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  ! CHECK: %[[COND:.*]] = arith.cmpi sgt, %[[INDEX]], %[[C0]] : index
+  ! CHECK: cond_br %[[COND]], ^[[BODY:.*]], ^[[EXIT:.*]]
+  do x=s,e,st
+    ! CHECK: ^[[BODY]]:  // pred: ^[[HDR]]
+    ! CHECK-DAG: %[[INDEX2:.*]] = fir.load %[[INDEX_REF]] : !fir.ref<index>
+    ! CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    ! CHECK: %[[INC:.*]] = arith.subi %[[INDEX2]], %[[C1]] : index
+    ! CHECK: fir.store %[[INC]] to %[[INDEX_REF]] : !fir.ref<index>
+    ! CHECK: %[[X2:.*]] = fir.load %[[X_REF]] : !fir.ref<f32>
+    ! CHECK: %[[XINC:.*]] = arith.addf %[[X2]], %[[ST]] : f32
+    ! CHECK: fir.store %[[XINC]] to %[[X_REF]] : !fir.ref<f32>
+    ! CHECK: br ^[[HDR]]
+  end do
+end subroutine