[mlir][transform][python] Add extended ApplyPatternsOp.
authorIngo Müller <ingomueller@google.com>
Mon, 17 Jul 2023 10:26:33 +0000 (10:26 +0000)
committerIngo Müller <ingomueller@google.com>
Thu, 20 Jul 2023 14:20:50 +0000 (14:20 +0000)
This patch adds a mixin for ApplyPatternsOp to _transform_ops_ext.py
with syntactic sugar for construction such ops. Curiously, the op did
not have any constructors yet, probably because its tablegen definition
said to skip the default builders. The new constructor is thus quite
straightforward. The commit also adds a refined `region` property which
returns the first block of the single region.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D155435

mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/test/python/dialects/transform.py

index 87f8d39..0db2e3b 100644 (file)
@@ -29,6 +29,32 @@ class CastOp:
     )
 
 
+class ApplyPatternsOp:
+
+  def __init__(
+      self,
+      target: Union[Operation, Value, OpView],
+      *,
+      loc=None,
+      ip=None,
+  ):
+    operands = []
+    operands.append(_get_op_result_or_value(target))
+    super().__init__(
+        self.build_generic(attributes={},
+                           results=[],
+                           operands=operands,
+                           successors=None,
+                           regions=None,
+                           loc=loc,
+                           ip=ip))
+    self.regions[0].blocks.append()
+
+  @property
+  def patterns(self) -> Block:
+    return self.regions[0].blocks[0]
+
+
 class testGetParentOp:
 
   def __init__(
index 668e004..3e7e29a 100644 (file)
@@ -172,6 +172,37 @@ def testMergeHandlesOp():
 
 
 @run
+def testApplyPatternsOpCompact():
+  sequence = transform.SequenceOp(
+      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+  )
+  with InsertionPoint(sequence.body):
+    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+      transform.ApplyCanonicalizationPatternsOp()
+    transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyPatternsOpCompact
+    # CHECK: apply_patterns to
+    # CHECK: transform.apply_patterns.canonicalization
+    # CHECK: !transform.any_op
+
+
+@run
+def testApplyPatternsOpWithType():
+  sequence = transform.SequenceOp(
+      transform.FailurePropagationMode.PROPAGATE, [],
+      transform.OperationType.get('test.dummy')
+  )
+  with InsertionPoint(sequence.body):
+    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+      transform.ApplyCanonicalizationPatternsOp()
+    transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyPatternsOp
+    # CHECK: apply_patterns to
+    # CHECK: transform.apply_patterns.canonicalization
+    # CHECK: !transform.op<"test.dummy">
+
+
+@run
 def testReplicateOp():
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):