--- /dev/null
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+class LoadOp:
+ """Specialization for the MemRef load operation."""
+
+ def __init__(self,
+ memref: Union[Operation, OpView, Value],
+ indices: Optional[Union[Operation, OpView,
+ Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None):
+ """Creates a memref load operation.
+
+ Args:
+ memref: the buffer to load from.
+ indices: the list of subscripts, may be empty for zero-dimensional
+ buffers.
+ loc: user-visible location of the operation.
+ ip: insertion point.
+ """
+ memref_resolved = _get_op_result_or_value(memref)
+ indices_resolved = [] if indices is None else _get_op_results_or_values(
+ indices)
+ return_type = memref_resolved.type
+ super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
def run(f):
print("\nTEST:", f.__name__)
f()
+ return f
# CHECK-LABEL: TEST: testSubViewAccessors
+@run
def testSubViewAccessors():
ctx = Context()
module = Module.parse(
print(subview.strides[1])
-run(testSubViewAccessors)
+# CHECK-LABEL: TEST: testCustomBuidlers
+@run
+def testCustomBuidlers():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.parse(r"""
+ func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
+ return
+ }
+ """)
+ func = module.body.operations[0]
+ func_body = func.regions[0].blocks[0]
+ with InsertionPoint.at_block_terminator(func_body):
+ memref.LoadOp(func.arguments[0], func.arguments[1:])
+
+ # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+ # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+ print(module)