From 9437bf418a7fdb9a1079f416dd28bb7107161d74 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 21 Mar 2023 01:04:04 -0700 Subject: [PATCH] [mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand --- .../Linalg/TransformOps/LinalgTransformOps.td | 20 ++++++++++---- .../Linalg/TransformOps/LinalgTransformOps.cpp | 31 ++++++++++------------ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 712abf3..c58e955 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op { + [FunctionalStyleTransformOpTrait, + MemoryEffectsOpInterface, + TransformOpInterface, + TransformEachOpTrait]> { let description = [{ Decomposes named complex operations, such as higher-dimensional (depthwise) convolutions, into combinations of lower-dimensional equivalents @@ -932,9 +934,10 @@ def ScalarizeOp : Op]> { + [FunctionalStyleTransformOpTrait, + MemoryEffectsOpInterface, + TransformOpInterface, + TransformEachOpTrait]> { let description = [{ Rewrite a supported tensor operation that is not in destination-passing style into a form that is in destination-passing style. @@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op< $target attr-dict `:` functional-type($target, results) }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 407b8d2..d98eb3b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::RewriteInDestinationPassingStyleOp::apply( - transform::TransformResults &results, transform::TransformState &state) { +transform::RewriteInDestinationPassingStyleOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { SmallVector res; - ArrayRef targetOps = state.getPayloadOps(getTarget()); - for (Operation *target : targetOps) { - IRRewriter rewriter(target->getContext()); - rewriter.setInsertionPoint(target); - FailureOr maybeResult = - TypeSwitch>(target) - .Case( - [&rewriter](auto op) { - return rewriteInDestinationPassingStyle(rewriter, op); - }); - if (failed(maybeResult)) - return emitDefaultSilenceableFailure(target); - res.push_back(*maybeResult); - } - results.set(getResult().cast(), res); + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr maybeResult = + TypeSwitch>(target) + .Case( + [&rewriter](auto op) { + return rewriteInDestinationPassingStyle(rewriter, op); + }); + if (failed(maybeResult)) + return emitDefaultSilenceableFailure(target); + results.push_back(*maybeResult); return DiagnosedSilenceableFailure::success(); } -- 2.7.4