#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
return true;
}
+namespace {
+/// Unsafely exposes an internal protected method of TransformState::Extension
+/// as public.
+///
+/// MUST NOT be used directly.
+class UnsafeOpReplacementStateExtension : public TransformState::Extension {
+public:
+ UnsafeOpReplacementStateExtension(TransformState &state)
+ : TransformState::Extension(state) {}
+
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ UnsafeOpReplacementStateExtension)
+
+ LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
+ return replacePayloadOp(op, replacement);
+ }
+};
+} // namespace
+
+/// Replaces `payload` with `replacement` in all handles stored in the state.
+/// MUST NOT be used except for the case immediately below.
+static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
+ Operation *payload,
+ Operation *replacement) {
+ UnsafeOpReplacementStateExtension extension(state);
+ // This may return failure if the payload is not associated with any handle,
+ // ignore that.
+ (void)extension.doReplacePayloadOp(payload, replacement);
+}
+
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
+ // Update handles associated with the containing op so we don't need to
+ // invalidate them. This is a hack to support better composability between
+ // tiling and fusion while a proper mechanism is being investigated.
+ //
+ // DO NOT replicate this elsewhere unless you understand what you are doing.
+ forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
+ containingOp);
+
results.set(cast<OpResult>(getFusedOp()), fusedOps);
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
return DiagnosedSilenceableFailure::success();
void transform::FuseIntoContainingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getProducerOp(), effects);
- consumesHandle(getContainingOp(), effects);
+ onlyReadsHandle(getContainingOp(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
}
// CHECK: transform.structured.scalarize
%0 = transform.structured.scalarize %arg0 : (!transform.any_op) -> !transform.any_op
}
+
+// Check that the second argument of `fuse_into_containing_op` is not consumed
+// (if it had been, we would have seen a diagnostic about multiple consumers).
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %0:2 = transform.structured.fuse_into_containing_op %arg1 into %loop
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %1:2 = transform.structured.fuse_into_containing_op %arg2 into %loop
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+}