From 7f1c03171ddb7d76bd476bc328afd1fd547128de Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Thu, 21 Jul 2022 09:40:30 +0200 Subject: [PATCH] Revert "[RFC][MLIR][SCF] Enable better bufferization for `TileConsumerAndFuseProducersUsingSCFForOp`" This reverts commit 9e6585030533e901a8c24dcb05b38d3f0d10331f. --- .../lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 19 ------------------- .../tile-and-fuse-using-interface.mlir | 6 +++--- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index c62f27d..3bad543 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -355,23 +355,6 @@ static Optional getFusableProducer(Value v) { return v.cast(); } -// Replace iter args of the outer most loop with region args of the inner most -// one. -static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, - PatternRewriter &rewriter) { - assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && - "expect same number of iter args"); - Block *block = &(*innerFor.getRegion().begin()); - for (auto it : - llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { - Value source = std::get<0>(it); - Value target = std::get<1>(it); - source.replaceUsesWithIf(target, [&](OpOperand &use) { - return use.getOwner()->getBlock() == block; - }); - } -} - FailureOr scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( TilingInterface op, PatternRewriter &rewriter) const { @@ -487,7 +470,5 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( } } } - replaceIterArgs(tileAndFuseResult.loops.front(), - tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 8887269..d1ca2d2 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -23,7 +23,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -123,7 +123,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r // CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : // CHECK-SAME: outs(%[[FILL0_TILE]] : // CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] -// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] // CHECK: %[[FILL1_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT1_TILE]] : // CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul -- 2.7.4