From 08dbed8a5725087a7be71c293e4d86ca0c4c0510 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 19 Aug 2021 11:23:36 +0900 Subject: [PATCH] [mlir][linalg] Canonicalize dim ops of tiled_loop block args E.g.: ``` %y = ... : tensor<...> linalg.tiled_loop ... ins(%x = %y : tensor<...>) { tensor.dim %x, %c0 : tensor<...> } ``` is rewritten to: ``` %y = ... : tensor<...> linalg.tiled_loop ... ins(%x = %y : tensor<...>) { tensor.dim %y, %c0 : tensor<...> } ``` Differential Revision: https://reviews.llvm.org/D108272 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 55 +++++++++++++++++++++- mlir/test/Dialect/Linalg/canonicalize.mlir | 27 +++++++++++ .../test/Dialect/Linalg/fusion-tensor-pattern.mlir | 8 ++-- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a58d91e..6b65a9e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2064,6 +2064,57 @@ struct TiledLoopInputsFolder : public OpRewritePattern { } }; +/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block +/// to dim(y) where `y` is the initial input/output value of the argument. +/// +/// E.g.: +/// %y = ... : tensor<...> +/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { +/// tensor.dim %x, %c0 : tensor<...> +/// } +/// +/// is folded to: +/// %y = ... : tensor<...> +/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { +/// tensor.dim %y, %c0 : tensor<...> +/// } +template +struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const final { + auto src = dimOp.source().template dyn_cast(); + if (!src) + return failure(); + auto loopOp = dyn_cast( + src.getOwner()->getParent()->getParentOp()); + if (!loopOp) + return failure(); + + auto inputArgs = loopOp.getRegionInputArgs(); + auto it1 = llvm::find(inputArgs, src); + if (it1 != inputArgs.end()) { + rewriter.updateRootInPlace(dimOp, [&] { + dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]); + }); + return success(); + } + + auto outputArgs = loopOp.getRegionOutputArgs(); + auto it2 = llvm::find(outputArgs, src); + if (it2 != outputArgs.end()) { + rewriter.updateRootInPlace(dimOp, [&] { + dimOp.sourceMutable().assign( + loopOp.outputs()[it2 - outputArgs.begin()]); + }); + return success(); + } + + return failure(); + } +}; + // Folds away TiledLoopOp output tensors when the following conditions are met: // * result of `linalg.tiled_loop` has no uses // * output tensor is the argument of `linalg.yield` @@ -2167,7 +2218,9 @@ struct TiledLoopResultsFolder : public OpRewritePattern { void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert, + DimOfTiledLoopInsOutsFolder>(context); } LogicalResult TiledLoopOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 34dacd5..41a4bfe 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -919,3 +919,30 @@ func @dim_of_pad_tensor(%arg0: tensor, %arg1: tensor, %r = tensor.dim %0, %c0 : tensor return %r : index } + +// ----- + +// CHECK-LABEL: func @dim_of_tiled_loop_input( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: linalg.tiled_loop +// CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]] +// CHECK: index_cast %[[dim]] +func @dim_of_tiled_loop_input(%arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0) + to (%d0, %d1) step (%c1, %c1) + ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) + outs (%out1 = %arg2 : tensor) { + %inner_dim = tensor.dim %in1, %c0 : tensor + %cast1 = std.index_cast %inner_dim : index to i32 + %cast2 = std.sitofp %cast1 : i32 to f32 + %fill = linalg.fill(%cast2, %out1) : f32, tensor -> tensor + linalg.yield %fill : tensor + } + return %r : tensor +} diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir index de7c0b7..01dab48 100644 --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -121,8 +121,8 @@ module { // TLOOP: %[[AB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]] -// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]] -// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]] // TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]]) @@ -300,7 +300,7 @@ module { // TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]] // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { -// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]] // TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0] // TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]] // TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] @@ -371,7 +371,7 @@ module { // TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]] // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { -// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]] // TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0] // TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]] // TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] -- 2.7.4