class MatchOp:
"""Specialization for MatchOp class."""
+ @overload
@classmethod
def match_op_names(
- MatchOp,
+ cls,
target: Union[Operation, Value],
names: Sequence[str],
+ *,
loc=None,
ip=None,
):
- pdl_operation_type = pdl.OperationType.get()
- return MatchOp(
- pdl_operation_type,
+ ...
+
+ @overload
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type: Type,
+ target: Union[Operation, Value],
+ names: Sequence[str],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type_or_target: Union[Type, Operation, Value],
+ target_or_names: Union[Operation, Value, Sequence[str]],
+ names_or_none: Optional[Sequence[str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_names
+ names = names_or_none
+ else:
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+ names = target_or_names
+
+ return cls(
+ result_type,
_get_op_result_or_value(target),
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
@run
+def testMatchOpNames():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMatchOpNames
+ # CHECK: transform.structured.match ops
+ # CHECK-SAME: ["test.dummy"]
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+@run
+def testMatchOpNamesTyped():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MatchOp.match_op_names(
+ transform.OperationType.get("test.dummy"),
+ sequence.bodyTarget,
+ ["test.dummy"],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMatchOpNamesTyped
+ # CHECK: transform.structured.match ops
+ # CHECK-SAME: ["test.dummy"]
+ # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
+
+
+@run
def testMultitileSizes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()