If an `scf.for` loop yields an equal index-typed value or a shaped value with the same dimension sizes (in comparison to the corresponding iter_arg), bounds can be computed for the iter_arg and the OpResult of the `scf.for` op.
Differential Revision: https://reviews.llvm.org/D146306
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
- using StopConditionFn = function_ref<bool(Value)>;
+ ///
+ /// The first parameter of the function is the shaped value/index-typed
+ /// value. The second parameter is the dimension in case of a shaped value.
+ using StopConditionFn =
+ function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
/// Compute a bound for the given index-typed value or shape dimension size.
/// The computed bound is stored in `resultMap`. The operands of the bound are
/// or a shaped value and a dimension.
///
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
- /// is computed in terms of values for which `stopCondition` evaluates to
- /// "true". To that end, the backward slice (reverse use-def chain) of the
- /// given value is visited in a worklist-driven manner and the constraint set
- /// is populated according to `ValueBoundsOpInterface` for each visited value.
+ /// is computed in terms of values/dimensions for which `stopCondition`
+ /// evaluates to "true". To that end, the backward slice (reverse use-def
+ /// chain) of the given value is visited in a worklist-driven manner and the
+ /// constraint set is populated according to `ValueBoundsOpInterface` for each
+ /// visited value.
static LogicalResult computeBound(AffineMap &resultMap,
ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim,
StopConditionFn stopCondition);
+ /// Compute a bound in terms of the values/dimensions in `dependencies`. The
+ /// computed bound consists of only constant terms and dependent values (or
+ /// dimension sizes thereof).
+ static LogicalResult computeBound(AffineMap &resultMap,
+ ValueDimList &mapOperands,
+ presburger::BoundType type, Value value,
+ std::optional<int64_t> dim,
+ ValueDimList dependencies);
+
/// Compute a constant bound for the given index-typed value or shape
/// dimension size.
///
std::optional<int64_t> dim) {
// We are trying to reify a bound for `value`. Construct a stop condition that
// evaluates to "true" for any SSA value expect for `value`. I.e., the bound
- // will be computed in terms of any SSA values expect for `value`. The first
+ // will be computed in terms of any SSA values except for `value`. The first
// such values are operands of the owner of `value`.
- auto stopCondition = [&](Value v) {
+ auto stopCondition = [&](Value v, std::optional<int64_t> d) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim, stopCondition);
}
-FailureOr<OpFoldResult>
-mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- function_ref<bool(Value)> stopCondition) {
+FailureOr<OpFoldResult> mlir::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+ std::optional<int64_t> dim,
+ function_ref<bool(Value, std::optional<int64_t>)> stopCondition) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
+using presburger::BoundType;
namespace mlir {
namespace scf {
struct ForOpInterface
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
+
+ /// Populate bounds of values/dimensions for iter_args/OpResults.
+ static void populateIterArgBounds(scf::ForOp forOp, Value value,
+ std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // `value` is an iter_arg or an OpResult.
+ int64_t iterArgIdx;
+ if (auto iterArg = value.dyn_cast<BlockArgument>()) {
+ iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
+ } else {
+ iterArgIdx = value.cast<OpResult>().getResultNumber();
+ }
+
+ // An EQ constraint can be added if the yielded value (dimension size)
+ // equals the corresponding block argument (dimension size).
+ assert(forOp.getLoopBody().hasOneBlock() &&
+ "multiple blocks not supported");
+ Value yieldedValue =
+ cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
+ .getOperand(iterArgIdx);
+ Value iterArg = forOp.getRegionIterArg(iterArgIdx);
+ Value initArg = forOp.getInitArgs()[iterArgIdx];
+
+ auto addEqBound = [&]() {
+ if (dim.has_value()) {
+ cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
+ } else {
+ cstr.bound(value) == initArg;
+ }
+ };
+
+ if (yieldedValue == iterArg) {
+ addEqBound();
+ return;
+ }
+
+ // Compute EQ bound for yielded value.
+ AffineMap bound;
+ ValueDimList boundOperands;
+ LogicalResult status = ValueBoundsConstraintSet::computeBound(
+ bound, boundOperands, BoundType::EQ, yieldedValue, dim,
+ [&](Value v, std::optional<int64_t> d) {
+ // Stop when reaching a block argument of the loop body.
+ if (auto bbArg = v.dyn_cast<BlockArgument>())
+ return bbArg.getOwner()->getParentOp() == forOp;
+ // Stop when reaching a value that is defined outside of the loop. It
+ // is impossible to reach an iter_arg from there.
+ Operation *op = v.getDefiningOp();
+ return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
+ });
+ if (failed(status))
+ return;
+ if (bound.getNumResults() != 1)
+ return;
+
+ // Check if computed bound equals the corresponding iter_arg.
+ Value singleValue = nullptr;
+ std::optional<int64_t> singleDim = std::nullopt;
+ if (auto dimExpr = bound.getResult(0).dyn_cast<AffineDimExpr>()) {
+ int64_t idx = dimExpr.getPosition();
+ singleValue = boundOperands[idx].first;
+ singleDim = boundOperands[idx].second;
+ } else if (auto symExpr = bound.getResult(0).dyn_cast<AffineSymbolExpr>()) {
+ int64_t idx = symExpr.getPosition() + bound.getNumDims();
+ singleValue = boundOperands[idx].first;
+ singleDim = boundOperands[idx].second;
+ }
+ if (singleValue == iterArg && singleDim == dim)
+ addEqBound();
+ }
+
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
- // Only IV is supported at the moment.
- if (value != forOp.getInductionVar())
+
+ if (value == forOp.getInductionVar()) {
+ // TODO: Take into account step size.
+ cstr.bound(value) >= forOp.getLowerBound();
+ cstr.bound(value) < forOp.getUpperBound();
return;
+ }
- // TODO: Take into account step size.
- cstr.bound(value) >= forOp.getLowerBound();
- cstr.bound(value) < forOp.getUpperBound();
+ // Handle iter_args and OpResults.
+ populateIterArgBounds(forOp, value, std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
- // iter_arg / return value not supported.
- return;
+ auto forOp = cast<ForOp>(op);
+ // Handle iter_args and OpResults.
+ populateIterArgBounds(forOp, value, dim, cstr);
}
};
}
// Do not process any further if the stop condition is met.
- if (stopCondition(value))
+ auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
+ if (stopCondition(value, maybeDim))
continue;
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
+ assert(!stopCondition(value, dim) &&
+ "stop condition should not be satisfied for starting point");
#endif // NDEBUG
Builder b(value.getContext());
mapOperands.clear();
- if (stopCondition(value)) {
+ if (stopCondition(value, dim)) {
// Special case: If the stop condition is satisfied for the input
// value/dimension, directly return it.
mapOperands.push_back(std::make_pair(value, dim));
// Do not project out `valueDim`.
if (valueDim == p)
return false;
- return !stopCondition(p.first);
+ auto maybeDim =
+ p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
+ return !stopCondition(p.first, maybeDim);
});
// Compute lower and upper bounds for `valueDim`.
return success();
}
+LogicalResult ValueBoundsConstraintSet::computeBound(
+ AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
+ Value value, std::optional<int64_t> dim, ValueDimList dependencies) {
+ return computeBound(resultMap, mapOperands, type, value, dim,
+ [&](Value v, std::optional<int64_t> d) {
+ return llvm::is_contained(dependencies,
+ std::make_pair(v, d));
+ });
+}
+
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, Value value, std::optional<int64_t> dim,
StopConditionFn stopCondition) {
} else {
// No stop condition specified: Keep adding constraints until a bound could
// be computed.
- cstr.processWorklist(/*stopCondition=*/[&](Value v) {
- return cstr.cstr.getConstantBound64(type, pos).has_value();
- });
+ cstr.processWorklist(
+ /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ });
}
// Compute constant bound for `valueDim`.
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_index_result_small(
+// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+// CHECK: "test.some_use"(%[[i]])
+// CHECK: "test.some_use"(%[[i]])
+func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) {
+ %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
+ %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
+ "test.some_use"(%1) : (index) -> ()
+ scf.yield %arg : index
+ }
+ %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ "test.some_use"(%2) : (index) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_index_result(
+// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+// CHECK: "test.some_use"(%[[i]])
+// CHECK: "test.some_use"(%[[i]])
+func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) {
+ %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
+ %add = arith.addi %arg, %a : index
+ %sub = arith.subi %add, %a : index
+
+ %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
+ "test.some_use"(%1) : (index) -> ()
+ scf.yield %sub : index
+ }
+ %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ "test.some_use"(%2) : (index) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_tensor_result_small(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: "test.some_use"(%[[dim]])
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: "test.some_use"(%[[dim]])
+func.func @scf_for_tensor_result_small(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
+ %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
+ %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+ "test.some_use"(%1) : (index) -> ()
+ scf.yield %arg : tensor<?xf32>
+ }
+ %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+ "test.some_use"(%2) : (index) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_tensor_result(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: "test.some_use"(%[[dim]])
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: "test.some_use"(%[[dim]])
+func.func @scf_for_tensor_result(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
+ %cst = arith.constant 5.0 : f32
+ %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
+ %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
+ %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+ "test.some_use"(%1) : (index) -> ()
+ scf.yield %filled : tensor<?xf32>
+ }
+ %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+ "test.some_use"(%2) : (index) -> ()
+ return
+}
+
+// -----
+
+func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: index, %b: index, %c: index) {
+ %cst = arith.constant 5.0 : f32
+ %r1, %r2 = scf.for %iv = %a to %b step %c iter_args(%arg1 = %t1, %arg2 = %t2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %filled1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %filled2 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+ scf.yield %filled2, %filled1 : tensor<?xf32>, tensor<?xf32>
+ }
+ // expected-error @below{{could not reify bound}}
+ %reify1 = "test.reify_bound"(%r1) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+ "test.some_use"(%reify1) : (index) -> ()
+ return
+}
MLIRIR
MLIRPass
MLIRSupport
+ MLIRMemRefDialect
+ MLIRTensorDialect
MLIRVectorUtils
)
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect>();
+ registry
+ .insert<AffineDialect, tensor::TensorDialect, memref::MemRefDialect>();
}
void runOnOperation() override;
// Prepare stop condition. By default, reify in terms of the op's
// operands. No stop condition is used when a constant was requested.
- std::function<bool(Value)> stopCondition = [&](Value v) {
- // Reify in terms of SSA values that are different from `value`.
- return v != value;
- };
+ std::function<bool(Value, std::optional<int64_t>)> stopCondition =
+ [&](Value v, std::optional<int64_t> d) {
+ // Reify in terms of SSA values that are different from `value`.
+ return v != value;
+ };
if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
- stopCondition = stopCondition = [](Value v) {
+ stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
auto bbArg = v.dyn_cast<BlockArgument>();
if (!bbArg)
return false;
"//mlir:Pass",
"//mlir:SCFDialect",
"//mlir:Support",
+ "//mlir:TensorDialect",
"//mlir:Transforms",
"//mlir:ValueBoundsOpInterface",
"//mlir:VectorDialect",