[mlir][Arithmetic] Use matchPattern to simplify code.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 22 Apr 2022 07:46:01 +0000 (07:46 +0000)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 22 Apr 2022 08:42:51 +0000 (08:42 +0000)
This patch replaces some code with matchPattern and move them before the constant folder function in order to avoid redundant invoking.

Differential Revision: https://reviews.llvm.org/D124235

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

index 1fa4b1b..5a10440 100644 (file)
@@ -262,6 +262,10 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+  // divui (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would require a division by zero.
   bool div0 = false;
   auto result =
@@ -273,15 +277,6 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
         return a.udiv(b);
       });
 
-  // Fold out division by one. Assumes all tensors of all ones are splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return div0 ? Attribute() : result;
 }
 
@@ -290,6 +285,10 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+  // divsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -301,15 +300,6 @@ OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
         return a.sdiv_ov(b, overflowOrDiv0);
       });
 
-  // Fold out division by one. Assumes all tensors of all ones are splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -330,6 +320,10 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+  // ceildivui (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   bool overflowOrDiv0 = false;
   auto result =
       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
@@ -343,15 +337,6 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
         APInt one(a.getBitWidth(), 1, true);
         return quotient.uadd_ov(one, overflowOrDiv0);
       });
-  // Fold out ceil division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
 
   return overflowOrDiv0 ? Attribute() : result;
 }
@@ -361,6 +346,10 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+  // ceildivsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -398,16 +387,6 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
         return zero.ssub_ov(div, overflowOrDiv0);
       });
 
-  // Fold out ceil division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -416,6 +395,10 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+  // floordivsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -453,16 +436,6 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
         return zero.ssub_ov(ceil, overflowOrDiv0);
       });
 
-  // Fold out floor division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }