[mlir][transforms] CSE ops with multiple regions
authorMatthias Springer <springerm@google.com>
Fri, 27 Jan 2023 11:14:10 +0000 (12:14 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 27 Jan 2023 11:14:45 +0000 (12:14 +0100)
There were issues with the CSE equivalence analysis that have been fixed with D142558. This makes it possible to CSE ops with multiple regions.

Differential Revision: https://reviews.llvm.org/D142562

mlir/lib/Transforms/CSE.cpp
mlir/test/Transforms/cse.mlir

index 86debe7..93e5c95 100644 (file)
@@ -199,11 +199,11 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
     return success();
   }
 
-  // Don't simplify operations with nested blocks. We don't currently model
-  // equality comparisons correctly among other things. It is also unclear
-  // whether we would want to CSE such operations.
-  if (op->getNumRegions() != 0 &&
-      (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))))
+  // Don't simplify operations with regions that have multiple blocks.
+  // TODO: We need additional tests to verify that we handle such IR correctly.
+  if (!llvm::all_of(op->getRegions(), [](Region &r) {
+        return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
+      }))
     return failure();
 
   // Some simple use case of operation with memory side-effect are dealt with
index 7fdbf95..7086f5f 100644 (file)
@@ -468,3 +468,28 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
 //       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
 //       CHECK:     test.region_yield %[[TRUE]]
 //       CHECK:   return %[[OP]], %[[OP]]
+
+func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+  %r1 = scf.if %c -> (tensor<5xf32>) {
+    %0 = tensor.empty() : tensor<5xf32>
+    scf.yield %0 : tensor<5xf32>
+  } else {
+    scf.yield %t : tensor<5xf32>
+  }
+  %r2 = scf.if %c -> (tensor<5xf32>) {
+    %0 = tensor.empty() : tensor<5xf32>
+    scf.yield %0 : tensor<5xf32>
+  } else {
+    scf.yield %t : tensor<5xf32>
+  }
+  return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
+}
+// CHECK-LABEL: func @cse_multiple_regions
+//       CHECK:   %[[if:.*]] = scf.if {{.*}} {
+//       CHECK:     tensor.empty
+//       CHECK:     scf.yield
+//       CHECK:   } else {
+//       CHECK:     scf.yield
+//       CHECK:   }
+//   CHECK-NOT:   scf.if
+//       CHECK:   return %[[if]], %[[if]]