From 2995d29bb42729043e707161bd7a7e4e428afbcf Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Oct 2021 17:19:06 +0200 Subject: [PATCH] [mlir][python] Infer result types in generated constructors whenever possible In several cases, operation result types can be unambiguously inferred from operands and attributes at operation construction time. Stop requiring the user to provide these types as arguments in the ODS-generated constructors in Python bindings. In particular, handle the SameOperandAndResultTypes and FirstAttrDerivedResultType traits as well as InferTypeOpInterface using the recently added interface support. This is a significant usability improvement for IR construction, similar to what C++ ODS provides. Depends On D111656 Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111811 --- mlir/include/mlir/TableGen/Operator.h | 3 + mlir/lib/Bindings/Python/IRModule.h | 132 +++++++++++---------- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 34 +++--- mlir/test/mlir-tblgen/op-python-bindings.td | 83 +++++++++++-- mlir/test/python/dialects/math.py | 2 +- mlir/test/python/dialects/python_test.py | 38 +++++- mlir/test/python/dialects/shape.py | 5 +- mlir/test/python/ir/dialects.py | 4 +- mlir/test/python/python_test_ops.td | 24 ++++ mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 130 ++++++++++++++++++-- 10 files changed, 344 insertions(+), 111 deletions(-) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index e2bd3bb..e8cc3d7 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -138,6 +138,9 @@ public: // Op attribute accessors. NamedAttribute &getAttribute(int index) { return attributes[index]; } + const NamedAttribute &getAttribute(int index) const { + return attributes[index]; + } // Op operand iterators. value_iterator operand_begin(); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 59285c0..dac9486 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -601,6 +601,71 @@ private: llvm::Optional refOperation; PyBlock block; }; +/// Wrapper around the generic MlirType. +/// The lifetime of a type is bound by the PyContext that created it. +class PyType : public BaseContextObject { +public: + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} + bool operator==(const PyType &other); + operator MlirType() const { return type; } + MlirType get() const { return type; } + + /// Gets a capsule wrapping the void* within the MlirType. + pybind11::object getCapsule(); + + /// Creates a PyType from the MlirType wrapped by a capsule. + /// Note that PyType instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirType + /// is taken by calling this function. + static PyType createFromCapsule(pybind11::object capsule); + +private: + MlirType type; +}; + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. @@ -685,71 +750,8 @@ public: cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -/// Wrapper around the generic MlirType. -/// The lifetime of a type is bound by the PyContext that created it. -class PyType : public BaseContextObject { -public: - PyType(PyMlirContextRef contextRef, MlirType type) - : BaseContextObject(std::move(contextRef)), type(type) {} - bool operator==(const PyType &other); - operator MlirType() const { return type; } - MlirType get() const { return type; } - - /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); - - /// Creates a PyType from the MlirType wrapped by a capsule. - /// Note that PyType instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirType - /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); - -private: - MlirType type; -}; - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = pybind11::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); + cls.def_property_readonly("type", [](PyAttribute &attr) { + return PyType(attr.getContext(), mlirAttributeGetType(attr)); }); DerivedTy::bindDerived(cls); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 1215d03..2ece9eb 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -221,7 +221,7 @@ class _BodyBuilder: elif expr.scalar_index: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) - return linalg.IndexOp(IndexType.get(), dim_attr).result + return linalg.IndexOp(dim_attr).result elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -303,61 +303,61 @@ class _BodyBuilder: def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs.type, lhs, rhs).result + return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs.type, lhs, rhs).result + return arith.AddIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") def _eval_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.ExpOp(x.type, x).result + return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _eval_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.LogOp(x.type, x).result + return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.SubFOp(lhs.type, lhs, rhs).result + return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.SubIOp(lhs.type, lhs, rhs).result + return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MulFOp(lhs.type, lhs, rhs).result + return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MulIOp(lhs.type, lhs, rhs).result + return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxSIOp(lhs.type, lhs, rhs).result + return std.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxUIOp(lhs.type, lhs, rhs).result + return std.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinSIOp(lhs.type, lhs, rhs).result + return std.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinUIOp(lhs.type, lhs, rhs).result + return std.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index b5c824e..d6dc564 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -1,6 +1,7 @@ // RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Bindings/Python/Attributes.td" // CHECK: @_ods_cext.register_dialect @@ -176,6 +177,27 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } +// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" +def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { + // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_result_type_source_attr = attributes["type"] + // CHECK: _ods_derived_result_type = ( + // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value + // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + // CHECK: _ods_result_type_source_attr.type) + // CHECK: results.extend([_ods_derived_result_type] * 2) + let arguments = (ins TypeAttr:$type); + let results = (outs AnyType:$res, AnyType); +} + +// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" +def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { + // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + let arguments = (ins TypeAttr:$type); + let results = (outs AnyType:$res, Variadic); +} // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): @@ -191,6 +213,35 @@ def EmptyOp : TestOp<"empty">; // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + +// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" +def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { + // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_context = _ods_get_default_loc_context(loc) + // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( + // CHECK: operands=operands, + // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + // CHECK: context=_ods_context, + // CHECK: loc=loc) + let results = (outs I32:$i32, F32:$f32); +} + +// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" +def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { + // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_context = _ods_get_default_loc_context(loc) + // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( + // CHECK: operands=operands, + // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + // CHECK: context=_ods_context, + // CHECK: loc=loc) + let results = (outs AnyType, AnyType, AnyType); +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" @@ -200,12 +251,12 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.append(i32) - // CHECK: results.append(_gen_res_1) - // CHECK: results.append(i64) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) // CHECK: operands.append(_get_op_result_or_value(f32)) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) + // CHECK: results.append(i32) + // CHECK: results.append(_gen_res_1) + // CHECK: results.append(i64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -223,7 +274,7 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: @builtins.property // CHECK: def i64(self): // CHECK: return self.operation.results[2] - let results = (outs I32:$i32, F32, I64:$i64); + let results = (outs I32:$i32, AnyFloat, I64:$i64); } // CHECK: @_ods_cext.register_operation(_Dialect) @@ -305,6 +356,24 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: return self.operation.operands[0] let arguments = (ins AnyType:$in); } +// CHECK-LABEL: OPERATION_NAME = "test.same_results" +def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: operands.append + // CHECK: results.extend([operands[0].type] * 1) + let arguments = (ins AnyType:$in1, AnyType:$in2); + let results = (outs AnyType:$res); +} + +// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" +def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { + // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None): + let arguments = (ins AnyType:$in1, AnyType:$in2); + let results = (outs Variadic:$res); +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView): @@ -361,10 +430,10 @@ def SimpleOp : TestOp<"simple"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.append(i64) - // CHECK: results.append(f64) // CHECK: operands.append(_get_op_result_or_value(i32)) // CHECK: operands.append(_get_op_result_or_value(f32)) + // CHECK: results.append(i64) + // CHECK: results.append(f64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -386,7 +455,7 @@ def SimpleOp : TestOp<"simple"> { // CHECK: @builtins.property // CHECK: def f64(self): // CHECK: return self.operation.results[1] - let results = (outs I64:$i64, F64:$f64); + let results = (outs I64:$i64, AnyFloat:$f64); } // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py index 73246e2..e3f8829 100644 --- a/mlir/test/python/dialects/math.py +++ b/mlir/test/python/dialects/math.py @@ -16,7 +16,7 @@ def testMathOps(): with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(F32Type.get()) def emit_sqrt(arg): - return mlir_math.SqrtOp(F32Type.get(), arg) + return mlir_math.SqrtOp(arg) # CHECK-LABEL: func @emit_sqrt( # CHECK-SAME: %[[ARG:.*]]: f32) { diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 3d0600e..2267b59 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -137,8 +137,7 @@ def inferReturnTypes(): test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): - op = test.InferResultsOp( - IntegerType.get_signless(32), IntegerType.get_signless(64)) + op = test.InferResultsOp() dummy = test.DummyOp() # CHECK: [Type(i32), Type(i64)] @@ -173,3 +172,38 @@ def inferReturnTypes(): pass else: assert False, "not expected dummy op class to implement the interface" + + +# CHECK-LABEL: TEST: resultTypesDefinedByTraits +@run +def resultTypesDefinedByTraits(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + inferred = test.InferResultsOp() + same = test.SameOperandAndResultTypeOp([inferred.results[0]]) + # CHECK-COUNT-2: i32 + print(same.one.type) + print(same.two.type) + + first_type_attr = test.FirstAttrDeriveTypeAttrOp( + inferred.results[1], TypeAttr.get(IndexType.get())) + # CHECK-COUNT-2: index + print(first_type_attr.one.type) + print(first_type_attr.two.type) + + first_attr = test.FirstAttrDeriveAttrOp( + FloatAttr.get(F32Type.get(), 3.14)) + # CHECK-COUNT-3: f32 + print(first_attr.one.type) + print(first_attr.two.type) + print(first_attr.three.type) + + implied = test.InferResultsImpliedOp() + # CHECK: i32 + print(implied.integer.type) + # CHECK: f64 + print(implied.flt.type) + # CHECK: index + print(implied.index.type) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py index 1772026..7c1c5d6 100644 --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -18,15 +18,12 @@ def testConstShape(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() - indexT = IndexType.get() with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( RankedTensorType.get((12, -1), f32)) def const_shape_tensor(arg): - return shape.ConstShapeOp(RankedTensorType.get((2,), indexT), - DenseElementsAttr.get(np.array([10, 20]))) + return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20]))) # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) # CHECK: shape.const_shape [10, 20] : tensor<2xindex> print(module) - diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index 342d93b..05e9222 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -82,11 +82,11 @@ def testCustomOpView(): # Create via dialects context collection. input1 = createInput() input2 = createInput() - op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2) + op1 = ctx.dialects.arith.AddFOp(input1, input2) # Create via an import from mlir.dialects.arith import AddFOp - AddFOp(input1.type, input1, op1.result) + AddFOp(input1, op1.result) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td index 74c90a3..0f947e7e 100644 --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -52,4 +52,28 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> { }]; } +// If all result types are buildable, the InferTypeOpInterface is implied and is +// autogenerated by C++ ODS. +def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> { + let results = (outs I32:$integer, F64:$flt, Index:$index); +} + +def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyType:$input, TypeAttr:$type); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyAttr:$iattr); + let results = (outs AnyType:$one, AnyType:$two, AnyType:$three); +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index f55f050..d9ce296 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -541,16 +541,42 @@ constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; /// {1} is the value to add. constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; -/// Populates `builderArgs` with the Python-compatible names of builder function -/// arguments, first the results, then the intermixed attributes and operands in -/// the same order as they appear in the `arguments` field of the op definition. -/// Additionally, `operandNames` is populated with names of operands in their -/// order of appearance. +/// Returns true if the SameArgumentAndResultTypes trait can be used to infer +/// result types of the given operation. +static bool hasSameArgumentAndResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && + op.getNumVariableLengthResults() == 0; +} + +/// Returns true if the FirstAttrDerivedResultType trait can be used to infer +/// result types of the given operation. +static bool hasFirstAttrDerivedResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && + op.getNumVariableLengthResults() == 0; +} + +/// Returns true if the InferTypeOpInterface can be used to infer result types +/// of the given operation. +static bool hasInferTypeInterface(const Operator &op) { + return op.getTrait("::mlir::InferTypeOpInterface::Trait") && + op.getNumRegions() == 0; +} + +/// Returns true if there is a trait or interface that can be used to infer +/// result types of the given operation. +static bool canInferType(const Operator &op) { + return hasSameArgumentAndResultTypes(op) || + hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); +} + +/// Populates `builderArgs` with result names if the builder is expected to +/// accept them as arguments. static void -populateBuilderArgs(const Operator &op, - llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames, - llvm::SmallVectorImpl &successorArgNames) { +populateBuilderArgsResults(const Operator &op, + llvm::SmallVectorImpl &builderArgs) { + if (canInferType(op)) + return; + for (int i = 0, e = op.getNumResults(); i < e; ++i) { std::string name = op.getResultName(i).str(); if (name.empty()) { @@ -565,6 +591,19 @@ populateBuilderArgs(const Operator &op, name = sanitizeName(name); builderArgs.push_back(name); } +} + +/// Populates `builderArgs` with the Python-compatible names of builder function +/// arguments using intermixed attributes and operands in the same order as they +/// appear in the `arguments` field of the op definition. Additionally, +/// `operandNames` is populated with names of operands in their order of +/// appearance. +static void +populateBuilderArgs(const Operator &op, + llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &operandNames, + llvm::SmallVectorImpl &successorArgNames) { + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) @@ -670,6 +709,43 @@ populateBuilderLinesOperand(const Operator &op, } } +/// Python code template for deriving the operation result types from its +/// attribute: +/// - {0} is the name of the attribute from which to derive the types. +constexpr const char *deriveTypeFromAttrTemplate = + R"PY(_ods_result_type_source_attr = attributes["{0}"] +_ods_derived_result_type = ( + _ods_ir.TypeAttr(_ods_result_type_source_attr).value + if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + _ods_result_type_source_attr.type))PY"; + +/// Python code template appending {0} type {1} times to the results list. +constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; + +/// Python code template for inferring the operation results using the +/// corresponding interface: +/// - {0} is the name of the class for which the types are inferred. +constexpr const char *inferTypeInterfaceTemplate = + R"PY(_ods_context = _ods_get_default_loc_context(loc) +results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( + operands=operands, + attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + context=_ods_context, + loc=loc) +)PY"; + +/// Appends the given multiline string as individual strings into +/// `builderLines`. +static void appendLineByLine(StringRef string, + llvm::SmallVectorImpl &builderLines) { + + std::pair split = std::make_pair(string, string); + do { + split = split.second.split('\n'); + builderLines.push_back(split.first.str()); + } while (!split.second.empty()); +} + /// Populates `builderLines` with additional lines that are required in the /// builder to set up op results. static void @@ -678,6 +754,32 @@ populateBuilderLinesResult(const Operator &op, llvm::SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; + if (hasSameArgumentAndResultTypes(op)) { + builderLines.push_back(llvm::formatv( + appendSameResultsTemplate, "operands[0].type", op.getNumResults())); + return; + } + + if (hasFirstAttrDerivedResultTypes(op)) { + const NamedAttribute &firstAttr = op.getAttribute(0); + assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " + "from which the type is derived"); + appendLineByLine( + llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), + builderLines); + builderLines.push_back(llvm::formatv(appendSameResultsTemplate, + "_ods_derived_result_type", + op.getNumResults())); + return; + } + + if (hasInferTypeInterface(op)) { + appendLineByLine( + llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), + builderLines); + return; + } + // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { const NamedTypeConstraint &element = op.getResult(i); @@ -741,14 +843,16 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { llvm::SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); + populateBuilderArgsResults(op, builderArgs); + size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); - populateBuilderLinesResult( - op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), - builderLines); populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesAttr( - op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), + op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), + builderLines); + populateBuilderLinesResult( + op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); -- 2.7.4