[mlir][python] Infer result types in generated constructors whenever possible
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Oct 2021 15:19:06 +0000 (17:19 +0200)
committerAlex Zinenko <zinenko@google.com>
Mon, 25 Oct 2021 10:50:44 +0000 (12:50 +0200)
In several cases, operation result types can be unambiguously inferred from
operands and attributes at operation construction time. Stop requiring the user
to provide these types as arguments in the ODS-generated constructors in Python
bindings. In particular, handle the SameOperandAndResultTypes and
FirstAttrDerivedResultType traits as well as InferTypeOpInterface using the
recently added interface support. This is a significant usability improvement
for IR construction, similar to what C++ ODS provides.

Depends On D111656

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111811

mlir/include/mlir/TableGen/Operator.h
mlir/lib/Bindings/Python/IRModule.h
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/math.py
mlir/test/python/dialects/python_test.py
mlir/test/python/dialects/shape.py
mlir/test/python/ir/dialects.py
mlir/test/python/python_test_ops.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

index e2bd3bb..e8cc3d7 100644 (file)
@@ -138,6 +138,9 @@ public:
 
   // Op attribute accessors.
   NamedAttribute &getAttribute(int index) { return attributes[index]; }
+  const NamedAttribute &getAttribute(int index) const {
+    return attributes[index];
+  }
 
   // Op operand iterators.
   value_iterator operand_begin();
index 59285c0..dac9486 100644 (file)
@@ -601,6 +601,71 @@ private:
   llvm::Optional<PyOperationRef> refOperation;
   PyBlock block;
 };
+/// Wrapper around the generic MlirType.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyType : public BaseContextObject {
+public:
+  PyType(PyMlirContextRef contextRef, MlirType type)
+      : BaseContextObject(std::move(contextRef)), type(type) {}
+  bool operator==(const PyType &other);
+  operator MlirType() const { return type; }
+  MlirType get() const { return type; }
+
+  /// Gets a capsule wrapping the void* within the MlirType.
+  pybind11::object getCapsule();
+
+  /// Creates a PyType from the MlirType wrapped by a capsule.
+  /// Note that PyType instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirType
+  /// is taken by calling this function.
+  static PyType createFromCapsule(pybind11::object capsule);
+
+private:
+  MlirType type;
+};
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirType);
+
+  PyConcreteType() = default;
+  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+      : BaseTy(std::move(contextRef), t) {}
+  PyConcreteType(PyType &orig)
+      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+  static MlirType castFrom(PyType &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(pybind11::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
+    cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
+    cls.def_static("isinstance", [](PyType &otherType) -> bool {
+      return DerivedTy::isaFunction(otherType);
+    });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
 
 /// Wrapper around the generic MlirAttribute.
 /// The lifetime of a type is bound by the PyContext that created it.
@@ -685,71 +750,8 @@ public:
     cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
       return DerivedTy::isaFunction(otherAttr);
     });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
-/// Wrapper around the generic MlirType.
-/// The lifetime of a type is bound by the PyContext that created it.
-class PyType : public BaseContextObject {
-public:
-  PyType(PyMlirContextRef contextRef, MlirType type)
-      : BaseContextObject(std::move(contextRef)), type(type) {}
-  bool operator==(const PyType &other);
-  operator MlirType() const { return type; }
-  MlirType get() const { return type; }
-
-  /// Gets a capsule wrapping the void* within the MlirType.
-  pybind11::object getCapsule();
-
-  /// Creates a PyType from the MlirType wrapped by a capsule.
-  /// Note that PyType instances are uniqued, so the returned object
-  /// may be a pre-existing object. Ownership of the underlying MlirType
-  /// is taken by calling this function.
-  static PyType createFromCapsule(pybind11::object capsule);
-
-private:
-  MlirType type;
-};
-
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirType);
-
-  PyConcreteType() = default;
-  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {}
-  PyConcreteType(PyType &orig)
-      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
-  static MlirType castFrom(PyType &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(pybind11::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
-    cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
-    cls.def_static("isinstance", [](PyType &otherType) -> bool {
-      return DerivedTy::isaFunction(otherType);
+    cls.def_property_readonly("type", [](PyAttribute &attr) {
+      return PyType(attr.getContext(), mlirAttributeGetType(attr));
     });
     DerivedTy::bindDerived(cls);
   }
index 1215d03..2ece9eb 100644 (file)
@@ -221,7 +221,7 @@ class _BodyBuilder:
     elif expr.scalar_index:
       dim_attr = IntegerAttr.get(
           IntegerType.get_signless(64), expr.scalar_index.dim)
-      return linalg.IndexOp(IndexType.get(), dim_attr).result
+      return linalg.IndexOp(dim_attr).result
     elif expr.scalar_apply:
       try:
         fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
@@ -303,61 +303,61 @@ class _BodyBuilder:
 
   def _eval_add(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return arith.AddFOp(lhs.type, lhs, rhs).result
+      return arith.AddFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.AddIOp(lhs.type, lhs, rhs).result
+      return arith.AddIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'add' operand: {lhs}")
 
   def _eval_exp(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
-      return math.ExpOp(x.type, x).result
+      return math.ExpOp(x).result
     raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
   def _eval_log(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
-      return math.LogOp(x.type, x).result
+      return math.LogOp(x).result
     raise NotImplementedError("Unsupported 'log' operand: {x}")
 
   def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return arith.SubFOp(lhs.type, lhs, rhs).result
+      return arith.SubFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.SubIOp(lhs.type, lhs, rhs).result
+      return arith.SubIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
 
   def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return arith.MulFOp(lhs.type, lhs, rhs).result
+      return arith.MulFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MulIOp(lhs.type, lhs, rhs).result
+      return arith.MulIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
 
   def _eval_max(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return std.MaxFOp(lhs.type, lhs, rhs).result
+      return std.MaxFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return std.MaxSIOp(lhs.type, lhs, rhs).result
+      return std.MaxSIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'max' operand: {lhs}")
 
   def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return std.MaxFOp(lhs.type, lhs, rhs).result
+      return std.MaxFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return std.MaxUIOp(lhs.type, lhs, rhs).result
+      return std.MaxUIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
 
   def _eval_min(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return std.MinFOp(lhs.type, lhs, rhs).result
+      return std.MinFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return std.MinSIOp(lhs.type, lhs, rhs).result
+      return std.MinSIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'min' operand: {lhs}")
 
   def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      return std.MinFOp(lhs.type, lhs, rhs).result
+      return std.MinFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return std.MinUIOp(lhs.type, lhs, rhs).result
+      return std.MinUIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
 
 
index b5c824e..d6dc564 100644 (file)
@@ -1,6 +1,7 @@
 // RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Bindings/Python/Attributes.td"
 
 // CHECK: @_ods_cext.register_dialect
@@ -176,6 +177,27 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
 }
 
+// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
+def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
+  // CHECK: def __init__(self, type, *, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   _ods_result_type_source_attr = attributes["type"]
+  // CHECK:   _ods_derived_result_type = (
+  // CHECK:       _ods_ir.TypeAttr(_ods_result_type_source_attr).value
+  // CHECK:       if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+  // CHECK:       _ods_result_type_source_attr.type)
+  // CHECK:   results.extend([_ods_derived_result_type] * 2)
+  let arguments = (ins TypeAttr:$type);
+  let results = (outs AnyType:$res, AnyType);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
+def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
+  // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
+  let arguments = (ins TypeAttr:$type);
+  let results = (outs AnyType:$res, Variadic<AnyType>);
+}
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -191,6 +213,35 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
+
+// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
+def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
+  // CHECK:  def __init__(self, *, loc=None, ip=None):
+  // CHECK:    operands = []
+  // CHECK:    results = []
+  // CHECK:    _ods_context = _ods_get_default_loc_context(loc)
+  // CHECK:    results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
+  // CHECK:        operands=operands,
+  // CHECK:        attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+  // CHECK:        context=_ods_context,
+  // CHECK:        loc=loc)
+  let results = (outs I32:$i32, F32:$f32);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
+def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
+  // CHECK:  def __init__(self, *, loc=None, ip=None):
+  // CHECK:    operands = []
+  // CHECK:    results = []
+  // CHECK:    _ods_context = _ods_get_default_loc_context(loc)
+  // CHECK:    results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
+  // CHECK:        operands=operands,
+  // CHECK:        attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+  // CHECK:        context=_ods_context,
+  // CHECK:        loc=loc)
+  let results = (outs AnyType, AnyType, AnyType);
+}
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -200,12 +251,12 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   results.append(i32)
-  // CHECK:   results.append(_gen_res_1)
-  // CHECK:   results.append(i64)
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
   // CHECK:   operands.append(_get_op_result_or_value(f32))
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
+  // CHECK:   results.append(i32)
+  // CHECK:   results.append(_gen_res_1)
+  // CHECK:   results.append(i64)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -223,7 +274,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK: @builtins.property
   // CHECK: def i64(self):
   // CHECK:   return self.operation.results[2]
-  let results = (outs I32:$i32, F32, I64:$i64);
+  let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -305,6 +356,24 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
+// CHECK-LABEL: OPERATION_NAME = "test.same_results"
+def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
+  // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
+  // CHECK: operands = []
+  // CHECK: results = []
+  // CHECK: operands.append
+  // CHECK: results.extend([operands[0].type] * 1)
+  let arguments = (ins AnyType:$in1, AnyType:$in2);
+  let results = (outs AnyType:$res);
+}
+
+// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
+def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
+  // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
+  let arguments = (ins AnyType:$in1, AnyType:$in2);
+  let results = (outs Variadic<AnyType>:$res);
+}
+
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -361,10 +430,10 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   results.append(i64)
-  // CHECK:   results.append(f64)
   // CHECK:   operands.append(_get_op_result_or_value(i32))
   // CHECK:   operands.append(_get_op_result_or_value(f32))
+  // CHECK:   results.append(i64)
+  // CHECK:   results.append(f64)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -386,7 +455,7 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK: @builtins.property
   // CHECK: def f64(self):
   // CHECK:   return self.operation.results[1]
-  let results = (outs I64:$i64, F64:$f64);
+  let results = (outs I64:$i64, AnyFloat:$f64);
 }
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
index 73246e2..e3f8829 100644 (file)
@@ -16,7 +16,7 @@ def testMathOps():
     with InsertionPoint(module.body):
       @builtin.FuncOp.from_py_func(F32Type.get())
       def emit_sqrt(arg):
-        return mlir_math.SqrtOp(F32Type.get(), arg)
+        return mlir_math.SqrtOp(arg)
 
     # CHECK-LABEL: func @emit_sqrt(
     # CHECK-SAME:                  %[[ARG:.*]]: f32) {
index 3d0600e..2267b59 100644 (file)
@@ -137,8 +137,7 @@ def inferReturnTypes():
     test.register_python_test_dialect(ctx)
     module = Module.create()
     with InsertionPoint(module.body):
-      op = test.InferResultsOp(
-          IntegerType.get_signless(32), IntegerType.get_signless(64))
+      op = test.InferResultsOp()
       dummy = test.DummyOp()
 
     # CHECK: [Type(i32), Type(i64)]
@@ -173,3 +172,38 @@ def inferReturnTypes():
       pass
     else:
       assert False, "not expected dummy op class to implement the interface"
+
+
+# CHECK-LABEL: TEST: resultTypesDefinedByTraits
+@run
+def resultTypesDefinedByTraits():
+  with Context() as ctx, Location.unknown(ctx):
+    test.register_python_test_dialect(ctx)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      inferred = test.InferResultsOp()
+      same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+      # CHECK-COUNT-2: i32
+      print(same.one.type)
+      print(same.two.type)
+
+      first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+          inferred.results[1], TypeAttr.get(IndexType.get()))
+      # CHECK-COUNT-2: index
+      print(first_type_attr.one.type)
+      print(first_type_attr.two.type)
+
+      first_attr = test.FirstAttrDeriveAttrOp(
+          FloatAttr.get(F32Type.get(), 3.14))
+      # CHECK-COUNT-3: f32
+      print(first_attr.one.type)
+      print(first_attr.two.type)
+      print(first_attr.three.type)
+
+      implied = test.InferResultsImpliedOp()
+      # CHECK: i32
+      print(implied.integer.type)
+      # CHECK: f64
+      print(implied.flt.type)
+      # CHECK: index
+      print(implied.index.type)
index 1772026..7c1c5d6 100644 (file)
@@ -18,15 +18,12 @@ def testConstShape():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f32 = F32Type.get()
-    indexT = IndexType.get()
     with InsertionPoint(module.body):
       @builtin.FuncOp.from_py_func(
           RankedTensorType.get((12, -1), f32))
       def const_shape_tensor(arg):
-        return shape.ConstShapeOp(RankedTensorType.get((2,), indexT),
-                                  DenseElementsAttr.get(np.array([10, 20])))
+        return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
 
     # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
     # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
     print(module)
-
index 342d93b..05e9222 100644 (file)
@@ -82,11 +82,11 @@ def testCustomOpView():
       # Create via dialects context collection.
       input1 = createInput()
       input2 = createInput()
-      op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2)
+      op1 = ctx.dialects.arith.AddFOp(input1, input2)
 
       # Create via an import
       from mlir.dialects.arith import AddFOp
-      AddFOp(input1.type, input1, op1.result)
+      AddFOp(input1, op1.result)
 
   # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
   # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
index 74c90a3..0f947e7 100644 (file)
@@ -52,4 +52,28 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
   }];
 }
 
+// If all result types are buildable, the InferTypeOpInterface is implied and is
+// autogenerated by C++ ODS.
+def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
+  let results = (outs I32:$integer, F64:$flt, Index:$index);
+}
+
+def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op",
+                                        [SameOperandsAndResultType]> {
+  let arguments = (ins Variadic<AnyType>);
+  let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op",
+                               [FirstAttrDerivedResultType]> {
+  let arguments = (ins AnyType:$input, TypeAttr:$type);
+  let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
+                               [FirstAttrDerivedResultType]> {
+  let arguments = (ins AnyAttr:$iattr);
+  let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
+}
+
 #endif // PYTHON_TEST_OPS
index f55f050..d9ce296 100644 (file)
@@ -541,16 +541,42 @@ constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
 ///   {1} is the value to add.
 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
 
-/// Populates `builderArgs` with the Python-compatible names of builder function
-/// arguments, first the results, then the intermixed attributes and operands in
-/// the same order as they appear in the `arguments` field of the op definition.
-/// Additionally, `operandNames` is populated with names of operands in their
-/// order of appearance.
+/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
+/// result types of the given operation.
+static bool hasSameArgumentAndResultTypes(const Operator &op) {
+  return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
+         op.getNumVariableLengthResults() == 0;
+}
+
+/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
+/// result types of the given operation.
+static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
+  return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
+         op.getNumVariableLengthResults() == 0;
+}
+
+/// Returns true if the InferTypeOpInterface can be used to infer result types
+/// of the given operation.
+static bool hasInferTypeInterface(const Operator &op) {
+  return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
+         op.getNumRegions() == 0;
+}
+
+/// Returns true if there is a trait or interface that can be used to infer
+/// result types of the given operation.
+static bool canInferType(const Operator &op) {
+  return hasSameArgumentAndResultTypes(op) ||
+         hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
+}
+
+/// Populates `builderArgs` with result names if the builder is expected to
+/// accept them as arguments.
 static void
-populateBuilderArgs(const Operator &op,
-                    llvm::SmallVectorImpl<std::string> &builderArgs,
-                    llvm::SmallVectorImpl<std::string> &operandNames,
-                    llvm::SmallVectorImpl<std::string> &successorArgNames) {
+populateBuilderArgsResults(const Operator &op,
+                           llvm::SmallVectorImpl<std::string> &builderArgs) {
+  if (canInferType(op))
+    return;
+
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     std::string name = op.getResultName(i).str();
     if (name.empty()) {
@@ -565,6 +591,19 @@ populateBuilderArgs(const Operator &op,
     name = sanitizeName(name);
     builderArgs.push_back(name);
   }
+}
+
+/// Populates `builderArgs` with the Python-compatible names of builder function
+/// arguments using intermixed attributes and operands in the same order as they
+/// appear in the `arguments` field of the op definition. Additionally,
+/// `operandNames` is populated with names of operands in their order of
+/// appearance.
+static void
+populateBuilderArgs(const Operator &op,
+                    llvm::SmallVectorImpl<std::string> &builderArgs,
+                    llvm::SmallVectorImpl<std::string> &operandNames,
+                    llvm::SmallVectorImpl<std::string> &successorArgNames) {
+
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     std::string name = op.getArgName(i).str();
     if (name.empty())
@@ -670,6 +709,43 @@ populateBuilderLinesOperand(const Operator &op,
   }
 }
 
+/// Python code template for deriving the operation result types from its
+/// attribute:
+///   - {0} is the name of the attribute from which to derive the types.
+constexpr const char *deriveTypeFromAttrTemplate =
+    R"PY(_ods_result_type_source_attr = attributes["{0}"]
+_ods_derived_result_type = (
+    _ods_ir.TypeAttr(_ods_result_type_source_attr).value
+    if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+    _ods_result_type_source_attr.type))PY";
+
+/// Python code template appending {0} type {1} times to the results list.
+constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+
+/// Python code template for inferring the operation results using the
+/// corresponding interface:
+///   - {0} is the name of the class for which the types are inferred.
+constexpr const char *inferTypeInterfaceTemplate =
+    R"PY(_ods_context = _ods_get_default_loc_context(loc)
+results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+    operands=operands,
+    attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+    context=_ods_context,
+    loc=loc)
+)PY";
+
+/// Appends the given multiline string as individual strings into
+/// `builderLines`.
+static void appendLineByLine(StringRef string,
+                             llvm::SmallVectorImpl<std::string> &builderLines) {
+
+  std::pair<StringRef, StringRef> split = std::make_pair(string, string);
+  do {
+    split = split.second.split('\n');
+    builderLines.push_back(split.first.str());
+  } while (!split.second.empty());
+}
+
 /// Populates `builderLines` with additional lines that are required in the
 /// builder to set up op results.
 static void
@@ -678,6 +754,32 @@ populateBuilderLinesResult(const Operator &op,
                            llvm::SmallVectorImpl<std::string> &builderLines) {
   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
 
+  if (hasSameArgumentAndResultTypes(op)) {
+    builderLines.push_back(llvm::formatv(
+        appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
+    return;
+  }
+
+  if (hasFirstAttrDerivedResultTypes(op)) {
+    const NamedAttribute &firstAttr = op.getAttribute(0);
+    assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
+                                      "from which the type is derived");
+    appendLineByLine(
+        llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
+        builderLines);
+    builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
+                                         "_ods_derived_result_type",
+                                         op.getNumResults()));
+    return;
+  }
+
+  if (hasInferTypeInterface(op)) {
+    appendLineByLine(
+        llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
+        builderLines);
+    return;
+  }
+
   // For each element, find or generate a name.
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     const NamedTypeConstraint &element = op.getResult(i);
@@ -741,14 +843,16 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   llvm::SmallVector<std::string> successorArgNames;
   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
                       op.getNumNativeAttributes() + op.getNumSuccessors());
+  populateBuilderArgsResults(op, builderArgs);
+  size_t numResultArgs = builderArgs.size();
   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
 
-  populateBuilderLinesResult(
-      op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
-      builderLines);
   populateBuilderLinesOperand(op, operandArgNames, builderLines);
   populateBuilderLinesAttr(
-      op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
+      op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
+      builderLines);
+  populateBuilderLinesResult(
+      op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
       builderLines);
   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
   populateBuilderRegions(op, builderArgs, builderLines);