From: Ingo Müller Date: Tue, 18 Jul 2023 09:07:17 +0000 (+0000) Subject: [mlir][linalg][transform][python] Add type arg to MatchOp extension. X-Git-Tag: upstream/17.0.6~1190 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1dccdf7f49a0cdad7121913c900ce9cb8b6e9fdc;p=platform%2Fupstream%2Fllvm.git [mlir][linalg][transform][python] Add type arg to MatchOp extension. The extension class to MatchOp has a class method called match_op_names. The previous version of that function did not allow to specify the result type. This, however, may be useful/necessary if the op consuming the resulting handle requires a particular type (such as the bufferization.EmptyTensorToAllocTensorOp). This patch adds an overload to match_op_names that allows to specify the result type. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155567 --- diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index b754034..6407309 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -85,17 +85,52 @@ class InterchangeOp: class MatchOp: """Specialization for MatchOp class.""" + @overload @classmethod def match_op_names( - MatchOp, + cls, target: Union[Operation, Value], names: Sequence[str], + *, loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Sequence[str], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str]], + names_or_none: Optional[Sequence[str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + return cls( + result_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 2dfae47..03a4716 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -58,6 +58,38 @@ def testInterchange(): @run +def testMatchOpNames(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNames + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMatchOpNamesTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names( + transform.OperationType.get("test.dummy"), + sequence.bodyTarget, + ["test.dummy"], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNamesTyped + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> + + +@run def testMultitileSizes(): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()