Also support replacing payload ops with ConstantLike ops in the TrackingListener, even if the replacement op does not have the same name. (Not supported for ops with multiple results, as this would require splitting the handle.)
Differential Revision: https://reviews.llvm.org/D152127
Only patterns that were registered in the transform dialect's
`PatternRegistry` are available. Additional patterns can be registered as
- part of transform dialect extensions.
+ part of transform dialect extensions. "canonicalization" is a special set
+ of patterns that refers to all canonicalization patterns of all loaded
+ dialects.
This transform only reads the target handle and modifies the payload. If a
pattern erases or replaces a tracked op, the mapping is updated accordingly.
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
+
+ // Register all canonicalization patterns.
+ getOrCreateExtraData<transform::PatternRegistry>().registerPatterns(
+ "canonicalization", [](RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
+ for (Dialect *dialect : ctx->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patterns);
+ for (RegisteredOperationName op : ctx->getRegisteredOperations())
+ op.getCanonicalizationPatterns(patterns, ctx);
+ });
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
if (op->getName() == defOp->getName())
return defOp;
+ // Replacing an op with a constant-like equivalent is a common
+ // canonicalization.
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
+ return defOp;
+
values.clear();
// Skip through ops that implement FindPayloadReplacementOpInterface.
transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
}
+
+// -----
+
+// CHECK-LABEL: func @canonicalization(
+// CHECK: %[[c5:.*]] = arith.constant 5 : index
+// CHECK: return %[[c5]]
+func.func @canonicalization(%t: tensor<5xf32>) -> index {
+ %c0 = arith.constant 0 : index
+ // expected-remark @below {{op was replaced}}
+ %dim = tensor.dim %t, %c0 : tensor<5xf32>
+ return %dim : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op
+ transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
+}