From 2b474c8a47560fbba44f0ebd11ac4e9e5cf7e23d Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 11 Aug 2020 06:16:11 +0200 Subject: [PATCH] Fixed access to the data of FP16 IRs with nGraph Python API (#1707) --- ngraph/python/src/pyngraph/ops/constant.cpp | 32 +++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/ngraph/python/src/pyngraph/ops/constant.cpp b/ngraph/python/src/pyngraph/ops/constant.cpp index f8ce67c..61b222e 100644 --- a/ngraph/python/src/pyngraph/ops/constant.cpp +++ b/ngraph/python/src/pyngraph/ops/constant.cpp @@ -46,11 +46,24 @@ py::buffer_info _get_buffer_info(const ngraph::op::Constant& c) return py::buffer_info( const_cast(c.get_data_ptr()), /* Pointer to buffer */ static_cast(c.get_element_type().size()), /* Size of one scalar */ - py::format_descriptor::format(), /* Python struct-style format - descriptor */ - static_cast(shape.size()), /* Number of dimensions */ - std::vector{shape.begin(), shape.end()}, /* Buffer dimensions */ - _get_byte_strides(shape) /* Strides (in bytes) for each index */ + py::format_descriptor::format(), /* Python struct-style format descriptor */ + static_cast(shape.size()), /* Number of dimensions */ + std::vector{shape.begin(), shape.end()}, /* Buffer dimensions */ + _get_byte_strides(shape) /* Strides (in bytes) for each index */ + ); +} + +template <> +py::buffer_info _get_buffer_info(const ngraph::op::Constant& c) +{ + ngraph::Shape shape = c.get_shape(); + return py::buffer_info( + const_cast(c.get_data_ptr()), /* Pointer to buffer */ + static_cast(c.get_element_type().size()), /* Size of one scalar */ + std::string(1, 'H'), /* Python struct-style format descriptor */ + static_cast(shape.size()), /* Number of dimensions */ + std::vector{shape.begin(), shape.end()}, /* Buffer dimensions */ + _get_byte_strides(shape) /* Strides (in bytes) for each index */ ); } @@ -61,6 +74,9 @@ void regclass_pyngraph_op_Constant(py::module m) constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant"; constant.def( py::init&>()); + constant.def(py::init&>()); constant.def( py::init&>()); constant.def( @@ -97,6 +113,10 @@ void regclass_pyngraph_op_Constant(py::module m) { return _get_buffer_info(self); } + else if (element_type == ngraph::element::f16) + { + return _get_buffer_info(self); + } else if (element_type == ngraph::element::f32) { return _get_buffer_info(self); @@ -139,7 +159,7 @@ void regclass_pyngraph_op_Constant(py::module m) } else { - throw std::runtime_error("Unsupproted data type!"); + throw std::runtime_error("Unsupported data type!"); } }); } -- 2.7.4