From 041bc485bf2122b238eb1a336d3a38168feb8eaa Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 7 Apr 2023 10:47:22 +0900 Subject: [PATCH] [mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds This change also adds support for `affine.min` and `affine.max` ops. Differential Revision: https://reviews.llvm.org/D145787 --- .../mlir/Dialect/Affine/Transforms/Transforms.h | 5 ++- .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp | 44 +++++++++++++++++++++ mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 40 ++++++++++++++----- .../Affine/value-bounds-op-interface-impl.mlir | 46 ++++++++++++++++++++++ .../lib/Dialect/Affine/TestReifyValueBounds.cpp | 31 ++++++++++++--- 5 files changed, 149 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h index 2abd329..e957fb1 100644 --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -51,14 +51,15 @@ FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); /// Reify a bound for the given index-typed value or shape dimension size in /// terms of the owning op's operands. `dim` must be `nullopt` if and only if -/// `value` is index-typed. +/// `value` is index-typed. LB and EQ bounds are closed, UB bounds are open. FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim); /// Reify a bound for the given index-typed value or shape dimension size in /// terms of SSA values for which `stopCondition` is met. `dim` must be -/// `nullopt` if and only if `value` is index-typed. +/// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are +/// closed, UB bounds are open. /// /// Example: /// %0 = arith.addi %a, %b : index diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index ed8b9a7..0036023 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -38,6 +38,48 @@ struct AffineApplyOpInterface } }; +struct AffineMinOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto minOp = cast(op); + assert(value == minOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : minOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.bound(value) <= bound; + } + }; +}; + +struct AffineMaxOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto maxOp = cast(op); + assert(value == maxOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : maxOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.bound(value) >= bound; + } + }; +}; + } // namespace } // namespace mlir @@ -45,5 +87,7 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { AffineApplyOp::attachInterface(*ctx); + AffineMaxOp::attachInterface(*ctx); + AffineMinOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 73757b6..8db5e68 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -214,9 +214,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound( assertValidValueDim(value, dim); #endif // NDEBUG - // Only EQ bounds are supported at the moment. - assert(type == BoundType::EQ && "unsupported bound type"); - Builder b(value.getContext()); mapOperands.clear(); @@ -249,16 +246,39 @@ LogicalResult ValueBoundsConstraintSet::computeBound( SmallVector lb(1), ub(1); cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a // case, no lower/upper bound can be computed at the moment. - if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || - lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + // EQ, UB bounds: upper bound is needed. + if ((type != BoundType::LB) && + (ub.empty() || !ub[0] || ub[0].getNumResults() == 0)) return failure(); + // EQ, LB bounds: lower bound is needed. + if ((type != BoundType::UB) && + (lb.empty() || !lb[0] || lb[0].getNumResults() == 0)) + return failure(); + + // TODO: Generate an affine map with multiple results. + if (type != BoundType::LB) + assert(ub.size() == 1 && ub[0].getNumResults() == 1 && + "multiple bounds not supported"); + if (type != BoundType::UB) + assert(lb.size() == 1 && lb[0].getNumResults() == 1 && + "multiple bounds not supported"); - // Look for same lower and upper bound: EQ bound. - if (ub[0] != lb[0]) + // EQ bound: lower and upper bound must match. + if (type == BoundType::EQ && ub[0] != lb[0]) return failure(); + AffineMap bound; + if (type == BoundType::EQ || type == BoundType::LB) { + bound = lb[0]; + } else { + // Computed UB is a closed bound. Turn into an open bound. + bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(), + ub[0].getResult(0) + 1); + } + // Gather all SSA values that are used in the computed bound. assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && "inconsistent mapping state"); @@ -273,10 +293,10 @@ LogicalResult ValueBoundsConstraintSet::computeBound( bool used = false; bool isDim = i < cstr.cstr.getNumDimVars(); if (isDim) { - if (lb[0].isFunctionOfDim(i)) + if (bound.isFunctionOfDim(i)) used = true; } else { - if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) used = true; } @@ -312,7 +332,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound( mapOperands.push_back(std::make_pair(value, dim)); } - resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols, + resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols, numDims, numSymbols); return success(); } diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 436ec4a..338c48c 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -12,3 +12,49 @@ func.func @affine_apply(%a: index, %b: index) -> index { %1 = "test.reify_bound"(%0) : (index) -> (index) return %1 : index } + +// ----- + +// CHECK-LABEL: func @affine_max_lb( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: return %[[c2]] +func.func @affine_max_lb(%a: index) -> (index) { + // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_max_ub(%a: index) -> (index) { + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK-LABEL: func @affine_min_ub( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func.func @affine_min_ub(%a: index) -> (index) { + // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_min_lb(%a: index) -> (index) { + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 5828315..e2a06a5 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -16,6 +16,7 @@ #define PASS_NAME "test-affine-reify-value-bounds" using namespace mlir; +using mlir::presburger::BoundType; namespace { @@ -45,6 +46,16 @@ private: } // namespace +FailureOr parseBoundType(std::string type) { + if (type == "EQ") + return BoundType::EQ; + if (type == "LB") + return BoundType::LB; + if (type == "UB") + return BoundType::UB; + return failure(); +} + /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. static LogicalResult testReifyValueBounds(func::FuncOp funcOp, @@ -67,6 +78,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, return WalkResult::skip(); } + // Get bound type. + std::string boundTypeStr = "EQ"; + if (auto boundTypeAttr = op->getAttrOfType("type")) + boundTypeStr = boundTypeAttr.str(); + auto boundType = parseBoundType(boundTypeStr); + if (failed(boundType)) { + op->emitOpError("invalid op"); + return WalkResult::interrupt(); + } + + // Get shape dimension (if any). auto dim = value.getType().isIndex() ? std::nullopt : std::make_optional( @@ -77,8 +99,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, FailureOr reified; if (!reifyToFuncArgs) { // Reify in terms of the op's operands. - reified = reifyValueBound(rewriter, op->getLoc(), - presburger::BoundType::EQ, value, dim); + reified = + reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim); } else { // Reify in terms of function block arguments. auto stopCondition = [](Value v) { @@ -88,9 +110,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, return isa( bbArg.getParentBlock()->getParentOp()); }; - reified = - reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ, - value, dim, stopCondition); + reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, + dim, stopCondition); } if (failed(reified)) { op->emitOpError("could not reify bound"); -- 2.7.4