[mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Tue, 20 Sep 2022 07:22:15 +0000 (15:22 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Mon, 26 Sep 2022 07:01:24 +0000 (15:01 +0800)
This patch adds commutative canonicalization support for D116383.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134258

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir

index bd69347..1891ce8 100644 (file)
@@ -627,9 +627,21 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
   if (getLhs() == getRhs())
     return Builder(getContext()).getZeroAttr(getType());
   /// xor(xor(x, a), a) -> x
-  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
+  /// xor(xor(a, x), a) -> x
+  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
     if (prev.getRhs() == getRhs())
       return prev.getLhs();
+    if (prev.getLhs() == getRhs())
+      return prev.getRhs();
+  }
+  /// xor(a, xor(x, a)) -> x
+  /// xor(a, xor(a, x)) -> x
+  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
+    if (prev.getRhs() == getLhs())
+      return prev.getLhs();
+    if (prev.getLhs() == getLhs())
+      return prev.getRhs();
+  }
 
   return constFoldBinaryOp<IntegerAttr>(
       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
index 649da01..632e7af 100644 (file)
@@ -1585,3 +1585,51 @@ func.func @test_andi_not_fold_lhs(%arg0 : index) -> index {
     %2 = arith.andi %1, %arg0 : index
     return %2 : index
 }
+
+// -----
+/// xor(xor(x, a), a) -> x
+
+// CHECK-LABEL: @xorxor0(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor0(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %a, %b : i32
+  %res = arith.xori %c, %b : i32
+  return %res : i32
+}
+
+// -----
+/// xor(xor(a, x), a) -> x
+
+// CHECK-LABEL: @xorxor1(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor1(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %b, %a : i32
+  %res = arith.xori %c, %b : i32
+  return %res : i32
+}
+
+// -----
+/// xor(a, xor(x, a)) -> x
+
+// CHECK-LABEL: @xorxor2(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor2(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %a, %b : i32
+  %res = arith.xori %b, %c : i32
+  return %res : i32
+}
+
+// -----
+/// xor(a, xor(a, x)) -> x
+
+// CHECK-LABEL: @xorxor3(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor3(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %b, %a : i32
+  %res = arith.xori %b, %c : i32
+  return %res : i32
+}