[mlir] Move PyConcreteAttribute to header. NFC.
authorAlex Zinenko <zinenko@google.com>
Thu, 22 Apr 2021 13:52:01 +0000 (15:52 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 22 Apr 2021 14:11:59 +0000 (16:11 +0200)
This allows out-of-tree users to derive PyConcreteAttribute to bind custom
attributes.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D101063

mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRModule.h

index b5e3c5c..0af762d 100644 (file)
@@ -27,46 +27,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
-/// CRTP base classes for Python attributes that subclass Attribute and should
-/// be castable from it (i.e. via something like StringAttr(attr)).
-/// By default, attribute class hierarchies are one level deep (i.e. a
-/// concrete attribute class extends PyAttribute); however, intermediate
-/// python-visible base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirAttribute);
-
-  PyConcreteAttribute() = default;
-  PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
-      : BaseTy(std::move(contextRef), attr) {}
-  PyConcreteAttribute(PyAttribute &orig)
-      : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
-
-  static MlirAttribute castFrom(PyAttribute &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
-    cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
index 861673a..f3f5ee5 100644 (file)
@@ -642,6 +642,46 @@ private:
   std::unique_ptr<std::string> ownedName;
 };
 
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirAttribute);
+
+  PyConcreteAttribute() = default;
+  PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+      : BaseTy(std::move(contextRef), attr) {}
+  PyConcreteAttribute(PyAttribute &orig)
+      : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
+
+  static MlirAttribute castFrom(PyAttribute &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError,
+                       llvm::Twine("Cannot cast attribute to ") +
+                           DerivedTy::pyClassName + " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(pybind11::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol());
+    cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
 /// Wrapper around the generic MlirType.
 /// The lifetime of a type is bound by the PyContext that created it.
 class PyType : public BaseContextObject {