From e90deaf1217d9ea0316a3ec03e199c658f5757d5 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Wed, 3 Aug 2022 11:52:21 +0000 Subject: [PATCH] [MLIR] Reconciliation of chains of unrealized casts The reconciliation pass has been improved to introduce the support for chains of casts, thus not limiting anymore the reconciliation to just consider pairs of unrealized casts. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D130711 --- mlir/include/mlir/Conversion/Passes.td | 5 +- .../ReconcileUnrealizedCasts.cpp | 93 +++++++++++++----- .../reconcile-unrealized-casts-failure.mlir | 45 +++++++++ .../reconcile-unrealized-casts.mlir | 105 +++++++++++++++++++++ 4 files changed, 224 insertions(+), 24 deletions(-) create mode 100644 mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir create mode 100644 mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 26ec204..00ca7af 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -619,8 +619,9 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> { ``` %0 = "producer.op"() : () -> !type.A %1 = unrealized_conversion_cast %0 : !type.A to !type.B - %2 = unrealized_conversion_cast %1 : !type.B to !type.A - "consumer.op"(%2) : (!type.A) -> () + %2 = unrealized_conversion_cast %1 : !type.B to !type.C + %3 = unrealized_conversion_cast %2 : !type.C to !type.A + "consumer.op"(%3) : (!type.A) -> () ``` Such situations appear when the consumer operation is converted by one pass diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp index 5196817..773e9b2 100644 --- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp +++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp @@ -17,37 +17,86 @@ using namespace mlir; namespace { -/// Removes `unrealized_conversion_cast`s whose results are only used by other -/// `unrealized_conversion_cast`s converting back to the original type. This -/// pattern is complementary to the folder and can be used to process operations -/// starting from the first, i.e. the usual traversal order in dialect -/// conversion. The folder, on the other hand, can only apply to the last -/// operation in a chain of conversions because it is not expected to walk -/// use-def chains. One would need to declare cast ops as dynamically illegal -/// with a complex condition in order to eliminate them using the folder alone -/// in the dialect conversion infra. +/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types +/// the same as the input ones. +/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A` +/// represent a noop within the IR, and thus the initial input values can be +/// propagated. +/// The same does not hold for 'open' chains chains of casts, such as +/// `A -> B -> C`. In this last case there is no cycle among the types and thus +/// the conversion is incomplete. The same hold for 'closed' chains like +/// `A -> B -> A`, but with the result of type `B` being used by some non-cast +/// operations. +/// Bifurcations (that is when a chain starts in between of another one) are +/// also taken into considerations, and all the above considerations remain +/// valid. +/// Special corner cases such as dead casts or single casts with same input and +/// output types are also covered. struct UnrealizedConversionCastPassthrough : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, PatternRewriter &rewriter) const override { - // Match the casts that are _only_ used by other casts, with the overall - // cast being a trivial noop: A->B->A. - auto users = op->getUsers(); - if (!llvm::all_of(users, [&](Operation *user) { - if (auto other = dyn_cast(user)) - return other.getResultTypes() == op.getInputs().getTypes() && - other.getInputs() == op.getOutputs(); - return false; - })) { - return rewriter.notifyMatchFailure(op, "live unrealized conversion cast"); + // The nodes that either are not used by any operation or have at least + // one user that is not an unrealized cast. + DenseSet exitNodes; + + // The nodes whose users are all unrealized casts + DenseSet intermediateNodes; + + // Stack used for the depth-first traversal of the use-def DAG. + SmallVector visitStack; + visitStack.push_back(op); + + while (!visitStack.empty()) { + UnrealizedConversionCastOp current = visitStack.pop_back_val(); + auto users = current->getUsers(); + bool isLive = false; + + for (Operation *user : users) { + if (auto other = dyn_cast(user)) { + if (other.getInputs() != current.getOutputs()) + return rewriter.notifyMatchFailure( + op, "mismatching values propagation"); + } else { + isLive = true; + } + + // Continue traversing the DAG of unrealized casts + if (auto other = dyn_cast(user)) + visitStack.push_back(other); + } + + // If the cast is live, then we need to check if the results of the last + // cast have the same type of the root inputs. It this is the case (e.g. + // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a + // no-op and the inputs can be forwarded. If it's not (e.g. + // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete. + + bool isCycle = current.getResultTypes() == op.getInputs().getTypes(); + + if (isLive && !isCycle) + return rewriter.notifyMatchFailure(op, + "live unrealized conversion cast"); + + bool isExitNode = users.empty() || isLive; + + if (isExitNode) { + exitNodes.insert(current); + } else { + intermediateNodes.insert(current); + } } - for (Operation *user : users) - rewriter.replaceOp(user, op.getInputs()); + // Replace the sink nodes with the root input values + for (UnrealizedConversionCastOp exitNode : exitNodes) + rewriter.replaceOp(exitNode, op.getInputs()); + + // Erase all the other casts belonging to the DAG + for (UnrealizedConversionCastOp castOp : intermediateNodes) + rewriter.eraseOp(castOp); - rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir new file mode 100644 index 0000000..f5ceb29 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir @@ -0,0 +1,45 @@ +// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s + +// CHECK-LABEL: @liveSingleCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: return %[[liveCast]] : i32 + +func.func @liveSingleCast(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @liveChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32 +// CHECK: return %[[cast1]] : i32 + +func.func @liveChain(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1 + %1 = builtin.unrealized_conversion_cast %0 : i1 to i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @liveBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64 +// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1 +// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64 +// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @liveBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %2 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %3 = arith.extsi %2 : i1 to i64 + %4 = arith.addi %1, %3 : i64 + return %4 : i64 +} diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir new file mode 100644 index 0000000..d71cbba1 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s + +// CHECK-LABEL: @unusedCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedCast(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @sameTypes +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @sameTypes(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64 + return %0 : i64 +} + +// ----- + +// CHECK-LABEL: @pair +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @pair(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @symmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @symmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %3 = builtin.unrealized_conversion_cast %2 : i32 to i64 + return %3 : i64 +} + +// ----- + +// CHECK-LABEL: @asymmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @asymmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + return %2 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @bifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @bifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %4 = builtin.unrealized_conversion_cast %3 : i32 to i64 + %5 = arith.addi %2, %4 : i64 + return %5 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @unusedBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %4 = arith.addi %arg0, %3 : i64 + return %4 : i64 +} -- 2.7.4