def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [DeclareOpInterfaceMethods<TransformOpInterface,
+ ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Fuse a producer into a containing operation.";
producer op handle may be associated with multiple payload ops. This
transform fuses producers one-by-one, always picking an unspecified producer
that has at least one use inside the containing op among the
- producers.
+ producers. A producer can be listed multiple times in the handle.
Note: If a producer has multiple uses inside the containing op, it is
currently tiled and/or cloned multiple times into the containing op.
containing op. I.e., "producers" that are not consumed within the containing
op are rejected by this operation.
- This operation reads and frees the producer handle.
- This operation reads the containing op handle.
+ This operation consumes the producer handle.
+ This operation only reads the containing op handle.
}];
let arguments = (ins PDL_Operation:$producer_op,
return fusedOp;
}
+bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
+ // Allow repeated handles since we are fusing everything anyway.
+ return true;
+}
+
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
// Helper function to find the next producer that should be fused. Take any
// producer that has a use inside the containing op.
- SmallVector<Operation *> remainingProducers(producerOps.begin(),
- producerOps.end());
+ SetVector<Operation *> remainingProducers(producerOps.begin(),
+ producerOps.end());
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
FULL_LDBG("--handle not consumed -> SKIP\n");
continue;
}
+ if (transform.allowsRepeatedHandleOperands()) {
+ FULL_LDBG("--op allows repeated handles -> SKIP\n");
+ continue;
+ }
FULL_LDBG("--handle is consumed\n");
Type operandType = operand.get().getType();
transform.structured.fuse_into_containing_op %0 into %1
}
}
+
+// -----
+
+module {
+ // CHECK-LABEL: func.func @fuse_repeated
+ func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
+
+ // CHECK: scf.forall
+ %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
+ %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ // CHECK: %[[FUSED:.+]] = linalg.fill
+ // CHECK: elemwise_unary ins(%[[FUSED]]
+ %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
+ }
+ }
+
+ return %1 : tensor<2xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+
+ // Create a new handle that points to `linalg.fill` twice.
+ %2 = transform.merge_handles %0, %0 : !pdl.operation
+
+ // It shouldn't be a problem to fuse this handle.
+ transform.structured.fuse_into_containing_op %2 into %1
+ }
+}