[mlir][Interfaces] ValueBoundsOpInterface: Check if two values are equal
authorMatthias Springer <me@m-sp.org>
Thu, 25 May 2023 17:03:17 +0000 (19:03 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 25 May 2023 17:08:48 +0000 (19:08 +0200)
Add a helper function that computes if two SSA values have the same value, utilizing the `ValueBoundsOpInterface` infrastructure. Two SSA values have the same value, an equality bound of 0 can be derived for their subtraction.

The helper function can also be used to determine if two tensor dimension sizes are equal.

Differential Revision: https://reviews.llvm.org/D151443

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

index c4d446e..367ae28 100644 (file)
@@ -156,6 +156,35 @@ public:
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
 
+  /// Compute a constant bound for the given affine map, where dims and symbols
+  /// are bound to the given operands. The affine map must have exactly one
+  /// result.
+  ///
+  /// This function traverses the backward slice of the given operands in a
+  /// worklist-driven manner until `stopCondition` evaluates to "true". The
+  /// constraint set is populated according to `ValueBoundsOpInterface` for each
+  /// visited value. (No constraints are added for values for which the stop
+  /// condition evaluates to "true".)
+  ///
+  /// The stop condition is optional: If none is specified, the backward slice
+  /// is traversed in a breadth-first manner until a constant bound could be
+  /// computed.
+  ///
+  /// By default, lower/equal bounds are closed and upper bounds are open. If
+  /// `closedUB` is set to "true", upper bounds are also closed.
+  static FailureOr<int64_t> computeConstantBound(
+      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
+      StopConditionFn stopCondition = nullptr, bool closedUB = false);
+
+  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  /// equality could not be determined.
+  ///
+  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+  /// index-typed.
+  static FailureOr<bool> areEqual(Value value1, Value value2,
+                                  std::optional<int64_t> dim1 = std::nullopt,
+                                  std::optional<int64_t> dim2 = std::nullopt);
+
   /// Add a bound for the given index-typed value or shaped value. This function
   /// returns a builder that adds the bound.
   BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
@@ -199,13 +228,23 @@ protected:
   int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
 
   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
-  /// "false", a dimension is added.
+  /// "false", a dimension is added. The value/dimension is added to the
+  /// worklist.
   ///
   /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
   /// cannot be multiplied. Furthermore, bounds can only be queried for
   /// dimensions but not for symbols.
   int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
 
+  /// Insert an anonymous column into the constraint set. The column is not
+  /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
+  /// is added.
+  ///
+  /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
+  /// cannot be multiplied. Furthermore, bounds can only be queried for
+  /// dimensions but not for symbols.
+  int64_t insert(bool isSymbol = true);
+
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
 
@@ -213,7 +252,7 @@ protected:
   void projectOut(function_ref<bool(ValueDim)> condition);
 
   /// Mapping of columns to values/shape dimensions.
-  SmallVector<ValueDim> positionToValueDim;
+  SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
   DenseMap<ValueDim, int64_t> valueDimToPosition;
 
index 28f34a9..bc7d6b4 100644 (file)
@@ -124,12 +124,24 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
   positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
-    valueDimToPosition[positionToValueDim[i]] = i;
+    if (positionToValueDim[i].has_value())
+      valueDimToPosition[*positionToValueDim[i]] = i;
 
   worklist.push(pos);
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
+  int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
+                         : cstr.appendVar(VarKind::SetDim);
+  positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
+  // Update reverse mapping.
+  for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
+    if (positionToValueDim[i].has_value())
+      valueDimToPosition[*positionToValueDim[i]] = i;
+  return pos;
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -155,7 +167,9 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
-    ValueDim valueDim = positionToValueDim[pos];
+    assert(positionToValueDim[pos].has_value() &&
+           "did not expect std::nullopt on worklist");
+    ValueDim valueDim = *positionToValueDim[pos];
     Value value = valueDim.first;
     int64_t dim = valueDim.second;
 
@@ -191,20 +205,24 @@ void ValueBoundsConstraintSet::projectOut(int64_t pos) {
   assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
          "invalid position");
   cstr.projectOut(pos);
-  bool erased = valueDimToPosition.erase(positionToValueDim[pos]);
-  (void)erased;
-  assert(erased && "inconsistent reverse mapping");
+  if (positionToValueDim[pos].has_value()) {
+    bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
+    (void)erased;
+    assert(erased && "inconsistent reverse mapping");
+  }
   positionToValueDim.erase(positionToValueDim.begin() + pos);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
-    valueDimToPosition[positionToValueDim[i]] = i;
+    if (positionToValueDim[i].has_value())
+      valueDimToPosition[*positionToValueDim[i]] = i;
 }
 
 void ValueBoundsConstraintSet::projectOut(
     function_ref<bool(ValueDim)> condition) {
   int64_t nextPos = 0;
   while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
-    if (condition(positionToValueDim[nextPos])) {
+    if (positionToValueDim[nextPos].has_value() &&
+        condition(*positionToValueDim[nextPos])) {
       projectOut(nextPos);
       // The column was projected out so another column is now at that position.
       // Do not increase the counter.
@@ -332,7 +350,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
       replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
     }
 
-    ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i];
+    assert(cstr.positionToValueDim[i].has_value() &&
+           "cannot build affine map in terms of anonymous column");
+    ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
     Value value = valueDim.first;
     int64_t dim = valueDim.second;
     if (dim == ValueBoundsConstraintSet::kIndexValue) {
@@ -406,10 +426,35 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   assertValidValueDim(value, dim);
 #endif // NDEBUG
 
-  // Process the backward slice of `value` (i.e., reverse use-def chain) until
-  // `stopCondition` is met.
-  ValueBoundsConstraintSet cstr(value.getContext());
-  int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+  AffineMap map =
+      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+                     Builder(value.getContext()).getAffineDimExpr(0));
+  return computeConstantBound(type, map, {{value, dim}}, stopCondition,
+                              closedUB);
+}
+
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+    presburger::BoundType type, AffineMap map, ValueDimList operands,
+    StopConditionFn stopCondition, bool closedUB) {
+  assert(map.getNumResults() == 1 && "expected affine map with one result");
+  ValueBoundsConstraintSet cstr(map.getContext());
+  int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+  // Add map and operands to the constraint set. Dimensions are converted to
+  // symbols. All operands are added to the worklist.
+  auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+    return cstr.getExpr(v.first, v.second);
+  };
+  SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+  SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+  cstr.addBound(
+      presburger::BoundType::EQ, pos,
+      map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+  // Process the backward slice of `operands` (i.e., reverse use-def chain)
+  // until `stopCondition` is met.
   if (stopCondition) {
     cstr.processWorklist(stopCondition);
   } else {
@@ -428,6 +473,27 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+                                   std::optional<int64_t> dim1,
+                                   std::optional<int64_t> dim2) {
+#ifndef NDEBUG
+  assertValidValueDim(value1, dim1);
+  assertValidValueDim(value2, dim2);
+#endif // NDEBUG
+
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  Builder b(value1.getContext());
+  AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+                                 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
+  FailureOr<int64_t> bound = computeConstantBound(
+      presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
+  if (failed(bound))
+    return failure();
+  return *bound == 0;
+}
+
 ValueBoundsConstraintSet::BoundBuilder &
 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
   assert(!this->dim.has_value() && "dim was already set");
index 614c601..45520da 100644 (file)
@@ -156,3 +156,49 @@ func.func @rank(%t: tensor<5xf32>) -> index {
   %1 = "test.reify_bound"(%0) : (index) -> (index)
   return %1 : index
 }
+
+// -----
+
+func.func @dynamic_dims_are_equal(%t: tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+  %dim1 = tensor.dim %t, %c0 : tensor<?xf32>
+  // expected-remark @below {{equal}}
+  "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @dynamic_dims_are_different(%t: tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+  %val = arith.addi %dim0, %c1 : index
+  // expected-remark @below {{different}}
+  "test.are_equal"(%dim0, %val) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @dynamic_dims_are_maybe_equal_1(%t: tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c5 = arith.constant 5 : index
+  %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+  // expected-error @below {{could not determine equality}}
+  "test.are_equal"(%dim0, %c5) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+  // expected-error @below {{could not determine equality}}
+  "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+  return
+}
index dff619e..db3b9a1 100644 (file)
@@ -175,10 +175,38 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
   return failure(result.wasInterrupted());
 }
 
+/// Look for "test.are_equal" ops and emit errors/remarks.
+static LogicalResult testEquality(func::FuncOp funcOp) {
+  IRRewriter rewriter(funcOp.getContext());
+  WalkResult result = funcOp.walk([&](Operation *op) {
+    // Look for test.are_equal ops.
+    if (op->getName().getStringRef() == "test.are_equal") {
+      if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
+          !op->getOperand(1).getType().isIndex()) {
+        op->emitOpError("invalid op");
+        return WalkResult::skip();
+      }
+      FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
+          op->getOperand(0), op->getOperand(1));
+      if (failed(equal)) {
+        op->emitError("could not determine equality");
+      } else if (*equal) {
+        op->emitRemark("equal");
+      } else {
+        op->emitRemark("different");
+      }
+    }
+    return WalkResult::advance();
+  });
+  return failure(result.wasInterrupted());
+}
+
 void TestReifyValueBounds::runOnOperation() {
   if (failed(
           testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
     signalPassFailure();
+  if (failed(testEquality(getOperation())))
+    signalPassFailure();
 }
 
 namespace mlir {