[mlir][python] Add custom constructor for memref load
authorAlex Zinenko <zinenko@google.com>
Wed, 13 Oct 2021 13:20:31 +0000 (15:20 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 13 Oct 2021 15:11:02 +0000 (17:11 +0200)
The type can be inferred trivially, but it is currently done as string
stitching between ODS and C++ and is not easily exposed to Python.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111712

mlir/python/mlir/dialects/_memref_ops_ext.py [new file with mode: 0644]
mlir/test/python/dialects/memref.py

diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
new file mode 100644 (file)
index 0000000..cb25ef1
--- /dev/null
@@ -0,0 +1,37 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+class LoadOp:
+  """Specialization for the MemRef load operation."""
+
+  def __init__(self,
+               memref: Union[Operation, OpView, Value],
+               indices: Optional[Union[Operation, OpView,
+                                       Sequence[Value]]] = None,
+               *,
+               loc=None,
+               ip=None):
+    """Creates a memref load operation.
+
+    Args:
+      memref: the buffer to load from.
+      indices: the list of subscripts, may be empty for zero-dimensional
+        buffers.
+      loc: user-visible location of the operation.
+      ip: insertion point.
+    """
+    memref_resolved = _get_op_result_or_value(memref)
+    indices_resolved = [] if indices is None else _get_op_results_or_values(
+        indices)
+    return_type = memref_resolved.type
+    super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
index 240fb9c221e9e9bb2ca27d34c94fd871c7d521ed..e421f9b2fde953f0177d13f4ef10f545de9ae40d 100644 (file)
@@ -8,9 +8,11 @@ import mlir.dialects.memref as memref
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  return f
 
 
 # CHECK-LABEL: TEST: testSubViewAccessors
+@run
 def testSubViewAccessors():
   ctx = Context()
   module = Module.parse(
@@ -52,4 +54,20 @@ def testSubViewAccessors():
   print(subview.strides[1])
 
 
-run(testSubViewAccessors)
+# CHECK-LABEL: TEST: testCustomBuidlers
+@run
+def testCustomBuidlers():
+  with Context() as ctx, Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
+        return
+      }
+    """)
+    func = module.body.operations[0]
+    func_body = func.regions[0].blocks[0]
+    with InsertionPoint.at_block_terminator(func_body):
+      memref.LoadOp(func.arguments[0], func.arguments[1:])
+
+    # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+    # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+    print(module)