[mlir][Arith] Fold integer shift op with zero.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Thu, 29 Dec 2022 02:51:05 +0000 (10:51 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 30 Dec 2022 09:19:23 +0000 (17:19 +0800)
This revision folds arith.shrui, arith.shrsi and arith.shli with zero
rhs to lhs.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140749

mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index 844905c..f812b3c 100644 (file)
@@ -2221,6 +2221,9 @@ LogicalResult arith::SelectOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
+  // shli(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
@@ -2236,6 +2239,9 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
+  // shrui(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
@@ -2251,6 +2257,9 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
+  // shrsi(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
index d181ae9..02cbaa2 100644 (file)
@@ -2107,3 +2107,30 @@ func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
   %hi = arith.trunci %sh: i64 to i32
   return %hi : i32
 }
+
+// CHECK-LABEL: @foldShli0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShli0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shli %x, %c0 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrui0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShrui0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shrui %x, %c0 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrsi0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShrsi0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shrsi %x, %c0 : i64
+  return %r : i64
+}