[mlir][Arithmetic] Use common constant fold function in RemSI and RemUI to cover...
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 22 Apr 2022 08:13:14 +0000 (08:13 +0000)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 22 Apr 2022 09:20:18 +0000 (09:20 +0000)
This patch replaces current fold function with the common constant fold funtion in order to cover the situation of constant splat.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D124236

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir

index 5a10440..8f26e66 100644 (file)
@@ -444,23 +444,22 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
-  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
-  if (!rhs)
-    return {};
-  auto rhsValue = rhs.getValue();
-
-  // x % 1 = 0
-  if (rhsValue.isOneValue())
-    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+  // remui (x, 1) -> 0.
+  if (matchPattern(getRhs(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
 
-  // Don't fold if it requires division by zero.
-  if (rhsValue.isNullValue())
-    return {};
+  // Don't fold if it would require a division by zero.
+  bool div0 = false;
+  auto result =
+      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+        if (div0 || b.isNullValue()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
 
-  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs)
-    return {};
-  return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
+  return div0 ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//
@@ -468,23 +467,22 @@ OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
-  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
-  if (!rhs)
-    return {};
-  auto rhsValue = rhs.getValue();
-
-  // x % 1 = 0
-  if (rhsValue.isOneValue())
-    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+  // remsi (x, 1) -> 0.
+  if (matchPattern(getRhs(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
 
-  // Don't fold if it requires division by zero.
-  if (rhsValue.isNullValue())
-    return {};
+  // Don't fold if it would require a division by zero.
+  bool div0 = false;
+  auto result =
+      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+        if (div0 || b.isNullValue()) {
+          div0 = true;
+          return a;
+        }
+        return a.srem(b);
+      });
 
-  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs)
-    return {};
-  return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
+  return div0 ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//
index 5ea88b3..1f6b473 100644 (file)
@@ -1319,3 +1319,49 @@ func.func @test_negf() -> (f32) {
   %0 = arith.negf %c : f32
   return %0: f32
 }
+
+// -----
+
+// CHECK-LABEL: @test_remui(
+// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remui() -> (vector<4xi32>) {
+  %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32>
+  %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %0 = arith.remui %v1, %v2 : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// // -----
+
+// CHECK-LABEL: @test_remui_1(
+// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remui_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
+  %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32>
+  %0 = arith.remui %arg, %v : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_remsi(
+// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remsi() -> (vector<4xi32>) {
+  %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32>
+  %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %0 = arith.remsi %v1, %v2 : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// // -----
+
+// CHECK-LABEL: @test_remsi_1(
+// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
+  %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32>
+  %0 = arith.remsi %arg, %v : vector<4xi32>
+  return %0 : vector<4xi32>
+}