[mlir][transform] SequenceOp: Top-level operations can be used as matchers
authorMatthias Springer <me@m-sp.org>
Mon, 19 Jun 2023 07:04:53 +0000 (09:04 +0200)
committerMatthias Springer <me@m-sp.org>
Mon, 19 Jun 2023 07:06:18 +0000 (09:06 +0200)
As a convenience to the user, top-level sequence ops can optionally be used as matchers: the op type is specified by the type of the block argument.

This is similar to how pass pipeline targets can be specified on the command line (`-pass-pipeline='builtin.module(func.func(...))`).

Differential Revision: https://reviews.llvm.org/D153121

14 files changed:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Tensor/fold-empty-op.mlir
mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir
mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir

index 2dbb95a..ca427ab 100644 (file)
@@ -768,6 +768,26 @@ def SequenceOp : TransformDialectOp<"sequence",
     dialect. Operand omission is only allowed for sequences not contained in
     another sequence.
 
+    The type of the block argument must match the type of the operand. If the
+    sequence is a top-level transform (without an operand), it can be used for
+    matching operations if the specified type within the top-level container
+    payload IR (including the container op itself). E.g.:
+
+    ```mlir
+    transform.sequence failures(propagate) {
+    ^bb1(%arg1: !transform.any_op):
+      // %arg1 is mapped to the top-level container of the payload IR, which is
+      // typically a module
+    }
+
+    transform.sequence failures(propagate) {
+    ^bb1(%arg1: !transform.op<"func.func>"):
+      // %arg1 is mapped to all "func.func" ops within and including the
+      // top-level container of the payload IR. Nested operations that have the
+      // specified op type are not included.
+    }
+    ```
+
     The body of the sequence terminates with an implicit or explicit
     `transform.yield` op. The operands of the terminator are returned as the
     results of the sequence op.
index a6c5629..679a860 100644 (file)
@@ -1562,7 +1562,24 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
              << " were provided to the interpreter";
     }
 
-    targets.push_back(state.getTopLevel());
+    // Top-level transforms can be used for matching. If no concrete operation
+    // type is specified, the block argument is mapped to the top-level op.
+    // Otherwise, it is mapped to all ops of the specified type within the
+    // top-level op (including the top-level op itself). Once an op is added as
+    // a target, its descendants are not explored any further.
+    BlockArgument bbArg = region.front().getArgument(0);
+    if (auto bbArgType = dyn_cast<transform::OperationType>(bbArg.getType())) {
+      state.getTopLevel()->walk<WalkOrder::PreOrder>([&](Operation *op) {
+        if (op->getName().getStringRef() == bbArgType.getOperationName()) {
+          targets.push_back(op);
+          return WalkResult::skip();
+        }
+        return WalkResult::advance();
+      });
+    } else {
+      targets.push_back(state.getTopLevel());
+    }
+
     for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
       extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
   }
index f0a6635..70788e9 100644 (file)
@@ -1,12 +1,10 @@
 // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.fold_tensor_empty
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
@@ -67,13 +65,11 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
 // -----
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.fold_tensor_empty
         {fold_single_use_only = true}
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
index 066420b..7037d7b 100644 (file)
@@ -1,12 +1,10 @@
 // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.rewrite_as_constant
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // CHECK-LABEL: func @tensor_generate_constant(
index f93ac8a..12bfe39 100644 (file)
@@ -209,10 +209,8 @@ func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<v
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 8990709..bbb461c 100644 (file)
@@ -250,10 +250,8 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 //       CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 9630ffc..8bf1bfe 100644 (file)
@@ -189,10 +189,8 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -
 //       CHECK:   return %{{.+}}
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 695400d..ed7d506 100644 (file)
@@ -83,10 +83,8 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
 //       CHECK:   vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 11389bb..a5e80c3 100644 (file)
@@ -107,12 +107,10 @@ func.func @split_vector_transfer_read_strided_2d(
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -170,12 +168,10 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
 // CHECK:         }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -240,10 +236,8 @@ func.func @split_vector_transfer_write_strided_2d(
 // CHECK:         }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 20fee63..5c22c0c 100644 (file)
@@ -102,12 +102,10 @@ func.func @split_vector_transfer_read_strided_2d(
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -162,12 +160,10 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
 
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -226,12 +222,10 @@ func.func @split_vector_transfer_write_strided_2d(
 // CHECK:         }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -270,10 +264,8 @@ func.func @transfer_read_within_scf_for(%A : memref<?x?xf32>, %lb : index, %ub :
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index 5e36223..5fa9e1a 100644 (file)
@@ -1,12 +1,10 @@
 // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.fold_tensor_slice_into_transfer
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // CHECK-LABEL: func @transfer_read_of_extract_slice(
index 5011f22..e096a18 100644 (file)
@@ -239,13 +239,11 @@ func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : i
 
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -363,11 +361,9 @@ func.func @transfer_write_broadcast_unit_dim(
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index fdf2dfd..f3363c5 100644 (file)
@@ -75,12 +75,10 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -100,12 +98,10 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -121,12 +117,10 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -610,12 +604,10 @@ func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32>
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -690,12 +682,10 @@ func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16x
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
 
 // -----
@@ -771,10 +761,8 @@ func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }
index f9b5849..3b3a56a 100644 (file)
@@ -30,11 +30,9 @@ func.func @entry() {
 }
 
 transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
+^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
-  } : !transform.any_op
+  } : !transform.op<"func.func">
 }