From: jacquesguan Date: Tue, 20 Sep 2022 07:22:15 +0000 (+0800) Subject: [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp. X-Git-Tag: upstream/17.0.6~32531 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9c8abbfa0aabb390e7ae79c5309e499fa43899e4;p=platform%2Fupstream%2Fllvm.git [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp. This patch adds commutative canonicalization support for D116383. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D134258 --- diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index bd69347..1891ce8 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -627,9 +627,21 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef operands) { if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); /// xor(xor(x, a), a) -> x - if (arith::XOrIOp prev = getLhs().getDefiningOp()) + /// xor(xor(a, x), a) -> x + if (arith::XOrIOp prev = getLhs().getDefiningOp()) { if (prev.getRhs() == getRhs()) return prev.getLhs(); + if (prev.getLhs() == getRhs()) + return prev.getRhs(); + } + /// xor(a, xor(x, a)) -> x + /// xor(a, xor(a, x)) -> x + if (arith::XOrIOp prev = getRhs().getDefiningOp()) { + if (prev.getRhs() == getLhs()) + return prev.getLhs(); + if (prev.getLhs() == getLhs()) + return prev.getRhs(); + } return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index 649da01..632e7af 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1585,3 +1585,51 @@ func.func @test_andi_not_fold_lhs(%arg0 : index) -> index { %2 = arith.andi %1, %arg0 : index return %2 : index } + +// ----- +/// xor(xor(x, a), a) -> x + +// CHECK-LABEL: @xorxor0( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor0(%a : i32, %b : i32) -> i32 { + %c = arith.xori %a, %b : i32 + %res = arith.xori %c, %b : i32 + return %res : i32 +} + +// ----- +/// xor(xor(a, x), a) -> x + +// CHECK-LABEL: @xorxor1( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor1(%a : i32, %b : i32) -> i32 { + %c = arith.xori %b, %a : i32 + %res = arith.xori %c, %b : i32 + return %res : i32 +} + +// ----- +/// xor(a, xor(x, a)) -> x + +// CHECK-LABEL: @xorxor2( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor2(%a : i32, %b : i32) -> i32 { + %c = arith.xori %a, %b : i32 + %res = arith.xori %b, %c : i32 + return %res : i32 +} + +// ----- +/// xor(a, xor(a, x)) -> x + +// CHECK-LABEL: @xorxor3( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor3(%a : i32, %b : i32) -> i32 { + %c = arith.xori %b, %a : i32 + %res = arith.xori %b, %c : i32 + return %res : i32 +}