Fixed access to the data of FP16 IRs with nGraph Python API (#1707)
authorJan Iwaszkiewicz <jan.iwaszkiewicz@intel.com>
Tue, 11 Aug 2020 04:16:11 +0000 (06:16 +0200)
committerGitHub <noreply@github.com>
Tue, 11 Aug 2020 04:16:11 +0000 (07:16 +0300)
ngraph/python/src/pyngraph/ops/constant.cpp

index f8ce67c..61b222e 100644 (file)
@@ -46,11 +46,24 @@ py::buffer_info _get_buffer_info(const ngraph::op::Constant& c)
     return py::buffer_info(
         const_cast<void*>(c.get_data_ptr()),               /* Pointer to buffer */
         static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
-        py::format_descriptor<T>::format(),                /* Python struct-style format
-                                                              descriptor */
-        static_cast<ssize_t>(shape.size()),                /* Number of dimensions */
-        std::vector<ssize_t>{shape.begin(), shape.end()},  /* Buffer dimensions */
-        _get_byte_strides<T>(shape)                        /* Strides (in bytes) for each index */
+        py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
+        static_cast<ssize_t>(shape.size()), /* Number of dimensions */
+        std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
+        _get_byte_strides<T>(shape)                       /* Strides (in bytes) for each index */
+        );
+}
+
+template <>
+py::buffer_info _get_buffer_info<ngraph::float16>(const ngraph::op::Constant& c)
+{
+    ngraph::Shape shape = c.get_shape();
+    return py::buffer_info(
+        const_cast<void*>(c.get_data_ptr()),               /* Pointer to buffer */
+        static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
+        std::string(1, 'H'),                /* Python struct-style format descriptor */
+        static_cast<ssize_t>(shape.size()), /* Number of dimensions */
+        std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
+        _get_byte_strides<ngraph::float16>(shape)         /* Strides (in bytes) for each index */
         );
 }
 
@@ -61,6 +74,9 @@ void regclass_pyngraph_op_Constant(py::module m)
     constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
     constant.def(
         py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
+    constant.def(py::init<const ngraph::element::Type&,
+                          const ngraph::Shape&,
+                          const std::vector<ngraph::float16>&>());
     constant.def(
         py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<float>&>());
     constant.def(
@@ -97,6 +113,10 @@ void regclass_pyngraph_op_Constant(py::module m)
         {
             return _get_buffer_info<char>(self);
         }
+        else if (element_type == ngraph::element::f16)
+        {
+            return _get_buffer_info<ngraph::float16>(self);
+        }
         else if (element_type == ngraph::element::f32)
         {
             return _get_buffer_info<float>(self);
@@ -139,7 +159,7 @@ void regclass_pyngraph_op_Constant(py::module m)
         }
         else
         {
-            throw std::runtime_error("Unsupproted data type!");
+            throw std::runtime_error("Unsupported data type!");
         }
     });
 }