[mlir][SCF] ValueBoundsOpInterface: Support `scf.for` results and iter_args
authorMatthias Springer <springerm@google.com>
Fri, 7 Apr 2023 02:43:41 +0000 (11:43 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 7 Apr 2023 02:57:44 +0000 (11:57 +0900)
If an `scf.for` loop yields an equal index-typed value or a shaped value with the same dimension sizes (in comparison to the corresponding iter_arg), bounds can be computed for the iter_arg and the OpResult of the `scf.for` op.

Differential Revision: https://reviews.llvm.org/D146306

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/CMakeLists.txt
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

index a4a7c98..694c7b0 100644 (file)
@@ -84,7 +84,11 @@ public:
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
-  using StopConditionFn = function_ref<bool(Value)>;
+  ///
+  /// The first parameter of the function is the shaped value/index-typed
+  /// value. The second parameter is the dimension in case of a shaped value.
+  using StopConditionFn =
+      function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
 
   /// Compute a bound for the given index-typed value or shape dimension size.
   /// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -92,16 +96,26 @@ public:
   /// or a shaped value and a dimension.
   ///
   /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values for which `stopCondition` evaluates to
-  /// "true". To that end, the backward slice (reverse use-def chain) of the
-  /// given value is visited in a worklist-driven manner and the constraint set
-  /// is populated according to `ValueBoundsOpInterface` for each visited value.
+  /// is computed in terms of values/dimensions for which `stopCondition`
+  /// evaluates to "true". To that end, the backward slice (reverse use-def
+  /// chain) of the given value is visited in a worklist-driven manner and the
+  /// constraint set is populated according to `ValueBoundsOpInterface` for each
+  /// visited value.
   static LogicalResult computeBound(AffineMap &resultMap,
                                     ValueDimList &mapOperands,
                                     presburger::BoundType type, Value value,
                                     std::optional<int64_t> dim,
                                     StopConditionFn stopCondition);
 
+  /// Compute a bound in terms of the values/dimensions in `dependencies`. The
+  /// computed bound consists of only constant terms and dependent values (or
+  /// dimension sizes thereof).
+  static LogicalResult computeBound(AffineMap &resultMap,
+                                    ValueDimList &mapOperands,
+                                    presburger::BoundType type, Value value,
+                                    std::optional<int64_t> dim,
+                                    ValueDimList dependencies);
+
   /// Compute a constant bound for the given index-typed value or shape
   /// dimension size.
   ///
index 80b4daa..3e93bf0 100644 (file)
@@ -21,19 +21,19 @@ FailureOr<OpFoldResult> mlir::reifyValueBound(OpBuilder &b, Location loc,
                                               std::optional<int64_t> dim) {
   // We are trying to reify a bound for `value`. Construct a stop condition that
   // evaluates to "true" for any SSA value expect for `value`. I.e., the bound
-  // will be computed in terms of any SSA values expect for `value`. The first
+  // will be computed in terms of any SSA values except for `value`. The first
   // such values are operands of the owner of `value`.
-  auto stopCondition = [&](Value v) {
+  auto stopCondition = [&](Value v, std::optional<int64_t> d) {
     // Reify in terms of SSA values that are different from `value`.
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, dim, stopCondition);
 }
 
-FailureOr<OpFoldResult>
-mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
-                      Value value, std::optional<int64_t> dim,
-                      function_ref<bool(Value)> stopCondition) {
+FailureOr<OpFoldResult> mlir::reifyValueBound(
+    OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+    std::optional<int64_t> dim,
+    function_ref<bool(Value, std::optional<int64_t>)> stopCondition) {
   // Compute bound.
   AffineMap boundMap;
   ValueDimList mapOperands;
index c83457b..668629a 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
+using presburger::BoundType;
 
 namespace mlir {
 namespace scf {
@@ -19,22 +20,97 @@ namespace {
 
 struct ForOpInterface
     : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
+
+  /// Populate bounds of values/dimensions for iter_args/OpResults.
+  static void populateIterArgBounds(scf::ForOp forOp, Value value,
+                                    std::optional<int64_t> dim,
+                                    ValueBoundsConstraintSet &cstr) {
+    // `value` is an iter_arg or an OpResult.
+    int64_t iterArgIdx;
+    if (auto iterArg = value.dyn_cast<BlockArgument>()) {
+      iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
+    } else {
+      iterArgIdx = value.cast<OpResult>().getResultNumber();
+    }
+
+    // An EQ constraint can be added if the yielded value (dimension size)
+    // equals the corresponding block argument (dimension size).
+    assert(forOp.getLoopBody().hasOneBlock() &&
+           "multiple blocks not supported");
+    Value yieldedValue =
+        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
+            .getOperand(iterArgIdx);
+    Value iterArg = forOp.getRegionIterArg(iterArgIdx);
+    Value initArg = forOp.getInitArgs()[iterArgIdx];
+
+    auto addEqBound = [&]() {
+      if (dim.has_value()) {
+        cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
+      } else {
+        cstr.bound(value) == initArg;
+      }
+    };
+
+    if (yieldedValue == iterArg) {
+      addEqBound();
+      return;
+    }
+
+    // Compute EQ bound for yielded value.
+    AffineMap bound;
+    ValueDimList boundOperands;
+    LogicalResult status = ValueBoundsConstraintSet::computeBound(
+        bound, boundOperands, BoundType::EQ, yieldedValue, dim,
+        [&](Value v, std::optional<int64_t> d) {
+          // Stop when reaching a block argument of the loop body.
+          if (auto bbArg = v.dyn_cast<BlockArgument>())
+            return bbArg.getOwner()->getParentOp() == forOp;
+          // Stop when reaching a value that is defined outside of the loop. It
+          // is impossible to reach an iter_arg from there.
+          Operation *op = v.getDefiningOp();
+          return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
+        });
+    if (failed(status))
+      return;
+    if (bound.getNumResults() != 1)
+      return;
+
+    // Check if computed bound equals the corresponding iter_arg.
+    Value singleValue = nullptr;
+    std::optional<int64_t> singleDim = std::nullopt;
+    if (auto dimExpr = bound.getResult(0).dyn_cast<AffineDimExpr>()) {
+      int64_t idx = dimExpr.getPosition();
+      singleValue = boundOperands[idx].first;
+      singleDim = boundOperands[idx].second;
+    } else if (auto symExpr = bound.getResult(0).dyn_cast<AffineSymbolExpr>()) {
+      int64_t idx = symExpr.getPosition() + bound.getNumDims();
+      singleValue = boundOperands[idx].first;
+      singleDim = boundOperands[idx].second;
+    }
+    if (singleValue == iterArg && singleDim == dim)
+      addEqBound();
+  }
+
   void populateBoundsForIndexValue(Operation *op, Value value,
                                    ValueBoundsConstraintSet &cstr) const {
     auto forOp = cast<ForOp>(op);
-    // Only IV is supported at the moment.
-    if (value != forOp.getInductionVar())
+
+    if (value == forOp.getInductionVar()) {
+      // TODO: Take into account step size.
+      cstr.bound(value) >= forOp.getLowerBound();
+      cstr.bound(value) < forOp.getUpperBound();
       return;
+    }
 
-    // TODO: Take into account step size.
-    cstr.bound(value) >= forOp.getLowerBound();
-    cstr.bound(value) < forOp.getUpperBound();
+    // Handle iter_args and OpResults.
+    populateIterArgBounds(forOp, value, std::nullopt, cstr);
   }
 
   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
                                        ValueBoundsConstraintSet &cstr) const {
-    // iter_arg / return value not supported.
-    return;
+    auto forOp = cast<ForOp>(op);
+    // Handle iter_args and OpResults.
+    populateIterArgBounds(forOp, value, dim, cstr);
   }
 };
 
index a2885e2..45193fd 100644 (file)
@@ -164,7 +164,8 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
     }
 
     // Do not process any further if the stop condition is met.
-    if (stopCondition(value))
+    auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
+    if (stopCondition(value, maybeDim))
       continue;
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
@@ -213,12 +214,14 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
+  assert(!stopCondition(value, dim) &&
+         "stop condition should not be satisfied for starting point");
 #endif // NDEBUG
 
   Builder b(value.getContext());
   mapOperands.clear();
 
-  if (stopCondition(value)) {
+  if (stopCondition(value, dim)) {
     // Special case: If the stop condition is satisfied for the input
     // value/dimension, directly return it.
     mapOperands.push_back(std::make_pair(value, dim));
@@ -239,7 +242,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     // Do not project out `valueDim`.
     if (valueDim == p)
       return false;
-    return !stopCondition(p.first);
+    auto maybeDim =
+        p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
+    return !stopCondition(p.first, maybeDim);
   });
 
   // Compute lower and upper bounds for `valueDim`.
@@ -338,6 +343,16 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   return success();
 }
 
+LogicalResult ValueBoundsConstraintSet::computeBound(
+    AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
+    Value value, std::optional<int64_t> dim, ValueDimList dependencies) {
+  return computeBound(resultMap, mapOperands, type, value, dim,
+                      [&](Value v, std::optional<int64_t> d) {
+                        return llvm::is_contained(dependencies,
+                                                  std::make_pair(v, d));
+                      });
+}
+
 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, Value value, std::optional<int64_t> dim,
     StopConditionFn stopCondition) {
@@ -354,9 +369,10 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   } else {
     // No stop condition specified: Keep adding constraints until a bound could
     // be computed.
-    cstr.processWorklist(/*stopCondition=*/[&](Value v) {
-      return cstr.cstr.getConstantBound64(type, pos).has_value();
-    });
+    cstr.processWorklist(
+        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
+          return cstr.cstr.getConstantBound64(type, pos).has_value();
+        });
   }
 
   // Compute constant bound for `valueDim`.
index 18f1904..e4d7141 100644 (file)
@@ -12,3 +12,95 @@ func.func @scf_for(%a: index, %b: index, %c: index) {
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_for_index_result_small(
+//  CHECK-SAME:     %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+//       CHECK:   "test.some_use"(%[[i]])
+//       CHECK:   "test.some_use"(%[[i]])
+func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) {
+  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
+    %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
+    "test.some_use"(%1) : (index) -> ()
+    scf.yield %arg : index
+  }
+  %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+  "test.some_use"(%2) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_index_result(
+//  CHECK-SAME:     %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+//       CHECK:   "test.some_use"(%[[i]])
+//       CHECK:   "test.some_use"(%[[i]])
+func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) {
+  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
+    %add = arith.addi %arg, %a : index
+    %sub = arith.subi %add, %a : index
+
+    %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
+    "test.some_use"(%1) : (index) -> ()
+    scf.yield %sub : index
+  }
+  %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+  "test.some_use"(%2) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_tensor_result_small(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
+//       CHECK:   "test.some_use"(%[[dim]])
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
+//       CHECK:   "test.some_use"(%[[dim]])
+func.func @scf_for_tensor_result_small(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
+  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
+    %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+    "test.some_use"(%1) : (index) -> ()
+    scf.yield %arg : tensor<?xf32>
+  }
+  %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+  "test.some_use"(%2) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_tensor_result(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
+//       CHECK:   "test.some_use"(%[[dim]])
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
+//       CHECK:   "test.some_use"(%[[dim]])
+func.func @scf_for_tensor_result(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
+  %cst = arith.constant 5.0 : f32
+  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
+    %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
+    %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+    "test.some_use"(%1) : (index) -> ()
+    scf.yield %filled : tensor<?xf32>
+  }
+  %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+  "test.some_use"(%2) : (index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: index, %b: index, %c: index) {
+  %cst = arith.constant 5.0 : f32
+  %r1, %r2 = scf.for %iv = %a to %b step %c iter_args(%arg1 = %t1, %arg2 = %t2) -> (tensor<?xf32>, tensor<?xf32>) {
+    %filled1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+    %filled2 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %filled2, %filled1 : tensor<?xf32>, tensor<?xf32>
+  }
+  // expected-error @below{{could not reify bound}}
+  %reify1 = "test.reify_bound"(%r1) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}
index d76884d..a6e73e0 100644 (file)
@@ -25,5 +25,7 @@ add_mlir_library(MLIRAffineTransformsTestPasses
   MLIRIR
   MLIRPass
   MLIRSupport
+  MLIRMemRefDialect
+  MLIRTensorDialect
   MLIRVectorUtils
   )
index 7f66db3..922c657 100644 (file)
@@ -9,6 +9,8 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Pass/Pass.h"
@@ -33,7 +35,8 @@ struct TestReifyValueBounds
   TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect>();
+    registry
+        .insert<AffineDialect, tensor::TensorDialect, memref::MemRefDialect>();
   }
 
   void runOnOperation() override;
@@ -101,13 +104,14 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value)> stopCondition = [&](Value v) {
-        // Reify in terms of SSA values that are different from `value`.
-        return v != value;
-      };
+      std::function<bool(Value, std::optional<int64_t>)> stopCondition =
+          [&](Value v, std::optional<int64_t> d) {
+            // Reify in terms of SSA values that are different from `value`.
+            return v != value;
+          };
       if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        stopCondition = stopCondition = [](Value v) {
+        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
           auto bbArg = v.dyn_cast<BlockArgument>();
           if (!bbArg)
             return false;
index 28a1a19..8f4639d 100644 (file)
@@ -558,6 +558,7 @@ cc_library(
         "//mlir:Pass",
         "//mlir:SCFDialect",
         "//mlir:Support",
+        "//mlir:TensorDialect",
         "//mlir:Transforms",
         "//mlir:ValueBoundsOpInterface",
         "//mlir:VectorDialect",