[mlir][transform] ApplyPatternsOp: Register canonicalization patterns
authorMatthias Springer <me@m-sp.org>
Mon, 5 Jun 2023 08:20:24 +0000 (10:20 +0200)
committerMatthias Springer <me@m-sp.org>
Mon, 5 Jun 2023 09:37:43 +0000 (11:37 +0200)
Also support replacing payload ops with ConstantLike ops in the TrackingListener, even if the replacement op does not have the same name. (Not supported for ops with multiple results, as this would require splitting the handle.)

Differential Revision: https://reviews.llvm.org/D152127

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir

index 57a7bd3..b674050 100644 (file)
@@ -137,7 +137,9 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
 
     Only patterns that were registered in the transform dialect's
     `PatternRegistry` are available. Additional patterns can be registered as
-    part of transform dialect extensions.
+    part of transform dialect extensions. "canonicalization" is a special set
+    of patterns that refers to all canonicalization patterns of all loaded
+    dialects.
 
     This transform only reads the target handle and modifies the payload. If a
     pattern erases or replaces a tracked op, the mapping is updated accordingly.
index d075994..20bed31 100644 (file)
@@ -57,6 +57,16 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
   initializeTypes();
+
+  // Register all canonicalization patterns.
+  getOrCreateExtraData<transform::PatternRegistry>().registerPatterns(
+      "canonicalization", [](RewritePatternSet &patterns) {
+        MLIRContext *ctx = patterns.getContext();
+        for (Dialect *dialect : ctx->getLoadedDialects())
+          dialect->getCanonicalizationPatterns(patterns);
+        for (RegisteredOperationName op : ctx->getRegisteredOperations())
+          op.getCanonicalizationPatterns(patterns, ctx);
+      });
 }
 
 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
index 49ec075..987c848 100644 (file)
@@ -89,6 +89,11 @@ transform::TrackingListener::findReplacementOp(Operation *op,
     if (op->getName() == defOp->getName())
       return defOp;
 
+    // Replacing an op with a constant-like equivalent is a common
+    // canonicalization.
+    if (defOp->hasTrait<OpTrait::ConstantLike>())
+      return defOp;
+
     values.clear();
 
     // Skip through ops that implement FindPayloadReplacementOpInterface.
index 0df76d8..c51543e 100644 (file)
@@ -121,3 +121,23 @@ transform.sequence failures(propagate) {
   transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
   transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @canonicalization(
+//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
+//       CHECK:   return %[[c5]]
+func.func @canonicalization(%t: tensor<5xf32>) -> index {
+  %c0 = arith.constant 0 : index
+  // expected-remark @below {{op was replaced}}
+  %dim = tensor.dim %t, %c0 : tensor<5xf32>
+  return %dim : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op
+  transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
+}