[mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 21 Mar 2023 08:04:04 +0000 (01:04 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 21 Mar 2023 09:17:45 +0000 (02:17 -0700)
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

index 712abf3..c58e955 100644 (file)
@@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
 //===----------------------------------------------------------------------===//
 
 def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait]> {
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformOpInterface, 
+     TransformEachOpTrait]> {
   let description = [{
     Decomposes named complex operations, such as higher-dimensional
     (depthwise) convolutions, into combinations of lower-dimensional equivalents
@@ -932,9 +934,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
 
 def RewriteInDestinationPassingStyleOp : Op<
     Transform_Dialect, "structured.rewrite_in_destination_passing_style",
-    [MemoryEffectsOpInterface,
-     NavigationTransformOpTrait,
-     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformOpInterface, 
+     TransformEachOpTrait]> {
   let description = [{
     Rewrite a supported tensor operation that is not in destination-passing style
     into a form that is in destination-passing style.
@@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op<
     $target attr-dict
     `:` functional-type($target, results)
   }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
index 407b8d2..d98eb3b 100644 (file)
@@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::RewriteInDestinationPassingStyleOp::apply(
-    transform::TransformResults &results, transform::TransformState &state) {
+transform::RewriteInDestinationPassingStyleOp::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
   SmallVector<Operation *> res;
-  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
-  for (Operation *target : targetOps) {
-    IRRewriter rewriter(target->getContext());
-    rewriter.setInsertionPoint(target);
-    FailureOr<Operation *> maybeResult =
-        TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-            .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
-                [&rewriter](auto op) {
-                  return rewriteInDestinationPassingStyle(rewriter, op);
-                });
-    if (failed(maybeResult))
-      return emitDefaultSilenceableFailure(target);
-    res.push_back(*maybeResult);
-  }
-  results.set(getResult().cast<OpResult>(), res);
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<Operation *> maybeResult =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+              [&rewriter](auto op) {
+                return rewriteInDestinationPassingStyle(rewriter, op);
+              });
+  if (failed(maybeResult))
+    return emitDefaultSilenceableFailure(target);
+  results.push_back(*maybeResult);
   return DiagnosedSilenceableFailure::success();
 }