[mlir][CSE] Add ability to remove commutative operations
authorValentin Clement <clementval@gmail.com>
Sat, 16 Apr 2022 19:08:16 +0000 (21:08 +0200)
committerValentin Clement <clementval@gmail.com>
Sat, 16 Apr 2022 19:09:47 +0000 (21:09 +0200)
This patch takes advantage of the Commutative trait on operation
to remove identical commutative operations where the operands are swapped.

The second operation below can be removed since `arith.addi` is commutative.
```
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123492

mlir/lib/IR/OperationSupport.cpp
mlir/test/Transforms/cse.mlir

index 012d980..aaae4c8 100644 (file)
@@ -633,8 +633,18 @@ llvm::hash_code OperationEquivalence::computeHash(
       op->getName(), op->getAttrDictionary(), op->getResultTypes());
 
   //   - Operands
-  for (Value operand : op->getOperands())
+  ValueRange operands = op->getOperands();
+  SmallVector<Value> operandStorage;
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    operandStorage.append(operands.begin(), operands.end());
+    llvm::sort(operandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    operands = operandStorage;
+  }
+  for (Value operand : operands)
     hash = llvm::hash_combine(hash, hashOperands(operand));
+
   //   - Operands
   for (Value result : op->getResults())
     hash = llvm::hash_combine(hash, hashResults(result));
@@ -710,6 +720,21 @@ bool OperationEquivalence::isEquivalentTo(
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
+  ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
+  SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
+  if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
+    llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    lhsOperands = lhsOperandStorage;
+
+    rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
+    llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    rhsOperands = rhsOperandStorage;
+  }
   auto checkValueRangeMapping =
       [](ValueRange lhs, ValueRange rhs,
          function_ref<LogicalResult(Value, Value)> mapValues) {
@@ -724,8 +749,7 @@ bool OperationEquivalence::isEquivalentTo(
         return true;
       };
   // Check mapping of operands and results.
-  if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
-                              mapOperands))
+  if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
     return false;
   if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
     return false;
index 189cdde..216218e 100644 (file)
@@ -310,3 +310,15 @@ func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
   %2 = arith.addi %0, %1 : i32
   return %2 : i32
 }
+
+/// This test is checking that identical commutative operation are gracefully
+/// handled but the CSE pass.
+// CHECK-LABEL: func @check_cummutative_cse
+func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
+  // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+  %1 = arith.addi %a, %b : i32
+  %2 = arith.addi %b, %a : i32
+  // CHECK-NEXT:  arith.muli %[[ADD1]], %[[ADD1]] : i32
+  %3 = arith.muli %1, %2 : i32
+  return %3 : i32
+}