// Op attribute accessors.
NamedAttribute &getAttribute(int index) { return attributes[index]; }
+ const NamedAttribute &getAttribute(int index) const {
+ return attributes[index];
+ }
// Op operand iterators.
value_iterator operand_begin();
llvm::Optional<PyOperationRef> refOperation;
PyBlock block;
};
+/// Wrapper around the generic MlirType.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyType : public BaseContextObject {
+public:
+ PyType(PyMlirContextRef contextRef, MlirType type)
+ : BaseContextObject(std::move(contextRef)), type(type) {}
+ bool operator==(const PyType &other);
+ operator MlirType() const { return type; }
+ MlirType get() const { return type; }
+
+ /// Gets a capsule wrapping the void* within the MlirType.
+ pybind11::object getCapsule();
+
+ /// Creates a PyType from the MlirType wrapped by a capsule.
+ /// Note that PyType instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirType
+ /// is taken by calling this function.
+ static PyType createFromCapsule(pybind11::object capsule);
+
+private:
+ MlirType type;
+};
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirType);
+
+ PyConcreteType() = default;
+ PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+ : BaseTy(std::move(contextRef), t) {}
+ PyConcreteType(PyType &orig)
+ : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+ static MlirType castFrom(PyType &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(pybind11::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
+ cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyType &otherType) -> bool {
+ return DerivedTy::isaFunction(otherType);
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
});
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-/// Wrapper around the generic MlirType.
-/// The lifetime of a type is bound by the PyContext that created it.
-class PyType : public BaseContextObject {
-public:
- PyType(PyMlirContextRef contextRef, MlirType type)
- : BaseContextObject(std::move(contextRef)), type(type) {}
- bool operator==(const PyType &other);
- operator MlirType() const { return type; }
- MlirType get() const { return type; }
-
- /// Gets a capsule wrapping the void* within the MlirType.
- pybind11::object getCapsule();
-
- /// Creates a PyType from the MlirType wrapped by a capsule.
- /// Note that PyType instances are uniqued, so the returned object
- /// may be a pre-existing object. Ownership of the underlying MlirType
- /// is taken by calling this function.
- static PyType createFromCapsule(pybind11::object capsule);
-
-private:
- MlirType type;
-};
-
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirType);
-
- PyConcreteType() = default;
- PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {}
- PyConcreteType(PyType &orig)
- : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
- static MlirType castFrom(PyType &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
- DerivedTy::pyClassName +
- " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(pybind11::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
- cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
- cls.def_static("isinstance", [](PyType &otherType) -> bool {
- return DerivedTy::isaFunction(otherType);
+ cls.def_property_readonly("type", [](PyAttribute &attr) {
+ return PyType(attr.getContext(), mlirAttributeGetType(attr));
});
DerivedTy::bindDerived(cls);
}
elif expr.scalar_index:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
- return linalg.IndexOp(IndexType.get(), dim_attr).result
+ return linalg.IndexOp(dim_attr).result
elif expr.scalar_apply:
try:
fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return arith.AddFOp(lhs.type, lhs, rhs).result
+ return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.AddIOp(lhs.type, lhs, rhs).result
+ return arith.AddIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
def _eval_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
- return math.ExpOp(x.type, x).result
+ return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _eval_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
- return math.LogOp(x.type, x).result
+ return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return arith.SubFOp(lhs.type, lhs, rhs).result
+ return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.SubIOp(lhs.type, lhs, rhs).result
+ return arith.SubIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return arith.MulFOp(lhs.type, lhs, rhs).result
+ return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MulIOp(lhs.type, lhs, rhs).result
+ return arith.MulIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MaxFOp(lhs.type, lhs, rhs).result
+ return std.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MaxSIOp(lhs.type, lhs, rhs).result
+ return std.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MaxFOp(lhs.type, lhs, rhs).result
+ return std.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MaxUIOp(lhs.type, lhs, rhs).result
+ return std.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MinFOp(lhs.type, lhs, rhs).result
+ return std.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MinSIOp(lhs.type, lhs, rhs).result
+ return std.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MinFOp(lhs.type, lhs, rhs).result
+ return std.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MinUIOp(lhs.type, lhs, rhs).result
+ return std.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Bindings/Python/Attributes.td"
// CHECK: @_ods_cext.register_dialect
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}
+// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
+def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
+ // CHECK: def __init__(self, type, *, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: _ods_result_type_source_attr = attributes["type"]
+ // CHECK: _ods_derived_result_type = (
+ // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
+ // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+ // CHECK: _ods_result_type_source_attr.type)
+ // CHECK: results.extend([_ods_derived_result_type] * 2)
+ let arguments = (ins TypeAttr:$type);
+ let results = (outs AnyType:$res, AnyType);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
+def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
+ // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
+ let arguments = (ins TypeAttr:$type);
+ let results = (outs AnyType:$res, Variadic<AnyType>);
+}
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+
+// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
+def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
+ // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: _ods_context = _ods_get_default_loc_context(loc)
+ // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
+ // CHECK: operands=operands,
+ // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+ // CHECK: context=_ods_context,
+ // CHECK: loc=loc)
+ let results = (outs I32:$i32, F32:$f32);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
+def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
+ // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: _ods_context = _ods_get_default_loc_context(loc)
+ // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
+ // CHECK: operands=operands,
+ // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+ // CHECK: context=_ods_context,
+ // CHECK: loc=loc)
+ let results = (outs AnyType, AnyType, AnyType);
+}
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: results.append(i32)
- // CHECK: results.append(_gen_res_1)
- // CHECK: results.append(i64)
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
// CHECK: operands.append(_get_op_result_or_value(f32))
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
+ // CHECK: results.append(i32)
+ // CHECK: results.append(_gen_res_1)
+ // CHECK: results.append(i64)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: @builtins.property
// CHECK: def i64(self):
// CHECK: return self.operation.results[2]
- let results = (outs I32:$i32, F32, I64:$i64);
+ let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}
+// CHECK-LABEL: OPERATION_NAME = "test.same_results"
+def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
+ // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: operands.append
+ // CHECK: results.extend([operands[0].type] * 1)
+ let arguments = (ins AnyType:$in1, AnyType:$in2);
+ let results = (outs AnyType:$res);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
+def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
+ // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
+ let arguments = (ins AnyType:$in1, AnyType:$in2);
+ let results = (outs Variadic<AnyType>:$res);
+}
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: results.append(i64)
- // CHECK: results.append(f64)
// CHECK: operands.append(_get_op_result_or_value(i32))
// CHECK: operands.append(_get_op_result_or_value(f32))
+ // CHECK: results.append(i64)
+ // CHECK: results.append(f64)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: @builtins.property
// CHECK: def f64(self):
// CHECK: return self.operation.results[1]
- let results = (outs I64:$i64, F64:$f64);
+ let results = (outs I64:$i64, AnyFloat:$f64);
}
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(F32Type.get())
def emit_sqrt(arg):
- return mlir_math.SqrtOp(F32Type.get(), arg)
+ return mlir_math.SqrtOp(arg)
# CHECK-LABEL: func @emit_sqrt(
# CHECK-SAME: %[[ARG:.*]]: f32) {
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
- op = test.InferResultsOp(
- IntegerType.get_signless(32), IntegerType.get_signless(64))
+ op = test.InferResultsOp()
dummy = test.DummyOp()
# CHECK: [Type(i32), Type(i64)]
pass
else:
assert False, "not expected dummy op class to implement the interface"
+
+
+# CHECK-LABEL: TEST: resultTypesDefinedByTraits
+@run
+def resultTypesDefinedByTraits():
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ inferred = test.InferResultsOp()
+ same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+ # CHECK-COUNT-2: i32
+ print(same.one.type)
+ print(same.two.type)
+
+ first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+ inferred.results[1], TypeAttr.get(IndexType.get()))
+ # CHECK-COUNT-2: index
+ print(first_type_attr.one.type)
+ print(first_type_attr.two.type)
+
+ first_attr = test.FirstAttrDeriveAttrOp(
+ FloatAttr.get(F32Type.get(), 3.14))
+ # CHECK-COUNT-3: f32
+ print(first_attr.one.type)
+ print(first_attr.two.type)
+ print(first_attr.three.type)
+
+ implied = test.InferResultsImpliedOp()
+ # CHECK: i32
+ print(implied.integer.type)
+ # CHECK: f64
+ print(implied.flt.type)
+ # CHECK: index
+ print(implied.index.type)
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
- indexT = IndexType.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
RankedTensorType.get((12, -1), f32))
def const_shape_tensor(arg):
- return shape.ConstShapeOp(RankedTensorType.get((2,), indexT),
- DenseElementsAttr.get(np.array([10, 20])))
+ return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
print(module)
-
# Create via dialects context collection.
input1 = createInput()
input2 = createInput()
- op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2)
+ op1 = ctx.dialects.arith.AddFOp(input1, input2)
# Create via an import
from mlir.dialects.arith import AddFOp
- AddFOp(input1.type, input1, op1.result)
+ AddFOp(input1, op1.result)
# CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
# CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
}];
}
+// If all result types are buildable, the InferTypeOpInterface is implied and is
+// autogenerated by C++ ODS.
+def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
+ let results = (outs I32:$integer, F64:$flt, Index:$index);
+}
+
+def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op",
+ [SameOperandsAndResultType]> {
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op",
+ [FirstAttrDerivedResultType]> {
+ let arguments = (ins AnyType:$input, TypeAttr:$type);
+ let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
+ [FirstAttrDerivedResultType]> {
+ let arguments = (ins AnyAttr:$iattr);
+ let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
+}
+
#endif // PYTHON_TEST_OPS
/// {1} is the value to add.
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
-/// Populates `builderArgs` with the Python-compatible names of builder function
-/// arguments, first the results, then the intermixed attributes and operands in
-/// the same order as they appear in the `arguments` field of the op definition.
-/// Additionally, `operandNames` is populated with names of operands in their
-/// order of appearance.
+/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
+/// result types of the given operation.
+static bool hasSameArgumentAndResultTypes(const Operator &op) {
+ return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
+ op.getNumVariableLengthResults() == 0;
+}
+
+/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
+/// result types of the given operation.
+static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
+ return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
+ op.getNumVariableLengthResults() == 0;
+}
+
+/// Returns true if the InferTypeOpInterface can be used to infer result types
+/// of the given operation.
+static bool hasInferTypeInterface(const Operator &op) {
+ return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
+ op.getNumRegions() == 0;
+}
+
+/// Returns true if there is a trait or interface that can be used to infer
+/// result types of the given operation.
+static bool canInferType(const Operator &op) {
+ return hasSameArgumentAndResultTypes(op) ||
+ hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
+}
+
+/// Populates `builderArgs` with result names if the builder is expected to
+/// accept them as arguments.
static void
-populateBuilderArgs(const Operator &op,
- llvm::SmallVectorImpl<std::string> &builderArgs,
- llvm::SmallVectorImpl<std::string> &operandNames,
- llvm::SmallVectorImpl<std::string> &successorArgNames) {
+populateBuilderArgsResults(const Operator &op,
+ llvm::SmallVectorImpl<std::string> &builderArgs) {
+ if (canInferType(op))
+ return;
+
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
name = sanitizeName(name);
builderArgs.push_back(name);
}
+}
+
+/// Populates `builderArgs` with the Python-compatible names of builder function
+/// arguments using intermixed attributes and operands in the same order as they
+/// appear in the `arguments` field of the op definition. Additionally,
+/// `operandNames` is populated with names of operands in their order of
+/// appearance.
+static void
+populateBuilderArgs(const Operator &op,
+ llvm::SmallVectorImpl<std::string> &builderArgs,
+ llvm::SmallVectorImpl<std::string> &operandNames,
+ llvm::SmallVectorImpl<std::string> &successorArgNames) {
+
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
}
}
+/// Python code template for deriving the operation result types from its
+/// attribute:
+/// - {0} is the name of the attribute from which to derive the types.
+constexpr const char *deriveTypeFromAttrTemplate =
+ R"PY(_ods_result_type_source_attr = attributes["{0}"]
+_ods_derived_result_type = (
+ _ods_ir.TypeAttr(_ods_result_type_source_attr).value
+ if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+ _ods_result_type_source_attr.type))PY";
+
+/// Python code template appending {0} type {1} times to the results list.
+constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+
+/// Python code template for inferring the operation results using the
+/// corresponding interface:
+/// - {0} is the name of the class for which the types are inferred.
+constexpr const char *inferTypeInterfaceTemplate =
+ R"PY(_ods_context = _ods_get_default_loc_context(loc)
+results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+ operands=operands,
+ attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+ context=_ods_context,
+ loc=loc)
+)PY";
+
+/// Appends the given multiline string as individual strings into
+/// `builderLines`.
+static void appendLineByLine(StringRef string,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+
+ std::pair<StringRef, StringRef> split = std::make_pair(string, string);
+ do {
+ split = split.second.split('\n');
+ builderLines.push_back(split.first.str());
+ } while (!split.second.empty());
+}
+
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+ if (hasSameArgumentAndResultTypes(op)) {
+ builderLines.push_back(llvm::formatv(
+ appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
+ return;
+ }
+
+ if (hasFirstAttrDerivedResultTypes(op)) {
+ const NamedAttribute &firstAttr = op.getAttribute(0);
+ assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
+ "from which the type is derived");
+ appendLineByLine(
+ llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
+ builderLines);
+ builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
+ "_ods_derived_result_type",
+ op.getNumResults()));
+ return;
+ }
+
+ if (hasInferTypeInterface(op)) {
+ appendLineByLine(
+ llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
+ builderLines);
+ return;
+ }
+
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
+ populateBuilderArgsResults(op, builderArgs);
+ size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
- populateBuilderLinesResult(
- op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
- builderLines);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(
- op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
+ op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
+ builderLines);
+ populateBuilderLinesResult(
+ op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);