[mlir][Linalg] Support tensor.parallel_insert_slice in transform.insert_slice_to_copy
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 14 Apr 2023 12:43:07 +0000 (05:43 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 14 Apr 2023 13:11:29 +0000 (06:11 -0700)
Differential Revision: https://reviews.llvm.org/D148333

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir

index af76660..9366ce7 100644 (file)
@@ -2030,7 +2030,7 @@ def InsertSliceToCopyOp :
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::tensor::InsertSliceOp target,
+        ::mlir::Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
index 39f7802..8e66702 100644 (file)
@@ -38,6 +38,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -3214,18 +3215,27 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
+template <typename OpTy>
+DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
+                                 transform::ApplyToEachResultList &results,
+                                 transform::TransformState &state) {
+  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
+                                tensor::ParallelInsertSliceOp>() &&
+                "wrong op type");
 
-DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
-    tensor::InsertSliceOp target, transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  if (auto copySource = target.getSource().getDefiningOp<linalg::CopyOp>()) {
+  if (auto copySource =
+          target.getSource().template getDefiningOp<linalg::CopyOp>()) {
     results.push_back(copySource);
     return DiagnosedSilenceableFailure::success();
   }
 
-  TrackingListener listener(state, *this);
-  IRRewriter rewriter(target->getContext(), &listener);
-  rewriter.setInsertionPoint(target);
+  // If we are inside an InParallel region, temporarily set the insertion point
+  // outside: only tensor.parallel_insert_slice ops are allowed in there.
+  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+    rewriter.setInsertionPoint(
+        target->template getParentOfType<scf::InParallelOp>());
+  }
+
   Value extracted = rewriter.create<tensor::ExtractSliceOp>(
       target.getLoc(), target.getDest(), target.getMixedOffsets(),
       target.getMixedSizes(), target.getMixedStrides());
@@ -3233,7 +3243,9 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
                      .create<linalg::CopyOp>(target.getLoc(),
                                              target.getSource(), extracted)
                      .getResult(0);
-  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+  // Reset the insertion point.
+  rewriter.setInsertionPoint(target);
+  rewriter.replaceOpWithNewOp<OpTy>(
       target, copied, target.getDest(), target.getMixedOffsets(),
       target.getMixedSizes(), target.getMixedStrides());
 
@@ -3241,6 +3253,25 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
+    Operation *targetOp, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(targetOp->getContext(), &listener);
+  rewriter.setInsertionPoint(targetOp);
+  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
+    return doit(rewriter, target, results, state);
+  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
+    return doit(rewriter, target, results, state);
+
+  DiagnosedSilenceableFailure diag =
+      emitSilenceableError()
+      << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
+  diag.attachNote(targetOp->getLoc()) << "target op";
+  return diag;
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
index 7c2461c..e6b2d2b 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: func @insert_slice_to_copy
     // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
@@ -108,3 +108,30 @@ transform.sequence failures(propagate) {
   transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
 }
 
+// -----
+
+// CHECK-LABEL: func @parallel_insert_slice_to_copy
+func.func @parallel_insert_slice_to_copy(%out : tensor<?x?xf32>, %sz0: index, %sz1: index) {
+  %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor<?x?xf32>) {
+    %t = "make_me_a_tensor"() : () -> (tensor<?x?xf32> )
+
+    //      CHECK: tensor.extract_slice
+    //      CHECK: linalg.copy
+    //      CHECK: scf.forall.in_parallel
+    //      CHECK:   tensor.parallel_insert_slice
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1] 
+        : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.insert_slice_to_copy %0
+    : (!transform.any_op) -> !transform.any_op
+  transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
+}