From: Matthias Springer Date: Mon, 19 Jun 2023 07:04:53 +0000 (+0200) Subject: [mlir][transform] SequenceOp: Top-level operations can be used as matchers X-Git-Tag: upstream/17.0.6~4617 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1826c728cf149ec2782df3c88960601ff214d4fd;p=platform%2Fupstream%2Fllvm.git [mlir][transform] SequenceOp: Top-level operations can be used as matchers As a convenience to the user, top-level sequence ops can optionally be used as matchers: the op type is specified by the type of the block argument. This is similar to how pass pipeline targets can be specified on the command line (`-pass-pipeline='builtin.module(func.func(...))`). Differential Revision: https://reviews.llvm.org/D153121 --- diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 2dbb95a..ca427ab 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -768,6 +768,26 @@ def SequenceOp : TransformDialectOp<"sequence", dialect. Operand omission is only allowed for sequences not contained in another sequence. + The type of the block argument must match the type of the operand. If the + sequence is a top-level transform (without an operand), it can be used for + matching operations if the specified type within the top-level container + payload IR (including the container op itself). E.g.: + + ```mlir + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + // %arg1 is mapped to the top-level container of the payload IR, which is + // typically a module + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.op<"func.func>"): + // %arg1 is mapped to all "func.func" ops within and including the + // top-level container of the payload IR. Nested operations that have the + // specified op type are not included. + } + ``` + The body of the sequence terminates with an implicit or explicit `transform.yield` op. The operands of the terminator are returned as the results of the sequence op. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index a6c5629..679a860 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1562,7 +1562,24 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( << " were provided to the interpreter"; } - targets.push_back(state.getTopLevel()); + // Top-level transforms can be used for matching. If no concrete operation + // type is specified, the block argument is mapped to the top-level op. + // Otherwise, it is mapped to all ops of the specified type within the + // top-level op (including the top-level op itself). Once an op is added as + // a target, its descendants are not explored any further. + BlockArgument bbArg = region.front().getArgument(0); + if (auto bbArgType = dyn_cast(bbArg.getType())) { + state.getTopLevel()->walk([&](Operation *op) { + if (op->getName().getStringRef() == bbArgType.getOperationName()) { + targets.push_back(op); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + } else { + targets.push_back(state.getTopLevel()); + } + for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i) extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i))); } diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index f0a6635..70788e9 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> @@ -67,13 +65,11 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens // ----- transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty {fold_single_use_only = true} - } : !transform.any_op + } : !transform.op<"func.func"> } func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index) diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir index 066420b..7037d7b 100644 --- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir +++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.rewrite_as_constant - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK-LABEL: func @tensor_generate_constant( diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir index f93ac8a..12bfe39 100644 --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -209,10 +209,8 @@ func.func @redpar_vecmattrans2x2(%arg0: memref>, %arg1: memref !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 8990709..bbb461c 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -250,10 +250,8 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc // CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir index 9630ffc..8bf1bfe 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -189,10 +189,8 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) - // CHECK: return %{{.+}} transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 695400d..ed7d506 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -83,10 +83,8 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d( // CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir index 11389bb..a5e80c3 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -107,12 +107,10 @@ func.func @split_vector_transfer_read_strided_2d( } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -170,12 +168,10 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -240,10 +236,8 @@ func.func @split_vector_transfer_write_strided_2d( // CHECK: } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir index 20fee63..5c22c0c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -102,12 +102,10 @@ func.func @split_vector_transfer_read_strided_2d( } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -162,12 +160,10 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -226,12 +222,10 @@ func.func @split_vector_transfer_write_strided_2d( // CHECK: } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -270,10 +264,8 @@ func.func @transfer_read_within_scf_for(%A : memref, %lb : index, %ub : } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir index 5e362231..5fa9e1a 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.fold_tensor_slice_into_transfer - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK-LABEL: func @transfer_read_of_extract_slice( diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 5011f22..e096a18 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -239,13 +239,11 @@ func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : i transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -363,11 +361,9 @@ func.func @transfer_write_broadcast_unit_dim( } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index fdf2dfd..f3363c5 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -75,12 +75,10 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8 } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -100,12 +98,10 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -121,12 +117,10 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -610,12 +604,10 @@ func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -690,12 +682,10 @@ func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16x } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -771,10 +761,8 @@ func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1 } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir index f9b5849..3b3a56a 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir @@ -30,11 +30,9 @@ func.func @entry() { } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> }