[flang][hlfir] Hoist forall bounds computation when possible
authorJean Perier <jperier@nvidia.com>
Tue, 23 May 2023 07:17:36 +0000 (09:17 +0200)
committerJean Perier <jperier@nvidia.com>
Tue, 23 May 2023 07:17:44 +0000 (09:17 +0200)
When inner forall bound computations do not depend on previous
forall indices, they can be hoisted.
This is possible because:
 - bound computation are required to be pure (so evaluating them only
   once is possible).
 - If the bound computation depends on a value previously assigned, the
   forall scheduling analysis created different run for it: the
   assignment impacting the bounds value is not part of the current loop
   nest.

The reason this optimization is done at that point and not as part of
generic loop hoisting optimization is that having the all the loop
bound computation hoisted will allow allocating simple temporary
storages. The number of iteration can be pre-computed and used as the
extent for the temporary.

Differential Revision: https://reviews.llvm.org/D151110

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir

index d49bc1e..0317f83 100644 (file)
@@ -141,7 +141,13 @@ private:
   /// code, except the final yield, at the current execution point.
   /// If the value was saved in a previous run, this fetches the saved value
   /// from the temporary storage and returns the value.
-  mlir::Value generateYieldedScalarValue(mlir::Region &region);
+  /// Inside Forall, the value will be hoisted outside of the forall loops if
+  /// it does not depend on the forall indices.
+  /// An optional type can be provided to get a value from a specific type
+  /// (the cast will be hoisted if the computation is hoisted).
+  mlir::Value generateYieldedScalarValue(
+      mlir::Region &region,
+      std::optional<mlir::Type> castToType = std::nullopt);
 
   /// Generate an entity yielded by an ordered assignment tree region, and
   /// optionally return the (uncloned) yield if there is any clean-up that
@@ -149,7 +155,8 @@ private:
   /// this will return the saved value if the region was saved in a previous
   /// run.
   std::pair<mlir::Value, std::optional<hlfir::YieldOp>>
-  generateYieldedEntity(mlir::Region &region);
+  generateYieldedEntity(mlir::Region &region,
+                        std::optional<mlir::Type> castToType = std::nullopt);
 
   /// If \p maybeYield is present and has a clean-up, generate the clean-up
   /// at the current insertion point (by cloning).
@@ -215,20 +222,20 @@ void OrderedAssignmentRewriter::walk(
 
 void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
   /// Create a fir.do_loop given the hlfir.forall control values.
-  mlir::Value rawLowerBound =
-      generateYieldedScalarValue(forallOp.getLbRegion());
-  mlir::Location loc = forallOp.getLoc();
   mlir::Type idxTy = builder.getIndexType();
-  mlir::Value lb = builder.createConvert(loc, idxTy, rawLowerBound);
-  mlir::Value rawUpperBound =
-      generateYieldedScalarValue(forallOp.getUbRegion());
-  mlir::Value ub = builder.createConvert(loc, idxTy, rawUpperBound);
+  mlir::Location loc = forallOp.getLoc();
+  mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
+  mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
   mlir::Value step;
   if (forallOp.getStepRegion().empty()) {
+    auto insertionPoint = builder.saveInsertionPoint();
+    if (!constructStack.empty())
+      builder.setInsertionPoint(constructStack[0]);
     step = builder.createIntegerConstant(loc, idxTy, 1);
+    if (!constructStack.empty())
+      builder.restoreInsertionPoint(insertionPoint);
   } else {
-    step = generateYieldedScalarValue(forallOp.getStepRegion());
-    step = builder.createConvert(loc, idxTy, step);
+    step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
   }
   auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
   builder.setInsertionPointToStart(doLoop.getBody());
@@ -256,8 +263,8 @@ void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
 
 void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
   mlir::Location loc = forallMaskOp.getLoc();
-  mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion());
-  mask = builder.createConvert(loc, builder.getI1Type(), mask);
+  mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
+                                                builder.getI1Type());
   auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   constructStack.push_back(ifOp);
@@ -350,35 +357,84 @@ void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
   builder.setInsertionPointAfter(constructStack.pop_back_val());
 }
 
+/// Is this value a Forall index?
+/// Forall index are block arguments of hlfir.forall body, or the result
+/// of hlfir.forall_index.
+static bool isForallIndex(mlir::Value value) {
+  if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
+    if (mlir::Block *block = blockArg.getOwner())
+      return block->isEntryBlock() &&
+             mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
+    return false;
+  }
+  return value.getDefiningOp<hlfir::ForallIndexOp>();
+}
+
 std::pair<mlir::Value, std::optional<hlfir::YieldOp>>
-OrderedAssignmentRewriter::generateYieldedEntity(mlir::Region &region) {
+OrderedAssignmentRewriter::generateYieldedEntity(
+    mlir::Region &region, std::optional<mlir::Type> castToType) {
   // TODO: if the region was saved, use that instead of generating code again.
   if (whereLoopNest.has_value()) {
     mlir::Location loc = region.getParentOp()->getLoc();
     return {generateMaskedEntity(loc, region), std::nullopt};
   }
   assert(region.hasOneBlock() && "region must contain one block");
-  // Clone all operations except the final hlfir.yield.
+  auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
+      region.back().getOperations().back());
+  assert(oldYield && "region computing entities must end with a YieldOp");
   mlir::Block::OpListType &ops = region.back().getOperations();
+
+  // Inside Forall, scalars that do not depend on forall indices can be hoisted
+  // here because their evaluation is required to only call pure procedures, and
+  // if they depend on a variable previously assigned to in a forall assignment,
+  // this assignment must have been scheduled in a previous run. Hoisting of
+  // scalars is done here to help creating simple temporary storage if needed.
+  // Inner forall bounds can often be hoisted, and this allows computing the
+  // total number of iterations to create temporary storages.
+  bool hoistComputation = false;
+  if (fir::isa_trivial(oldYield.getEntity().getType()) &&
+      !constructStack.empty()) {
+    hoistComputation = true;
+    for (mlir::Operation &op : ops)
+      if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
+            return isForallIndex(value);
+          })) {
+        hoistComputation = false;
+        break;
+      }
+  }
+  auto insertionPoint = builder.saveInsertionPoint();
+  if (hoistComputation)
+    builder.setInsertionPoint(constructStack[0]);
+
+  // Clone all operations except the final hlfir.yield.
   assert(!ops.empty() && "yield block cannot be empty");
   auto end = ops.end();
   for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
     (void)builder.clone(*opIt, mapper);
-  auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
-      region.back().getOperations().back());
-  assert(oldYield && "region computing scalar must end with a YieldOp");
   // Get the value for the yielded entity, it may be the result of an operation
   // that was cloned, or it may be the same as the previous value if the yield
   // operand was created before the ordered assignment tree.
   mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
+  if (castToType.has_value())
+    newEntity =
+        builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
+
+  if (hoistComputation) {
+    // Hoisted trivial scalars clean-up can be done right away, the value is
+    // in registers.
+    generateCleanupIfAny(oldYield);
+    builder.restoreInsertionPoint(insertionPoint);
+    return {newEntity, std::nullopt};
+  }
   if (oldYield.getCleanup().empty())
     return {newEntity, std::nullopt};
   return {newEntity, oldYield};
 }
 
-mlir::Value
-OrderedAssignmentRewriter::generateYieldedScalarValue(mlir::Region &region) {
-  auto [value, maybeYield] = generateYieldedEntity(region);
+mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
+    mlir::Region &region, std::optional<mlir::Type> castToType) {
+  auto [value, maybeYield] = generateYieldedEntity(region, castToType);
   assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
   generateCleanupIfAny(maybeYield);
   return value;
index dace9b2..784367f 100644 (file)
@@ -24,10 +24,10 @@ func.func @test_simple(%x: !fir.ref<!fir.array<10xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 10 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           fir.do_loop %[[VAL_4:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] {
-// CHECK:             %[[VAL_5:.*]] = arith.constant 42 : i32
-// CHECK:             %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_4]])  : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
-// CHECK:             hlfir.assign %[[VAL_5]] to %[[VAL_6]] : i32, !fir.ref<i32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 42 : i32
+// CHECK:           fir.do_loop %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_5]])  : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
+// CHECK:             hlfir.assign %[[VAL_4]] to %[[VAL_6]] : i32, !fir.ref<i32>
 // CHECK:           }
 
 func.func @test_index(%x: !fir.ref<!fir.array<10xi32>>) {
@@ -122,11 +122,11 @@ func.func @split_schedule(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: !fir.box<!f
 // CHECK:           %[[VAL_17:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
 // CHECK:           %[[VAL_18:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
 // CHECK:           %[[VAL_19:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
+// CHECK:           %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
+// CHECK:           %[[VAL_24:.*]] = arith.constant 1 : index
 // CHECK:           fir.do_loop %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_19]] {
 // CHECK:             %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64
-// CHECK:             %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
-// CHECK:             %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
-// CHECK:             %[[VAL_24:.*]] = arith.constant 1 : index
 // CHECK:             fir.do_loop %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_24]] {
 // CHECK:               %[[VAL_26:.*]] = fir.convert %[[VAL_25]] : (index) -> i64
 // CHECK:               %[[VAL_27:.*]] = arith.subi %[[VAL_3]], %[[VAL_21]] : i64
@@ -181,15 +181,15 @@ func.func @test_mask(%arg0: !fir.box<!fir.array<?x?xf32>>, %arg1: !fir.box<!fir.
 // CHECK:           %[[VAL_8:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
 // CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
 // CHECK:           %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
+// CHECK:           %[[VAL_18:.*]] = arith.constant 1 : index
 // CHECK:           fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_10]] {
 // CHECK:             %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (index) -> i64
 // CHECK:             %[[VAL_13:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_12]])  : (!fir.box<!fir.array<?x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
 // CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<4>>
 // CHECK:             %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<4>) -> i1
 // CHECK:             fir.if %[[VAL_15]] {
-// CHECK:               %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
 // CHECK:               %[[VAL_17:.*]] = fir.convert %[[VAL_12]] : (i64) -> index
-// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
 // CHECK:               fir.do_loop %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_18]] {
 // CHECK:                 %[[VAL_20:.*]] = fir.convert %[[VAL_19]] : (index) -> i64
 // CHECK:                 %[[VAL_21:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_12]], %[[VAL_20]])  : (!fir.box<!fir.array<?x?xf32>>, i64, i64) -> !fir.ref<f32>