The reconciliation pass has been improved to introduce the support for chains of casts, thus not limiting anymore the reconciliation to just consider pairs of unrealized casts.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D130711
```
%0 = "producer.op"() : () -> !type.A
%1 = unrealized_conversion_cast %0 : !type.A to !type.B
- %2 = unrealized_conversion_cast %1 : !type.B to !type.A
- "consumer.op"(%2) : (!type.A) -> ()
+ %2 = unrealized_conversion_cast %1 : !type.B to !type.C
+ %3 = unrealized_conversion_cast %2 : !type.C to !type.A
+ "consumer.op"(%3) : (!type.A) -> ()
```
Such situations appear when the consumer operation is converted by one pass
namespace {
-/// Removes `unrealized_conversion_cast`s whose results are only used by other
-/// `unrealized_conversion_cast`s converting back to the original type. This
-/// pattern is complementary to the folder and can be used to process operations
-/// starting from the first, i.e. the usual traversal order in dialect
-/// conversion. The folder, on the other hand, can only apply to the last
-/// operation in a chain of conversions because it is not expected to walk
-/// use-def chains. One would need to declare cast ops as dynamically illegal
-/// with a complex condition in order to eliminate them using the folder alone
-/// in the dialect conversion infra.
+/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
+/// the same as the input ones.
+/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
+/// represent a noop within the IR, and thus the initial input values can be
+/// propagated.
+/// The same does not hold for 'open' chains chains of casts, such as
+/// `A -> B -> C`. In this last case there is no cycle among the types and thus
+/// the conversion is incomplete. The same hold for 'closed' chains like
+/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
+/// operations.
+/// Bifurcations (that is when a chain starts in between of another one) are
+/// also taken into considerations, and all the above considerations remain
+/// valid.
+/// Special corner cases such as dead casts or single casts with same input and
+/// output types are also covered.
struct UnrealizedConversionCastPassthrough
: public OpRewritePattern<UnrealizedConversionCastOp> {
using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const override {
- // Match the casts that are _only_ used by other casts, with the overall
- // cast being a trivial noop: A->B->A.
- auto users = op->getUsers();
- if (!llvm::all_of(users, [&](Operation *user) {
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
- return other.getResultTypes() == op.getInputs().getTypes() &&
- other.getInputs() == op.getOutputs();
- return false;
- })) {
- return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
+ // The nodes that either are not used by any operation or have at least
+ // one user that is not an unrealized cast.
+ DenseSet<UnrealizedConversionCastOp> exitNodes;
+
+ // The nodes whose users are all unrealized casts
+ DenseSet<UnrealizedConversionCastOp> intermediateNodes;
+
+ // Stack used for the depth-first traversal of the use-def DAG.
+ SmallVector<UnrealizedConversionCastOp, 2> visitStack;
+ visitStack.push_back(op);
+
+ while (!visitStack.empty()) {
+ UnrealizedConversionCastOp current = visitStack.pop_back_val();
+ auto users = current->getUsers();
+ bool isLive = false;
+
+ for (Operation *user : users) {
+ if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
+ if (other.getInputs() != current.getOutputs())
+ return rewriter.notifyMatchFailure(
+ op, "mismatching values propagation");
+ } else {
+ isLive = true;
+ }
+
+ // Continue traversing the DAG of unrealized casts
+ if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
+ visitStack.push_back(other);
+ }
+
+ // If the cast is live, then we need to check if the results of the last
+ // cast have the same type of the root inputs. It this is the case (e.g.
+ // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
+ // no-op and the inputs can be forwarded. If it's not (e.g.
+ // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
+
+ bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
+
+ if (isLive && !isCycle)
+ return rewriter.notifyMatchFailure(op,
+ "live unrealized conversion cast");
+
+ bool isExitNode = users.empty() || isLive;
+
+ if (isExitNode) {
+ exitNodes.insert(current);
+ } else {
+ intermediateNodes.insert(current);
+ }
}
- for (Operation *user : users)
- rewriter.replaceOp(user, op.getInputs());
+ // Replace the sink nodes with the root input values
+ for (UnrealizedConversionCastOp exitNode : exitNodes)
+ rewriter.replaceOp(exitNode, op.getInputs());
+
+ // Erase all the other casts belonging to the DAG
+ for (UnrealizedConversionCastOp castOp : intermediateNodes)
+ rewriter.eraseOp(castOp);
- rewriter.eraseOp(op);
return success();
}
};
--- /dev/null
+// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s
+
+// CHECK-LABEL: @liveSingleCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: return %[[liveCast]] : i32
+
+func.func @liveSingleCast(%arg0: i64) -> i32 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
+// CHECK: return %[[cast1]] : i32
+
+func.func @liveChain(%arg0: i64) -> i32 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
+ %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
+// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
+// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @liveBifurcation(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+ %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %3 = arith.extsi %2 : i1 to i64
+ %4 = arith.addi %1, %3 : i64
+ return %4 : i64
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s
+
+// CHECK-LABEL: @unusedCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedCast(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @sameTypes
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @sameTypes(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64
+ return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @pair
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @pair(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @symmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @symmetricChain(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %2 = builtin.unrealized_conversion_cast %1 : i1 to i32
+ %3 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ return %3 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @asymmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @asymmetricChain(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+ return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedChain(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @bifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @bifurcation(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+ %3 = builtin.unrealized_conversion_cast %1 : i1 to i32
+ %4 = builtin.unrealized_conversion_cast %3 : i32 to i64
+ %5 = arith.addi %2, %4 : i64
+ return %5 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @unusedBifurcation(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+ %3 = builtin.unrealized_conversion_cast %0 : i32 to i64
+ %4 = arith.addi %arg0, %3 : i64
+ return %4 : i64
+}