From 65eedcebdc03052959508911417bac548009652a Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Sat, 22 Apr 2023 08:57:10 +0000 Subject: [PATCH] [mlir] detensorize: don't accidentally convert function entry blocks In the Linalg detensorize pass, dialect conversion could accidentally trigger signature conversion of the function entry block after inlining the body of a Linalg generic into it. Such a conversion is not desirable because it would break the internal validity of the function op, that is futhermore not supposed to be detensorized at the boundary. Mitigate this by creating a dummy (empty) entry block so Linalg operations are never inlined into it and the conversion is never triggered. Closes #62249. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D148983 --- mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp | 21 +++++++++++++++++++-- .../Dialect/Linalg/detensorize_entry_block.mlir | 21 +++++++++++++++++++++ .../Dialect/Linalg/detensorize_while_pure_cf.mlir | 4 ++-- 3 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block.mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 5289ed6..9012a63 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -60,7 +60,7 @@ bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { }); } -/// A conversion patttern for detensoring `linalg.generic` ops. +/// A conversion pattern for detensoring `linalg.generic` ops. class DetensorizeGenericOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -69,7 +69,7 @@ public: ConversionPatternRewriter &rewriter) const override { Block *originalBlock = op->getBlock(); - // Gather some information about the op before inling its region. + // Gather some information about the op before inlining its region. Block *opEntryBlock = &*op.getRegion().begin(); YieldOp yieldOp = dyn_cast(op.getRegion().back().getTerminator()); @@ -476,6 +476,18 @@ struct LinalgDetensorize DenseSet blockArgsToDetensor; FunctionOpInterface funcOp = getOperation(); + // Make sure the entry block of the function doesn't contain any Linalg ops. + // Otherwise, it may lead to the signature of the block being changed by the + // dialect conversion below, which would make the function op invalid + // because its type shouldn't change. + IRRewriter rewriter(funcOp->getContext()); + Block *entryBlock = &funcOp.getFunctionBody().front(); + Block *postEntryBlock = + rewriter.splitBlock(entryBlock, entryBlock->begin()); + rewriter.setInsertionPointToStart(entryBlock); + auto branch = + rewriter.create(rewriter.getUnknownLoc(), postEntryBlock); + if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; costModel.compute(funcOp, typeConverter, opsToDetensor, @@ -553,6 +565,11 @@ struct LinalgDetensorize if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); + + // Get rid of the dummy entry block we created in the beginning to work + // around dialect conversion signature rewriting. + rewriter.eraseOp(branch); + rewriter.mergeBlocks(postEntryBlock, entryBlock); } }; } // namespace diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir new file mode 100644 index 0000000..d1a8922 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s + +#map = affine_map<() -> ()> +func.func @main(%arg0: tensor) -> tensor { + %0 = tensor.empty() : tensor + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor + cf.br ^bb1(%1 : tensor) +^bb1(%2: tensor): // pred: ^bb0 + return %2 : tensor +} + +// CHECK-LABEL: @main +// CHECK-SAME: (%[[ARG0:.+]]: tensor) -> tensor +// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor +// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32) +// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32): +// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor +// CHECK: return %[[ELEMENTS]] : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir index 455fcfe..6d8d5fe 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -44,8 +44,8 @@ func.func @main() -> () attributes {} { } // CHECK-LABEL: func @main -// CHECK-NEXT: arith.constant 0 : i32 -// CHECK-NEXT: arith.constant 10 +// CHECK-DAG: arith.constant 0 : i32 +// CHECK-DAG: arith.constant 10 // CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb1]](%{{.*}}: i32) // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} -- 2.7.4