From 83b582d51b742ad4a3e2b10e55058508b0e1ebc6 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 27 Dec 2022 06:14:58 -0800 Subject: [PATCH] [mlir][Linalg] Properly propagate transform result in ScalarizeOp --- .../Linalg/TransformOps/LinalgTransformOps.cpp | 20 ++++++++++++-------- mlir/test/Dialect/Linalg/transform-op-scalarize.mlir | 10 +++++++++- mlir/test/Dialect/Linalg/transform-ops.mlir | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 347c530..5660891 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -67,14 +67,14 @@ DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { -#define DOWNSCALE(trans) \ - { \ - FailureOr res = tryApply(target); \ - if (succeeded(res)) { \ - results.push_back(*res); \ - return DiagnosedSilenceableFailure::success(); \ - } \ - } +#define DOWNSCALE(trans) \ + { \ + FailureOr res = tryApply(target); \ + if (succeeded(res)) { \ + results.push_back(*res); \ + return DiagnosedSilenceableFailure::success(); \ + } \ + } #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) @@ -986,6 +986,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, if (failed(maybeTilingResult)) return emitDefaultDefiniteFailure(target); + if (target->getNumResults()) + rewriter.replaceOp(target, maybeTilingResult->replacements); + else + rewriter.eraseOp(target); results.append(maybeTilingResult->tiledOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir index 89c8d32..fbf083c 100644 --- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -5,8 +5,16 @@ func.func @scalarize(%arg0: tensor<24x12xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { // The op is first tiled by 10 in the first dimension, which creates a // dynamic size, and then scalarized, which brings the dimension to static 1. - // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12 + // CHECK: %[[RES_LOOP_1:.*]] = scf.for {{.*}} -> (tensor<24x25xf32>) + // CHECK: %[[RES_LOOP_2:.*]] = scf.for {{.*}} -> (tensor) + // CHECK: %[[MM:.*]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12 + // CHECK: %[[INS_2:.*]] = tensor.insert_slice %[[MM]] into %{{.*}} [1, 25] [1, 1] : tensor<1x25xf32> into tensor + // CHECK: scf.yield %[[INS_2]] : tensor + // CHECK: %[[INS_1:.*]] = tensor.insert_slice %[[RES_LOOP_2]] into %{{.*}}, 25] [1, 1] : tensor into tensor<24x25xf32> + // CHECK: scf.yield %[[INS_1]] : tensor<24x25xf32> %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + + // CHECK: return %[[RES_LOOP_1]] : tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index 898cce7..64cf3fb 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -8,7 +8,7 @@ transform.sequence failures(propagate) { //===----------------------------------------------------------------------===// // Check that operations are registered correctly through the extension -// mechanism. Their syntax is generated and requries no additional testing since +// mechanism. Their syntax is generated and requires no additional testing since // we test the generator. //===----------------------------------------------------------------------===// -- 2.7.4