[mlir][math][NFC] Migrate math dialect to the new fold API
authorMarkus Böck <markus.boeck02@gmail.com>
Tue, 10 Jan 2023 18:52:56 +0000 (19:52 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Wed, 11 Jan 2023 17:11:46 +0000 (18:11 +0100)
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Similar to the patch for the arith dialect, the math dialects fold implementations make heavy use of generic fold functions, hence the change being comparatively mechanical and mostly changing the function signature.

Differential Revision: https://reviews.llvm.org/D141500

mlir/include/mlir/Dialect/Math/IR/MathBase.td
mlir/lib/Dialect/Math/IR/MathOps.cpp

index 0189fd5..e63db4c 100644 (file)
@@ -30,5 +30,6 @@ def Math_Dialect : Dialect {
     ```
   }];
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 #endif // MATH_BASE
index 621bfa5..78186ba 100644 (file)
@@ -25,8 +25,8 @@ using namespace mlir::math;
 // AbsFOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AbsFOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands,
+OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
                                      [](const APFloat &a) { return abs(a); });
 }
 
@@ -34,8 +34,8 @@ OpFoldResult math::AbsFOp::fold(ArrayRef<Attribute> operands) {
 // AbsIOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AbsIOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands,
+OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
                                        [](const APInt &a) { return a.abs(); });
 }
 
@@ -43,9 +43,9 @@ OpFoldResult math::AbsIOp::fold(ArrayRef<Attribute> operands) {
 // AtanOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AtanOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(atan(a.convertToDouble()));
@@ -61,9 +61,10 @@ OpFoldResult math::AtanOp::fold(ArrayRef<Attribute> operands) {
 // Atan2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Atan2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
         if (a.isZero() && b.isZero())
           return llvm::APFloat::getNaN(a.getSemantics());
 
@@ -83,20 +84,21 @@ OpFoldResult math::Atan2Op::fold(ArrayRef<Attribute> operands) {
 // CeilOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::TowardPositive);
-    return result;
-  });
+OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::TowardPositive);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CopySignOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldBinaryOp<FloatAttr>(operands,
+OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
                                       [](const APFloat &a, const APFloat &b) {
                                         APFloat result(a);
                                         result.copySign(b);
@@ -108,9 +110,9 @@ OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
 // CosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CosOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(cos(a.convertToDouble()));
@@ -126,9 +128,9 @@ OpFoldResult math::CosOp::fold(ArrayRef<Attribute> operands) {
 // SinOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::SinOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(sin(a.convertToDouble()));
@@ -144,39 +146,42 @@ OpFoldResult math::SinOp::fold(ArrayRef<Attribute> operands) {
 // CountLeadingZerosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countLeadingZeros());
-  });
+OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countLeadingZeros());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CountTrailingZerosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countTrailingZeros());
-  });
+OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countTrailingZeros());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CtPopOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countPopulation());
-  });
+OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countPopulation());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // ErfOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ErfOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(erf(a.convertToDouble()));
@@ -192,9 +197,10 @@ OpFoldResult math::ErfOp::fold(ArrayRef<Attribute> operands) {
 // IPowIOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::IPowIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<IntegerAttr>(
-      operands, [](const APInt &base, const APInt &power) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &base, const APInt &power) -> Optional<APInt> {
         unsigned width = base.getBitWidth();
         auto zeroValue = APInt::getZero(width);
         APInt oneValue{width, 1ULL, /*isSigned=*/true};
@@ -242,9 +248,9 @@ OpFoldResult math::IPowIOp::fold(ArrayRef<Attribute> operands) {
 // LogOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -262,9 +268,9 @@ OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
 // Log2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -282,9 +288,9 @@ OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
 // Log10Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -303,9 +309,9 @@ OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
 // Log1pOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           if ((a + APFloat(1.0)).isNegative())
@@ -325,9 +331,10 @@ OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
 // PowFOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
         if (a.getSizeInBits(a.getSemantics()) == 64 &&
             b.getSizeInBits(b.getSemantics()) == 64)
           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
@@ -344,9 +351,9 @@ OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
 // SqrtOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -365,9 +372,9 @@ OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
 // ExpOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(exp(a.convertToDouble()));
@@ -383,9 +390,9 @@ OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
 // Exp2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(exp2(a.convertToDouble()));
@@ -401,9 +408,9 @@ OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
 // ExpM1Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ExpM1Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(expm1(a.convertToDouble()));
@@ -419,9 +426,9 @@ OpFoldResult math::ExpM1Op::fold(ArrayRef<Attribute> operands) {
 // TanOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TanOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(tan(a.convertToDouble()));
@@ -437,9 +444,9 @@ OpFoldResult math::TanOp::fold(ArrayRef<Attribute> operands) {
 // TanhOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TanhOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(tanh(a.convertToDouble()));
@@ -455,33 +462,35 @@ OpFoldResult math::TanhOp::fold(ArrayRef<Attribute> operands) {
 // RoundEvenOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::RoundEvenOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
-    return result;
-  });
+OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // FloorOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::FloorOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::TowardNegative);
-    return result;
-  });
+OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::TowardNegative);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // RoundOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::RoundOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(round(a.convertToDouble()));
@@ -497,9 +506,9 @@ OpFoldResult math::RoundOp::fold(ArrayRef<Attribute> operands) {
 // TruncOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TruncOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(trunc(a.convertToDouble()));