[mlir][transform][structured][python] Allow str arg in match_op_names.
authorIngo Müller <ingomueller@google.com>
Thu, 20 Jul 2023 09:58:41 +0000 (09:58 +0000)
committerIngo Müller <ingomueller@google.com>
Fri, 21 Jul 2023 09:36:55 +0000 (09:36 +0000)
Allow the `names` argument in `MatchOp.match_op_names` to be of type
`str` in addition to `Sequence[str]`. In this case, the argument is
treated as a list with one name, i.e., it is possible to write
`MatchOp.match_op_names(..., "test.dummy")` instead of
`MatchOp.match_op_names(..., ["test.dummy"])`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D155807

mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py

index 1936f4b..9f623ef 100644 (file)
@@ -195,7 +195,7 @@ class MatchOp:
     def match_op_names(
         cls,
         target: Union[Operation, Value],
-        names: Sequence[str],
+        names: Union[str, Sequence[str]],
         *,
         loc=None,
         ip=None,
@@ -208,7 +208,7 @@ class MatchOp:
         cls,
         result_type: Type,
         target: Union[Operation, Value],
-        names: Sequence[str],
+        names: Union[str, Sequence[str]],
         *,
         loc=None,
         ip=None,
@@ -219,8 +219,8 @@ class MatchOp:
     def match_op_names(
         cls,
         result_type_or_target: Union[Type, Operation, Value],
-        target_or_names: Union[Operation, Value,  Sequence[str]],
-        names_or_none: Optional[Sequence[str]] = None,
+        target_or_names: Union[Operation, Value, Sequence[str], str],
+        names_or_none: Optional[Union[Sequence[str], str]] = None,
         *,
         loc=None,
         ip=None,
@@ -234,6 +234,9 @@ class MatchOp:
            target = result_type_or_target
            names = target_or_names
 
+        if isinstance(names, str):
+           names = [names]
+
         return cls(
             result_type,
             _get_op_result_or_value(target),
index 0bcfd81..1da55ed 100644 (file)
@@ -97,14 +97,28 @@ def testInterchange():
 
 
 @run
-def testMatchOpNames():
+def testMatchOpNamesString():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMatchOpNamesString
+    # CHECK: transform.structured.match ops
+    # CHECK-SAME: ["test.dummy"]
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+@run
+def testMatchOpNamesList():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: testMatchOpNames
+    # CHECK-LABEL: TEST: testMatchOpNamesList
     # CHECK: transform.structured.match ops
     # CHECK-SAME: ["test.dummy"]
     # CHECK-SAME: (!transform.any_op) -> !transform.any_op