[mlir] Fix block merging with the result of a terminator
authorMarkus Böck <markus.boeck02@gmail.com>
Mon, 21 Mar 2022 12:26:00 +0000 (13:26 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Mon, 21 Mar 2022 12:26:35 +0000 (13:26 +0100)
When the current implementation merges two blocks that have operands defined outside of their block respectively, it will merge these by adding a block argument in the resulting merged block and adding successor arguments to the predecessors.
There is a special case where this is incorrect however: If one of predecessors terminator produce the operand, inserting the block argument and updating the predecessor would lead to the terminator using its own result as successor argument.
IR Example:
```
  %0 = "test.producing_br"()[^bb1, ^bb2] {
        operand_segment_sizes = dense<0> : vector<2 x i32>
} : () -> i32

^bb1:
  "test.br"(%0)[^bb4] : (i32) -> ()
```
where `^bb1` is then merged with another block would lead to:
 ```
  %0 = "test.producing_br"(%0)[^bb1, ^bb2]
```

This patch fixes that issue during clustering by making sure that if the operand is from an outside block, that it is not produced by the terminator of a predecessor.

Differential Revision: https://reviews.llvm.org/D121988

mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/Transforms/canonicalize-block-merge.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index cf77610..9d72be5 100644 (file)
@@ -518,6 +518,23 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
       // Let the operands differ if they are defined in a different block. These
       // will become new arguments if the blocks get merged.
       if (!lhsIsInBlock) {
+
+        // Check whether the operands aren't the result of an immediate
+        // predecessors terminator. In that case we are not able to use it as a
+        // successor operand when branching to the merged block as it does not
+        // dominate its producing operation.
+        auto isValidSuccessorArg = [](Block *block, Value operand) {
+          if (operand.getDefiningOp() !=
+              operand.getParentBlock()->getTerminator())
+            return true;
+          return !llvm::is_contained(block->getPredecessors(),
+                                     operand.getParentBlock());
+        };
+
+        if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
+            !isValidSuccessorArg(mergeBlock, rhsOperand))
+          return mlir::failure();
+
         mismatchedOperands.emplace_back(opI, operand);
         continue;
       }
index 2a9cf97..5c692d9 100644 (file)
@@ -251,3 +251,27 @@ func @nomerge(%arg0: i32, %i: i32) {
 ^bb3:  // pred: ^bb1
   return
 }
+
+
+// CHECK-LABEL: func @mismatch_dominance(
+func @mismatch_dominance() -> i32 {
+  // CHECK: %[[RES:.*]] = "test.producing_br"()
+  %0 = "test.producing_br"()[^bb1, ^bb2] {
+        operand_segment_sizes = dense<0> : vector<2 x i32>
+       } : () -> i32
+
+^bb1:
+  // CHECK: "test.br"(%[[RES]])[^[[MERGE_BLOCK:.*]]]
+  "test.br"(%0)[^bb4] : (i32) -> ()
+
+^bb2:
+  %1 = "foo.def"() : () -> i32
+  "test.br"()[^bb3] : () -> ()
+
+^bb3:
+  // CHECK: "test.br"(%{{.*}})[^[[MERGE_BLOCK]]]
+  "test.br"(%1)[^bb4] : (i32) -> ()
+
+^bb4(%3: i32):
+  return %3 : i32
+}
index c530582..8c8a230 100644 (file)
@@ -342,6 +342,19 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
 }
 
 //===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<MutableOperandRange>
+TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
+  assert(index <= 1 && "invalid successor index");
+  if (index == 1) {
+    return getFirstOperandsMutable();
+  }
+  return getSecondOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
 // TestDialectCanonicalizerOp
 //===----------------------------------------------------------------------===//
 
index 3f3f812..b675a55 100644 (file)
@@ -617,6 +617,15 @@ def TestBranchOp : TEST_Op<"br",
   let successors = (successor AnySuccessor:$target);
 }
 
+def TestProducingBranchOp : TEST_Op<"producing_br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
+     AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$firstOperands,
+                       Variadic<AnyType>:$secondOperands);
+  let results = (outs I32:$dummy);
+  let successors = (successor AnySuccessor:$first,AnySuccessor:$second);
+}
+
 def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
   let arguments = (ins