[mlir][python] Expose fp8 types with pybind.
authorQiao Zhang <zhangqiaorjc@google.com>
Tue, 3 Jan 2023 19:06:30 +0000 (19:06 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 3 Jan 2023 19:18:46 +0000 (19:18 +0000)
Expose fp8 types with pybind.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D140746

mlir/lib/Bindings/Python/IRTypes.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/ir/builtin_types.py

index 7a41cb1..10527af 100644 (file)
@@ -102,6 +102,42 @@ public:
   }
 };
 
+/// Floating Point Type subclass - Float8E4M3FNType.
+class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+  static constexpr const char *pyClassName = "Float8E4M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+          return PyFloat8E4M3FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
+  }
+};
+
+/// Floating Point Type subclass - Float8M5E2Type.
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+  static constexpr const char *pyClassName = "Float8E5M2Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E5M2TypeGet(context->get());
+          return PyFloat8E5M2Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e5m2 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type> {
 public:
@@ -663,6 +699,8 @@ public:
 void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
   PyIndexType::bind(m);
+  PyFloat8E4M3FNType::bind(m);
+  PyFloat8E5M2Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyF32Type::bind(m);
index 60bc367..505946c 100644 (file)
@@ -50,6 +50,8 @@ __all__ = [
     "DiagnosticHandler",
     "DiagnosticSeverity",
     "DictAttr",
+    "Float8E4M3FNType",
+    "Float8E5M2Type",
     "F16Type",
     "F32Type",
     "F64Type",
@@ -577,6 +579,20 @@ class DictAttr(Attribute):
     @property
     def type(self) -> Type: ...
 
+class Float8E4M3FNType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E4M3FNType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
+class Float8E5M2Type(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E5M2Type: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 # TODO: Auto-generated. Audit and fix.
 class F16Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...
index 91c820f..e160216 100644 (file)
@@ -193,6 +193,10 @@ def testIndexType():
 @run
 def testFloatType():
   with Context():
+    # CHECK: float: f8E4M3FN
+    print("float:", Float8E4M3FNType.get())
+    # CHECK: float: f8E5M2
+    print("float:", Float8E5M2Type.get())
     # CHECK: float: bf16
     print("float:", BF16Type.get())
     # CHECK: float: f16