From 1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 15 May 2023 12:28:21 +0000 Subject: [PATCH] [mlir] allow repeated payload in structured.fuse_into_containing Structured fusion proceeds by iteratively finding the next suitable producer to be fused into the loop. Therefore, it shouldn't matter if the same producer is listed multiple times (e.g., it is used as multiple operands). Adjust the implementation of the transform op to support this case. Also fix the checking code in the interpreter to actually respect the TransformOpInterface indication that repeated payload is allowed, it seems to have been accidentally dropped in one of the refactorings. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D150561 --- .../Linalg/TransformOps/LinalgTransformOps.td | 9 +++--- .../Linalg/TransformOps/LinalgTransformOps.cpp | 9 ++++-- .../Dialect/Transform/IR/TransformInterfaces.cpp | 4 +++ .../Linalg/transform-op-fuse-into-containing.mlir | 36 ++++++++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index cdeabb7..c7bc376 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -143,7 +143,8 @@ def FuseOp : Op, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Fuse a producer into a containing operation."; @@ -160,7 +161,7 @@ def FuseIntoContainingOp : producer op handle may be associated with multiple payload ops. This transform fuses producers one-by-one, always picking an unspecified producer that has at least one use inside the containing op among the - producers. + producers. A producer can be listed multiple times in the handle. Note: If a producer has multiple uses inside the containing op, it is currently tiled and/or cloned multiple times into the containing op. @@ -176,8 +177,8 @@ def FuseIntoContainingOp : containing op. I.e., "producers" that are not consumed within the containing op are rejected by this operation. - This operation reads and frees the producer handle. - This operation reads the containing op handle. + This operation consumes the producer handle. + This operation only reads the containing op handle. }]; let arguments = (ins PDL_Operation:$producer_op, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index afef599..0703ca3 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, return fusedOp; } +bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { + // Allow repeated handles since we are fusing everything anyway. + return true; +} + DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. - SmallVector remainingProducers(producerOps.begin(), - producerOps.end()); + SetVector remainingProducers(producerOps.begin(), + producerOps.end()); auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index bad1d74f..5685187e 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { FULL_LDBG("--handle not consumed -> SKIP\n"); continue; } + if (transform.allowsRepeatedHandleOperands()) { + FULL_LDBG("--op allows repeated handles -> SKIP\n"); + continue; + } FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index d6b3ff3..537ee86 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -247,3 +247,39 @@ module { transform.structured.fuse_into_containing_op %0 into %1 } } + +// ----- + +module { + // CHECK-LABEL: func.func @fuse_repeated + func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> + + // CHECK: scf.forall + %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { + %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> + %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> + // CHECK: %[[FUSED:.+]] = linalg.fill + // CHECK: elemwise_unary ins(%[[FUSED]] + %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> + } + } + + return %1 : tensor<2xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation + + // Create a new handle that points to `linalg.fill` twice. + %2 = transform.merge_handles %0, %0 : !pdl.operation + + // It shouldn't be a problem to fuse this handle. + transform.structured.fuse_into_containing_op %2 into %1 + } +} -- 2.7.4