[nGraph] Public Py API to get function from cnnnetwork (#1567)
authorJan Iwaszkiewicz <jan.iwaszkiewicz@intel.com>
Tue, 4 Aug 2020 07:39:37 +0000 (09:39 +0200)
committerGitHub <noreply@github.com>
Tue, 4 Aug 2020 07:39:37 +0000 (09:39 +0200)
inference-engine/ie_bridges/python/tests/test_NGraph.py
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/helpers.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_onnx_import.py
ngraph/python/tests/test_onnx/utils/onnx_helpers.py

index 490db12..9de972d 100644 (file)
@@ -1,6 +1,8 @@
 from openvino.inference_engine import IENetwork
+
 try:
-    from ngraph.impl.op import Parameter, Relu
+    import ngraph as ng
+    from ngraph.impl.op import Parameter
     from ngraph.impl import Function, Shape, Type
     ngraph_available=True
 except:
@@ -12,28 +14,25 @@ import pytest
 if not ngraph_available:
     pytest.skip("NGraph is not installed, skip", allow_module_level=True)
 
-@pytest.mark.skip(reason="nGraph python API has been removed in 2020.2 LTS release")
 def test_CreateIENetworkFromNGraph():
     element_type = Type.f32
     param = Parameter(element_type, Shape([1, 3, 22, 22]))
-    relu = Relu(param)
+    relu = ng.relu(param)
     func = Function([relu], [param], 'test')
     caps = Function.to_capsule(func)
     cnnNetwork = IENetwork(caps)
     assert cnnNetwork != None
-    assert cnnNetwork.get_function() != None
+    assert ng.function_from_cnn(cnnNetwork) != None
     assert len(cnnNetwork.layers) == 2
 
-@pytest.mark.skip(reason="nGraph python API has been removed in 2020.2 LTS release")
 def test_GetIENetworkFromNGraph():
     element_type = Type.f32
     param = Parameter(element_type, Shape([1, 3, 22, 22]))
-    relu = Relu(param)
+    relu = ng.relu(param)
     func = Function([relu], [param], 'test')
     caps = Function.to_capsule(func)
     cnnNetwork = IENetwork(caps)
     assert cnnNetwork != None
-    assert cnnNetwork.get_function() != None
-    caps2 = cnnNetwork.get_function()
-    func2 = Function.from_capsule(caps2)
+    assert ng.function_from_cnn(cnnNetwork) != None
+    func2 = ng.function_from_cnn(cnnNetwork)
     assert func2 != None
index 1b3d18b..ebb8554 100644 (file)
@@ -24,6 +24,7 @@ except DistributionNotFound:
     __version__ = "0.0.0.dev0"
 
 from ngraph.impl import Node
+from ngraph.helpers import function_from_cnn
 
 from ngraph.opset4 import absolute
 from ngraph.opset4 import absolute as abs
diff --git a/ngraph/python/src/ngraph/helpers.py b/ngraph/python/src/ngraph/helpers.py
new file mode 100644 (file)
index 0000000..226b4dd
--- /dev/null
@@ -0,0 +1,26 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+"""nGraph helper functions."""
+
+from ngraph.impl import Function
+from openvino.inference_engine import IENetwork
+
+
+def function_from_cnn(cnn_network: IENetwork) -> Function:
+    """Get nGraph function from Inference Engine CNN network."""
+    capsule = cnn_network._get_function_capsule()
+    ng_function = Function.from_capsule(capsule)
+    return ng_function
index 18009d5..c71dbc7 100644 (file)
 import os
 
 import numpy as np
+import ngraph as ng
 import onnx
 from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
 from openvino.inference_engine import IECore
 
-from ngraph.impl import Function
 from tests.runtime import get_runtime
 from tests.test_onnx.utils.onnx_helpers import import_onnx_model
 
@@ -31,8 +31,7 @@ def test_import_onnx_function():
     ie = IECore()
     ie_network = ie.read_network(model=model_path)
 
-    capsule = ie_network._get_function_capsule()
-    ng_function = Function.from_capsule(capsule)
+    ng_function = ng.function_from_cnn(ie_network)
 
     dtype = np.float32
     value_a = np.array([1.0], dtype=dtype)
index de7e744..cc98c3f 100644 (file)
@@ -18,6 +18,7 @@ import onnx
 from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
 from openvino.inference_engine import IECore
 
+import ngraph as ng
 from ngraph.impl import Function
 
 
@@ -37,6 +38,5 @@ def import_onnx_model(model: onnx.ModelProto) -> Function:
     ie = IECore()
     ie_network = ie.read_network(model=model_byte_string, weights=b"", init_from_buffer=True)
 
-    capsule = ie_network._get_function_capsule()
-    ng_function = Function.from_capsule(capsule)
+    ng_function = ng.function_from_cnn(ie_network)
     return ng_function