[mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds
authorMatthias Springer <springerm@google.com>
Fri, 7 Apr 2023 01:47:22 +0000 (10:47 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 7 Apr 2023 01:48:19 +0000 (10:48 +0900)
This change also adds support for `affine.min` and `affine.max` ops.

Differential Revision: https://reviews.llvm.org/D145787

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

index 2abd329..e957fb1 100644 (file)
@@ -51,14 +51,15 @@ FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
 
 /// Reify a bound for the given index-typed value or shape dimension size in
 /// terms of the owning op's operands. `dim` must be `nullopt` if and only if
-/// `value` is index-typed.
+/// `value` is index-typed. LB and EQ bounds are closed, UB bounds are open.
 FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
                                         presburger::BoundType type, Value value,
                                         std::optional<int64_t> dim);
 
 /// Reify a bound for the given index-typed value or shape dimension size in
 /// terms of SSA values for which `stopCondition` is met. `dim` must be
-/// `nullopt` if and only if `value` is index-typed.
+/// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are
+/// closed, UB bounds are open.
 ///
 /// Example:
 /// %0 = arith.addi %a, %b : index
index ed8b9a7..0036023 100644 (file)
@@ -38,6 +38,48 @@ struct AffineApplyOpInterface
   }
 };
 
+struct AffineMinOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
+                                                   AffineMinOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto minOp = cast<AffineMinOp>(op);
+    assert(value == minOp.getResult() && "invalid value");
+
+    // Align affine map results with dims/symbols in the constraint set.
+    for (AffineExpr expr : minOp.getAffineMap().getResults()) {
+      SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+          minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+          minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      AffineExpr bound =
+          expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+      cstr.bound(value) <= bound;
+    }
+  };
+};
+
+struct AffineMaxOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AffineMaxOpInterface,
+                                                   AffineMaxOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto maxOp = cast<AffineMaxOp>(op);
+    assert(value == maxOp.getResult() && "invalid value");
+
+    // Align affine map results with dims/symbols in the constraint set.
+    for (AffineExpr expr : maxOp.getAffineMap().getResults()) {
+      SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+          maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+          maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      AffineExpr bound =
+          expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+      cstr.bound(value) >= bound;
+    }
+  };
+};
+
 } // namespace
 } // namespace mlir
 
@@ -45,5 +87,7 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
     AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
+    AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
+    AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
 }
index 73757b6..8db5e68 100644 (file)
@@ -214,9 +214,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   assertValidValueDim(value, dim);
 #endif // NDEBUG
 
-  // Only EQ bounds are supported at the moment.
-  assert(type == BoundType::EQ && "unsupported bound type");
-
   Builder b(value.getContext());
   mapOperands.clear();
 
@@ -249,16 +246,39 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   SmallVector<AffineMap> lb(1), ub(1);
   cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
                            /*getClosedUB=*/true);
+
   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
   // case, no lower/upper bound can be computed at the moment.
-  if (lb.empty() || !lb[0] || ub.empty() || !ub[0] ||
-      lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1)
+  // EQ, UB bounds: upper bound is needed.
+  if ((type != BoundType::LB) &&
+      (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
     return failure();
+  // EQ, LB bounds: lower bound is needed.
+  if ((type != BoundType::UB) &&
+      (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
+    return failure();
+
+  // TODO: Generate an affine map with multiple results.
+  if (type != BoundType::LB)
+    assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
+           "multiple bounds not supported");
+  if (type != BoundType::UB)
+    assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
+           "multiple bounds not supported");
 
-  // Look for same lower and upper bound: EQ bound.
-  if (ub[0] != lb[0])
+  // EQ bound: lower and upper bound must match.
+  if (type == BoundType::EQ && ub[0] != lb[0])
     return failure();
 
+  AffineMap bound;
+  if (type == BoundType::EQ || type == BoundType::LB) {
+    bound = lb[0];
+  } else {
+    // Computed UB is a closed bound. Turn into an open bound.
+    bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
+                           ub[0].getResult(0) + 1);
+  }
+
   // Gather all SSA values that are used in the computed bound.
   assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
          "inconsistent mapping state");
@@ -273,10 +293,10 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool used = false;
     bool isDim = i < cstr.cstr.getNumDimVars();
     if (isDim) {
-      if (lb[0].isFunctionOfDim(i))
+      if (bound.isFunctionOfDim(i))
         used = true;
     } else {
-      if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
+      if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
         used = true;
     }
 
@@ -312,7 +332,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     mapOperands.push_back(std::make_pair(value, dim));
   }
 
-  resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols,
+  resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
                                           numDims, numSymbols);
   return success();
 }
index 436ec4a..338c48c 100644 (file)
@@ -12,3 +12,49 @@ func.func @affine_apply(%a: index, %b: index) -> index {
   %1 = "test.reify_bound"(%0) : (index) -> (index)
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_max_lb(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK:   return %[[c2]]
+func.func @affine_max_lb(%a: index) -> (index) {
+  // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always
+  // returns the constant one at the moment.
+  %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+  %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+func.func @affine_max_ub(%a: index) -> (index) {
+  %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+  // expected-error @below{{could not reify bound}}
+  %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @affine_min_ub(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[c3]]
+func.func @affine_min_ub(%a: index) -> (index) {
+  // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always
+  // returns the constant one at the moment.
+  %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+  %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+func.func @affine_min_lb(%a: index) -> (index) {
+  %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+  // expected-error @below{{could not reify bound}}
+  %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+  return %2 : index
+}
index 5828315..e2a06a5 100644 (file)
@@ -16,6 +16,7 @@
 #define PASS_NAME "test-affine-reify-value-bounds"
 
 using namespace mlir;
+using mlir::presburger::BoundType;
 
 namespace {
 
@@ -45,6 +46,16 @@ private:
 
 } // namespace
 
+FailureOr<BoundType> parseBoundType(std::string type) {
+  if (type == "EQ")
+    return BoundType::EQ;
+  if (type == "LB")
+    return BoundType::LB;
+  if (type == "UB")
+    return BoundType::UB;
+  return failure();
+}
+
 /// Look for "test.reify_bound" ops in the input and replace their results with
 /// the reified values.
 static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
@@ -67,6 +78,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
         return WalkResult::skip();
       }
 
+      // Get bound type.
+      std::string boundTypeStr = "EQ";
+      if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
+        boundTypeStr = boundTypeAttr.str();
+      auto boundType = parseBoundType(boundTypeStr);
+      if (failed(boundType)) {
+        op->emitOpError("invalid op");
+        return WalkResult::interrupt();
+      }
+
+      // Get shape dimension (if any).
       auto dim = value.getType().isIndex()
                      ? std::nullopt
                      : std::make_optional<int64_t>(
@@ -77,8 +99,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
       FailureOr<OpFoldResult> reified;
       if (!reifyToFuncArgs) {
         // Reify in terms of the op's operands.
-        reified = reifyValueBound(rewriter, op->getLoc(),
-                                  presburger::BoundType::EQ, value, dim);
+        reified =
+            reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim);
       } else {
         // Reify in terms of function block arguments.
         auto stopCondition = [](Value v) {
@@ -88,9 +110,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
           return isa<FunctionOpInterface>(
               bbArg.getParentBlock()->getParentOp());
         };
-        reified =
-            reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ,
-                            value, dim, stopCondition);
+        reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
+                                  dim, stopCondition);
       }
       if (failed(reified)) {
         op->emitOpError("could not reify bound");