return py::buffer_info(
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
- py::format_descriptor<T>::format(), /* Python struct-style format
- descriptor */
- static_cast<ssize_t>(shape.size()), /* Number of dimensions */
- std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
- _get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
+ py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
+ static_cast<ssize_t>(shape.size()), /* Number of dimensions */
+ std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
+ _get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
+ );
+}
+
+template <>
+py::buffer_info _get_buffer_info<ngraph::float16>(const ngraph::op::Constant& c)
+{
+ ngraph::Shape shape = c.get_shape();
+ return py::buffer_info(
+ const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
+ static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
+ std::string(1, 'H'), /* Python struct-style format descriptor */
+ static_cast<ssize_t>(shape.size()), /* Number of dimensions */
+ std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
+ _get_byte_strides<ngraph::float16>(shape) /* Strides (in bytes) for each index */
);
}
constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
constant.def(
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
+ constant.def(py::init<const ngraph::element::Type&,
+ const ngraph::Shape&,
+ const std::vector<ngraph::float16>&>());
constant.def(
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<float>&>());
constant.def(
{
return _get_buffer_info<char>(self);
}
+ else if (element_type == ngraph::element::f16)
+ {
+ return _get_buffer_info<ngraph::float16>(self);
+ }
else if (element_type == ngraph::element::f32)
{
return _get_buffer_info<float>(self);
}
else
{
- throw std::runtime_error("Unsupproted data type!");
+ throw std::runtime_error("Unsupported data type!");
}
});
}