[nGraph] Py API get/set partial shape of parameter (#1560)
authorJan Iwaszkiewicz <jan.iwaszkiewicz@intel.com>
Fri, 31 Jul 2020 08:14:39 +0000 (10:14 +0200)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 08:14:39 +0000 (10:14 +0200)
ngraph/python/src/pyngraph/ops/parameter.cpp
ngraph/python/tests/test_ngraph/test_basic.py

index 1290157..90dcc3b 100644 (file)
@@ -42,4 +42,12 @@ void regclass_pyngraph_op_Parameter(py::module m)
     parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>());
     parameter.def(py::init<const ngraph::element::Type&, const ngraph::PartialShape&>());
     //    parameter.def_property_readonly("description", &ngraph::op::Parameter::description);
+
+    parameter.def("get_partial_shape",
+                  (const ngraph::PartialShape& (ngraph::op::Parameter::*)() const) &
+                      ngraph::op::Parameter::get_partial_shape);
+    parameter.def("get_partial_shape",
+                  (ngraph::PartialShape & (ngraph::op::Parameter::*)()) &
+                      ngraph::op::Parameter::get_partial_shape);
+    parameter.def("set_partial_shape", &ngraph::op::Parameter::set_partial_shape);
 }
index c25635f..5821dc3 100644 (file)
@@ -33,6 +33,8 @@ def test_ngraph_function_api():
     model = (parameter_a + parameter_b) * parameter_c
     function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction")
 
+    function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5]))
+
     ordered_ops = function.get_ordered_ops()
     op_types = [op.get_type_name() for op in ordered_ops]
     assert op_types == ["Parameter", "Parameter", "Parameter", "Add", "Multiply", "Result"]
@@ -41,6 +43,7 @@ def test_ngraph_function_api():
     assert function.get_output_op(0).get_type_name() == "Result"
     assert function.get_output_element_type(0) == parameter_a.get_element_type()
     assert list(function.get_output_shape(0)) == [2, 2]
+    assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
     assert len(function.get_parameters()) == 3
     assert len(function.get_results()) == 1
     assert function.get_friendly_name() == "TestFunction"