[mlir] Move PyConcreteType to header. NFC.
authorJohn Demme <john.demme@microsoft.com>
Wed, 28 Apr 2021 23:16:45 +0000 (16:16 -0700)
committerJohn Demme <john.demme@microsoft.com>
Wed, 28 Apr 2021 23:40:56 +0000 (16:40 -0700)
This allows out-of-tree users to derive PyConcreteType to bind custom
types.

The Type version of https://reviews.llvm.org/D101063/new/

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D101496

mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp

index ff3faee..292080d 100644 (file)
@@ -705,6 +705,49 @@ private:
   MlirType type;
 };
 
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirType);
+
+  PyConcreteType() = default;
+  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+      : BaseTy(std::move(contextRef), t) {}
+  PyConcreteType(PyType &orig)
+      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+  static MlirType castFrom(PyType &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(pybind11::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
+    cls.def_static("isinstance", [](PyType &otherType) -> bool {
+      return DerivedTy::isaFunction(otherType);
+    });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
 /// Wrapper around the generic MlirValue.
 /// Values are managed completely by the operation that resulted in their
 /// definition. For op result value, this is the operation that defines the
index 421df4d..b6875c7 100644 (file)
@@ -28,49 +28,6 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
 }
 
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirType);
-
-  PyConcreteType() = default;
-  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {}
-  PyConcreteType(PyType &orig)
-      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
-  static MlirType castFrom(PyType &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
-    cls.def_static("isinstance", [](PyType &otherType) -> bool {
-      return DerivedTy::isaFunction(otherType);
-    });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
 class PyIntegerType : public PyConcreteType<PyIntegerType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;