[mlir][Index][NFC] Migrate index dialect to the new fold API
authorMarkus Böck <markus.boeck02@gmail.com>
Tue, 10 Jan 2023 19:05:49 +0000 (20:05 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Wed, 11 Jan 2023 20:47:25 +0000 (21:47 +0100)
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Similar to the patch for the arith dialect, the index dialects fold implementations make heavy use of generic fold functions, hence the change being comparatively mechanical and mostly changing the function signature.

Differential Revision: https://reviews.llvm.org/D141502

mlir/include/mlir/Dialect/Index/IR/IndexDialect.td
mlir/lib/Dialect/Index/IR/IndexOps.cpp

index be0fea7..7e1130c 100644 (file)
@@ -83,6 +83,7 @@ def IndexDialect : Dialect {
 
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // INDEX_DIALECT
index dee6025..598bb9a 100644 (file)
@@ -115,36 +115,40 @@ static OpFoldResult foldBinaryOpChecked(
 // AddOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // SubOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // DivSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold division by zero.
         if (rhs.isZero())
           return std::nullopt;
@@ -156,9 +160,10 @@ OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
 // DivUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold division by zero.
         if (rhs.isZero())
           return std::nullopt;
@@ -193,18 +198,19 @@ static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
   return (n + x).sdiv(m) + 1;
 }
 
-OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, calculateCeilDivS);
+OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
 }
 
 //===----------------------------------------------------------------------===//
 // CeilDivUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
   // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
   return foldBinaryOpChecked(
-      operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &n, const APInt &m) -> Optional<APInt> {
         // Don't fold division by zero.
         if (m.isZero())
           return std::nullopt;
@@ -242,56 +248,58 @@ static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
   return -1 - (x - n).sdiv(m);
 }
 
-OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, calculateFloorDivS);
+OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
 }
 
 //===----------------------------------------------------------------------===//
 // RemSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.srem(rhs);
-  });
+OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
 }
 
 //===----------------------------------------------------------------------===//
 // RemUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.urem(rhs);
-  });
+OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.sgt(rhs) ? lhs : rhs;
-  });
+OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.sgt(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.ugt(rhs) ? lhs : rhs;
-  });
+OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.ugt(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
 // MinSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
+OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
     return lhs.slt(rhs) ? lhs : rhs;
   });
 }
@@ -300,8 +308,8 @@ OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
 // MinUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
+OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
     return lhs.ult(rhs) ? lhs : rhs;
   });
 }
@@ -310,9 +318,10 @@ OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
 // ShlOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // We cannot fold if the RHS is greater than or equal to 32 because
         // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
         // already treated as unsigned.
@@ -326,9 +335,10 @@ OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
 // ShrSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold if RHS is greater than or equal to 32.
         if (rhs.uge(32))
           return {};
@@ -340,9 +350,10 @@ OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
 // ShrUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold if RHS is greater than or equal to 32.
         if (rhs.uge(32))
           return {};
@@ -354,27 +365,30 @@ OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
 // AndOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // OrOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // XOrOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -425,10 +439,9 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
   llvm_unreachable("unhandled IndexCmpPredicate predicate");
 }
 
-OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "compare expected 2 operands");
-  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
-  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
+OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
+  auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
+  auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
   if (!lhs || !rhs)
     return {};
 
@@ -453,9 +466,7 @@ void ConstantOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-  return getValueAttr();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
   build(b, state, b.getIndexType(), b.getIndexAttr(value));
@@ -465,7 +476,7 @@ void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
 // BoolConstantOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
   return getValueAttr();
 }