[mlir][Transforms] CSE of ops with a single block.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 20 Sep 2022 00:49:01 +0000 (00:49 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Wed, 16 Nov 2022 02:55:43 +0000 (02:55 +0000)
Currently CSE does not support CSE of ops with regions. This patch
extends the CSE support to ops with a single region.

Differential Revision: https://reviews.llvm.org/D134306
Depends on D137857

mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp
mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
mlir/test/Transforms/cse.mlir
mlir/test/lib/Dialect/Test/TestOps.td

index d46f1b4..97d09eb 100644 (file)
@@ -721,16 +721,34 @@ bool OperationEquivalence::isEquivalentTo(
   ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
   SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
   if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
-    lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
-    llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
-      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
-    });
-    lhsOperands = lhsOperandStorage;
+    auto sortValues = [](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::to_vector(values);
+      llvm::sort(sortedValues, [](Value a, Value b) {
+        auto aArg = a.dyn_cast<BlockArgument>();
+        auto bArg = b.dyn_cast<BlockArgument>();
+
+        // Case 1. Both `a` and `b` are `BlockArgument`s.
+        if (aArg && bArg) {
+          if (aArg.getParentBlock() == bArg.getParentBlock())
+            return aArg.getArgNumber() < bArg.getArgNumber();
+          return aArg.getParentBlock() < bArg.getParentBlock();
+        }
 
-    rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
-    llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
-      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
-    });
+        // Case 2. One of then is a `BlockArgument` and other is not. Treat
+        // `BlockArgument` as lesser.
+        if (aArg && !bArg)
+          return true;
+        if (bArg && !aArg)
+          return false;
+
+        // Case 3. Both are values.
+        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+      });
+      return sortedValues;
+    };
+    lhsOperandStorage = sortValues(lhsOperands);
+    lhsOperands = lhsOperandStorage;
+    rhsOperandStorage = sortValues(rhsOperands);
     rhsOperands = rhsOperandStorage;
   }
   auto checkValueRangeMapping =
index 3df419c..97f6cfd 100644 (file)
@@ -47,11 +47,70 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    // If op has no regions, operation equivalence w.r.t operands alone is
+    // enough.
+    if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
+      return OperationEquivalence::isEquivalentTo(
+          const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+          OperationEquivalence::exactValueMatch,
+          OperationEquivalence::ignoreValueEquivalence,
+          OperationEquivalence::IgnoreLocations);
+    }
+
+    // If lhs or rhs does not have a single region with a single block, they
+    // aren't CSEed for now.
+    if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
+        !llvm::hasSingleElement(lhs->getRegion(0)) ||
+        !llvm::hasSingleElement(rhs->getRegion(0)))
+      return false;
+
+    // Compare the two blocks.
+    Block &lhsBlock = lhs->getRegion(0).front();
+    Block &rhsBlock = rhs->getRegion(0).front();
+
+    // Don't CSE if number of arguments differ.
+    if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
+      return false;
+
+    // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
+    // `rhsBlock`. `Value`s from `lhsBlock` are the key.
+    DenseMap<Value, Value> areEquivalentValues;
+    for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
+                                 rhs->getRegion(0).getArguments())) {
+      areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
+    }
+
+    // Helper function to get the parent operation.
+    auto getParent = [](Value v) -> Operation * {
+      if (auto blockArg = v.dyn_cast<BlockArgument>())
+        return blockArg.getParentBlock()->getParentOp();
+      return v.getDefiningOp()->getParentOp();
+    };
+
+    // Callback to compare if operands of ops in the region of `lhs` and `rhs`
+    // are equivalent.
+    auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+      if (lhsValue == rhsValue)
+        return success();
+      if (areEquivalentValues.lookup(lhsValue) == rhsValue)
+        return success();
+      return failure();
+    };
+
+    // Callback to compare if results of ops in the region of `lhs` and `rhs`
+    // are equivalent.
+    auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
+      if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
+        auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
+        return success(insertion.first->second == rhsResult);
+      }
+      return success();
+    };
+
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
-        /*mapOperands=*/OperationEquivalence::exactValueMatch,
-        /*mapResults=*/OperationEquivalence::ignoreValueEquivalence,
-        OperationEquivalence::IgnoreLocations);
+        mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
   }
 };
 } // namespace
@@ -204,7 +263,8 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
   // Don't simplify operations with nested blocks. We don't currently model
   // equality comparisons correctly among other things. It is also unclear
   // whether we would want to CSE such operations.
-  if (op->getNumRegions() != 0)
+  if (!(op->getNumRegions() == 0 ||
+        (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)))))
     return failure();
 
   // Some simple use case of operation with memory side-effect are dealt with
index cc5f861..dede407 100644 (file)
@@ -17,7 +17,6 @@
 //       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>)
 //       CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
 //       CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
-//       CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[T6]] : memref<16xf64>)
 //       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
 //       CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex>
 //       CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
index dbc2d5e..08429f7 100644 (file)
@@ -322,3 +322,127 @@ func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
   %3 = arith.muli %1, %2 : i32
   return %3 : i32
 }
+
+// Check that an operation with a single region can CSE.
+func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32):
+    test.region_yield %arg0 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32):
+    test.region_yield %arg0 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_ops
+//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
+//   CHECK-NOT:   test.cse_of_single_block_op
+//       CHECK:   return %[[OP]], %[[OP]]
+
+// Operations with different number of bbArgs dont CSE.
+func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    test.region_yield %arg0 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32):
+    test.region_yield %arg0 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_varied_bbargs
+//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
+//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
+//       CHECK:   return %[[OP0]], %[[OP1]]
+
+// Operations with different regions dont CSE
+func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    test.region_yield %arg0 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    test.region_yield %arg1 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_region_difference_simple
+//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
+//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
+//       CHECK:   return %[[OP0]], %[[OP1]]
+
+// Operation with identical region with multiple statements CSE.
+func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.divf %arg0, %arg1 : f32
+    %2 = arith.remf %arg0, %c : f32
+    %3 = arith.select %d, %1, %2 : f32
+    test.region_yield %3 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.divf %arg0, %arg1 : f32
+    %2 = arith.remf %arg0, %c : f32
+    %3 = arith.select %d, %1, %2 : f32
+    test.region_yield %3 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_ops_identical_bodies
+//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
+//   CHECK-NOT:   test.cse_of_single_block_op
+//       CHECK:   return %[[OP]], %[[OP]]
+
+// Operation with non-identical regions dont CSE.
+func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.divf %arg0, %arg1 : f32
+    %2 = arith.remf %arg0, %c : f32
+    %3 = arith.select %d, %1, %2 : f32
+    test.region_yield %3 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.divf %arg0, %arg1 : f32
+    %2 = arith.remf %arg0, %c : f32
+    %3 = arith.select %d, %2, %1 : f32
+    test.region_yield %3 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_single_block_ops_different_bodies
+//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
+//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
+//       CHECK:   return %[[OP0]], %[[OP1]]
+
+// Account for commutative ops within regions during CSE.
+func.func @cse_single_block_with_commutative_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.addf %arg0, %arg1 : f32
+    %2 = arith.mulf %1, %c : f32
+    test.region_yield %2 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = test.cse_of_single_block_op inputs(%a, %b) {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+    %1 = arith.addf %arg1, %arg0 : f32
+    %2 = arith.mulf %c, %1 : f32
+    test.region_yield %2 : f32
+  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_with_commutative_ops
+//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
+//   CHECK-NOT:   test.cse_of_single_block_op
+//       CHECK:   return %[[OP]], %[[OP]]
index 84dd37f..cd447d7 100644 (file)
@@ -670,8 +670,8 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
 
 // Produces an error value on the error path
 def TestInternalBranchOp : TEST_Op<"internal_br",
-       [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
-        AttrSizedOperandSegments]> {
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
+     AttrSizedOperandSegments]> {
 
   let arguments = (ins Variadic<AnyType>:$successOperands,
                        Variadic<AnyType>:$errorOperands);
@@ -3045,4 +3045,19 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
   let regions = (region SizedRegion<1>:$body);
 }
 
+//===---------------------------------------------------------------------===//
+// Test CSE
+//===---------------------------------------------------------------------===//
+
+def TestCSEOfSingleBlockOp : TEST_Op<"cse_of_single_block_op",
+    [SingleBlockImplicitTerminator<"RegionYieldOp">, Pure]> {
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = [{
+    attr-dict `inputs` `(` $inputs `)`
+    $region `:` type($inputs)  `->` type($outputs)
+  }];
+}
+
 #endif // TEST_OPS