#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include <type_traits>
using namespace mlir;
using namespace mlir::linalg;
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
+template <typename OpTy>
+DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
+ tensor::ParallelInsertSliceOp>() &&
+ "wrong op type");
-DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
- tensor::InsertSliceOp target, transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- if (auto copySource = target.getSource().getDefiningOp<linalg::CopyOp>()) {
+ if (auto copySource =
+ target.getSource().template getDefiningOp<linalg::CopyOp>()) {
results.push_back(copySource);
return DiagnosedSilenceableFailure::success();
}
- TrackingListener listener(state, *this);
- IRRewriter rewriter(target->getContext(), &listener);
- rewriter.setInsertionPoint(target);
+ // If we are inside an InParallel region, temporarily set the insertion point
+ // outside: only tensor.parallel_insert_slice ops are allowed in there.
+ if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+ rewriter.setInsertionPoint(
+ target->template getParentOfType<scf::InParallelOp>());
+ }
+
Value extracted = rewriter.create<tensor::ExtractSliceOp>(
target.getLoc(), target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
.create<linalg::CopyOp>(target.getLoc(),
target.getSource(), extracted)
.getResult(0);
- rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+ // Reset the insertion point.
+ rewriter.setInsertionPoint(target);
+ rewriter.replaceOpWithNewOp<OpTy>(
target, copied, target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
+ Operation *targetOp, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ TrackingListener listener(state, *this);
+ IRRewriter rewriter(targetOp->getContext(), &listener);
+ rewriter.setInsertionPoint(targetOp);
+ if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
+ return doit(rewriter, target, results, state);
+ if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
+ return doit(rewriter, target, results, state);
+
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
-// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @insert_slice_to_copy
// CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
}
+// -----
+
+// CHECK-LABEL: func @parallel_insert_slice_to_copy
+func.func @parallel_insert_slice_to_copy(%out : tensor<?x?xf32>, %sz0: index, %sz1: index) {
+ %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor<?x?xf32>) {
+ %t = "make_me_a_tensor"() : () -> (tensor<?x?xf32> )
+
+ // CHECK: tensor.extract_slice
+ // CHECK: linalg.copy
+ // CHECK: scf.forall.in_parallel
+ // CHECK: tensor.parallel_insert_slice
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.insert_slice_to_copy %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
+}