include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
+include "mlir/IR/RegionKindInterface.td"
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
}];
}
+def ReplaceOp : Op<Transform_Dialect, "structured.replace",
+ [IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
+ let description = [{
+ Replace all `target` payload ops with the single op that is contained in
+ this op's region. All targets must have zero arguments and must be isolated
+ from above.
+
+ This op is for debugging/experiments only.
+
+ #### Return modes
+
+ This operation consumes the `target` handle.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$replacement);
+ let regions = (region SizedRegion<1>:$bodyRegion);
+ let assemblyFormat = "$target attr-dict-with-keyword regions";
+ let hasVerifier = 1;
+}
+
def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
}
//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceOp::apply(TransformResults &transformResults,
+ TransformState &state) {
+ ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+
+ // Check for invalid targets.
+ for (Operation *target : payload) {
+ if (target->getNumOperands() > 0)
+ return emitDefiniteFailure() << "expected target without operands";
+ if (!target->hasTrait<IsIsolatedFromAbove>() && target->getNumRegions() > 0)
+ return emitDefiniteFailure()
+ << "expected target that is isloated from above";
+ }
+
+ // Clone and replace.
+ IRRewriter rewriter(getContext());
+ Operation *pattern = &getBodyRegion().front().front();
+ SmallVector<Operation *> replacements;
+ for (Operation *target : payload) {
+ if (getOperation()->isAncestor(target))
+ continue;
+ rewriter.setInsertionPoint(target);
+ Operation *replacement = rewriter.clone(*pattern);
+ rewriter.replaceOp(target, replacement->getResults());
+ replacements.push_back(replacement);
+ }
+ transformResults.set(getReplacement().cast<OpResult>(), replacements);
+ return DiagnosedSilenceableFailure(success());
+}
+
+void transform::ReplaceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ producesHandle(getReplacement(), effects);
+ modifiesPayload(effects);
+}
+
+LogicalResult transform::ReplaceOp::verify() {
+ if (!getBodyRegion().hasOneBlock())
+ return emitOpError() << "expected one block";
+ if (std::distance(getBodyRegion().front().begin(),
+ getBodyRegion().front().end()) != 1)
+ return emitOpError() << "expected one operation in block";
+ Operation *replacement = &getBodyRegion().front().front();
+ if (replacement->getNumOperands() > 0)
+ return replacement->emitOpError()
+ << "expected replacement without operands";
+ if (!replacement->hasTrait<IsIsolatedFromAbove>() &&
+ replacement->getNumRegions() > 0)
+ return replacement->emitOpError()
+ << "expect op that is isolated from above";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ScalarizeOp
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt -test-transform-dialect-interpreter %s -allow-unregistered-dialect -verify-diagnostics --split-input-file | FileCheck %s
+
+// CHECK: func.func @foo() {
+// CHECK: "dummy_op"() : () -> ()
+// CHECK: }
+// CHECK-NOT: func.func @bar
+func.func @bar() {
+ "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ transform.structured.replace %0 {
+ func.func @foo() {
+ "dummy_op"() : () -> ()
+ }
+ }
+}
+
+// -----
+
+func.func @bar(%arg0: i1) {
+ "another_op"(%arg0) : (i1) -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["another_op"]} in %arg1
+ // expected-error @+1 {{expected target without operands}}
+ transform.structured.replace %0 {
+ "dummy_op"() : () -> ()
+ }
+}
+
+// -----
+
+func.func @bar() {
+ "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["another_op"]} in %arg1
+ transform.structured.replace %0 {
+ ^bb0(%a: i1):
+ // expected-error @+1 {{expected replacement without operands}}
+ "dummy_op"(%a) : (i1) -> ()
+ }
+}