[mlir] Rework subclass construction in PybindAdaptors.h
authorAlex Zinenko <zinenko@google.com>
Wed, 19 Jan 2022 11:21:42 +0000 (12:21 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 19 Jan 2022 17:09:05 +0000 (18:09 +0100)
The constructor function was being defined without indicating its "__init__"
name, which made it interpret it as a regular fuction rather than a
constructor. When overload resolution failed, Pybind would attempt to print the
arguments actually passed to the function, including "self", which is not
initialized since the constructor couldn't be called. This would result in
"__repr__" being called with "self" referencing an uninitialized MLIR C API
object, which in turn would cause undefined behavior when attempting to print
in C++. Even if the correct name is provided, the mechanism used by
PybindAdaptors.h to bind constructors directly as "__init__" functions taking
"self" is deprecated by Pybind. The new mechanism does not seem to have access
to a fully-constructed "self" object (i.e., the constructor in C++ takes a
`pybind11::detail::value_and_holder` that cannot be forwarded back to Python).

Instead, redefine "__new__" to perform the required checks (there are no
additional initialization needed for attributes and types as they are all
wrappers around a C++ pointer). "__new__" can call its equivalent on a
superclass without needing "self".

Bump pybind11 dependency to 3.8.0, which is the first version that allows one
to redefine "__new__".

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D117646

12 files changed:
mlir/cmake/modules/MLIRDetectPythonEnv.cmake
mlir/include/mlir/Bindings/Python/PybindAdaptors.h
mlir/python/mlir/dialects/python_test.py
mlir/python/requirements.txt
mlir/test/python/CMakeLists.txt
mlir/test/python/dialects/python_test.py
mlir/test/python/lib/PythonTestCAPI.cpp
mlir/test/python/lib/PythonTestCAPI.h
mlir/test/python/lib/PythonTestDialect.cpp
mlir/test/python/lib/PythonTestDialect.h
mlir/test/python/lib/PythonTestModule.cpp
mlir/test/python/python_test_ops.td

index 4739fdd..9b36406 100644 (file)
@@ -32,7 +32,7 @@ macro(mlir_configure_python_dev_packages)
   message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
   message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}")
   mlir_detect_pybind11_install()
-  find_package(pybind11 2.6 CONFIG REQUIRED)
+  find_package(pybind11 2.8 CONFIG REQUIRED)
   message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
   message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
                  "suffix = '${PYTHON_MODULE_SUFFIX}', "
index 0340e9c..73cc7e4 100644 (file)
@@ -314,31 +314,34 @@ public:
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
-                          IsAFunctionTy isaFunction,
-                          const py::object &superClass)
-      : pure_subclass(scope, typeClassName, superClass) {
-    // Casting constructor. Note that defining an __init__ method is special
-    // and not yet generalized on pure_subclass (it requires a somewhat
-    // different cpp_function and other requirements on chaining to super
-    // __init__ make it more awkward to do generally).
+                          IsAFunctionTy isaFunction, const py::object &superCls)
+      : pure_subclass(scope, typeClassName, superCls) {
+    // Casting constructor. Note that it hard, if not impossible, to properly
+    // call chain to parent `__init__` in pybind11 due to its special handling
+    // for init functions that don't have a fully constructed self-reference,
+    // which makes it impossible to forward it to `__init__` of a superclass.
+    // Instead, provide a custom `__new__` and call that of a superclass, which
+    // eventually calls `__init__` of the superclass. Since attribute subclasses
+    // have no additional members, we can just return the instance thus created
+    // without amending it.
     std::string captureTypeName(
         typeClassName); // As string in case if typeClassName is not static.
-    py::cpp_function initCf(
-        [superClass, isaFunction, captureTypeName](py::object self,
-                                                   py::object otherType) {
-          MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherType);
+    py::cpp_function newCf(
+        [superCls, isaFunction, captureTypeName](py::object cls,
+                                                 py::object otherAttribute) {
+          MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
           if (!isaFunction(rawAttribute)) {
-            auto origRepr = py::repr(otherType).cast<std::string>();
+            auto origRepr = py::repr(otherAttribute).cast<std::string>();
             throw std::invalid_argument(
                 (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
                  " (from " + origRepr + ")")
                     .str());
           }
-          superClass.attr("__init__")(self, otherType);
+          py::object self = superCls.attr("__new__")(cls, otherAttribute);
+          return self;
         },
-        py::arg("cast_from_type"), py::is_method(py::none()),
-        "Casts the passed type to this specific sub-type.");
-    thisClass.attr("__init__") = initCf;
+        py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
+    thisClass.attr("__new__") = newCf;
 
     // 'isinstance' method.
     def_staticmethod(
@@ -366,17 +369,21 @@ public:
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_type_subclass(py::handle scope, const char *typeClassName,
-                     IsAFunctionTy isaFunction, const py::object &superClass)
-      : pure_subclass(scope, typeClassName, superClass) {
-    // Casting constructor. Note that defining an __init__ method is special
-    // and not yet generalized on pure_subclass (it requires a somewhat
-    // different cpp_function and other requirements on chaining to super
-    // __init__ make it more awkward to do generally).
+                     IsAFunctionTy isaFunction, const py::object &superCls)
+      : pure_subclass(scope, typeClassName, superCls) {
+    // Casting constructor. Note that it hard, if not impossible, to properly
+    // call chain to parent `__init__` in pybind11 due to its special handling
+    // for init functions that don't have a fully constructed self-reference,
+    // which makes it impossible to forward it to `__init__` of a superclass.
+    // Instead, provide a custom `__new__` and call that of a superclass, which
+    // eventually calls `__init__` of the superclass. Since attribute subclasses
+    // have no additional members, we can just return the instance thus created
+    // without amending it.
     std::string captureTypeName(
         typeClassName); // As string in case if typeClassName is not static.
-    py::cpp_function initCf(
-        [superClass, isaFunction, captureTypeName](py::object self,
-                                                   py::object otherType) {
+    py::cpp_function newCf(
+        [superCls, isaFunction, captureTypeName](py::object cls,
+                                                 py::object otherType) {
           MlirType rawType = py::cast<MlirType>(otherType);
           if (!isaFunction(rawType)) {
             auto origRepr = py::repr(otherType).cast<std::string>();
@@ -385,11 +392,11 @@ public:
                                          origRepr + ")")
                                             .str());
           }
-          superClass.attr("__init__")(self, otherType);
+          py::object self = superCls.attr("__new__")(cls, otherType);
+          return self;
         },
-        py::arg("cast_from_type"), py::is_method(py::none()),
-        "Casts the passed type to this specific sub-type.");
-    thisClass.attr("__init__") = initCf;
+        py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
+    thisClass.attr("__new__") = newCf;
 
     // 'isinstance' method.
     def_staticmethod(
index 82c01d5..9f560c2 100644 (file)
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType
 
 def register_python_test_dialect(context, load=True):
   from .._mlir_libs import _mlirPythonTest
index f76dcf6..0cc86af 100644 (file)
@@ -1,4 +1,3 @@
 numpy
-# Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136
-pybind11>=2.6.0,!=2.7.0
+pybind11>=2.8.0
 PyYAML
index c8cb474..1db957c 100644 (file)
@@ -3,6 +3,10 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
 mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
 mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
 mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(lib/PythonTestTypes.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRPythonTestIncGen)
 
 add_subdirectory(lib)
index f9da91f..9c09657 100644 (file)
@@ -225,3 +225,62 @@ def testOptionalOperandOp():
       op2 = test.OptionalOperandOp(op1)
       # CHECK: op2.input is None: False
       print(f"op2.input is None: {op2.input is None}")
+
+
+# CHECK-LABEL: TEST: testCustomAttribute
+@run
+def testCustomAttribute():
+  with Context() as ctx:
+    test.register_python_test_dialect(ctx)
+    a = test.TestAttr.get()
+    # CHECK: #python_test.test_attr
+    print(a)
+
+    # The following cast must not assert.
+    b = test.TestAttr(a)
+
+    unit = UnitAttr.get()
+    try:
+      test.TestAttr(unit)
+    except ValueError as e:
+      assert "Cannot cast attribute to TestAttr" in str(e)
+    else:
+      raise
+
+    # The following must trigger a TypeError from pybind (therefore, not
+    # checking its message) and must not crash.
+    try:
+      test.TestAttr(42, 56)
+    except TypeError:
+      pass
+    else:
+      raise
+
+
+@run
+def testCustomType():
+  with Context() as ctx:
+    test.register_python_test_dialect(ctx)
+    a = test.TestType.get()
+    # CHECK: !python_test.test_type
+    print(a)
+
+    # The following cast must not assert.
+    b = test.TestType(a)
+
+    i8 = IntegerType.get_signless(8)
+    try:
+      test.TestType(i8)
+    except ValueError as e:
+      assert "Cannot cast type to TestType" in str(e)
+    else:
+      raise
+
+    # The following must trigger a TypeError from pybind (therefore, not
+    # checking its message) and must not crash.
+    try:
+      test.TestType(42, 56)
+    except TypeError:
+      pass
+    else:
+      raise
index 474476e..e52588a 100644 (file)
@@ -9,6 +9,23 @@
 #include "PythonTestCAPI.h"
 #include "PythonTestDialect.h"
 #include "mlir/CAPI/Registration.h"
+#include "mlir/CAPI/Wrap.h"
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
                                       python_test::PythonTestDialect)
+
+bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) {
+  return unwrap(attr).isa<python_test::TestAttrAttr>();
+}
+
+MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
+  return wrap(python_test::TestAttrAttr::get(unwrap(context)));
+}
+
+bool mlirTypeIsAPythonTestTestType(MlirType type) {
+  return unwrap(type).isa<python_test::TestTypeType>();
+}
+
+MlirType mlirPythonTestTestTypeGet(MlirContext context) {
+  return wrap(python_test::TestTypeType::get(unwrap(context)));
+}
index 627ce3f..dd49102 100644 (file)
@@ -17,6 +17,16 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test);
 
+MLIR_CAPI_EXPORTED bool
+mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirPythonTestTestAttributeGet(MlirContext context);
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
+
 #ifdef __cplusplus
 }
 #endif
index b70c033..a0ff315 100644 (file)
@@ -9,9 +9,16 @@
 #include "PythonTestDialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #include "PythonTestDialect.cpp.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "PythonTestAttributes.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "PythonTestTypes.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "PythonTestOps.cpp.inc"
 
@@ -21,5 +28,14 @@ void PythonTestDialect::initialize() {
 #define GET_OP_LIST
 #include "PythonTestOps.cpp.inc"
       >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "PythonTestAttributes.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "PythonTestTypes.cpp.inc"
+      >();
 }
+
 } // namespace python_test
index e25d00c..e91cba1 100644 (file)
 #define GET_OP_CLASSES
 #include "PythonTestOps.h.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "PythonTestAttributes.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "PythonTestTypes.h.inc"
+
 #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
index 4232a86..6fb9b24 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
+using namespace mlir::python::adaptors;
 
 PYBIND11_MODULE(_mlirPythonTest, m) {
   m.def(
@@ -23,4 +24,20 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
         }
       },
       py::arg("context"), py::arg("load") = true);
+
+  mlir_attribute_subclass(m, "TestAttr",
+                          mlirAttributeIsAPythonTestTestAttribute)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirContext ctx) {
+            return cls(mlirPythonTestTestAttributeGet(ctx));
+          },
+          py::arg("cls"), py::arg("context") = py::none());
+  mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirContext ctx) {
+            return cls(mlirPythonTestTestTypeGet(ctx));
+          },
+          py::arg("cls"), py::arg("context") = py::none());
 }
index 6ee71db..a274ffa 100644 (file)
@@ -17,9 +17,36 @@ def Python_Test_Dialect : Dialect {
   let name = "python_test";
   let cppNamespace = "python_test";
 }
+
+class TestType<string name, string typeMnemonic>
+    : TypeDef<Python_Test_Dialect, name> {
+  let mnemonic = typeMnemonic;
+}
+
+class TestAttr<string name, string attrMnemonic>
+    : AttrDef<Python_Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class TestOp<string mnemonic, list<OpTrait> traits = []>
     : Op<Python_Test_Dialect, mnemonic, traits>;
 
+//===----------------------------------------------------------------------===//
+// Type definitions.
+//===----------------------------------------------------------------------===//
+
+def TestType : TestType<"TestType", "test_type">;
+
+//===----------------------------------------------------------------------===//
+// Attribute definitions.
+//===----------------------------------------------------------------------===//
+
+def TestAttr : TestAttr<"TestAttr", "test_attr">;
+
+//===----------------------------------------------------------------------===//
+// Operation definitions.
+//===----------------------------------------------------------------------===//
+
 def AttributedOp : TestOp<"attributed_op"> {
   let arguments = (ins I32Attr:$mandatory_i32,
                    OptionalAttr<I32Attr>:$optional_i32,