/// code, except the final yield, at the current execution point.
/// If the value was saved in a previous run, this fetches the saved value
/// from the temporary storage and returns the value.
- mlir::Value generateYieldedScalarValue(mlir::Region ®ion);
+ /// Inside Forall, the value will be hoisted outside of the forall loops if
+ /// it does not depend on the forall indices.
+ /// An optional type can be provided to get a value from a specific type
+ /// (the cast will be hoisted if the computation is hoisted).
+ mlir::Value generateYieldedScalarValue(
+ mlir::Region ®ion,
+ std::optional<mlir::Type> castToType = std::nullopt);
/// Generate an entity yielded by an ordered assignment tree region, and
/// optionally return the (uncloned) yield if there is any clean-up that
/// this will return the saved value if the region was saved in a previous
/// run.
std::pair<mlir::Value, std::optional<hlfir::YieldOp>>
- generateYieldedEntity(mlir::Region ®ion);
+ generateYieldedEntity(mlir::Region ®ion,
+ std::optional<mlir::Type> castToType = std::nullopt);
/// If \p maybeYield is present and has a clean-up, generate the clean-up
/// at the current insertion point (by cloning).
void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
/// Create a fir.do_loop given the hlfir.forall control values.
- mlir::Value rawLowerBound =
- generateYieldedScalarValue(forallOp.getLbRegion());
- mlir::Location loc = forallOp.getLoc();
mlir::Type idxTy = builder.getIndexType();
- mlir::Value lb = builder.createConvert(loc, idxTy, rawLowerBound);
- mlir::Value rawUpperBound =
- generateYieldedScalarValue(forallOp.getUbRegion());
- mlir::Value ub = builder.createConvert(loc, idxTy, rawUpperBound);
+ mlir::Location loc = forallOp.getLoc();
+ mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
+ mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
mlir::Value step;
if (forallOp.getStepRegion().empty()) {
+ auto insertionPoint = builder.saveInsertionPoint();
+ if (!constructStack.empty())
+ builder.setInsertionPoint(constructStack[0]);
step = builder.createIntegerConstant(loc, idxTy, 1);
+ if (!constructStack.empty())
+ builder.restoreInsertionPoint(insertionPoint);
} else {
- step = generateYieldedScalarValue(forallOp.getStepRegion());
- step = builder.createConvert(loc, idxTy, step);
+ step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
}
auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
builder.setInsertionPointToStart(doLoop.getBody());
void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
mlir::Location loc = forallMaskOp.getLoc();
- mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion());
- mask = builder.createConvert(loc, builder.getI1Type(), mask);
+ mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
+ builder.getI1Type());
auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
constructStack.push_back(ifOp);
builder.setInsertionPointAfter(constructStack.pop_back_val());
}
+/// Is this value a Forall index?
+/// Forall index are block arguments of hlfir.forall body, or the result
+/// of hlfir.forall_index.
+static bool isForallIndex(mlir::Value value) {
+ if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
+ if (mlir::Block *block = blockArg.getOwner())
+ return block->isEntryBlock() &&
+ mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
+ return false;
+ }
+ return value.getDefiningOp<hlfir::ForallIndexOp>();
+}
+
std::pair<mlir::Value, std::optional<hlfir::YieldOp>>
-OrderedAssignmentRewriter::generateYieldedEntity(mlir::Region ®ion) {
+OrderedAssignmentRewriter::generateYieldedEntity(
+ mlir::Region ®ion, std::optional<mlir::Type> castToType) {
// TODO: if the region was saved, use that instead of generating code again.
if (whereLoopNest.has_value()) {
mlir::Location loc = region.getParentOp()->getLoc();
return {generateMaskedEntity(loc, region), std::nullopt};
}
assert(region.hasOneBlock() && "region must contain one block");
- // Clone all operations except the final hlfir.yield.
+ auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
+ region.back().getOperations().back());
+ assert(oldYield && "region computing entities must end with a YieldOp");
mlir::Block::OpListType &ops = region.back().getOperations();
+
+ // Inside Forall, scalars that do not depend on forall indices can be hoisted
+ // here because their evaluation is required to only call pure procedures, and
+ // if they depend on a variable previously assigned to in a forall assignment,
+ // this assignment must have been scheduled in a previous run. Hoisting of
+ // scalars is done here to help creating simple temporary storage if needed.
+ // Inner forall bounds can often be hoisted, and this allows computing the
+ // total number of iterations to create temporary storages.
+ bool hoistComputation = false;
+ if (fir::isa_trivial(oldYield.getEntity().getType()) &&
+ !constructStack.empty()) {
+ hoistComputation = true;
+ for (mlir::Operation &op : ops)
+ if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
+ return isForallIndex(value);
+ })) {
+ hoistComputation = false;
+ break;
+ }
+ }
+ auto insertionPoint = builder.saveInsertionPoint();
+ if (hoistComputation)
+ builder.setInsertionPoint(constructStack[0]);
+
+ // Clone all operations except the final hlfir.yield.
assert(!ops.empty() && "yield block cannot be empty");
auto end = ops.end();
for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
(void)builder.clone(*opIt, mapper);
- auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
- region.back().getOperations().back());
- assert(oldYield && "region computing scalar must end with a YieldOp");
// Get the value for the yielded entity, it may be the result of an operation
// that was cloned, or it may be the same as the previous value if the yield
// operand was created before the ordered assignment tree.
mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
+ if (castToType.has_value())
+ newEntity =
+ builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
+
+ if (hoistComputation) {
+ // Hoisted trivial scalars clean-up can be done right away, the value is
+ // in registers.
+ generateCleanupIfAny(oldYield);
+ builder.restoreInsertionPoint(insertionPoint);
+ return {newEntity, std::nullopt};
+ }
if (oldYield.getCleanup().empty())
return {newEntity, std::nullopt};
return {newEntity, oldYield};
}
-mlir::Value
-OrderedAssignmentRewriter::generateYieldedScalarValue(mlir::Region ®ion) {
- auto [value, maybeYield] = generateYieldedEntity(region);
+mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
+ mlir::Region ®ion, std::optional<mlir::Type> castToType) {
+ auto [value, maybeYield] = generateYieldedEntity(region, castToType);
assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
generateCleanupIfAny(maybeYield);
return value;
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 10 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] {
-// CHECK: %[[VAL_5:.*]] = arith.constant 42 : i32
-// CHECK: %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_4]]) : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
-// CHECK: hlfir.assign %[[VAL_5]] to %[[VAL_6]] : i32, !fir.ref<i32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i32
+// CHECK: fir.do_loop %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] {
+// CHECK: %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_5]]) : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_4]] to %[[VAL_6]] : i32, !fir.ref<i32>
// CHECK: }
func.func @test_index(%x: !fir.ref<!fir.array<10xi32>>) {
// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
// CHECK: %[[VAL_19:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
+// CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
+// CHECK: %[[VAL_24:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_19]] {
// CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64
-// CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
-// CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
-// CHECK: %[[VAL_24:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_24]] {
// CHECK: %[[VAL_26:.*]] = fir.convert %[[VAL_25]] : (index) -> i64
// CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_3]], %[[VAL_21]] : i64
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
+// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_10]] {
// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (index) -> i64
// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_12]]) : (!fir.box<!fir.array<?x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<4>) -> i1
// CHECK: fir.if %[[VAL_15]] {
-// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_12]] : (i64) -> index
-// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_18]] {
// CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_19]] : (index) -> i64
// CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_12]], %[[VAL_20]]) : (!fir.box<!fir.array<?x?xf32>>, i64, i64) -> !fir.ref<f32>