From 20c926e0797e074bfb946d2c8ce002888ebc2bcd Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Thu, 26 Nov 2020 17:08:49 +0100 Subject: [PATCH] [mlir][DialectConversion] Do not prematurely drop unused cast operations The rewrite logic has an optimization to drop a cast operation after rewriting block arguments if the cast operation has no users. This is unsafe as there might be a pending rewrite that replaced the cast operation itself and hence would trigger a second free. Instead, do not remove the casts and leave it up to a later canonicalization to do so. Differential Revision: https://reviews.llvm.org/D92184 --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 ----- mlir/test/Transforms/test-legalizer.mlir | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c423103..0a1a6b7 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -364,11 +364,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue)); - - // If all users of the cast were removed, we can drop it. Otherwise, keep - // the operation alive and let the user handle any remaining usages. - if (castValue.use_empty() && castValue.getDefiningOp()) - castValue.getDefiningOp()->erase(); } } } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 878d903..376f0c0 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -36,8 +36,9 @@ func @remap_call_1_to_1(%arg0: i64) { // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16) func @remap_input_1_to_N(%arg0: f32) -> f32 { - // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () - "test.return"(%arg0) : (f32) -> () + // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () + "test.return"(%arg0) : (f32) -> () } // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16) -- 2.7.4