op->getName(), op->getAttrDictionary(), op->getResultTypes());
// - Operands
- for (Value operand : op->getOperands())
+ ValueRange operands = op->getOperands();
+ SmallVector<Value> operandStorage;
+ if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ operandStorage.append(operands.begin(), operands.end());
+ llvm::sort(operandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ operands = operandStorage;
+ }
+ for (Value operand : operands)
hash = llvm::hash_combine(hash, hashOperands(operand));
+
// - Operands
for (Value result : op->getResults())
hash = llvm::hash_combine(hash, hashResults(result));
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
+ ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
+ SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
+ if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
+ llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ lhsOperands = lhsOperandStorage;
+
+ rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
+ llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ rhsOperands = rhsOperandStorage;
+ }
auto checkValueRangeMapping =
[](ValueRange lhs, ValueRange rhs,
function_ref<LogicalResult(Value, Value)> mapValues) {
return true;
};
// Check mapping of operands and results.
- if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
- mapOperands))
+ if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
return false;
if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
return false;
%2 = arith.addi %0, %1 : i32
return %2 : i32
}
+
+/// This test is checking that identical commutative operation are gracefully
+/// handled but the CSE pass.
+// CHECK-LABEL: func @check_cummutative_cse
+func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
+ // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+ %1 = arith.addi %a, %b : i32
+ %2 = arith.addi %b, %a : i32
+ // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32
+ %3 = arith.muli %1, %2 : i32
+ return %3 : i32
+}