[mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`
authorJakub Kuderski <kubak@google.com>
Tue, 7 Mar 2023 00:28:39 +0000 (19:28 -0500)
committerJakub Kuderski <kubak@google.com>
Tue, 7 Mar 2023 00:28:39 +0000 (19:28 -0500)
These propagate all the way down to SPIR-V and result in some fishy code
with large constants.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D145423

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index abf3db1..7c68714 100644 (file)
@@ -49,6 +49,27 @@ def AddISubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1)),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
 
+def IsScalarOrSplatNegativeOne :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($0))">,
+      CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>;
+
+// addi(x, muli(y, -1)) -> subi(x, y)
+def AddIMulNegativeOneRhs :
+    Pat<(Arith_AddIOp
+           $x,
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
+        (Arith_SubIOp $x, $y),
+        [(IsScalarOrSplatNegativeOne $c0)]>;
+
+// addi(muli(x, -1), y) -> subi(y, x)
+def AddIMulNegativeOneLhs :
+    Pat<(Arith_AddIOp
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
+           $y),
+        (Arith_SubIOp $y, $x),
+        [(IsScalarOrSplatNegativeOne $c0)]>;
+
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
 //===----------------------------------------------------------------------===//
index f6308a6..e56f452 100644 (file)
@@ -258,8 +258,8 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
 
 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
-      context);
+  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
+               AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
 }
 
 //===----------------------------------------------------------------------===//
index eaafa9e..396f5ee 100644 (file)
@@ -735,6 +735,72 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// CHECK-LABEL: @addiMuliToSubiRhsI32
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 : i32
+  %add = arith.addi %arg0, %neg : i32
+  return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsIndex
+//  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsIndex(%arg0: index, %arg1: index) -> index {
+  %c-1 = arith.constant -1 : index
+  %neg = arith.muli %arg1, %c-1 : index
+  %add = arith.addi %arg0, %neg : index
+  return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsVector
+//  CHECK-SAME:   (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+  %c-1 = arith.constant dense<-1> : vector<3xi64>
+  %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+  %add = arith.addi %arg0, %neg : vector<3xi64>
+  return %add : vector<3xi64>
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsI32
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 : i32
+  %add = arith.addi %neg, %arg0 : i32
+  return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsIndex
+//  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsIndex(%arg0: index, %arg1: index) -> index {
+  %c-1 = arith.constant -1 : index
+  %neg = arith.muli %arg1, %c-1 : index
+  %add = arith.addi %neg, %arg0 : index
+  return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsVector
+//  CHECK-SAME:   (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+  %c-1 = arith.constant dense<-1> : vector<3xi64>
+  %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+  %add = arith.addi %neg, %arg0 : vector<3xi64>
+  return %add : vector<3xi64>
+}
+
 // CHECK-LABEL: @adduiExtendedZeroRhs
 //  CHECK-NEXT:   %[[false:.+]] = arith.constant false
 //  CHECK-NEXT:   return %arg0, %[[false]]