# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from typing import Optional, Sequence
+ from typing import Optional, Sequence, Union
import inspect
return self.attributes["sym_visibility"]
@property
- def name(self):
- return self.attributes["sym_name"]
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
@property
def arg_attrs(self):
- return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
- def arg_attrs(self, attribute: ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context)
@property
def arguments(self):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
+
+
+class CallOp:
+ """Specialization for the call op class."""
+
+ def __init__(self,
+ calleeOrResults: Union[FuncOp, List[Type]],
+ argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+ arguments: Optional[List] = None,
+ *,
+ loc=None,
+ ip=None):
+ """Creates an call operation.
+
+ The constructor accepts three different forms:
+
+ 1. A function op to be called followed by a list of arguments.
+ 2. A list of result types, followed by the name of the function to be
+ called as string, following by a list of arguments.
+ 3. A list of result types, followed by the name of the function to be
+ called as symbol reference attribute, followed by a list of arguments.
+
+ For example
+
+ f = builtin.FuncOp("foo", ...)
+ std.CallOp(f, [args])
+ std.CallOp([result_types], "foo", [args])
+
+ In all cases, the location and insertion point may be specified as keyword
+ arguments if not provided by the surrounding context managers.
+ """
+
+ # TODO: consider supporting constructor "overloads", e.g., through a custom
+ # or pybind-provided metaclass.
+ if isinstance(calleeOrResults, FuncOp):
+ if not isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function, expected " +
+ "the second argument to be a list of call arguments, " +
+ f"got {type(argumentsOrCallee)}")
+ if arguments is not None:
+ raise ValueError("unexpected third argument when constructing a call" +
+ "to a function")
+
+ super().__init__(
+ calleeOrResults.type.results,
+ FlatSymbolRefAttr.get(
+ calleeOrResults.name.value,
+ context=_get_default_loc_context(loc)),
+ argumentsOrCallee,
+ loc=loc,
+ ip=ip)
+ return
+
+ if isinstance(argumentsOrCallee, list):
+ raise ValueError("when constructing a call to a function by name, " +
+ "expected the second argument to be a string or a " +
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
+
+ if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+ super().__init__(
+ calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
+ elif isinstance(argumentsOrCallee, str):
+ super().__init__(
+ calleeOrResults,
+ FlatSymbolRefAttr.get(
+ argumentsOrCallee, context=_get_default_loc_context(loc)),
+ arguments,
+ loc=loc,
+ ip=ip)
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(module.body):
- func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
+ func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
with InsertionPoint(func.add_entry_block()):
std.ReturnOp(func.arguments)
func.arg_attrs = ArrayAttr.get([
DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
])
+ other = builtin.FuncOp("other_func", ([f32, f32], []))
+ with InsertionPoint(other.add_entry_block()):
+ std.ReturnOp([])
+ other.arg_attrs = [
+ DictAttr.get({"foo": StringAttr.get("qux")}),
+ DictAttr.get()
+ ]
+
# CHECK: [{baz, foo = "bar"}, {qux = []}]
print(func.arg_attrs)
# CHECK: func @some_func(
# CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
# CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
- # CHECK: f64 {res1 = 4.200000e+01 : f32},
- # CHECK: f64 {res2 = 2.560000e+02 : f64})
+ # CHECK: f32 {res1 = 4.200000e+01 : f32},
+ # CHECK: f32 {res2 = 2.560000e+02 : f64})
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+ #
+ # CHECK: func @other_func(
+ # CHECK: %{{.*}}: f32 {foo = "qux"},
+ # CHECK: %{{.*}}: f32)
print(module)
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
+from mlir.dialects import builtin
from mlir.dialects import std
print(c1.literal_value)
# CHECK: = constant 10 : index
+
+# CHECK-LABEL: TEST: testFunctionCalls
+@constructAndPrintInModule
+def testFunctionCalls():
+ foo = builtin.FuncOp("foo", ([], []))
+ bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
+ qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
+
+ with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
+ std.CallOp(foo, [])
+ std.CallOp([IndexType.get()], "bar", [])
+ std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
+ std.ReturnOp([])
+
+# CHECK: func @foo()
+# CHECK: func @bar() -> index
+# CHECK: func @qux() -> f32
+# CHECK: func @caller() {
+# CHECK: call @foo() : () -> ()
+# CHECK: %0 = call @bar() : () -> index
+# CHECK: %1 = call @qux() : () -> f32
+# CHECK: return
+# CHECK: }
+