if (getLhs() == getRhs())
return Builder(getContext()).getZeroAttr(getType());
/// xor(xor(x, a), a) -> x
- if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
+ /// xor(xor(a, x), a) -> x
+ if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
if (prev.getRhs() == getRhs())
return prev.getLhs();
+ if (prev.getLhs() == getRhs())
+ return prev.getRhs();
+ }
+ /// xor(a, xor(x, a)) -> x
+ /// xor(a, xor(a, x)) -> x
+ if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
+ if (prev.getRhs() == getLhs())
+ return prev.getLhs();
+ if (prev.getLhs() == getLhs())
+ return prev.getRhs();
+ }
return constFoldBinaryOp<IntegerAttr>(
operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
%2 = arith.andi %1, %arg0 : index
return %2 : index
}
+
+// -----
+/// xor(xor(x, a), a) -> x
+
+// CHECK-LABEL: @xorxor0(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor0(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %a, %b : i32
+ %res = arith.xori %c, %b : i32
+ return %res : i32
+}
+
+// -----
+/// xor(xor(a, x), a) -> x
+
+// CHECK-LABEL: @xorxor1(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor1(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %b, %a : i32
+ %res = arith.xori %c, %b : i32
+ return %res : i32
+}
+
+// -----
+/// xor(a, xor(x, a)) -> x
+
+// CHECK-LABEL: @xorxor2(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor2(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %a, %b : i32
+ %res = arith.xori %b, %c : i32
+ return %res : i32
+}
+
+// -----
+/// xor(a, xor(a, x)) -> x
+
+// CHECK-LABEL: @xorxor3(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor3(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %b, %a : i32
+ %res = arith.xori %b, %c : i32
+ return %res : i32
+}