[MLIR] Reconciliation of chains of unrealized casts
authorMichele Scuttari <michele.scuttari@outlook.com>
Wed, 3 Aug 2022 11:52:21 +0000 (11:52 +0000)
committerAlex Zinenko <zinenko@google.com>
Wed, 3 Aug 2022 11:57:20 +0000 (11:57 +0000)
The reconciliation pass has been improved to introduce the support for chains of casts, thus not limiting anymore the reconciliation to just consider pairs of unrealized casts.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D130711

mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir [new file with mode: 0644]
mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir [new file with mode: 0644]

index 26ec204..00ca7af 100644 (file)
@@ -619,8 +619,9 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
     ```
     %0 = "producer.op"() : () -> !type.A
     %1 = unrealized_conversion_cast %0 : !type.A to !type.B
-    %2 = unrealized_conversion_cast %1 : !type.B to !type.A
-    "consumer.op"(%2) : (!type.A) -> ()
+    %2 = unrealized_conversion_cast %1 : !type.B to !type.C
+    %3 = unrealized_conversion_cast %2 : !type.C to !type.A
+    "consumer.op"(%3) : (!type.A) -> ()
     ```
 
     Such situations appear when the consumer operation is converted by one pass
index 5196817..773e9b2 100644 (file)
@@ -17,37 +17,86 @@ using namespace mlir;
 
 namespace {
 
-/// Removes `unrealized_conversion_cast`s whose results are only used by other
-/// `unrealized_conversion_cast`s converting back to the original type. This
-/// pattern is complementary to the folder and can be used to process operations
-/// starting from the first, i.e. the usual traversal order in dialect
-/// conversion. The folder, on the other hand, can only apply to the last
-/// operation in a chain of conversions because it is not expected to walk
-/// use-def chains. One would need to declare cast ops as dynamically illegal
-/// with a complex condition in order to eliminate them using the folder alone
-/// in the dialect conversion infra.
+/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
+/// the same as the input ones.
+/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
+/// represent a noop within the IR, and thus the initial input values can be
+/// propagated.
+/// The same does not hold for 'open' chains chains of casts, such as
+/// `A -> B -> C`. In this last case there is no cycle among the types and thus
+/// the conversion is incomplete. The same hold for 'closed' chains like
+/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
+/// operations.
+/// Bifurcations (that is when a chain starts in between of another one) are
+/// also taken into considerations, and all the above considerations remain
+/// valid.
+/// Special corner cases such as dead casts or single casts with same input and
+/// output types are also covered.
 struct UnrealizedConversionCastPassthrough
     : public OpRewritePattern<UnrealizedConversionCastOp> {
   using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
                                 PatternRewriter &rewriter) const override {
-    // Match the casts that are _only_ used by other casts, with the overall
-    // cast being a trivial noop: A->B->A.
-    auto users = op->getUsers();
-    if (!llvm::all_of(users, [&](Operation *user) {
-          if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-            return other.getResultTypes() == op.getInputs().getTypes() &&
-                   other.getInputs() == op.getOutputs();
-          return false;
-        })) {
-      return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
+    // The nodes that either are not used by any operation or have at least
+    // one user that is not an unrealized cast.
+    DenseSet<UnrealizedConversionCastOp> exitNodes;
+
+    // The nodes whose users are all unrealized casts
+    DenseSet<UnrealizedConversionCastOp> intermediateNodes;
+
+    // Stack used for the depth-first traversal of the use-def DAG.
+    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
+    visitStack.push_back(op);
+
+    while (!visitStack.empty()) {
+      UnrealizedConversionCastOp current = visitStack.pop_back_val();
+      auto users = current->getUsers();
+      bool isLive = false;
+
+      for (Operation *user : users) {
+        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
+          if (other.getInputs() != current.getOutputs())
+            return rewriter.notifyMatchFailure(
+                op, "mismatching values propagation");
+        } else {
+          isLive = true;
+        }
+
+        // Continue traversing the DAG of unrealized casts
+        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
+          visitStack.push_back(other);
+      }
+
+      // If the cast is live, then we need to check if the results of the last
+      // cast have the same type of the root inputs. It this is the case (e.g.
+      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
+      // no-op and the inputs can be forwarded. If it's not (e.g.
+      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
+
+      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
+
+      if (isLive && !isCycle)
+        return rewriter.notifyMatchFailure(op,
+                                           "live unrealized conversion cast");
+
+      bool isExitNode = users.empty() || isLive;
+
+      if (isExitNode) {
+        exitNodes.insert(current);
+      } else {
+        intermediateNodes.insert(current);
+      }
     }
 
-    for (Operation *user : users)
-      rewriter.replaceOp(user, op.getInputs());
+    // Replace the sink nodes with the root input values
+    for (UnrealizedConversionCastOp exitNode : exitNodes)
+      rewriter.replaceOp(exitNode, op.getInputs());
+
+    // Erase all the other casts belonging to the DAG
+    for (UnrealizedConversionCastOp castOp : intermediateNodes)
+      rewriter.eraseOp(castOp);
 
-    rewriter.eraseOp(op);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
new file mode 100644 (file)
index 0000000..f5ceb29
--- /dev/null
@@ -0,0 +1,45 @@
+// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s
+
+// CHECK-LABEL: @liveSingleCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: return %[[liveCast]] : i32
+
+func.func @liveSingleCast(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
+// CHECK: return %[[cast1]] : i32
+
+func.func @liveChain(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
+    %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
+    return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
+// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
+// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @liveBifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %3 = arith.extsi %2 : i1 to i64
+    %4 = arith.addi %1, %3 : i64
+    return %4 : i64
+}
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
new file mode 100644 (file)
index 0000000..d71cbba
--- /dev/null
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s
+
+// CHECK-LABEL: @unusedCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedCast(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @sameTypes
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @sameTypes(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64
+    return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @pair
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @pair(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @symmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @symmetricChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i32
+    %3 = builtin.unrealized_conversion_cast %2 : i32 to i64
+    return %3 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @asymmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @asymmetricChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @bifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @bifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    %3 = builtin.unrealized_conversion_cast %1 : i1 to i32
+    %4 = builtin.unrealized_conversion_cast %3 : i32 to i64
+    %5 = arith.addi %2, %4 : i64
+    return %5 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @unusedBifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    %3 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    %4 = arith.addi %arg0, %3 : i64
+    return %4 : i64
+}