From fa0d044c4499535fb7960a5b7053bd043ad09e52 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 28 Apr 2021 19:48:27 +0200 Subject: [PATCH] [mlir] Fix canonicalization of tiled_loop if not all opresults fold. The current canonicalization did not remap operation results correctly and attempted to erase tiledLoop, which is incorrect if not all tensor results are folded. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 22 ++++++-- mlir/test/Dialect/Linalg/canonicalize.mlir | 86 ++++++++++++++---------------- 2 files changed, 58 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 17ecab1..750d818 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2094,7 +2094,7 @@ namespace { static constexpr int64_t kNoMatch = -1; -// Folds away TiledLoopOp input tensors if they have no uses within the body. +// Folds away TiledLoopOp inputs if they have no uses within the body. // // Example: // @@ -2117,7 +2117,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern { Value in, bbArg; size_t index = en.index(); std::tie(in, bbArg) = en.value(); - if (!in.getType().isa() || !bbArg.use_empty()) { + if (!bbArg.use_empty()) { oldInputIdToNew[index] = newInputs.size(); newInputs.push_back(in); } @@ -2142,7 +2142,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern { OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : *tiledLoop.getBody()) innerBuilder.clone(op, bvm); - rewriter.eraseOp(tiledLoop); + rewriter.replaceOp(tiledLoop, newTiledLoop.getResults()); return success(); } @@ -2184,6 +2184,10 @@ struct TiledLoopResultsFolder : public OpRewritePattern { // Store ids of the corresponding old and new output operands. SmallVector oldOutputIdToNew(tiledLoop.outputs().size(), kNoMatch); + // Store ids of the corresponding old and new results. + SmallVector oldResultIdToNew(tiledLoop.getNumResults(), + kNoMatch); + SmallVector resultReplacement(tiledLoop.getNumResults()); for (auto en : llvm::enumerate( llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { size_t index = en.index(); @@ -2199,6 +2203,8 @@ struct TiledLoopResultsFolder : public OpRewritePattern { Value yieldArg = yieldOp.getOperand(resultId); if (yieldArg != outRegionArg || !result.use_empty()) { oldOutputIdToNew[index] = newOutputOperands.size(); + oldResultIdToNew[resultId] = newYieldArgs.size(); + resultReplacement[resultId] = out; newOutputOperands.push_back(out); newYieldArgs.push_back(yieldArg); } @@ -2228,8 +2234,14 @@ struct TiledLoopResultsFolder : public OpRewritePattern { OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) innerBuilder.clone(op, bvm); - innerBuilder.create(loc, newYieldArgs); - rewriter.eraseOp(tiledLoop); + innerBuilder.create( + loc, llvm::to_vector<2>(llvm::map_range( + newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); + + for (const auto &en : llvm::enumerate(oldResultIdToNew)) + if (en.value() != kNoMatch) + resultReplacement[en.index()] = newTiledLoop.getResult(en.value()); + rewriter.replaceOp(tiledLoop, resultReplacement); return success(); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index e66ee38..244b78f 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -867,75 +867,71 @@ func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor } -// ----- -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> +// ----- -func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>) -> () +func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>, + %C: memref<48xf32>) -> (tensor<48xf32>) -func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>, - %C_tensor: tensor<192x192xf32>) { - %cst = constant 0.000000e+00 : f32 - %c24 = constant 24 : index - %c16 = constant 16 : index +func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>, + %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> { %c0 = constant 0 : index - %c192 = constant 192 : index - %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) - step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%CT_ = %C_tensor: tensor<192x192xf32>, - %C_ = %C: memref<192x192xf32>) { - call @foo(%A_, %B_, %C_) - : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - linalg.yield %CT_ : tensor<192x192xf32> + %c24 = constant 24 : index + %c48 = constant 48 : index + %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24) + ins (%A_ = %A: memref<48xf32>) + outs (%B_ = %B: tensor<48xf32>, + %CT_ = %C_tensor: tensor<48xf32>, + %C_ = %C: memref<48xf32>) { + %result = call @foo(%A_, %B_, %C_) + : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>) + linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32> } - return + return %useful : tensor<48xf32> } // CHECK-LABEL: func @fold_tiled_loop_results( -// CHECK-SAME: %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]], -// CHECK-SAME: %[[C_TENSOR:.*]]: tensor<{{.*}}>) { -// CHECK: %[[C24:.*]] = constant 24 : index -// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK-SAME: %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]], +// CHECK-SAME: %[[C:.*]]: [[BUF_TY]], %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] { + // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C192:.*]] = constant 192 : index +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C48:.*]] = constant 48 : index // CHECK-NOT: %{{.*}} = linalg.tiled_loop -// CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) -// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>) -// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) { -// CHECK-NEXT: call @foo(%[[A_]], %[[B_]], %[[C_]]) -// CHECK-NEXT: linalg.yield +// CHECK: %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]]) +// CHECK-SAME: to (%[[C48]]) step (%[[C24]]) +// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]]) +// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) { +// CHECK-NEXT: %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]]) +// CHECK-NEXT: linalg.yield %[[RES]] : -// ----- +// CHECK: return %[[RESULT]] -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> +// ----- -func private @foo(%A: memref<192xf32>) -> () +func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32> -func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) { +func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>, + %B_tensor: tensor<192xf32>) -> tensor<192xf32> { %c0 = constant 0 : index %c24 = constant 24 : index %c192 = constant 192 : index - linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24) - ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) { - call @foo(%A_) : (memref<192xf32>)-> () - linalg.yield + %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24) + ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) + outs (%BT_ = %B_tensor: tensor<192xf32>) { + %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32> + linalg.yield %0 : tensor<192xf32> } - return + return %result : tensor<192xf32> } // CHECK-LABEL: func @fold_tiled_loop_inputs -// CHECK: linalg.tiled_loop +// CHECK: %[[RESULT:.*]] = linalg.tiled_loop // CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>) +// CHECK: return %[[RESULT]] + // ----- func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, -- 2.7.4