Add nGraph-ONNX tests (#1215)
authorMichał Karzyński <4430709+postrational@users.noreply.github.com>
Fri, 10 Jul 2020 09:53:56 +0000 (11:53 +0200)
committerGitHub <noreply@github.com>
Fri, 10 Jul 2020 09:53:56 +0000 (11:53 +0200)
41 files changed:
.ci/openvino-onnx/Dockerfile [new file with mode: 0644]
inference-engine/ie_bridges/python/src/openvino/inference_engine/ie_api_impl.cpp
inference-engine/thirdparty/CMakeLists.txt
ngraph/python/examples/basic.py [deleted file]
ngraph/python/requirements_test.txt [moved from ngraph/python/test_requirements.txt with 86% similarity]
ngraph/python/setup.py
ngraph/python/src/ngraph/impl/onnx_import/__init__.py [deleted file]
ngraph/python/src/ngraph/ops.py
ngraph/python/src/ngraph/utils/node_factory.py
ngraph/python/src/ngraph/utils/tensor_iterator_types.py
ngraph/python/src/pyngraph/function.cpp
ngraph/python/src/pyngraph/onnx_import/onnx_import.cpp [deleted file]
ngraph/python/src/pyngraph/onnx_import/onnx_import.hpp [deleted file]
ngraph/python/tests/conftest.py
ngraph/python/tests/runtime.py
ngraph/python/tests/test_ngraph/test_basic.py
ngraph/python/tests/test_ngraph/test_dyn_attributes.py
ngraph/python/tests/test_ngraph/test_normalization.py
ngraph/python/tests/test_ngraph/test_ops.py
ngraph/python/tests/test_ngraph/test_ops_binary.py
ngraph/python/tests/test_ngraph/test_ops_fused.py
ngraph/python/tests/test_ngraph/test_ops_matmul.py
ngraph/python/tests/test_ngraph/test_pooling.py
ngraph/python/tests/test_onnx/models/add_abc.onnx [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_onnx_import.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_batchnorm.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_binary.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_convpool.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_logical.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_matmul.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_nonlinear.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_reduction.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_reshape.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_unary.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_ops_variadic.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/test_zoo_models.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/utils/__init__.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/utils/model_zoo_tester.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/utils/onnx_backend.py [new file with mode: 0644]
ngraph/python/tests/test_onnx/utils/onnx_helpers.py [new file with mode: 0644]
ngraph/python/tox.ini

diff --git a/.ci/openvino-onnx/Dockerfile b/.ci/openvino-onnx/Dockerfile
new file mode 100644 (file)
index 0000000..c8e70e0
--- /dev/null
@@ -0,0 +1,84 @@
+FROM ubuntu:20.04
+
+LABEL version=2020.07.09.1
+
+ARG http_proxy
+ARG https_proxy
+ENV http_proxy ${http_proxy}
+ENV https_proxy ${https_proxy}
+
+ENV CI=true
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED 1
+
+# Install base dependencies
+RUN apt-get update && apt-get install -y locales && apt-get clean autoclean && apt-get autoremove -y
+
+# Set the locale to en_US.UTF-8
+RUN locale-gen en_US.UTF-8
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:en
+ENV LC_ALL en_US.UTF-8
+
+RUN apt-get update && apt-get -y --no-install-recommends install \
+# OpenVINO dependencies
+        autoconf \
+        automake \
+        build-essential \
+        cmake \
+        curl \
+        git \
+        libtool \
+        ocl-icd-opencl-dev \
+        pkg-config \
+        unzip \
+        wget \
+# Python dependencies
+        python3 \
+        python3-pip \
+        python3-dev \
+        python3-virtualenv \
+        cython3 \
+        tox \
+# ONNX dependencies
+        git-lfs \
+        protobuf-compiler \
+        libprotobuf-dev && \
+    apt-get clean autoclean && \
+    apt-get autoremove -y
+
+# Build OpenVINO
+COPY . /openvino/
+WORKDIR /openvino/build
+RUN cmake .. \
+    -DCMAKE_BUILD_TYPE=Release \
+    -DENABLE_VPU=OFF \
+    -DENABLE_GNA=OFF \
+    -DENABLE_OPENCV=OFF \
+    -DENABLE_CPPLINT=OFF \
+    -DENABLE_TESTS=OFF \
+    -DENABLE_BEH_TESTS=OFF \
+    -DENABLE_FUNCTIONAL_TESTS=OFF \
+    -DENABLE_MKL_DNN=ON \
+    -DENABLE_CLDNN=OFF \
+    -DENABLE_PROFILING_ITT=OFF \
+    -DENABLE_SAMPLES=OFF \
+    -DENABLE_SPEECH_DEMO=OFF \
+    -DENABLE_PYTHON=ON \
+    -DPYTHON_EXECUTABLE=/usr/bin/python3 \
+    -DNGRAPH_ONNX_IMPORT_ENABLE=ON \
+    -DNGRAPH_IE_ENABLE=ON \
+    -DNGRAPH_INTERPRETER_ENABLE=ON \
+    -DNGRAPH_DEBUG_ENABLE=OFF \
+    -DNGRAPH_DYNAMIC_COMPONENTS_ENABLE=ON \
+    -DCMAKE_INSTALL_PREFIX=/openvino/dist
+RUN make -j $(nproc) install
+
+# Run tests via tox
+WORKDIR /openvino/ngraph/python
+ENV NGRAPH_CPP_BUILD_PATH=/openvino/dist
+ENV LD_LIBRARY_PATH=/openvino/dist/lib
+ENV NGRAPH_ONNX_IMPORT_ENABLE=TRUE
+ENV PYTHONPATH=/openvino/bin/intel64/Debug/lib/python_api/python3.8:${PYTHONPATH}
+RUN git clone --recursive https://github.com/pybind/pybind11.git
+RUN tox
index 2ba4cb7..a417b31 100644 (file)
@@ -542,8 +542,12 @@ InferenceEnginePython::IECore::readNetwork(const std::string& modelPath, const s
 
 InferenceEnginePython::IENetwork
 InferenceEnginePython::IECore::readNetwork(const std::string& model, uint8_t *bin, size_t bin_size) {
-    InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8, { bin_size }, InferenceEngine::Layout::C);
-    auto weights_blob = InferenceEngine::make_shared_blob<uint8_t>(tensorDesc, bin, bin_size);
+    InferenceEngine::Blob::Ptr weights_blob;
+    if(bin_size!=0)
+    {
+        InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8, { bin_size }, InferenceEngine::Layout::C);
+        weights_blob = InferenceEngine::make_shared_blob<uint8_t>(tensorDesc, bin, bin_size);
+    }
     InferenceEngine::CNNNetwork net = actual.ReadNetwork(model, weights_blob);
     return IENetwork(std::make_shared<InferenceEngine::CNNNetwork>(net));
 }
index f94453e..b27397d 100644 (file)
@@ -5,6 +5,10 @@
 if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=unknown-warning-option -Wno-error=inconsistent-missing-override -Wno-error=pass-failed")
   set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-error=unknown-warning-option -Wno-error=inconsistent-missing-override -Wno-error=pass-failed")
+elseif(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.1)
+    # On g++ 9.3.0 (Ubuntu 20.04) the ADE library raises "redundant-move" warnings
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-move")
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-error=redundant-move")
 endif()
 
 function(build_with_lto)
diff --git a/ngraph/python/examples/basic.py b/ngraph/python/examples/basic.py
deleted file mode 100644 (file)
index 3bc4f40..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-# ******************************************************************************
-# Copyright 2017-2020 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ******************************************************************************
-"""Usage example for the ngraph Pythonic API."""
-
-import numpy as np
-import ngraph as ng
-
-A = ng.parameter(shape=[2, 2], name="A", dtype=np.float32)
-B = ng.parameter(shape=[2, 2], name="B")
-C = ng.parameter(shape=[2, 2], name="C")
-# >>> print(A)
-# <Parameter: 'A' ([2, 2], float)>
-
-model = (A + B) * C
-# >>> print(model)
-# <Multiply: 'Multiply_14' ([2, 2])>
-
-runtime = ng.runtime(backend_name="CPU")
-# >>> print(runtime)
-# <Runtime: Backend='CPU'>
-
-computation = runtime.computation(model, A, B, C)
-# >>> print(computation)
-# <Computation: Multiply_14(A, B, C)>
-
-value_a = np.array([[1, 2], [3, 4]], dtype=np.float32)
-value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
-value_c = np.array([[9, 10], [11, 12]], dtype=np.float32)
-
-result = computation(value_a, value_b, value_c)
-# >>> print(result)
-# [[ 54.  80.]
-#  [110. 144.]]
-
-print("Result = ", result)
similarity index 86%
rename from ngraph/python/test_requirements.txt
rename to ngraph/python/requirements_test.txt
index 0b12492..81d516c 100644 (file)
@@ -2,8 +2,9 @@ flake8
 flake8-comprehensions
 flake8-docstrings
 flake8-quotes
+onnx
 pydocstyle
 pytest
+retrying
 tox
 wheel
-zipp==0.5.0
index b98aefc..7b32b94 100644 (file)
@@ -27,7 +27,6 @@ __version__ = os.environ.get("NGRAPH_VERSION", "0.0.0.dev0")
 PYNGRAPH_ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
 PYNGRAPH_SRC_DIR = os.path.join(PYNGRAPH_ROOT_DIR, "src")
 NGRAPH_DEFAULT_INSTALL_DIR = os.environ.get("HOME")
-NGRAPH_ONNX_IMPORT_ENABLE = os.environ.get("NGRAPH_ONNX_IMPORT_ENABLE")
 NGRAPH_PYTHON_DEBUG = os.environ.get("NGRAPH_PYTHON_DEBUG")
 
 
@@ -232,9 +231,6 @@ library_dirs = [NGRAPH_CPP_LIBRARY_DIR]
 libraries = [NGRAPH_CPP_LIBRARY_NAME, ONNX_IMPORTER_CPP_LIBRARY_NAME]
 
 extra_compile_args = []
-if NGRAPH_ONNX_IMPORT_ENABLE in ["TRUE", "ON", True]:
-    extra_compile_args.append("-DNGRAPH_ONNX_IMPORT_ENABLE")
-
 extra_link_args = []
 
 data_files = [
@@ -243,6 +239,7 @@ data_files = [
         [
             os.path.join(NGRAPH_CPP_LIBRARY_DIR, library)
             for library in os.listdir(NGRAPH_CPP_LIBRARY_DIR)
+            if os.path.isfile(os.path.join(NGRAPH_CPP_LIBRARY_DIR, library))
         ],
     ),
     (
@@ -255,15 +252,6 @@ data_files = [
     ("", [os.path.join(NGRAPH_CPP_DIST_DIR, "LICENSE")],),
 ]
 
-if NGRAPH_ONNX_IMPORT_ENABLE in ["TRUE", "ON", True]:
-    onnx_sources = [
-        "pyngraph/onnx_import/onnx_import.cpp",
-    ]
-    onnx_sources = [PYNGRAPH_SRC_DIR + "/" + source for source in onnx_sources]
-    sources = sources + onnx_sources
-
-    packages.append("ngraph.impl.onnx_import")
-
 ext_modules = [
     Extension(
         "_pyngraph",
@@ -374,7 +362,7 @@ setup(
     url="https://github.com/openvinotoolkit/openvino",
     license="License :: OSI Approved :: Apache Software License",
     ext_modules=ext_modules,
-    package_dir={'': 'src'},
+    package_dir={"": "src"},
     packages=packages,
     cmdclass={"build_ext": BuildExt},
     data_files=data_files,
diff --git a/ngraph/python/src/ngraph/impl/onnx_import/__init__.py b/ngraph/python/src/ngraph/impl/onnx_import/__init__.py
deleted file mode 100644 (file)
index ad2a417..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-# ******************************************************************************
-# Copyright 2017-2020 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ******************************************************************************
-"""
-Package: ngraph
-Low level wrappers for the nGraph c++ api in ngraph::onnx_import.
-"""
-
-# flake8: noqa
-
-from _pyngraph import import_onnx_model
-from _pyngraph import import_onnx_model_file
index 200742a..facd1b4 100644 (file)
@@ -3481,7 +3481,6 @@ def tensor_iterator(
 
     :returns:   Node representing TensorIterator operation.
     """
-
     attributes = {
         "body": graph_body.serialize(),
         "slice_input_desc": [desc.serialize() for desc in slice_input_desc],
@@ -3491,7 +3490,7 @@ def tensor_iterator(
         "concat_output_desc": [desc.serialize() for desc in concat_output_desc],
     }
 
-    return _get_node_factory().create('TensorIterator', as_nodes(*inputs), attributes)
+    return _get_node_factory().create("TensorIterator", as_nodes(*inputs), attributes)
 
 
 @nameable_op
@@ -3515,7 +3514,8 @@ def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = No
     :param name:         Optional name for output node.
     :return: ReadValue node
     """
-    return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id})
+    return _get_node_factory().create("ReadValue", [as_node(init_value)],
+                                      {"variable_id": variable_id})
 
 
 @nameable_op
index 73aebb8..cff8eb1 100644 (file)
@@ -58,14 +58,14 @@ class NodeFactory(object):
 
         # Setup helper members for caching attribute values.
         # The cache would be lazily populated at first access attempt.
-        setattr(node, "_attr_cache", dict())
-        setattr(node, "_attr_cache_valid", bool(False))
+        node._attr_cache = {}
+        node._attr_cache_valid = False
 
         return node
 
     @staticmethod
     def _normalize_attr_name(attr_name: str, prefix: str) -> str:
-        """Normalizes attribute name.
+        """Normalize attribute name.
 
         :param      attr_name:  The attribute name.
         :param      prefix:     The prefix to attach to attribute name.
@@ -79,7 +79,7 @@ class NodeFactory(object):
 
     @classmethod
     def _normalize_attr_name_getter(cls, attr_name: str) -> str:
-        """Normalizes atr name to be suitable for getter function name.
+        """Normalize atr name to be suitable for getter function name.
 
         :param      attr_name:  The attribute name to normalize
 
@@ -89,7 +89,7 @@ class NodeFactory(object):
 
     @classmethod
     def _normalize_attr_name_setter(cls, attr_name: str) -> str:
-        """Normalizes atr name to be suitable for setter function name.
+        """Normalize attribute name to be suitable for setter function name.
 
         :param      attr_name:  The attribute name to normalize
 
@@ -99,7 +99,7 @@ class NodeFactory(object):
 
     @staticmethod
     def _get_node_attr_value(node: Node, attr_name: str) -> Any:
-        """Gets provided node attribute value.
+        """Get provided node attribute value.
 
         :param      node:       The node we retrieve attribute value from.
         :param      attr_name:  The attribute name.
@@ -113,7 +113,7 @@ class NodeFactory(object):
 
     @staticmethod
     def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
-        """Sets the node attribute value.
+        """Set the node attribute value.
 
         :param      node:       The node we change attribute value for.
         :param      attr_name:  The attribute name.
index cae8721..f4e1e15 100644 (file)
@@ -29,6 +29,7 @@ class GraphBody(object):
         self.results = results
 
     def serialize(self) -> dict:
+        """Serialize GraphBody as a dictionary."""
         return {
             "parameters": self.parameters,
             "results": self.results,
@@ -43,6 +44,7 @@ class TensorIteratorInputDesc(object):
         self.body_parameter_idx = body_parameter_idx
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorInputDesc as a dictionary."""
         return {
             "input_idx": self.input_idx,
             "body_parameter_idx": self.body_parameter_idx,
@@ -70,6 +72,7 @@ class TensorIteratorSliceInputDesc(TensorIteratorInputDesc):
         self.axis = axis
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorSliceInputDesc as a dictionary."""
         output = super().serialize()
         output["start"] = self.start
         output["stride"] = self.stride
@@ -90,6 +93,7 @@ class TensorIteratorMergedInputDesc(TensorIteratorInputDesc):
         self.body_value_idx = body_value_idx
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorMergedInputDesc as a dictionary."""
         output = super().serialize()
         output["body_value_idx"] = self.body_value_idx
         return output
@@ -110,6 +114,7 @@ class TensorIteratorOutputDesc(object):
         self.output_idx = output_idx
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorOutputDesc as a dictionary."""
         return {
             "body_value_idx": self.body_value_idx,
             "output_idx": self.output_idx,
@@ -124,6 +129,7 @@ class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
         self.iteration = iteration
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorBodyOutputDesc as a dictionary."""
         output = super().serialize()
         output["iteration"] = self.iteration
         return output
@@ -150,6 +156,7 @@ class TensorIteratorConcatOutputDesc(TensorIteratorOutputDesc):
         self.axis = axis
 
     def serialize(self) -> dict:
+        """Serialize TensorIteratorConcatOutputDesc as a dictionary."""
         output = super().serialize()
         output["start"] = self.start
         output["stride"] = self.stride
index 05a2ef3..9e37bfc 100644 (file)
@@ -78,7 +78,7 @@ void regclass_pyngraph_Function(py::module m)
             throw std::runtime_error("The provided capsule does not contain an ngraph::Function");
         }
     });
-    function.def("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) {
+    function.def_static("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) {
         // create a shared pointer on the heap before putting it in the capsule
         // this secures the lifetime of the object transferred by the capsule
         auto* sp_copy = new std::shared_ptr<ngraph::Function>(ngraph_function);
diff --git a/ngraph/python/src/pyngraph/onnx_import/onnx_import.cpp b/ngraph/python/src/pyngraph/onnx_import/onnx_import.cpp
deleted file mode 100644 (file)
index 6a23a1c..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-//*****************************************************************************
-// Copyright 2017-2020 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#if defined(NGRAPH_ONNX_IMPORT_ENABLE)
-#include <pybind11/pybind11.h>
-#include <pybind11/stl.h>
-
-#include <istream>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "ngraph/frontend/onnx_import/onnx.hpp"
-#include "ngraph/function.hpp"
-#include "pyngraph/onnx_import/onnx_import.hpp"
-
-namespace py = pybind11;
-
-static std::shared_ptr<ngraph::Function> import_onnx_model(const std::string& model_proto)
-{
-    std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
-    return ngraph::onnx_import::import_onnx_model(iss);
-}
-
-static std::shared_ptr<ngraph::Function> import_onnx_model_file(const std::string& filename)
-{
-    return ngraph::onnx_import::import_onnx_model(filename);
-}
-
-void regmodule_pyngraph_onnx_import(py::module mod)
-{
-    mod.def("import_onnx_model", &import_onnx_model);
-    mod.def("import_onnx_model_file", &import_onnx_model_file);
-}
-#endif
diff --git a/ngraph/python/src/pyngraph/onnx_import/onnx_import.hpp b/ngraph/python/src/pyngraph/onnx_import/onnx_import.hpp
deleted file mode 100644 (file)
index af03d2a..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-//*****************************************************************************
-// Copyright 2017-2020 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#pragma once
-#if defined(NGRAPH_ONNX_IMPORT_ENABLE)
-
-#include <pybind11/pybind11.h>
-
-namespace py = pybind11;
-
-void regmodule_pyngraph_onnx_import(py::module m);
-#endif
index 7d4cb6c..eab2ca0 100644 (file)
@@ -38,6 +38,7 @@ def pytest_configure(config):
     config.addinivalue_line("markers", "skip_on_hddl: Skip test on HDDL")
     config.addinivalue_line("markers", "skip_on_myriad: Skip test on MYRIAD")
     config.addinivalue_line("markers", "skip_on_hetero: Skip test on HETERO")
+    config.addinivalue_line("markers", "onnx_coverage: Collect ONNX operator coverage")
 
 
 def pytest_collection_modifyitems(config, items):
index ec53144..f06a5e3 100644 (file)
@@ -21,7 +21,7 @@ import numpy as np
 from openvino.inference_engine import IECore, IENetwork
 
 from ngraph.exceptions import UserInputError
-from ngraph.impl import Function, Node, PartialShape, Shape, serialize, util
+from ngraph.impl import Function, Node, serialize
 from ngraph.utils.types import NumericData
 import tests
 
@@ -43,7 +43,7 @@ class Runtime(object):
 
     def __init__(self, backend_name: str) -> None:
         self.backend_name = backend_name
-        log.debug("Creating Inference Engine for .".format(backend_name))
+        log.debug("Creating Inference Engine for %s" % backend_name)
         self.backend = IECore()
         assert backend_name in self.backend.available_devices, (
             'The requested device "' + backend_name + '" is not supported!'
index e84d002..c1caa3e 100644 (file)
@@ -269,5 +269,5 @@ def test_backend_config():
 def test_result():
     node = [[11, 10], [1, 8], [3, 4]]
 
-    result = util.run_op_node([node], ng.ops.result)
+    result = run_op_node([node], ng.ops.result)
     assert np.allclose(result, node)
index 8b6fb8a..cecbb8b 100644 (file)
@@ -97,16 +97,16 @@ def test_dynamic_get_attribute_value(int_dtype, fp_dtype):
     assert node.get_num_classes() == int_dtype(85)
     assert node.get_background_label_id() == int_dtype(13)
     assert node.get_top_k() == int_dtype(16)
-    assert node.get_variance_encoded_in_target() == True
+    assert node.get_variance_encoded_in_target()
     assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype)))
     assert node.get_code_type() == "pytorch.some_parameter_name"
-    assert node.get_share_location() == False
+    assert not node.get_share_location()
     assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645))
     assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111))
-    assert node.get_clip_after_nms() == True
-    assert node.get_clip_before_nms() == False
-    assert node.get_decrease_label_id() == True
-    assert node.get_normalized() == True
+    assert node.get_clip_after_nms()
+    assert not node.get_clip_before_nms()
+    assert node.get_decrease_label_id()
+    assert node.get_normalized()
     assert node.get_input_height() == int_dtype(86)
     assert node.get_input_width() == int_dtype(79)
     assert np.isclose(node.get_objectness_score(), fp_dtype(0.77))
@@ -165,9 +165,9 @@ def test_dynamic_set_attribute_value(int_dtype, fp_dtype):
     assert node.get_min_size() == int_dtype(123)
     assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
     assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
-    assert node.get_clip_before_nms() == True
-    assert node.get_clip_after_nms() == True
-    assert node.get_normalize() == True
+    assert node.get_clip_before_nms()
+    assert node.get_clip_after_nms()
+    assert node.get_normalize()
     assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34))
     assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88))
     assert node.get_framework() == "OpenVINO"
index 9d765df..b31cee5 100644 (file)
@@ -18,6 +18,7 @@ import numpy as np
 
 import ngraph as ng
 from tests.runtime import get_runtime
+from tests.test_ngraph.util import run_op_node
 
 
 def test_lrn():
@@ -33,7 +34,7 @@ def test_lrn():
         np.array(
             [
                 [[[0.0], [0.05325444]], [[0.03402646], [0.01869806]], [[0.06805293], [0.03287071]]],
-                [[[0.00509002], [0.00356153]], [[0.00174719], [0.0012555]], [[0.00322708], [0.00235574]],],
+                [[[0.00509002], [0.00356153]], [[0.00174719], [0.0012555]], [[0.00322708], [0.00235574]]],
             ],
             dtype=np.float32,
         ),
@@ -48,7 +49,7 @@ def test_lrn():
         np.array(
             [
                 [[[0.0], [0.35355338]], [[0.8944272], [1.0606602]], [[1.7888544], [1.767767]]],
-                [[[0.93704253], [0.97827977]], [[1.2493901], [1.2577883]], [[1.5617375], [1.5372968]],],
+                [[[0.93704253], [0.97827977]], [[1.2493901], [1.2577883]], [[1.5617375], [1.5372968]]],
             ],
             dtype=np.float32,
         ),
@@ -95,7 +96,7 @@ def test_lrn_factory():
         ],
         dtype=np.float32,
     )
-    result = util.run_op_node([x, axis], ng.ops.lrn, alpha, beta, bias, nsize)
+    result = run_op_node([x, axis], ng.ops.lrn, alpha, beta, bias, nsize)
 
     assert np.allclose(result, excepted)
 
@@ -109,5 +110,5 @@ def test_batch_norm_inference():
     epsilon = 9.99e-06
     excepted = [[2.0, 6.0, 12.0], [-2.0, -6.0, -12.0]]
 
-    result = util.run_op_node([data, gamma, beta, mean, variance], ng.ops.batch_norm_inference, epsilon)
+    result = run_op_node([data, gamma, beta, mean, variance], ng.ops.batch_norm_inference, epsilon)
     assert np.allclose(result, excepted)
index 350f75d..cf8a6fb 100644 (file)
@@ -14,8 +14,6 @@
 # limitations under the License.
 # ******************************************************************************
 # flake8: noqa
-from __future__ import absolute_import
-
 import numpy as np
 
 import ngraph as ng
index d89b435..245d70e 100644 (file)
@@ -96,7 +96,7 @@ def test_binary_op_with_scalar(ng_api_helper, numpy_function):
 
 @pytest.mark.parametrize(
     "ng_api_helper,numpy_function",
-    [(ng.logical_and, np.logical_and), (ng.logical_or, np.logical_or), (ng.logical_xor, np.logical_xor),],
+    [(ng.logical_and, np.logical_and), (ng.logical_or, np.logical_or), (ng.logical_xor, np.logical_xor)],
 )
 def test_binary_logical_op(ng_api_helper, numpy_function):
     runtime = get_runtime()
@@ -118,7 +118,7 @@ def test_binary_logical_op(ng_api_helper, numpy_function):
 
 @pytest.mark.parametrize(
     "ng_api_helper,numpy_function",
-    [(ng.logical_and, np.logical_and), (ng.logical_or, np.logical_or), (ng.logical_xor, np.logical_xor),],
+    [(ng.logical_and, np.logical_and), (ng.logical_or, np.logical_or), (ng.logical_xor, np.logical_xor)],
 )
 def test_binary_logical_op_with_scalar(ng_api_helper, numpy_function):
     runtime = get_runtime()
index feacfce..cc35f6f 100644 (file)
@@ -136,7 +136,7 @@ def test_depth_to_space():
 
     result = computation(data_value)
     expected = np.array(
-        [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23],]]],
+        [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]],
         dtype=np.float32,
     )
     assert np.allclose(result, expected)
@@ -284,13 +284,13 @@ def test_squeeze_operator():
 
     data_shape = [1, 2, 1, 3, 1, 1]
     parameter_data = ng.parameter(data_shape, name="Data", dtype=np.float32)
-    data_value = np.arange(6.0, dtype=np.float32).reshape(1, 2, 1, 3, 1, 1)
+    data_value = np.arange(6.0, dtype=np.float32).reshape([1, 2, 1, 3, 1, 1])
     axes = [2, 4]
     model = ng.squeeze(parameter_data, axes)
     computation = runtime.computation(model, parameter_data)
 
     result = computation(data_value)
-    expected = np.arange(6.0, dtype=np.float32).reshape(1, 2, 3, 1)
+    expected = np.arange(6.0, dtype=np.float32).reshape([1, 2, 3, 1])
     assert np.allclose(result, expected)
 
 
@@ -365,14 +365,14 @@ def test_unsqueeze():
     computation = runtime.computation(model, parameter_data)
 
     result = computation(data_value)
-    expected = np.arange(60.0, dtype=np.float32).reshape(1, 3, 4, 5, 1)
+    expected = np.arange(60.0, dtype=np.float32).reshape([1, 3, 4, 5, 1])
     assert np.allclose(result, expected)
 
 
 def test_grn_operator():
     runtime = get_runtime()
 
-    data_value = np.arange(start=1.0, stop=25.0, dtype=np.float32).reshape(1, 2, 3, 4)
+    data_value = np.arange(start=1.0, stop=25.0, dtype=np.float32).reshape([1, 2, 3, 4])
     bias = np.float32(1e-6)
 
     data_shape = [1, 2, 3, 4]
@@ -574,9 +574,6 @@ def test_space_to_depth_operator():
     ).reshape(1, 8, 2, 2)
     assert np.allclose(result, expected)
 
-
-    runtime = get_runtime()
-
     batch_size = 2
     input_size = 3
     hidden_size = 3
index 17e1b3c..952053a 100644 (file)
@@ -17,7 +17,6 @@ import numpy as np
 import pytest
 
 import ngraph as ng
-from tests.runtime import get_runtime
 from tests.test_ngraph.util import run_op_node
 
 
index baecc90..61c2778 100644 (file)
@@ -281,7 +281,7 @@ def test_max_pool_same_lower_auto_pads():
     result = comp(data)
 
     expected = np.array(
-        [[[[0.5, 1.5, 2.5, 3.5], [4.5, 5.5, 6.5, 7.5], [8.5, 9.5, 10.5, 11.5], [12.5, 13.5, 14.5, 15.5],]]],
+        [[[[0.5, 1.5, 2.5, 3.5], [4.5, 5.5, 6.5, 7.5], [8.5, 9.5, 10.5, 11.5], [12.5, 13.5, 14.5, 15.5]]]],
         dtype=np.float32,
     )
     assert np.allclose(result, expected)
diff --git a/ngraph/python/tests/test_onnx/models/add_abc.onnx b/ngraph/python/tests/test_onnx/models/add_abc.onnx
new file mode 100644 (file)
index 0000000..5c2da5d
--- /dev/null
@@ -0,0 +1,24 @@
+\b\ 3\12\13ngraph ONNXImporter:\86\ 1
+\19
+\ 1A
+\ 1B\12\ 1X\1a add_node1"\ 3Add
+\19
+\ 1X
+\ 1C\12\ 1Y\1a add_node2"\ 3Add\12
+test_graphZ\ f
+\ 1A\12
+
+\b\b\ 1\12\ 4
+\ 2\b\ 1Z\ f
+\ 1B\12
+
+\b\b\ 1\12\ 4
+\ 2\b\ 1Z\ f
+\ 1C\12
+
+\b\b\ 1\12\ 4
+\ 2\b\ 1b\ f
+\ 1Y\12
+
+\b\b\ 1\12\ 4
+\ 2\b\ 1B\ 2\10\ 4
\ No newline at end of file
diff --git a/ngraph/python/tests/test_onnx/test_onnx_import.py b/ngraph/python/tests/test_onnx/test_onnx_import.py
new file mode 100644 (file)
index 0000000..18009d5
--- /dev/null
@@ -0,0 +1,68 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+
+import os
+
+import numpy as np
+import onnx
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+from openvino.inference_engine import IECore
+
+from ngraph.impl import Function
+from tests.runtime import get_runtime
+from tests.test_onnx.utils.onnx_helpers import import_onnx_model
+
+
+def test_import_onnx_function():
+    model_path = os.path.join(os.path.dirname(__file__), "models/add_abc.onnx")
+    ie = IECore()
+    ie_network = ie.read_network(model=model_path)
+
+    capsule = ie_network._get_function_capsule()
+    ng_function = Function.from_capsule(capsule)
+
+    dtype = np.float32
+    value_a = np.array([1.0], dtype=dtype)
+    value_b = np.array([2.0], dtype=dtype)
+    value_c = np.array([3.0], dtype=dtype)
+
+    runtime = get_runtime()
+    computation = runtime.computation(ng_function)
+    result = computation(value_a, value_b, value_c)
+    assert np.allclose(result, np.array([6], dtype=dtype))
+
+
+def test_simple_graph():
+    node1 = make_node("Add", ["A", "B"], ["X"], name="add_node1")
+    node2 = make_node("Add", ["X", "C"], ["Y"], name="add_node2")
+    graph = make_graph(
+        [node1, node2],
+        "test_graph",
+        [
+            make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1]),
+            make_tensor_value_info("B", onnx.TensorProto.FLOAT, [1]),
+            make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1]),
+        ],
+        [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1])],
+    )
+    model = make_model(graph, producer_name="ngraph ONNX Importer")
+
+    ng_model_function = import_onnx_model(model)
+
+    runtime = get_runtime()
+    computation = runtime.computation(ng_model_function)
+    assert np.array_equal(computation(1, 2, 3)[0], np.array([6.0], dtype=np.float32))
+    assert np.array_equal(computation(4, 5, 6)[0], np.array([15.0], dtype=np.float32))
diff --git a/ngraph/python/tests/test_onnx/test_ops_batchnorm.py b/ngraph/python/tests/test_onnx/test_ops_batchnorm.py
new file mode 100644 (file)
index 0000000..f941e8a
--- /dev/null
@@ -0,0 +1,97 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+
+import numpy as np
+import onnx
+
+from tests.test_onnx.utils import run_node
+
+
+def make_batch_norm_node(**node_attributes):
+    return onnx.helper.make_node(
+        "BatchNormalization", inputs=["X", "scale", "B", "mean", "var"], outputs=["Y"], **node_attributes
+    )
+
+
+def test_batch_norm_test_node():
+    data = np.arange(48).reshape((1, 3, 4, 4)).astype(np.float32)
+    scale = np.ones((3,)).astype(np.float32)  # Gamma
+    bias = np.zeros((3,)).astype(np.float32)  # Beta
+    mean = np.mean(data, axis=(0, 2, 3))
+    var = np.var(data, axis=(0, 2, 3))
+
+    expected_output = np.array(
+        [
+            [
+                [
+                    [-1.62694025, -1.41001487, -1.19308949, -0.97616416],
+                    [-0.75923878, -0.54231346, -0.32538807, -0.10846269],
+                    [0.10846269, 0.32538807, 0.54231334, 0.75923872],
+                    [0.9761641, 1.19308949, 1.41001487, 1.62694025],
+                ],
+                [
+                    [-1.62694049, -1.41001511, -1.19308972, -0.97616434],
+                    [-0.7592392, -0.54231358, -0.32538843, -0.10846281],
+                    [0.10846233, 0.32538795, 0.5423131, 0.75923872],
+                    [0.97616386, 1.19308949, 1.41001463, 1.62694025],
+                ],
+                [
+                    [-1.62694025, -1.41001511, -1.19308949, -0.97616434],
+                    [-0.75923872, -0.54231358, -0.32538795, -0.10846233],
+                    [0.10846233, 0.32538795, 0.54231358, 0.7592392],
+                    [0.97616386, 1.19308949, 1.41001511, 1.62694073],
+                ],
+            ]
+        ],
+        dtype=np.float32,
+    )
+
+    node = make_batch_norm_node()
+    result = run_node(node, [data, scale, bias, mean, var])[0]
+    assert np.allclose(result, expected_output, rtol=1e-04, atol=1e-08)
+
+    scale = np.broadcast_to(0.1, (3,)).astype(np.float32)  # Gamma
+    bias = np.broadcast_to(1, (3,)).astype(np.float32)  # Beta
+
+    expected_output = np.array(
+        [
+            [
+                [
+                    [0.83730596, 0.85899848, 0.88069105, 0.90238357],
+                    [0.92407608, 0.94576865, 0.96746117, 0.98915374],
+                    [1.01084626, 1.03253877, 1.05423129, 1.07592392],
+                    [1.09761643, 1.11930895, 1.14100146, 1.16269398],
+                ],
+                [
+                    [0.83730596, 0.85899854, 0.88069105, 0.90238357],
+                    [0.92407608, 0.94576865, 0.96746117, 0.98915374],
+                    [1.01084626, 1.03253877, 1.05423141, 1.07592392],
+                    [1.09761643, 1.11930895, 1.14100146, 1.16269398],
+                ],
+                [
+                    [0.83730596, 0.85899848, 0.88069105, 0.90238357],
+                    [0.92407614, 0.94576865, 0.96746117, 0.98915374],
+                    [1.01084626, 1.03253877, 1.05423141, 1.07592392],
+                    [1.09761643, 1.11930895, 1.14100146, 1.16269398],
+                ],
+            ]
+        ],
+        dtype=np.float32,
+    )
+
+    node = make_batch_norm_node()
+    result = run_node(node, [data, scale, bias, mean, var])[0]
+    assert np.allclose(result, expected_output, rtol=1e-04, atol=1e-08)
diff --git a/ngraph/python/tests/test_onnx/test_ops_binary.py b/ngraph/python/tests/test_onnx/test_ops_binary.py
new file mode 100644 (file)
index 0000000..3156d5e
--- /dev/null
@@ -0,0 +1,148 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+from onnx.helper import make_graph, make_model, make_tensor_value_info
+
+from tests.test_onnx.utils import run_model
+
+
+def import_and_compute(op_type, input_data_left, input_data_right, opset=7, **node_attributes):
+    inputs = [np.array(input_data_left), np.array(input_data_right)]
+    onnx_node = onnx.helper.make_node(op_type, inputs=["x", "y"], outputs=["z"], **node_attributes)
+
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, value.shape)
+        for name, value in zip(onnx_node.input, inputs)
+    ]
+    output_tensors = [make_tensor_value_info(name, onnx.TensorProto.FLOAT, ()) for name in onnx_node.output]
+
+    graph = make_graph([onnx_node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="ngraph ONNX Importer")
+    model.opset_import[0].version = opset
+    return run_model(model, inputs)[0]
+
+
+def test_add_opset4():
+    assert np.array_equal(import_and_compute("Add", 1, 2, opset=4), np.array(3, dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Add", [1], [2], opset=4), np.array([3], dtype=np.float32))
+
+    assert np.array_equal(
+        import_and_compute("Add", [1, 2], [3, 4], opset=4), np.array([4, 6], dtype=np.float32)
+    )
+
+    assert np.array_equal(
+        import_and_compute("Add", [1, 2, 3], [4, 5, 6], opset=4), np.array([5, 7, 9], dtype=np.float32)
+    )
+
+    assert np.array_equal(
+        import_and_compute("Add", [[1, 2, 3], [4, 5, 6]], [7, 8, 9], broadcast=1, opset=4),
+        np.array([[8, 10, 12], [11, 13, 15]], dtype=np.float32),
+    )
+
+    # shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar
+    left_operand = np.ones((2, 3, 4, 5)).astype(np.float32)
+    assert np.array_equal(import_and_compute("Add", left_operand, 8, broadcast=1, opset=4), left_operand + 8)
+
+    # shape(A) = (2, 3, 4, 5), shape(B) = (5,)
+    left_operand = np.ones((2, 3, 4, 5), dtype=np.float32)
+    right_operand = np.random.rand(5).astype(np.float32)
+    import_and_compute("Add", left_operand, right_operand, broadcast=1, opset=4)
+
+    # shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
+    left_operand = np.ones((2, 3, 4, 5), dtype=np.float32)
+    right_operand = np.random.rand(4, 5).astype(np.float32)
+    assert np.array_equal(
+        import_and_compute("Add", left_operand, right_operand, broadcast=1, opset=4),
+        left_operand + right_operand,
+    )
+
+    # shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
+    left_operand = np.ones((2, 3, 4, 5), dtype=np.float32)
+    right_operand = np.random.rand(3, 4).astype(np.float32)
+    assert np.array_equal(
+        import_and_compute("Add", left_operand, right_operand, broadcast=1, axis=1, opset=4),
+        left_operand + right_operand.reshape(1, 3, 4, 1),
+    )
+
+    # shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0
+    left_operand = np.ones((2, 3, 4, 5), dtype=np.float32)
+    right_operand = np.random.rand(2).astype(np.float32)
+    assert np.array_equal(
+        import_and_compute("Add", left_operand, right_operand, broadcast=1, axis=0, opset=4),
+        left_operand + right_operand.reshape(2, 1, 1, 1),
+    )
+
+
+@pytest.mark.parametrize(
+    "left_shape,right_shape",
+    [
+        ((1,), (1,)),
+        ((256, 256, 3), (3,)),
+        ((5, 4), (1,)),
+        ((5, 4), (4,)),
+        ((15, 3, 5), (3, 5)),
+        ((15, 3, 5), (15, 1, 5)),
+        ((15, 3, 5), (3, 1)),
+        ((8, 1, 6, 1), (7, 1, 5)),
+    ],
+)
+def test_add_opset7(left_shape, right_shape):
+    """Test Add-7 operator, which uses numpy-style broadcasting."""
+    left_input = np.ones(left_shape)
+    right_input = np.ones(right_shape)
+    assert np.array_equal(import_and_compute("Add", left_input, right_input), left_input + right_input)
+
+
+def test_sub():
+    assert np.array_equal(import_and_compute("Sub", 20, 1), np.array(19, dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Sub", [20], [1]), np.array([19], dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Sub", [20, 19], [1, 2]), np.array([19, 17], dtype=np.float32))
+
+    assert np.array_equal(
+        import_and_compute("Sub", [[1, 2, 3], [4, 5, 6]], [7, 8, 9], broadcast=1),
+        np.array([[-6, -6, -6], [-3, -3, -3]], dtype=np.float32),
+    )
+
+
+def test_mul():
+    assert np.array_equal(import_and_compute("Mul", 2, 3), np.array(6, dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Mul", [2], [3]), np.array([6], dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Mul", [2, 3], [4, 5]), np.array([8, 15], dtype=np.float32))
+
+    assert np.array_equal(
+        import_and_compute("Mul", [[1, 2, 3], [4, 5, 6]], [7, 8, 9], broadcast=1),
+        np.array([[7, 16, 27], [28, 40, 54]], dtype=np.float32),
+    )
+
+
+def test_div():
+    assert np.array_equal(import_and_compute("Div", 6, 3), np.array(2, dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Div", [6], [3]), np.array([2], dtype=np.float32))
+
+    assert np.array_equal(import_and_compute("Div", [6, 8], [3, 2]), np.array([2, 4], dtype=np.float32))
+
+    assert np.array_equal(
+        import_and_compute("Div", [[10, 20, 30], [40, 50, 60]], [2, 5, 6], broadcast=1),
+        np.array([[5, 4, 5], [20, 10, 10]], dtype=np.float32),
+    )
diff --git a/ngraph/python/tests/test_onnx/test_ops_convpool.py b/ngraph/python/tests/test_onnx/test_ops_convpool.py
new file mode 100644 (file)
index 0000000..bb55ccc
--- /dev/null
@@ -0,0 +1,417 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+
+from tests.runtime import get_runtime
+from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
+
+
+@pytest.fixture
+def ndarray_1x1x4x4():
+    return np.array(
+        [[11, 12, 13, 14], [15, 16, 17, 18], [19, 20, 21, 22], [23, 24, 25, 26]], dtype=np.float32
+    ).reshape([1, 1, 4, 4])
+
+
+def make_onnx_model_for_conv_op(x_shape, weights_shape, transpose=False, **attributes):
+    output_shape = ()  # We don't need output shape to be accurate for these tests
+
+    if transpose:
+        node_op = "ConvTranspose"
+    else:
+        node_op = "Conv"
+
+    node = make_node(node_op, ["X", "weight"], ["Y"], name="test_node", **attributes)
+    graph = make_graph(
+        [node],
+        "test_graph",
+        [
+            make_tensor_value_info("X", onnx.TensorProto.FLOAT, x_shape),
+            make_tensor_value_info("weight", onnx.TensorProto.FLOAT, weights_shape),
+        ],
+        [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, output_shape)],
+    )
+    model = make_model(graph, producer_name="ngraph ONNXImporter")
+    return model
+
+
+def import_and_compute_conv(x, weights, transpose=False, **attributes):
+    x, weights = np.array(x), np.array(weights)
+    onnx_model = make_onnx_model_for_conv_op(x.shape, weights.shape, transpose=transpose, **attributes)
+    ng_model_function = import_onnx_model(onnx_model)
+    computation = get_runtime().computation(ng_model_function)
+    return computation(x, weights)[0]
+
+
+def test_2d_conv():
+    # x should have shape N(batch) x C x H x W
+    input_x = np.array(
+        [
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        ],
+        dtype=np.float32,
+    ).reshape(1, 1, 9, 9)
+
+    # filter weights should have shape M x C x kH x kW
+    input_filter = np.array([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]], dtype=np.float32).reshape(
+        [1, 1, 3, 3]
+    )
+
+    # convolution with padding=1 should produce 9 x 9 output:
+    result = import_and_compute_conv(input_x, input_filter, pads=(1, 1, 1, 1), strides=(1, 1))
+    assert np.array_equal(
+        result,
+        np.array(
+            [
+                [
+                    [
+                        [0.0, -15.0, -15.0, 15.0, 15.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+                        [0.0, -15.0, -15.0, 15.0, 15.0, 0.0, 0.0, 0.0, 0.0],
+                    ]
+                ]
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+    # convolution with padding=0 should produce 7 x 7 output:
+    result = import_and_compute_conv(input_x, input_filter, pads=(0, 0, 0, 0), strides=(1, 1))
+    assert np.array_equal(
+        result,
+        np.array(
+            [
+                [
+                    [
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                        [-20, -20, 20, 20, 0, 0, 0],
+                    ]
+                ]
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+    # convolution with strides=2 should produce 4 x 4 output:
+    result = import_and_compute_conv(input_x, input_filter, pads=(0, 0, 0, 0), strides=(2, 2))
+    assert np.array_equal(
+        result,
+        np.array(
+            [
+                [
+                    [
+                        [-20.0, 20.0, 0.0, 0.0],
+                        [-20.0, 20.0, 0.0, 0.0],
+                        [-20.0, 20.0, 0.0, 0.0],
+                        [-20.0, 20.0, 0.0, 0.0],
+                    ]
+                ]
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+    # convolution with dilations=2 should produce 5 x 5 output:
+    result = import_and_compute_conv(input_x, input_filter, dilations=(2, 2))
+    assert np.array_equal(
+        result,
+        np.array(
+            [
+                [
+                    [
+                        [0, 0, 20, 20, 0],
+                        [0, 0, 20, 20, 0],
+                        [0, 0, 20, 20, 0],
+                        [0, 0, 20, 20, 0],
+                        [0, 0, 20, 20, 0],
+                    ]
+                ]
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+
+def test_3d_conv():
+    # x should have shape N(batch) x C x H x W x D
+    input_x = np.array(
+        [
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        ],
+        dtype=np.float32,
+    ).reshape([1, 1, 9, 9, 1])
+    input_x = np.broadcast_to(input_x, (1, 1, 9, 9, 4))
+
+    # filter weights should have shape M x C x kH x kW x kD
+    input_filter = np.array([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]], dtype=np.float32).reshape(
+        [1, 1, 3, 3, 1]
+    )
+    input_filter = np.broadcast_to(input_filter, (1, 1, 3, 3, 3))
+
+    # convolution with padding=0 should produce 7 x 7 x 2 output:
+    result = import_and_compute_conv(
+        input_x, input_filter, dilations=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1)
+    )
+
+    assert np.array_equal(
+        np.moveaxis(result.squeeze(), (0, 1, 2), (1, 2, 0)),
+        np.array(
+            [
+                [
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                ],
+                [
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                    [-60.0, -60.0, 60.0, 60.0, 0.0, 0.0, 0.0],
+                ],
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+
+def test_2d_conv_transpose():
+    # x should have shape N(batch) x C x H x W
+    input_x = np.array(
+        [
+            [0.0, -15.0, -15.0, 15.0, 15.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -20.0, -20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0],
+            [0.0, -15.0, -15.0, 15.0, 15.0, 0.0, 0.0, 0.0, 0.0],
+        ],
+        dtype=np.float32,
+    ).reshape([1, 1, 9, 9])
+
+    # filter weights should have shape M x C x kH x kW
+    input_filter = np.array([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]], dtype=np.float32).reshape(
+        [1, 1, 3, 3]
+    )
+
+    # deconvolution with padding=1 should produce 9 x 9 output:
+    result = import_and_compute_conv(input_x, input_filter, transpose=True, pads=(1, 1, 1, 1), strides=(1, 1))
+
+    assert np.array_equal(
+        result.reshape([9, 9]),
+        np.array(
+            [
+                [-50.0, -50.0, 100.0, 100.0, -50.0, -50.0, 0.0, 0.0, 0.0],
+                [-75.0, -75.0, 150.0, 150.0, -75.0, -75.0, 0.0, 0.0, 0.0],
+                [-80.0, -80.0, 160.0, 160.0, -80.0, -80.0, 0.0, 0.0, 0.0],
+                [-80.0, -80.0, 160.0, 160.0, -80.0, -80.0, 0.0, 0.0, 0.0],
+                [-80.0, -80.0, 160.0, 160.0, -80.0, -80.0, 0.0, 0.0, 0.0],
+                [-80.0, -80.0, 160.0, 160.0, -80.0, -80.0, 0.0, 0.0, 0.0],
+                [-80.0, -80.0, 160.0, 160.0, -80.0, -80.0, 0.0, 0.0, 0.0],
+                [-75.0, -75.0, 150.0, 150.0, -75.0, -75.0, 0.0, 0.0, 0.0],
+                [-50.0, -50.0, 100.0, 100.0, -50.0, -50.0, 0.0, 0.0, 0.0],
+            ],
+            dtype=np.float32,
+        ),
+    )
+
+
+def test_pad_opset_1():
+    x = np.ones((2, 2), dtype=np.float32)
+    y = np.pad(x, pad_width=1, mode="constant")
+
+    model = get_node_model("Pad", x, paddings=[1, 1, 1, 1])
+    ng_results = run_model(model, [x])
+    assert np.array_equal(ng_results, [y])
+
+    x = np.random.randn(1, 3, 4, 5).astype(np.float32)
+    y = np.pad(x, pad_width=((0, 0), (0, 0), (1, 2), (3, 4)), mode="constant")
+
+    model = get_node_model("Pad", x, mode="constant", paddings=[0, 0, 1, 3, 0, 0, 2, 4])
+    ng_results = run_model(model, [x])
+    assert np.array_equal(ng_results, [y])
+
+    # incorrect paddings rank
+    x = np.ones((2, 2), dtype=np.float32)
+    model = get_node_model("Pad", x, paddings=[0, 1, 1, 3, 1, 2])
+    with pytest.raises(RuntimeError):
+        run_model(model, [x])
+
+    # no paddings arttribute
+    model = get_node_model("Pad", x)
+    with pytest.raises(RuntimeError):
+        import_onnx_model(model)
+
+
+def test_pad_opset_2():
+    x = np.ones((2, 2), dtype=np.float32)
+    y = np.pad(x, pad_width=1, mode="constant")
+
+    model = get_node_model("Pad", x, opset=2, pads=[1, 1, 1, 1])
+    ng_results = run_model(model, [x])
+    assert np.array_equal(ng_results, [y])
+
+    x = np.random.randn(1, 3, 4, 5).astype(np.float32)
+    y = np.pad(x, pad_width=((0, 0), (0, 0), (1, 2), (3, 4)), mode="constant")
+
+    model = get_node_model("Pad", x, opset=2, mode="constant", pads=[0, 0, 1, 3, 0, 0, 2, 4])
+    ng_results = run_model(model, [x])
+    assert np.array_equal(ng_results, [y])
+
+    # incorrect pads rank
+    x = np.ones((2, 2), dtype=np.float32)
+    model = get_node_model("Pad", x, opset=2, pads=[0, 1, 1, 3, 1, 2])
+    with pytest.raises(RuntimeError):
+        run_model(model, [x])
+
+
+# Error of validate layer: B with type: Pad. Cannot parse parameter pads_begin
+# from IR for layer B. Value -1,0 cannot be casted to int.
+def test_pad_negative_values_begin():
+    x = np.ones((2, 2), dtype=np.float32)
+
+    # Axis 1 begin
+    model = get_node_model("Pad", x, opset=2, pads=[-1, 0, 0, 0])
+    ng_result = run_model(model, [x])[0]
+    assert np.array_equal(ng_result, np.array([[1, 1]]))
+
+    # Axis 2 begin
+    model = get_node_model("Pad", x, opset=2, pads=[0, -1, 0, 0])
+    ng_result = run_model(model, [x])[0]
+    assert np.array_equal(ng_result, np.array([[1], [1]]))
+
+
+# Error of validate layer: B with type: Pad. Cannot parse parameter pads_begin
+# from IR for layer B. Value -1,0 cannot be casted to int.
+def test_pad_negative_values_end():
+    x = np.ones((2, 2), dtype=np.float32)
+
+    # Axis 1 end
+    model = get_node_model("Pad", x, opset=2, pads=[0, 0, -1, 0])
+    ng_result = run_model(model, [x])[0]
+    assert np.array_equal(ng_result, np.array([[1.0, 1.0]]))
+
+    # Axis 2 end
+    model = get_node_model("Pad", x, opset=2, pads=[0, 0, 0, -1])
+    ng_result = run_model(model, [x])[0]
+    assert np.array_equal(ng_result, np.array([[1], [1]]))
+
+
+def test_pool_average(ndarray_1x1x4x4):
+    x = ndarray_1x1x4x4
+    node = onnx.helper.make_node(
+        "AveragePool", inputs=["x"], outputs=["y"], kernel_shape=(2, 2), strides=(2, 2)
+    )
+    y = np.array([[13.5, 15.5], [21.5, 23.5]], dtype=np.float32).reshape([1, 1, 2, 2])
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+    node = onnx.helper.make_node(
+        "AveragePool", inputs=["x"], outputs=["y"], kernel_shape=(2, 2), strides=(2, 2), pads=(1, 1, 1, 1)
+    )
+    y = np.array([[11, 12.5, 14], [17, 18.5, 20], [23, 24.5, 26]], dtype=np.float32).reshape([1, 1, 3, 3])
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+
+def test_pool_average_3d(ndarray_1x1x4x4):
+    x = np.broadcast_to(ndarray_1x1x4x4, (1, 1, 4, 4, 4))
+    node = onnx.helper.make_node(
+        "AveragePool", inputs=["x"], outputs=["y"], kernel_shape=(2, 2, 2), strides=(2, 2, 2)
+    )
+    y = np.array([[[13.5, 15.5], [21.5, 23.5]], [[13.5, 15.5], [21.5, 23.5]]], dtype=np.float32).reshape(
+        [1, 1, 2, 2, 2]
+    )
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+
+def test_pool_max(ndarray_1x1x4x4):
+    node = onnx.helper.make_node("MaxPool", inputs=["x"], outputs=["y"], kernel_shape=(2, 2), strides=(2, 2))
+
+    x = ndarray_1x1x4x4
+    y = np.array([[16, 18], [24, 26]], dtype=np.float32).reshape([1, 1, 2, 2])
+
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+
+def test_pool_global_max(ndarray_1x1x4x4):
+    node = onnx.helper.make_node("GlobalMaxPool", inputs=["x"], outputs=["y"])
+
+    x = ndarray_1x1x4x4
+    y = np.array([26], dtype=np.float32).reshape([1, 1, 1, 1])
+
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+
+def test_pool_global_average(ndarray_1x1x4x4):
+    node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"])
+
+    x = ndarray_1x1x4x4
+    y = np.array([18.5], dtype=np.float32).reshape([1, 1, 1, 1])
+
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
+
+
+def test_pool_global_average_3d(ndarray_1x1x4x4):
+    x = np.broadcast_to(ndarray_1x1x4x4, (1, 1, 4, 4, 4))
+
+    node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"])
+    y = np.array([18.5], dtype=np.float32).reshape([1, 1, 1, 1, 1])
+    ng_results = run_node(node, [x])
+    assert np.array_equal(ng_results, [y])
diff --git a/ngraph/python/tests/test_onnx/test_ops_logical.py b/ngraph/python/tests/test_onnx/test_ops_logical.py
new file mode 100644 (file)
index 0000000..f1963ba
--- /dev/null
@@ -0,0 +1,57 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+
+# [PARAMETER_MISMATCH] Failed to set Blob with precision FP32
+from tests.test_onnx.utils import run_node
+
+
+@pytest.mark.parametrize(
+    "onnx_op, numpy_func, data_type",
+    [
+        ("And", np.logical_and, np.bool),
+        ("Or", np.logical_or, np.bool),
+        ("Xor", np.logical_xor, np.bool),
+        ("Equal", np.equal, np.int32),
+        ("Greater", np.greater, np.int32),
+        ("Less", np.less, np.int32),
+    ],
+)
+def test_logical(onnx_op, numpy_func, data_type):
+    node = onnx.helper.make_node(onnx_op, inputs=["A", "B"], outputs=["C"], broadcast=1)
+
+    input_a = np.array([[0, 1, -1], [0, 1, -1], [0, 1, -1]]).astype(data_type)
+    input_b = np.array([[0, 0, 0], [1, 1, 1], [-1, -1, -1]]).astype(data_type)
+    expected_output = numpy_func(input_a, input_b)
+    ng_results = run_node(node, [input_a, input_b], opset_version=4)
+    assert np.array_equal(ng_results, [expected_output])
+
+    input_a = np.array([[0, 1, -1], [0, 1, -1], [0, 1, -1]]).astype(data_type)
+    input_b = np.array(1).astype(data_type)
+    expected_output = numpy_func(input_a, input_b)
+    ng_results = run_node(node, [input_a, input_b], opset_version=4)
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_logical_not():
+    input_data = np.array([[False, True, True], [False, True, False], [False, False, True]])
+    expected_output = np.logical_not(input_data)
+
+    node = onnx.helper.make_node("Not", inputs=["X"], outputs=["Y"])
+    ng_results = run_node(node, [input_data])
+    assert np.array_equal(ng_results, [expected_output])
diff --git a/ngraph/python/tests/test_onnx/test_ops_matmul.py b/ngraph/python/tests/test_onnx/test_ops_matmul.py
new file mode 100644 (file)
index 0000000..34061ae
--- /dev/null
@@ -0,0 +1,182 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+
+from tests.runtime import get_runtime
+from tests.test_onnx.utils import import_onnx_model
+
+
+def make_onnx_model_for_matmul_op(input_left, input_right):
+    output_shape = np.matmul(input_left, input_right).shape
+    node = make_node("MatMul", ["X", "Y"], ["Z"], name="test_node")
+    graph = make_graph(
+        [node],
+        "test_graph",
+        [
+            make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape),
+            make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape),
+        ],
+        [make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)],
+    )
+    model = make_model(graph, producer_name="ngraph ONNXImporter")
+    return model
+
+
+def import_and_compute_matmul(input_left, input_right):
+    input_data_left = np.array(input_left)
+    input_data_right = np.array(input_right)
+    onnx_model = make_onnx_model_for_matmul_op(input_data_left, input_data_right)
+    transformer = get_runtime()
+    ng_model_function = import_onnx_model(onnx_model)
+    computation = transformer.computation(ng_model_function)
+    return computation(input_data_left, input_data_right)[0]
+
+
+def numpy_gemm(input_a, input_b, input_c, alpha=1, beta=1, trans_a=False, trans_b=False, broadcast=False):
+    input_a, input_b, input_c = np.array(input_a), np.array(input_b), np.array(input_c)
+    if trans_a:
+        input_a = input_a.T
+    if trans_b:
+        input_b = input_b.T
+
+    return (alpha * np.dot(input_a, input_b)) + (beta * input_c)
+
+
+def make_onnx_model_for_gemm_op(input_a, input_b, input_c, **kwargs):
+    input_a_for_output = input_a
+    input_b_for_output = input_b
+    if kwargs.get("transA"):
+        input_a_for_output = input_a.T
+    if kwargs.get("transB"):
+        input_b_for_output = input_b.T
+
+    output_shape = np.dot(input_a_for_output, input_b_for_output).shape
+    node = make_node("Gemm", ["A", "B", "C"], ["Y"], name="test_node", **kwargs)
+    graph = make_graph(
+        [node],
+        "test_graph",
+        [
+            make_tensor_value_info("A", onnx.TensorProto.FLOAT, input_a.shape),
+            make_tensor_value_info("B", onnx.TensorProto.FLOAT, input_b.shape),
+            make_tensor_value_info("C", onnx.TensorProto.FLOAT, input_c.shape),
+        ],
+        [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, output_shape)],
+    )
+    model = make_model(graph, producer_name="ngraph ONNXImporter")
+    return model
+
+
+def import_and_compute_gemm(input_a, input_b, input_c, **kwargs):
+    input_a, input_b, input_c = np.array(input_a), np.array(input_b), np.array(input_c)
+
+    if kwargs.get("trans_a"):
+        kwargs["transA"] = kwargs["trans_a"]
+        del kwargs["trans_a"]
+
+    if kwargs.get("trans_b"):
+        kwargs["transB"] = kwargs["trans_b"]
+        del kwargs["trans_b"]
+
+    onnx_model = make_onnx_model_for_gemm_op(input_a, input_b, input_c, **kwargs)
+    transformer = get_runtime()
+    ng_model_function = import_onnx_model(onnx_model)
+    computation = transformer.computation(ng_model_function)
+    return computation(input_a, input_b, input_c)[0]
+
+
+def test_op_matmul():
+    # vector @ vector
+    data = ([1, 2], [1, 3])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    data = ([1, 2, 3], [[4], [5], [6]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    data = ([[1, 2, 3]], [1, 2, 3])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    # vector @ matrix
+    data = ([1, 2, 3], [[4, 5], [6, 7], [8, 9]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    # matrix @ vector
+    data = ([[1, 2, 3], [4, 5, 6]], [[7], [8], [9]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    # matrix @ matrix
+    data = ([[1, 2], [3, 4]], [[5, 6], [7, 8]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    data = ([[1, 2, 3], [4, 5, 6]], [[7, 8], [9, 10], [11, 12]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    data = ([[1, 2], [3, 4], [5, 6]], [[7, 8, 9], [10, 11, 12]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+
+def test_op_matmul_3d():
+    # 3d tensor @ 3d tensor
+    data = ([[[1, 2], [3, 4]], [[1, 2], [3, 4]]], [[[5, 6], [7, 8]], [[5, 6], [7, 8]]])
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+    data = (np.ones((5, 2, 3)), (np.ones((5, 3, 2)) + 2))
+    assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
+
+
+def test_gemm():
+    data = ([1, 2], [1, 3], [1, 4])
+    assert np.array_equal(import_and_compute_gemm(*data), numpy_gemm(*data))
+
+    data = ([1, 2], [1, 3], 1)
+    assert np.array_equal(import_and_compute_gemm(*data), numpy_gemm(*data))
+
+    data = ([1, 2], [1, 3], [1])
+    assert np.array_equal(import_and_compute_gemm(*data), numpy_gemm(*data))
+
+    data = ([1, 2], [1, 3], [1, 4])
+    kwargs = {"alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+    data = ([1, 2, 3, 4], [1, 3, 5, 7], [1, 4])
+    kwargs = {"alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+
+def test_gemm_transpositions():
+    data = ([1, 2], [1, 3], [1, 4])
+    kwargs = {"trans_a": True, "trans_b": True}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+    data = ([[1, 2], [1, 2]], [[1, 3], [1, 3]], [4, 1])
+    kwargs = {"trans_a": True, "trans_b": True, "alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+    data = ([[1, 2]], [[1, 3]], 1)
+    kwargs = {"trans_b": True, "alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+    data = ([[1], [2]], [[1], [3]], 1)
+    kwargs = {"trans_a": True, "alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
+
+
+def test_gemm_flatten():
+    # input_a.shape is (4,1,1)
+    data = ([[[1]], [[2]], [[3]], [[4]]], [1, 3, 5, 7], [1, 4])
+    kwargs = {"alpha": 7, "beta": 9}
+    assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
diff --git a/ngraph/python/tests/test_onnx/test_ops_nonlinear.py b/ngraph/python/tests/test_onnx/test_ops_nonlinear.py
new file mode 100644 (file)
index 0000000..9ed74ed
--- /dev/null
@@ -0,0 +1,115 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+
+from tests.test_onnx.utils import run_node
+
+
+def import_and_compute(op_type, input_data, **node_attrs):
+    data_inputs = [np.array(input_data)]
+    node = onnx.helper.make_node(op_type, inputs=["x"], outputs=["y"], **node_attrs)
+    return run_node(node, data_inputs).pop()
+
+
+def assert_onnx_import_equals_callable(onnx_op_type, python_function, data, **kwargs):
+    data = np.array(data, dtype=np.float32)
+    assert np.allclose(import_and_compute(onnx_op_type, data, **kwargs), python_function(data, **kwargs))
+
+
+def test_sigmoid():
+    def sigmoid(x):
+        return 1 / (1 + np.exp(-x))
+
+    assert_onnx_import_equals_callable("Sigmoid", sigmoid, [-2, -1.0, 0.0, 1.0, 2.0])
+    assert_onnx_import_equals_callable("Sigmoid", sigmoid, [0.0])
+    assert_onnx_import_equals_callable("Sigmoid", sigmoid, [-2, -1.0, 0.0, 1.0, 2.0])
+
+
+def test_tanh():
+    assert_onnx_import_equals_callable("Tanh", np.tanh, [-2, -1.0, 0.0, 1.0, 2.0])
+    assert_onnx_import_equals_callable("Tanh", np.tanh, [0.0])
+    assert_onnx_import_equals_callable("Tanh", np.tanh, [-2, -1.0, 0.0, 1.0, 2.0])
+
+
+def test_relu():
+    def relu(x):
+        return np.maximum(x, 0)
+
+    assert_onnx_import_equals_callable("Relu", relu, [-2, -1.0, 0.0, 1.0, 2.0])
+    assert_onnx_import_equals_callable("Relu", relu, [0.0])
+    assert_onnx_import_equals_callable("Relu", relu, [-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1])
+    assert_onnx_import_equals_callable("Relu", relu, [[1, 2, 3], [4, 5, 6]])
+    assert_onnx_import_equals_callable("Relu", relu, [[-3, -2, -1], [1, 2, 3]])
+
+
+def test_leaky_relu():
+    def leaky_relu(x, alpha=0.01):
+        return np.maximum(alpha * x, x)
+
+    assert_onnx_import_equals_callable("LeakyRelu", leaky_relu, [-2, -1.0, 0.0, 1.0, 2.0], alpha=0.5)
+    assert_onnx_import_equals_callable("LeakyRelu", leaky_relu, [0.0])
+    assert_onnx_import_equals_callable(
+        "LeakyRelu", leaky_relu, [-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], alpha=1.0
+    )
+    assert_onnx_import_equals_callable("LeakyRelu", leaky_relu, [[1, 2, 3], [4, 5, 6]], alpha=0.2)
+    assert_onnx_import_equals_callable("LeakyRelu", leaky_relu, [[-3, -2, -1], [1, 2, 3]])
+
+
+@pytest.mark.parametrize(
+    "x,slope",
+    [
+        ([-2, -1.0, 0.0, 1.0, 2.0], 0.5),
+        ([0.0], 1),
+        ([-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], 1),
+        ([[1, 2, 3], [4, 5, 6]], 0.5),
+        ([[-3, -2, -1], [1, 2, 3]], 1),
+    ],
+)
+def test_parametric_relu(x, slope):
+    def parametic_relu(x, slope):
+        return np.where(x < 0, slope * x, x)
+
+    x, slope = np.array(x).astype(np.float32), np.array(slope).astype(np.float32)
+    expected_output = parametic_relu(x, slope)
+    node = onnx.helper.make_node("PRelu", inputs=["x", "slope"], outputs=["y"])
+    output = run_node(node, [x, slope]).pop()
+    assert np.allclose(output, expected_output)
+
+
+def test_selu():
+    # f(x) = gamma * (alpha * exp(x) - alpha) for x <= 0, y = gamma * x for x > 0
+    def selu(x, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875):
+        return np.where(x <= 0, gamma * (alpha * np.exp(x) - alpha), gamma * x)
+
+    assert_onnx_import_equals_callable("Selu", selu, [-2, -1.0, 0.0, 1.0, 2.0])
+    assert_onnx_import_equals_callable("Selu", selu, [0.0])
+    assert_onnx_import_equals_callable("Selu", selu, [-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1])
+    assert_onnx_import_equals_callable("Selu", selu, [[1, 2, 3], [4, 5, 6]])
+    assert_onnx_import_equals_callable("Selu", selu, [-2, -1.0, 0.0, 1.0, 2.0], gamma=0.5, alpha=0.5)
+
+
+def test_elu():
+    # f(x) = alpha * (exp(x) - 1) for x < 0, f(x) = x for x >= 0
+    def elu(x, alpha=1):
+        return np.where(x < 0, alpha * (np.exp(x) - 1), x)
+
+    assert_onnx_import_equals_callable("Elu", elu, [-2, -1.0, 0.0, 1.0, 2.0])
+    assert_onnx_import_equals_callable("Elu", elu, [0.0])
+    assert_onnx_import_equals_callable("Elu", elu, [-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1])
+    assert_onnx_import_equals_callable("Elu", elu, [[1, 2, 3], [4, 5, 6]])
+    assert_onnx_import_equals_callable("Elu", elu, [-2, -1.0, 0.0, 1.0, 2.0], alpha=0.5)
diff --git a/ngraph/python/tests/test_onnx/test_ops_reduction.py b/ngraph/python/tests/test_onnx/test_ops_reduction.py
new file mode 100644 (file)
index 0000000..da23f38
--- /dev/null
@@ -0,0 +1,452 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+
+from tests.test_onnx.utils import run_node
+
+
+def import_and_compute(op_type, input_data, **node_attrs):
+    data_inputs = [np.array(input_data)]
+    node = onnx.helper.make_node(op_type, inputs=["x"], outputs=["y"], **node_attrs)
+    return run_node(node, data_inputs).pop()
+
+
+def test_reduce_max():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceMax", data, keepdims=0), np.max(data, keepdims=False))
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0,), keepdims=0), np.max(data, keepdims=False, axis=(0,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(1,), keepdims=0), np.max(data, keepdims=False, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(2,), keepdims=0), np.max(data, keepdims=False, axis=(2,))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 1), keepdims=0),
+        np.max(data, keepdims=False, axis=(0, 1)),
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 2), keepdims=0),
+        np.max(data, keepdims=False, axis=(0, 2)),
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(1, 2), keepdims=0),
+        np.max(data, keepdims=False, axis=(1, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 1, 2), keepdims=0),
+        np.max(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+def test_reduce_max_keepdims():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceMax", data), np.max(data, keepdims=True))
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0,)), np.max(data, keepdims=True, axis=(0,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(1,)), np.max(data, keepdims=True, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(2,)), np.max(data, keepdims=True, axis=(2,))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 1)), np.max(data, keepdims=True, axis=(0, 1))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 2)), np.max(data, keepdims=True, axis=(0, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(1, 2)), np.max(data, keepdims=True, axis=(1, 2))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMax", data, axes=(0, 1, 2)), np.max(data, keepdims=True, axis=(0, 1, 2))
+    )
+
+
+def test_reduce_min():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceMin", data), np.min(data, keepdims=True))
+    assert np.array_equal(import_and_compute("ReduceMin", data, keepdims=0), np.min(data, keepdims=False))
+
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(1,)), np.min(data, keepdims=True, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(1,), keepdims=0), np.min(data, keepdims=False, axis=(1,))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(0, 2)), np.min(data, keepdims=True, axis=(0, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(0, 2), keepdims=0),
+        np.min(data, keepdims=False, axis=(0, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(0, 1, 2)), np.min(data, keepdims=True, axis=(0, 1, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMin", data, axes=(0, 1, 2), keepdims=0),
+        np.min(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+def test_reduce_mean():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceMean", data), np.mean(data, keepdims=True))
+    assert np.array_equal(import_and_compute("ReduceMean", data, keepdims=0), np.mean(data, keepdims=False))
+
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(1,)), np.mean(data, keepdims=True, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(1,), keepdims=0),
+        np.mean(data, keepdims=False, axis=(1,)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(0, 2)), np.mean(data, keepdims=True, axis=(0, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(0, 2), keepdims=0),
+        np.mean(data, keepdims=False, axis=(0, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(0, 1, 2)), np.mean(data, keepdims=True, axis=(0, 1, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceMean", data, axes=(0, 1, 2), keepdims=0),
+        np.mean(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+def test_reduce_sum():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceSum", data, keepdims=0), np.sum(data, keepdims=False))
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(1,), keepdims=0), np.sum(data, keepdims=False, axis=(1,))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(0, 2), keepdims=0),
+        np.sum(data, keepdims=False, axis=(0, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(0, 1, 2), keepdims=0),
+        np.sum(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+def test_reduce_sum_keepdims():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceSum", data), np.sum(data, keepdims=True))
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(1,)), np.sum(data, keepdims=True, axis=(1,))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(0, 2)), np.sum(data, keepdims=True, axis=(0, 2))
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceSum", data, axes=(0, 1, 2)), np.sum(data, keepdims=True, axis=(0, 1, 2))
+    )
+
+
+def test_reduce_prod():
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceProd", data), np.prod(data, keepdims=True))
+    assert np.array_equal(import_and_compute("ReduceProd", data, keepdims=0), np.prod(data, keepdims=False))
+
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(1,)), np.prod(data, keepdims=True, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(1,), keepdims=0),
+        np.prod(data, keepdims=False, axis=(1,)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(0, 2)), np.prod(data, keepdims=True, axis=(0, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(0, 2), keepdims=0),
+        np.prod(data, keepdims=False, axis=(0, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(0, 1, 2)), np.prod(data, keepdims=True, axis=(0, 1, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceProd", data, axes=(0, 1, 2), keepdims=0),
+        np.prod(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
+def test_reduce_l1(reduction_axes):
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sum(np.abs(input_data), keepdims=True, axis=reduction_axes)
+    node = onnx.helper.make_node("ReduceL1", inputs=["x"], outputs=["y"], axes=reduction_axes)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sum(np.abs(input_data), keepdims=False, axis=reduction_axes)
+    node = onnx.helper.make_node("ReduceL1", inputs=["x"], outputs=["y"], keepdims=0, axes=reduction_axes)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_l1_default_axes():
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sum(np.abs(input_data), keepdims=True)
+    node = onnx.helper.make_node("ReduceL1", inputs=["x"], outputs=["y"])
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sum(np.abs(input_data), keepdims=False)
+    node = onnx.helper.make_node("ReduceL1", inputs=["x"], outputs=["y"], keepdims=0)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
+def test_reduce_l2(reduction_axes):
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sqrt(np.sum(np.square(input_data), keepdims=True, axis=reduction_axes))
+    node = onnx.helper.make_node("ReduceL2", inputs=["x"], outputs=["y"], axes=reduction_axes)
+    raw_result = run_node(node, [input_data])
+    ng_result = np.array(raw_result.pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sqrt(np.sum(np.square(input_data), keepdims=False, axis=reduction_axes))
+    node = onnx.helper.make_node("ReduceL2", inputs=["x"], outputs=["y"], keepdims=0, axes=reduction_axes)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_l2_default_axes():
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sqrt(np.sum(np.square(input_data), keepdims=True))
+    node = onnx.helper.make_node("ReduceL2", inputs=["x"], outputs=["y"])
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sqrt(np.sum(np.square(input_data), keepdims=False))
+    node = onnx.helper.make_node("ReduceL2", inputs=["x"], outputs=["y"], keepdims=0)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
+def test_reduce_log_sum(reduction_axes):
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(0, 1, shape).astype(np.float32)
+
+    expected = np.log(np.sum(input_data, keepdims=True, axis=reduction_axes))
+    node = onnx.helper.make_node("ReduceLogSum", inputs=["x"], outputs=["y"], axes=reduction_axes)
+    ng_result = run_node(node, [input_data]).pop()
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.log(np.sum(input_data, keepdims=False, axis=reduction_axes))
+    node = onnx.helper.make_node("ReduceLogSum", inputs=["x"], outputs=["y"], keepdims=0, axes=reduction_axes)
+    ng_result = run_node(node, [input_data]).pop()
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_log_sum_default_axes():
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(0, 1, shape).astype(np.float32)
+
+    expected = np.log(np.sum(input_data, keepdims=True))
+    node = onnx.helper.make_node("ReduceLogSum", inputs=["x"], outputs=["y"])
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.log(np.sum(input_data, keepdims=False))
+    node = onnx.helper.make_node("ReduceLogSum", inputs=["x"], outputs=["y"], keepdims=0)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_log_sum_exp():
+    def logsumexp(data, axis=None, keepdims=True):
+        return np.log(np.sum(np.exp(data), axis=axis, keepdims=keepdims))
+
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ReduceLogSumExp", data), logsumexp(data, keepdims=True))
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, keepdims=0), logsumexp(data, keepdims=False)
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(1,)), logsumexp(data, keepdims=True, axis=(1,))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(1,), keepdims=0),
+        logsumexp(data, keepdims=False, axis=(1,)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(0, 2)), logsumexp(data, keepdims=True, axis=(0, 2))
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(0, 2), keepdims=0),
+        logsumexp(data, keepdims=False, axis=(0, 2)),
+    )
+
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(0, 1, 2)),
+        logsumexp(data, keepdims=True, axis=(0, 1, 2)),
+    )
+    assert np.array_equal(
+        import_and_compute("ReduceLogSumExp", data, axes=(0, 1, 2), keepdims=0),
+        logsumexp(data, keepdims=False, axis=(0, 1, 2)),
+    )
+
+
+@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
+def test_reduce_sum_square(reduction_axes):
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sum(np.square(input_data), keepdims=True, axis=reduction_axes)
+    node = onnx.helper.make_node("ReduceSumSquare", inputs=["x"], outputs=["y"], axes=reduction_axes)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sum(np.square(input_data), keepdims=False, axis=reduction_axes)
+    node = onnx.helper.make_node(
+        "ReduceSumSquare", inputs=["x"], outputs=["y"], keepdims=0, axes=reduction_axes
+    )
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_sum_square_default_axes():
+    shape = [2, 4, 3, 2]
+    np.random.seed(133391)
+    input_data = np.random.uniform(-100, 100, shape).astype(np.float32)
+
+    expected = np.sum(np.square(input_data), keepdims=True)
+    node = onnx.helper.make_node("ReduceSumSquare", inputs=["x"], outputs=["y"])
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+    expected = np.sum(np.square(input_data), keepdims=False)
+    node = onnx.helper.make_node("ReduceSumSquare", inputs=["x"], outputs=["y"], keepdims=0)
+    ng_result = np.array(run_node(node, [input_data]).pop())
+    assert np.array_equal(expected.shape, ng_result.shape)
+    assert np.allclose(expected, ng_result)
+
+
+def test_reduce_argmin():
+    def argmin(ndarray, axis, keepdims=False):
+        res = np.argmin(ndarray, axis=axis)
+        if keepdims:
+            res = np.expand_dims(res, axis=axis)
+        return res
+
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ArgMin", data, axis=0), argmin(data, keepdims=True, axis=0))
+    assert np.array_equal(
+        import_and_compute("ArgMin", data, axis=0, keepdims=0), argmin(data, keepdims=False, axis=0)
+    )
+    assert np.array_equal(import_and_compute("ArgMin", data, axis=1), argmin(data, keepdims=True, axis=1))
+    assert np.array_equal(
+        import_and_compute("ArgMin", data, axis=1, keepdims=0), argmin(data, keepdims=False, axis=1)
+    )
+    assert np.array_equal(import_and_compute("ArgMin", data, axis=2), argmin(data, keepdims=True, axis=2))
+    assert np.array_equal(
+        import_and_compute("ArgMin", data, axis=2, keepdims=0), argmin(data, keepdims=False, axis=2)
+    )
+
+
+def test_reduce_argmax():
+    def argmax(ndarray, axis, keepdims=False):
+        res = np.argmax(ndarray, axis=axis)
+        if keepdims:
+            res = np.expand_dims(res, axis=axis)
+        return res
+
+    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
+
+    assert np.array_equal(import_and_compute("ArgMax", data, axis=0), argmax(data, keepdims=True, axis=0))
+    assert np.array_equal(
+        import_and_compute("ArgMax", data, axis=0, keepdims=0), argmax(data, keepdims=False, axis=0)
+    )
+    assert np.array_equal(import_and_compute("ArgMax", data, axis=1), argmax(data, keepdims=True, axis=1))
+    assert np.array_equal(
+        import_and_compute("ArgMax", data, axis=1, keepdims=0), argmax(data, keepdims=False, axis=1)
+    )
+    assert np.array_equal(import_and_compute("ArgMax", data, axis=2), argmax(data, keepdims=True, axis=2))
+    assert np.array_equal(
+        import_and_compute("ArgMax", data, axis=2, keepdims=0), argmax(data, keepdims=False, axis=2)
+    )
diff --git a/ngraph/python/tests/test_onnx/test_ops_reshape.py b/ngraph/python/tests/test_onnx/test_ops_reshape.py
new file mode 100644 (file)
index 0000000..c18aad7
--- /dev/null
@@ -0,0 +1,351 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import pytest
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+
+from tests.runtime import get_runtime
+from tests.test_onnx.utils import (
+    all_arrays_equal,
+    get_node_model,
+    import_onnx_model,
+    run_model,
+    run_node,
+)
+
+
+def test_reshape():
+    input_data = np.arange(2560).reshape([16, 4, 4, 10])
+    reshape_node = onnx.helper.make_node("Reshape", inputs=["x"], outputs=["y"], shape=(256, 10))
+    expected_output = input_data.reshape([256, 10])
+
+    ng_results = run_node(reshape_node, [input_data], opset_version=4)
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_reshape_opset5():
+    original_shape = [2, 3, 4]
+    test_cases = {
+        "reordered_dims": np.array([4, 2, 3], dtype=np.int64),
+        "reduced_dims": np.array([3, 8], dtype=np.int64),
+        "extended_dims": np.array([3, 2, 2, 2], dtype=np.int64),
+        "one_dim": np.array([24], dtype=np.int64),
+        "negative_dim": np.array([6, -1, 2], dtype=np.int64),
+    }
+    input_data = np.random.random_sample(original_shape).astype(np.float32)
+
+    for _, shape in test_cases.items():
+        const_node = make_node(
+            "Constant",
+            inputs=[],
+            outputs=["const_shape"],
+            value=onnx.helper.make_tensor(
+                name="const_tensor", data_type=onnx.TensorProto.INT64, dims=shape.shape, vals=shape.flatten()
+            ),
+        )
+        reshape_node = onnx.helper.make_node("Reshape", inputs=["data", "const_shape"], outputs=["reshaped"])
+
+        graph = make_graph(
+            [const_node, reshape_node],
+            "test_graph",
+            [make_tensor_value_info("data", onnx.TensorProto.FLOAT, input_data.shape)],
+            [make_tensor_value_info("reshaped", onnx.TensorProto.FLOAT, ())],
+        )
+
+        model = make_model(graph, producer_name="ngraph ONNX Importer")
+        model.opset_import[0].version = 5
+        ng_model_function = import_onnx_model(model)
+        runtime = get_runtime()
+        computation = runtime.computation(ng_model_function)
+        ng_results = computation(input_data)
+        expected_output = np.reshape(input_data, shape)
+        assert np.array_equal(ng_results[0], expected_output)
+
+
+def test_reshape_opset5_param_err():
+    original_shape = [2, 3, 4]
+    output_shape = np.array([4, 2, 3], dtype=np.int64)
+    input_data = np.random.random_sample(original_shape).astype(np.float32)
+    reshape_node = onnx.helper.make_node("Reshape", inputs=["x", "y"], outputs=["z"])
+    ng_result = run_node(reshape_node, [input_data, output_shape], opset_version=5)
+    assert ng_result[0].shape == output_shape
+
+
+@pytest.mark.parametrize(
+    "axis,expected_output",
+    [
+        (0, np.arange(120).reshape(1, 120)),
+        (1, np.arange(120).reshape(2, 60)),
+        (2, np.arange(120).reshape(6, 20)),
+        (3, np.arange(120).reshape(24, 5)),
+        (4, np.arange(120).reshape(120, 1)),
+    ],
+)
+def test_flatten(axis, expected_output):
+    data = np.arange(120).reshape([2, 3, 4, 5])
+    node = onnx.helper.make_node("Flatten", inputs=["x"], outputs=["y"], axis=axis)
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_flatten_exception():
+    data = np.arange(120).reshape([2, 3, 4, 5])
+    node = onnx.helper.make_node("Flatten", inputs=["x"], outputs=["y"], axis=5)
+
+    with pytest.raises(RuntimeError):
+        run_node(node, [data])
+
+
+def test_transpose():
+    data = np.arange(120).reshape([2, 3, 4, 5])
+
+    node = onnx.helper.make_node("Transpose", inputs=["x"], outputs=["y"])
+    expected_output = data.T
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    node = onnx.helper.make_node("Transpose", inputs=["x"], outputs=["y"], perm=(3, 1, 0, 2))
+    expected_output = np.transpose(data, axes=(3, 1, 0, 2))
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_slice_opset1():
+    data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+
+    expected_output = np.array([[5, 6, 7]])
+    model = get_node_model("Slice", data, axes=[0, 1], starts=[1, 0], ends=[2, 3])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    expected_output = np.array([[2, 3, 4]])
+    model = get_node_model("Slice", data, starts=[0, 1], ends=[-1, 1000])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    data = np.random.randn(20, 10, 5).astype(np.float32)
+    expected_output = data[0:3, 0:10]
+    model = get_node_model("Slice", data, axes=[0, 1], starts=[0, 0], ends=[3, 10])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    # default axes
+    data = np.random.randn(20, 10, 5).astype(np.float32)
+    expected_output = data[:, :, 3:4]
+    model = get_node_model("Slice", data, starts=[0, 0, 3], ends=[20, 10, 4])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    # end out of bounds
+    data = np.random.randn(20, 10, 5).astype(np.float32)
+    expected_output = data[:, 1:1000]
+    model = get_node_model("Slice", data, axes=[1], starts=[1], ends=[1000])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    # negative value
+    data = np.random.randn(20, 10, 5).astype(np.float32)
+    expected_output = data[:, 0:-1]
+    model = get_node_model("Slice", data, axes=[1], starts=[0], ends=[-1])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    # start ouf of bounds
+    data = np.random.randn(20, 10, 5).astype(np.float32)
+    expected_output = data[:, 1000:1000]
+    model = get_node_model("Slice", data, axes=[1], starts=[1000], ends=[1000])
+    ng_results = run_model(model, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_concat():
+    a = np.array([[1, 2], [3, 4]])
+    b = np.array([[5, 6]])
+
+    node = onnx.helper.make_node("Concat", inputs=["x"], outputs=["z"], axis=0)
+    ng_results = run_node(node, [a])
+    assert np.array_equal(ng_results, [a])
+
+    expected_output = np.concatenate((a, b), axis=0)
+    node = onnx.helper.make_node("Concat", inputs=["x", "y"], outputs=["z"], axis=0)
+    ng_results = run_node(node, [a, b])
+    assert np.array_equal(ng_results, [expected_output])
+
+    a = np.array([[1, 2], [3, 4]])
+    b = np.array([[5, 6]]).T
+    expected_output = np.concatenate((a, b), axis=1)
+    node = onnx.helper.make_node("Concat", inputs=["x", "y"], outputs=["z"], axis=1)
+    ng_results = run_node(node, [a, b])
+    assert np.array_equal(ng_results, [expected_output])
+
+    test_cases = {
+        "1d": ([1, 2], [3, 4]),
+        "2d": ([[1, 2], [3, 4]], [[5, 6], [7, 8]]),
+        "3d": ([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]),
+    }
+
+    for _, values in test_cases.items():
+        values = [np.asarray(v) for v in values]
+        for i in range(len(values[0].shape)):
+            in_args = ["value" + str(k) for k in range(len(values))]
+            node = onnx.helper.make_node("Concat", inputs=list(in_args), outputs=["output"], axis=i,)
+            expected_output = np.concatenate(values, i)
+            ng_results = run_node(node, list(values))
+            assert np.array_equal(ng_results, [expected_output])
+
+
+def test_squeeze():
+    data = np.arange(6).reshape([1, 2, 3, 1])
+    expected_output = data.reshape([2, 3])
+
+    node = onnx.helper.make_node("Squeeze", inputs=["x"], outputs=["y"], axes=[0, 3])
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    data = np.random.randn(1, 3, 4, 5).astype(np.float32)
+    expected_output = np.squeeze(data, axis=0)
+    node = onnx.helper.make_node("Squeeze", inputs=["x"], outputs=["y"], axes=[0])
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_unsqueeze():
+    data = np.random.randn(3, 4, 5).astype(np.float32)
+    expected_output = np.expand_dims(data, axis=0)
+    node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0])
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    expected_output = np.reshape(data, [1, 3, 4, 5, 1])
+    node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0, 4])
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    expected_output = np.reshape(data, [1, 3, 1, 4, 5])
+    node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0, 2])
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "node, expected_output",
+    [
+        # Split into 2 equal parts along axis=0
+        (
+            onnx.helper.make_node("Split", inputs=["x"], outputs=["y", "z"], axis=0),
+            [np.array([[0, 1, 2, 3]]), np.array([[4, 5, 6, 7]])],
+        ),
+        # Default, split along axis=0 into 2 equal parts
+        (
+            onnx.helper.make_node("Split", inputs=["x"], outputs=["y", "z"]),
+            [np.array([[0, 1, 2, 3]]), np.array([[4, 5, 6, 7]])],
+        ),
+        # Split into 2 equal parts along axis=1
+        (
+            onnx.helper.make_node("Split", inputs=["x"], outputs=["a", "b"], axis=1),
+            [np.array([[0, 1], [4, 5]]), np.array([[2, 3], [6, 7]])],
+        ),
+        # Split into 4 equal parts along axis=1
+        (
+            onnx.helper.make_node("Split", inputs=["x"], outputs=["a", "b", "c", "d"], axis=1),
+            [np.array([[0], [4]]), np.array([[1], [5]]), np.array([[2], [6]]), np.array([[3], [7]])],
+        ),
+        # Split into 2 unequal parts along axis=1
+        (
+            onnx.helper.make_node("Split", inputs=["x"], outputs=["a", "b"], axis=1, split=(3, 1)),
+            [np.array([[0, 1, 2], [4, 5, 6]]), np.array([[3], [7]])],
+        ),
+    ],
+)
+def test_split_2d(node, expected_output):
+    data = np.arange(8).reshape(2, 4)
+    ng_results = run_node(node, [data])
+    assert all_arrays_equal(ng_results, expected_output)
+
+
+def test_split_1d():
+    # 1D
+    data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
+
+    node = onnx.helper.make_node("Split", inputs=["input"], outputs=["z", "w"], axis=0)
+    expected_outputs = [
+        np.array([1.0, 2.0, 3.0]).astype(np.float32),
+        np.array([4.0, 5.0, 6.0]).astype(np.float32),
+    ]
+    ng_results = run_node(node, [data])
+    assert all_arrays_equal(ng_results, expected_outputs)
+
+    node = onnx.helper.make_node("Split", inputs=["input"], outputs=["y", "z", "w"], axis=0, split=[2, 3, 1])
+    expected_outputs = [
+        np.array([1.0, 2.0]).astype(np.float32),
+        np.array([3.0, 4.0, 5.0]).astype(np.float32),
+        np.array([6.0]).astype(np.float32),
+    ]
+    ng_results = run_node(node, [data])
+    assert all_arrays_equal(ng_results, expected_outputs)
+
+    # Default values
+    data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
+
+    node = onnx.helper.make_node("Split", inputs=["input"], outputs=["y", "z", "w"])
+    expected_outputs = [
+        np.array([1.0, 2.0]).astype(np.float32),
+        np.array([3.0, 4.0]).astype(np.float32),
+        np.array([5.0, 6.0]).astype(np.float32),
+    ]
+    ng_results = run_node(node, [data])
+    assert all_arrays_equal(ng_results, expected_outputs)
+
+    node = onnx.helper.make_node("Split", inputs=["input"], outputs=["y", "z"], split=[2, 4])
+    expected_outputs = [
+        np.array([1.0, 2.0]).astype(np.float32),
+        np.array([3.0, 4.0, 5.0, 6.0]).astype(np.float32),
+    ]
+    ng_results = run_node(node, [data])
+    assert all_arrays_equal(ng_results, expected_outputs)
+
+
+def test_depth_to_space():
+    b, c, h, w = shape = (2, 8, 3, 3)
+    blocksize = 2
+    data = np.random.random_sample(shape).astype(np.float32)
+    tmp = np.reshape(data, [b, blocksize, blocksize, c // (blocksize ** 2), h, w])
+    tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
+    expected_output = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize])
+
+    node = onnx.helper.make_node("DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blocksize)
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
+
+    # (1, 4, 2, 3) input tensor
+    data = np.array(
+        [
+            [
+                [[0, 1, 2], [3, 4, 5]],
+                [[6, 7, 8], [9, 10, 11]],
+                [[12, 13, 14], [15, 16, 17]],
+                [[18, 19, 20], [21, 22, 23]],
+            ]
+        ]
+    ).astype(np.float32)
+    # (1, 1, 4, 6) output tensor
+    expected_output = np.array(
+        [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]]
+    ).astype(np.float32)
+
+    ng_results = run_node(node, [data])
+    assert np.array_equal(ng_results, [expected_output])
diff --git a/ngraph/python/tests/test_onnx/test_ops_unary.py b/ngraph/python/tests/test_onnx/test_ops_unary.py
new file mode 100644 (file)
index 0000000..a696e28
--- /dev/null
@@ -0,0 +1,530 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+import onnx.mapping
+import pytest
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+
+from ngraph.exceptions import NgraphTypeError
+from tests.runtime import get_runtime
+from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([-4, 0, 5, -10]),
+        np.array([[-4, 0, 5, -10], [-4, 0, 5, -10]]),
+        np.array([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]),
+    ],
+)
+def test_abs(input_data):
+    expected_output = np.abs(input_data)
+    node = onnx.helper.make_node("Abs", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([4, 0, 5, 10]),
+        np.array([[4, 0, 5, 10], [4, 0, 5, 10]]),
+        np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]]]),
+    ],
+)
+def test_sqrt(input_data):
+    input_data = input_data.astype(np.float32)
+    expected_output = np.sqrt(input_data)
+    node = onnx.helper.make_node("Sqrt", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.allclose(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([4, 0, 5, 10]),
+        np.array([[4, 0, 5, 10], [4, 0, 5, 10]]),
+        np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]]]),
+    ],
+)
+def test_exp(input_data):
+    input_data = input_data.astype(np.float32)
+    expected_output = np.exp(input_data)
+    node = onnx.helper.make_node("Exp", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.allclose(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([4, 2, 5, 10]),
+        np.array([[4, 1, 5, 10], [4, 2, 5, 10]]),
+        np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]]]),
+    ],
+)
+def test_log(input_data):
+    input_data = input_data.astype(np.float32)
+    expected_output = np.log(input_data)
+    node = onnx.helper.make_node("Log", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.allclose(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([-4, 0, 5, -10]),
+        np.array([[-4, 0, 5, -10], [-4, 0, 5, -10]]),
+        np.array([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]),
+    ],
+)
+def test_neg(input_data):
+    expected_output = np.negative(input_data)
+    node = onnx.helper.make_node("Neg", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([-4.2, 0.43, 5.99, -10.01]),
+        np.array([[-4.5, 0.99, 5.01, -10.00], [-4.5, 0.5, 5.1, 10.01]]),
+        np.array([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]) / 6,
+    ],
+)
+def test_floor(input_data):
+    expected_output = np.floor(input_data)
+    node = onnx.helper.make_node("Floor", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([-4.2, 0, 5.99, -10.01]),
+        np.array([[-4.5, 0.99, 5.01, -10.00], [-4.5, 0.5, 5.1, 10.01]]),
+        np.array([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]) / 6,
+    ],
+)
+def test_ceil(input_data):
+    expected_output = np.ceil(input_data)
+    node = onnx.helper.make_node("Ceil", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.array_equal(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize(
+    "min_value, max_value",
+    [(np.finfo(np.float32).min, np.finfo(np.float32).max), (-0.5, 0.5), (0.0, np.finfo(np.float32).max)],
+)
+def test_clip(min_value, max_value):
+    np.random.seed(133391)
+    input_data = np.float32(-100.0) + np.random.randn(3, 4, 5).astype(np.float32) * np.float32(200.0)
+    model = get_node_model("Clip", input_data, opset=10, min=float(min_value), max=float(max_value))
+    result = run_model(model, [input_data])
+    expected = np.clip(input_data, min_value, max_value)
+    assert np.allclose(result, [expected])
+
+
+def test_clip_default():
+    np.random.seed(133391)
+    input_data = -100.0 + np.random.randn(3, 4, 5).astype(np.float32) * 200.0
+
+    model = get_node_model("Clip", input_data, opset=10, min=0.0)
+    result = run_model(model, [input_data])
+    expected = np.clip(input_data, np.float32(0.0), np.finfo(np.float32).max)
+    assert np.allclose(result, [expected])
+
+    model = get_node_model("Clip", input_data, opset=10, max=0.0)
+    result = run_model(model, [input_data])
+    expected = np.clip(input_data, np.finfo(np.float32).min, np.float32(0.0))
+    assert np.allclose(result, [expected])
+
+
+@pytest.mark.parametrize(
+    "input_data",
+    [
+        np.array([-4.2, 1, 5.99, -10.01]),
+        np.array([[-4.5, 0.99, 5.01, -10.00], [-4.5, 0.5, 5.1, 10.01]]),
+        np.array([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]) / 6,
+    ],
+)
+def test_reciprocal(input_data):
+    expected_output = np.reciprocal(input_data)
+    node = onnx.helper.make_node("Reciprocal", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [input_data])
+    assert np.allclose(ng_results, [expected_output])
+
+
+@pytest.mark.parametrize("axis, dim1, dim2", [(0, 1, 60), (1, 3, 20), (2, 12, 5)])
+def test_hardmax(axis, dim1, dim2):
+    def hardmax_2d(data):
+        return np.eye(data.shape[1], dtype=data.dtype)[np.argmax(data, axis=1)]
+
+    np.random.seed(133391)
+    data = np.random.rand(3, 4, 5).astype(np.float32)
+    expected = hardmax_2d(data.reshape(dim1, dim2)).reshape(3, 4, 5)
+    node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"], axis=axis)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+
+def test_hardmax_special_cases():
+    def hardmax_2d(data):
+        return np.eye(data.shape[1], dtype=data.dtype)[np.argmax(data, axis=1)]
+
+    np.random.seed(133391)
+    data = np.random.rand(3, 4, 5).astype(np.float32)
+
+    # default axis=1
+    expected = hardmax_2d(data.reshape(3, 20)).reshape(3, 4, 5)
+    node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    expected = hardmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
+    node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"], axis=-1)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    with pytest.raises(RuntimeError):
+        node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"], axis=3)
+        ng_results = run_node(node, [data])
+
+    # For multiple occurrences of the maximal values, the first occurrence is selected
+    # for one-hot output
+    data = np.array([[3, 3, 3, 1]]).astype(np.float32)
+    expected = np.array([[1, 0, 0, 0]]).astype(np.float32)
+    node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+
+def test_hardsigmoid():
+    def hardsigmoid(data, alpha=0.2, beta=0.5):
+        return np.clip(alpha * data + beta, 0, 1)
+
+    np.random.seed(133391)
+    alpha = np.random.rand()
+    beta = np.random.rand()
+    data = np.random.rand(3, 4, 5).astype(np.float32)
+
+    expected = hardsigmoid(data, alpha, beta)
+    node = onnx.helper.make_node("HardSigmoid", inputs=["x"], outputs=["y"], alpha=alpha, beta=beta)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    expected = hardsigmoid(data)
+    node = onnx.helper.make_node("HardSigmoid", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+
+def test_softmax():
+    def softmax_2d(x):
+        max_x = np.max(x, axis=1).reshape((-1, 1))
+        exp_x = np.exp(x - max_x)
+        return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1))
+
+    np.random.seed(133391)
+    data = np.random.randn(3, 4, 5).astype(np.float32)
+
+    node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=0)
+    expected = softmax_2d(data.reshape(1, 60)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=1)
+    expected = softmax_2d(data.reshape(3, 20)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    # default axis is 1
+    node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=2)
+    expected = softmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=-1)
+    expected = softmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    with pytest.raises(RuntimeError):
+        node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=3)
+        ng_results = run_node(node, [data])
+
+
+def test_logsoftmax():
+    def logsoftmax_2d(x):
+        max_x = np.max(x, axis=1).reshape((-1, 1))
+        exp_x = np.exp(x - max_x)
+        return x - max_x - np.log(np.sum(exp_x, axis=1).reshape((-1, 1)))
+
+    np.random.seed(133391)
+    data = np.random.randn(3, 4, 5).astype(np.float32)
+
+    node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=0)
+    expected = logsoftmax_2d(data.reshape(1, 60)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=1)
+    expected = logsoftmax_2d(data.reshape(3, 20)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    # default axis is 1
+    node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"])
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=2)
+    expected = logsoftmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+    with pytest.raises(RuntimeError):
+        node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=3)
+        ng_results = run_node(node, [data])
+
+
+def test_softplus():
+    def softplus(x):
+        return np.log(np.exp(x) + 1)
+
+    np.random.seed(133391)
+    data = np.random.randn(3, 4, 5).astype(np.float32)
+
+    node = onnx.helper.make_node("Softplus", inputs=["x"], outputs=["y"])
+    expected = softplus(data)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+
+def test_softsign():
+    def softsign(x):
+        return x / (1 + np.abs(x))
+
+    np.random.seed(133391)
+    data = np.random.randn(3, 4, 5).astype(np.float32)
+
+    node = onnx.helper.make_node("Softsign", inputs=["x"], outputs=["y"])
+    expected = softsign(data)
+    ng_results = run_node(node, [data])
+    assert np.allclose(ng_results, [expected])
+
+
+def test_identity():
+    np.random.seed(133391)
+    shape = [2, 4]
+    input_data = np.random.randn(*shape).astype(np.float32)
+
+    identity_node = make_node("Identity", inputs=["x"], outputs=["y"])
+    ng_results = run_node(identity_node, [input_data])
+    assert np.array_equal(ng_results, [input_data])
+
+    node1 = make_node("Add", inputs=["A", "B"], outputs=["add1"], name="add_node1")
+    node2 = make_node("Identity", inputs=["add1"], outputs=["identity1"], name="identity_node1")
+    node3 = make_node("Abs", inputs=["identity1"], outputs=["Y"], name="abs_node1")
+
+    graph = make_graph(
+        [node1, node2, node3],
+        "test_graph",
+        [
+            make_tensor_value_info("A", onnx.TensorProto.FLOAT, shape),
+            make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape),
+        ],
+        [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape)],
+    )
+    model = make_model(graph, producer_name="ngraph ONNX Importer")
+    ng_model_function = import_onnx_model(model)
+    runtime = get_runtime()
+    computation = runtime.computation(ng_model_function)
+    ng_results = computation(input_data, input_data)
+    expected_result = np.abs(input_data + input_data)
+
+    assert np.array_equal(ng_results[0], expected_result)
+
+
+@pytest.mark.parametrize("val_type, input_data", [(np.dtype(bool), np.zeros((2, 2), dtype=int))])
+def test_cast_to_bool(val_type, input_data):
+    expected = np.array(input_data, dtype=val_type)
+
+    model = get_node_model("Cast", input_data, opset=6, to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val_type])
+    result = run_model(model, [input_data])
+    assert np.allclose(result, expected)
+
+
+@pytest.mark.parametrize(
+    "val_type, range_start, range_end, in_dtype",
+    [
+        pytest.param(np.dtype(np.float32), -8, 8, np.dtype(np.int32)),
+        pytest.param(np.dtype(np.float64), -16383, 16383, np.dtype(np.int64)),
+    ],
+)
+def test_cast_to_float(val_type, range_start, range_end, in_dtype):
+    np.random.seed(133391)
+    input_data = np.random.randint(range_start, range_end, size=(2, 2), dtype=in_dtype)
+    expected = np.array(input_data, dtype=val_type)
+
+    model = get_node_model("Cast", input_data, opset=6, to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val_type])
+    result = run_model(model, [input_data])
+    assert np.allclose(result, expected)
+
+
+@pytest.mark.parametrize(
+    "val_type", [np.dtype(np.int8), np.dtype(np.int16), np.dtype(np.int32), np.dtype(np.int64)]
+)
+def test_cast_to_int(val_type):
+    np.random.seed(133391)
+    input_data = np.ceil(-8 + np.random.rand(2, 3, 4) * 16)
+    expected = np.array(input_data, dtype=val_type)
+
+    model = get_node_model("Cast", input_data, opset=6, to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val_type])
+    result = run_model(model, [input_data])
+    assert np.allclose(result, expected)
+
+
+@pytest.mark.parametrize(
+    "val_type", [np.dtype(np.uint8), np.dtype(np.uint16), np.dtype(np.uint32), np.dtype(np.uint64)]
+)
+def test_cast_to_uint(val_type):
+    np.random.seed(133391)
+    input_data = np.ceil(np.random.rand(2, 3, 4) * 16)
+    expected = np.array(input_data, dtype=val_type)
+
+    model = get_node_model("Cast", input_data, opset=6, to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val_type])
+    result = run_model(model, [input_data])
+    assert np.allclose(result, expected)
+
+
+def test_cast_errors():
+    np.random.seed(133391)
+    input_data = np.ceil(np.random.rand(2, 3, 4) * 16)
+
+    # missing 'to' attribute
+    node = onnx.helper.make_node("Cast", inputs=["A"], outputs=["B"])
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, value.shape)
+        for name, value in zip(node.input, [input_data])
+    ]
+    output_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT16, value.shape)
+        for name, value in zip(node.output, ())
+    ]  # type: ignore
+
+    graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="NgraphBackend")
+    with pytest.raises(RuntimeError):
+        import_onnx_model(model)
+
+    # unsupported data type representation
+    node = onnx.helper.make_node("Cast", inputs=["A"], outputs=["B"], to=1.2345)
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, value.shape)
+        for name, value in zip(node.input, [input_data])
+    ]
+    output_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
+        for name, value in zip(node.output, ())
+    ]  # type: ignore
+
+    graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="NgraphBackend")
+    with pytest.raises(RuntimeError):
+        import_onnx_model(model)
+
+    # unsupported input tensor data type:
+    node = onnx.helper.make_node("Cast", inputs=["A"], outputs=["B"], to=onnx.TensorProto.INT32)
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.COMPLEX64, value.shape)
+        for name, value in zip(node.input, [input_data])
+    ]
+    output_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
+        for name, value in zip(node.output, ())
+    ]  # type: ignore
+
+    graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="NgraphBackend")
+    with pytest.raises((RuntimeError, NgraphTypeError)):
+        import_onnx_model(model)
+
+    # unsupported output tensor data type:
+    node = onnx.helper.make_node("Cast", inputs=["A"], outputs=["B"], to=onnx.TensorProto.COMPLEX128)
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, value.shape)
+        for name, value in zip(node.input, [input_data])
+    ]
+    output_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.COMPLEX128, value.shape)
+        for name, value in zip(node.output, ())
+    ]  # type: ignore
+
+    graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="NgraphBackend")
+    with pytest.raises(RuntimeError):
+        import_onnx_model(model)
+
+
+@pytest.mark.parametrize("value_type", [np.float32, np.float64])
+def test_constant(value_type):
+    values = np.random.randn(5, 5).astype(value_type)
+    node = onnx.helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["values"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(value_type)],
+            dims=values.shape,
+            vals=values.flatten(),
+        ),
+    )
+
+    ng_results = run_node(node, [])
+    assert np.allclose(ng_results, [values])
+
+
+# See https://github.com/onnx/onnx/issues/1190
+@pytest.mark.xfail(reason="ONNX#1190 numpy.float16 not supported by ONNX make_node", strict=True)
+def test_constant_err():
+    values = np.random.randn(5, 5).astype(np.float16)
+    node = onnx.helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["values"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(np.float16)],
+            dims=values.shape,
+            vals=values.flatten(),
+        ),
+    )
+
+    ng_results = run_node(node, [])
+    assert np.allclose(ng_results, [values])
diff --git a/ngraph/python/tests/test_onnx/test_ops_variadic.py b/ngraph/python/tests/test_onnx/test_ops_variadic.py
new file mode 100644 (file)
index 0000000..e1238b7
--- /dev/null
@@ -0,0 +1,41 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+from functools import reduce
+
+import numpy as np
+import onnx
+import pytest
+
+from tests.test_onnx.utils import run_node
+
+
+@pytest.mark.parametrize("onnx_op,numpy_func", [("Sum", np.add), ("Min", np.minimum), ("Max", np.maximum)])
+def test_variadic(onnx_op, numpy_func):
+    data = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])]
+    node = onnx.helper.make_node(onnx_op, inputs=["data_0", "data_1", "data_2"], outputs=["y"])
+    expected_output = reduce(numpy_func, data)
+
+    ng_results = run_node(node, data)
+    assert np.array_equal(ng_results, [expected_output])
+
+
+def test_mean():
+    data = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])]
+    node = onnx.helper.make_node("Mean", inputs=["data_0", "data_1", "data_2"], outputs=["y"])
+    expected_output = reduce(np.add, data) / len(data)
+
+    ng_results = run_node(node, data)
+    assert np.array_equal(ng_results, [expected_output])
diff --git a/ngraph/python/tests/test_onnx/test_zoo_models.py b/ngraph/python/tests/test_onnx/test_zoo_models.py
new file mode 100644 (file)
index 0000000..a91d877
--- /dev/null
@@ -0,0 +1,575 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+# ## Prepare a list of models from the ONNX Model Zoo
+#
+# from pathlib import Path
+# from operator import itemgetter
+# import re
+#
+# MODELS_ROOT_DIR = '/path/to/onnx/models'
+# zoo_models = []
+# for path in Path(MODELS_ROOT_DIR).rglob('*.tar.gz'):
+#     match = re.search('.*onnx\/models\/(.*\/model\/(.+)-(\d+)\.tar\.gz)', str(path))
+#     url = match.group(1)
+#     model_name = match.group(2)
+#     opset = match.group(3)
+#     zoo_models.append({'model_name': '{}_opset{}'.format(model_name.replace('-', '_'), opset), 'url': url})
+#
+# sorted(zoo_models, key=itemgetter('model_name'))
+import tests
+from tests.test_onnx.utils import OpenVinoOnnxBackend
+from tests.test_onnx.utils.model_zoo_tester import ModelZooTestRunner
+
+_GITHUB_MODELS_LTS = "https://media.githubusercontent.com/media/onnx/models/master/"
+
+zoo_models = [
+    {
+        "model_name": "FasterRCNN_opset10",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/object_detection_segmentation/faster-rcnn/model/FasterRCNN-10.tar.gz",
+    },
+    {
+        "model_name": "MaskRCNN_opset10",
+        "url": _GITHUB_MODELS_LTS + "vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.tar.gz",
+    },
+    {
+        "model_name": "ResNet101_DUC_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/object_detection_segmentation/duc/model/ResNet101-DUC-7.tar.gz",
+    },
+    {
+        "model_name": "arcfaceresnet100_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/body_analysis/arcface/model/arcfaceresnet100-8.tar.gz",
+    },
+    {
+        "model_name": "bertsquad_opset10",
+        "url": _GITHUB_MODELS_LTS + "text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz",
+    },
+    {
+        "model_name": "bertsquad_opset8",
+        "url": _GITHUB_MODELS_LTS + "text/machine_comprehension/bert-squad/model/bertsquad-8.tar.gz",
+    },
+    {
+        "model_name": "bidaf_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.tar.gz",
+    },
+    {
+        "model_name": "bvlcalexnet_opset3",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/alexnet/model/bvlcalexnet-3.tar.gz",
+    },
+    {
+        "model_name": "bvlcalexnet_opset6",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/alexnet/model/bvlcalexnet-6.tar.gz",
+    },
+    {
+        "model_name": "bvlcalexnet_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/alexnet/model/bvlcalexnet-7.tar.gz",
+    },
+    {
+        "model_name": "bvlcalexnet_opset8",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/alexnet/model/bvlcalexnet-8.tar.gz",
+    },
+    {
+        "model_name": "bvlcalexnet_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/alexnet/model/bvlcalexnet-9.tar.gz",
+    },
+    {
+        "model_name": "caffenet_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/caffenet/model/caffenet-3.tar.gz",
+    },
+    {
+        "model_name": "caffenet_opset6",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/caffenet/model/caffenet-6.tar.gz",
+    },
+    {
+        "model_name": "caffenet_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/caffenet/model/caffenet-7.tar.gz",
+    },
+    {
+        "model_name": "caffenet_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/caffenet/model/caffenet-8.tar.gz",
+    },
+    {
+        "model_name": "caffenet_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/caffenet/model/caffenet-9.tar.gz",
+    },
+    {
+        "model_name": "candy_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/candy-8.tar.gz",
+    },
+    {
+        "model_name": "candy_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/candy-9.tar.gz",
+    },
+    {
+        "model_name": "densenet_opset3",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/densenet-121/model/densenet-3.tar.gz",
+    },
+    {
+        "model_name": "densenet_opset6",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/densenet-121/model/densenet-6.tar.gz",
+    },
+    {
+        "model_name": "densenet_opset7",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/densenet-121/model/densenet-7.tar.gz",
+    },
+    {
+        "model_name": "densenet_opset8",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/densenet-121/model/densenet-8.tar.gz",
+    },
+    {
+        "model_name": "densenet_opset9",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/densenet-121/model/densenet-9.tar.gz",
+    },
+    {
+        "model_name": "emotion_ferplus_opset2",
+        "url": _GITHUB_MODELS_LTS + "vision/body_analysis/emotion_ferplus/model/emotion-ferplus-2.tar.gz",
+    },
+    {
+        "model_name": "emotion_ferplus_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/body_analysis/emotion_ferplus/model/emotion-ferplus-7.tar.gz",
+    },
+    {
+        "model_name": "emotion_ferplus_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.tar.gz",
+    },
+    {
+        "model_name": "googlenet_opset3",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/googlenet/model/googlenet-3.tar.gz",
+    },
+    {
+        "model_name": "googlenet_opset6",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/googlenet/model/googlenet-6.tar.gz",
+    },
+    {
+        "model_name": "googlenet_opset7",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/googlenet/model/googlenet-7.tar.gz",
+    },
+    {
+        "model_name": "googlenet_opset8",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/googlenet/model/googlenet-8.tar.gz",
+    },
+    {
+        "model_name": "googlenet_opset9",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.tar.gz",
+    },
+    {
+        "model_name": "gpt2_opset10",
+        "url": _GITHUB_MODELS_LTS + "text/machine_comprehension/gpt-2/model/gpt2-10.tar.gz",
+    },
+    {
+        "model_name": "inception_v1_opset3",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-3.tar.gz",
+    },
+    {
+        "model_name": "inception_v1_opset6",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-6.tar.gz",
+    },
+    {
+        "model_name": "inception_v1_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.tar.gz",
+    },
+    {
+        "model_name": "inception_v1_opset8",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-8.tar.gz",
+    },
+    {
+        "model_name": "inception_v1_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-9.tar.gz",
+    },
+    {
+        "model_name": "inception_v2_opset3",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-3.tar.gz",
+    },
+    {
+        "model_name": "inception_v2_opset6",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-6.tar.gz",
+    },
+    {
+        "model_name": "inception_v2_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.tar.gz",
+    },
+    {
+        "model_name": "inception_v2_opset8",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-8.tar.gz",
+    },
+    {
+        "model_name": "inception_v2_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS
+        + "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.tar.gz",
+    },
+    {
+        "model_name": "mnist_opset1",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/mnist/model/mnist-1.tar.gz",
+    },
+    {
+        "model_name": "mnist_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/mnist/model/mnist-7.tar.gz",
+    },
+    {
+        "model_name": "mnist_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/mnist/model/mnist-8.tar.gz",
+    },
+    {
+        "model_name": "mobilenetv2_opset7",
+        "atol": 1e-07,
+        "rtol": 0.002,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/mobilenet/model/mobilenetv2-7.tar.gz",
+    },
+    {
+        "model_name": "mosaic_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/mosaic-8.tar.gz",
+    },
+    {
+        "model_name": "mosaic_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/mosaic-9.tar.gz",
+    },
+    {
+        "model_name": "pointilism_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/pointilism-8.tar.gz",
+    },
+    {
+        "model_name": "pointilism_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/pointilism-9.tar.gz",
+    },
+    {
+        "model_name": "rain_princess_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/rain-princess-8.tar.gz",
+    },
+    {
+        "model_name": "rain_princess_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/rain-princess-9.tar.gz",
+    },
+    {
+        "model_name": "rcnn_ilsvrc13_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-3.tar.gz",
+    },
+    {
+        "model_name": "rcnn_ilsvrc13_opset6",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-6.tar.gz",
+    },
+    {
+        "model_name": "rcnn_ilsvrc13_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-7.tar.gz",
+    },
+    {
+        "model_name": "rcnn_ilsvrc13_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-8.tar.gz",
+    },
+    {
+        "model_name": "rcnn_ilsvrc13_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.tar.gz",
+    },
+    {
+        "model_name": "resnet101_v1_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet101-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet101_v2_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet101-v2-7.tar.gz",
+    },
+    {
+        "model_name": "resnet152_v1_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet152-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet152_v2_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet152-v2-7.tar.gz",
+    },
+    {
+        "model_name": "resnet18_v1_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet18-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet18_v2_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet18-v2-7.tar.gz",
+    },
+    {
+        "model_name": "resnet34_v1_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet34-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet34_v2_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet34-v2-7.tar.gz",
+    },
+    {
+        "model_name": "resnet50_caffe2_v1_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-caffe2-v1-3.tar.gz",
+    },
+    {
+        "model_name": "resnet50_caffe2_v1_opset6",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-caffe2-v1-6.tar.gz",
+    },
+    {
+        "model_name": "resnet50_caffe2_v1_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-caffe2-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet50_caffe2_v1_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-caffe2-v1-8.tar.gz",
+    },
+    {
+        "model_name": "resnet50_caffe2_v1_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-caffe2-v1-9.tar.gz",
+    },
+    {
+        "model_name": "resnet50_v1_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-v1-7.tar.gz",
+    },
+    {
+        "model_name": "resnet50_v2_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/resnet/model/resnet50-v2-7.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-3.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_opset6",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-6.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-7.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_opset8",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-8.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-9.tar.gz",
+    },
+    {
+        "model_name": "shufflenet_v2_opset10",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/shufflenet/model/shufflenet-v2-10.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.0_opset3",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.0-3.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.0_opset6",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.0-6.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.0_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.0-7.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.0_opset8",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.0-8.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.0_opset9",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.0-9.tar.gz",
+    },
+    {
+        "model_name": "squeezenet1.1_opset7",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/classification/squeezenet/model/squeezenet1.1-7.tar.gz",
+    },
+    {
+        "model_name": "ssd_opset10",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/object_detection_segmentation/ssd/model/ssd-10.tar.gz",
+    },
+    {
+        "model_name": "super_resolution_opset10",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.tar.gz",
+    },
+    {
+        "model_name": "tiny_yolov3_opset11",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/object_detection_segmentation/tiny-yolov3/model/tiny-yolov3-11.tar.gz",
+    },
+    {
+        "model_name": "tinyyolov2_opset1",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-1.tar.gz",
+    },
+    {
+        "model_name": "tinyyolov2_opset7",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-7.tar.gz",
+    },
+    {
+        "model_name": "tinyyolov2_opset8",
+        "url": _GITHUB_MODELS_LTS
+        + "vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-8.tar.gz",
+    },
+    {
+        "model_name": "udnie_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/udnie-8.tar.gz",
+    },
+    {
+        "model_name": "udnie_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/style_transfer/fast_neural_style/model/udnie-9.tar.gz",
+    },
+    {
+        "model_name": "vgg16_bn_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg16-bn-7.tar.gz",
+    },
+    {
+        "model_name": "vgg16_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg16-7.tar.gz",
+    },
+    {
+        "model_name": "vgg19_bn_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-bn-7.tar.gz",
+    },
+    {
+        "model_name": "vgg19_caffe2_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-caffe2-3.tar.gz",
+    },
+    {
+        "model_name": "vgg19_caffe2_opset6",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-caffe2-6.tar.gz",
+    },
+    {
+        "model_name": "vgg19_caffe2_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-caffe2-7.tar.gz",
+    },
+    {
+        "model_name": "vgg19_caffe2_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-caffe2-8.tar.gz",
+    },
+    {
+        "model_name": "vgg19_caffe2_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-caffe2-9.tar.gz",
+    },
+    {
+        "model_name": "vgg19_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/vgg/model/vgg19-7.tar.gz",
+    },
+    {
+        "model_name": "yolov3_opset10",
+        "atol": 1e-07,
+        "rtol": 0.001,
+        "url": _GITHUB_MODELS_LTS + "vision/object_detection_segmentation/yolov3/model/yolov3-10.tar.gz",
+    },
+    {
+        "model_name": "zfnet512_opset3",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/zfnet-512/model/zfnet512-3.tar.gz",
+    },
+    {
+        "model_name": "zfnet512_opset6",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/zfnet-512/model/zfnet512-6.tar.gz",
+    },
+    {
+        "model_name": "zfnet512_opset7",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/zfnet-512/model/zfnet512-7.tar.gz",
+    },
+    {
+        "model_name": "zfnet512_opset8",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/zfnet-512/model/zfnet512-8.tar.gz",
+    },
+    {
+        "model_name": "zfnet512_opset9",
+        "url": _GITHUB_MODELS_LTS + "vision/classification/zfnet-512/model/zfnet512-9.tar.gz",
+    },
+]
+
+# Set backend device name to be used instead of hardcoded by ONNX BackendTest class ones.
+OpenVinoOnnxBackend.backend_name = tests.BACKEND_NAME
+
+# import all test cases at global scope to make them visible to pytest
+backend_test = ModelZooTestRunner(OpenVinoOnnxBackend, zoo_models, __name__)
+test_cases = backend_test.test_cases["OnnxBackendZooModelTest"]
+
+del test_cases
+globals().update(backend_test.enable_report().test_cases)
diff --git a/ngraph/python/tests/test_onnx/utils/__init__.py b/ngraph/python/tests/test_onnx/utils/__init__.py
new file mode 100644 (file)
index 0000000..54ba480
--- /dev/null
@@ -0,0 +1,105 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+
+from string import ascii_uppercase
+from typing import Any, Dict, Iterable, List, Optional, Text
+
+import numpy as np
+import onnx
+import pytest
+from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
+
+import tests
+from tests.runtime import get_runtime
+from tests.test_onnx.utils.onnx_backend import OpenVinoOnnxBackend
+from tests.test_onnx.utils.onnx_helpers import import_onnx_model
+
+
+def xfail_test(*backends, reason="Mark the test as expected to fail"):
+    return pytest.mark.xfail(condition=tests.BACKEND_NAME in backends, reason=reason, strict=True)
+
+
+def run_node(onnx_node, data_inputs, **kwargs):
+    # type: (onnx.NodeProto, List[np.ndarray], Dict[Text, Any]) -> List[np.ndarray]
+    """
+    Convert ONNX node to ngraph node and perform computation on input data.
+
+    :param onnx_node: ONNX NodeProto describing a computation node
+    :param data_inputs: list of numpy ndarrays with input data
+    :return: list of numpy ndarrays with computed output
+    """
+    OpenVinoOnnxBackend.backend_name = tests.BACKEND_NAME
+    return OpenVinoOnnxBackend.run_node(onnx_node, data_inputs, **kwargs)
+
+
+def run_model(onnx_model, data_inputs):
+    # type: (onnx.ModelProto, List[np.ndarray]) -> List[np.ndarray]
+    """
+    Convert ONNX model to an ngraph model and perform computation on input data.
+
+    :param onnx_model: ONNX ModelProto describing an ONNX model
+    :param data_inputs: list of numpy ndarrays with input data
+    :return: list of numpy ndarrays with computed output
+    """
+    ng_model_function = import_onnx_model(onnx_model)
+    runtime = get_runtime()
+    computation = runtime.computation(ng_model_function)
+    return computation(*data_inputs)
+
+
+def get_node_model(op_type, *input_data, opset=1, num_outputs=1, **node_attributes):
+    # type: (str, *Any, Optional[int], Optional[int], **Any) -> onnx.ModelProto
+    """Generate model with single requested node.
+
+    Input and output Tensor data type is the same.
+
+    :param op_type: The ONNX node operation.
+    :param input_data: Optional list of input arguments for node.
+    :param opset: The ONNX operation set version to use. Default to 4.
+    :param num_outputs: The number of node outputs.
+    :param node_attributes: Optional dictionary of node attributes.
+    :return: Generated model with single node for requested ONNX operation.
+    """
+    node_inputs = [np.array(data) for data in input_data]
+    num_inputs = len(node_inputs)
+    node_input_names = [ascii_uppercase[idx] for idx in range(num_inputs)]
+    node_output_names = [ascii_uppercase[num_inputs + idx] for idx in range(num_outputs)]
+    onnx_node = make_node(op_type, node_input_names, node_output_names, **node_attributes)
+
+    input_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, value.shape)
+        for name, value in zip(onnx_node.input, node_inputs)
+    ]
+    output_tensors = [
+        make_tensor_value_info(name, onnx.TensorProto.FLOAT, ()) for name in onnx_node.output
+    ]  # type: ignore
+
+    graph = make_graph([onnx_node], "compute_graph", input_tensors, output_tensors)
+    model = make_model(graph, producer_name="Ngraph ONNX Importer")
+    model.opset_import[0].version = opset
+    return model
+
+
+def all_arrays_equal(first_list, second_list):
+    # type: (Iterable[np.ndarray], Iterable[np.ndarray]) -> bool
+    """
+    Check that all numpy ndarrays in `first_list` are equal to all numpy ndarrays in `second_list`.
+
+    :param first_list: iterable containing numpy ndarray objects
+    :param second_list: another iterable containing numpy ndarray objects
+    :return: True if all ndarrays are equal, otherwise False
+    """
+    return all(map(lambda pair: np.array_equal(*pair), zip(first_list, second_list)))
diff --git a/ngraph/python/tests/test_onnx/utils/model_zoo_tester.py b/ngraph/python/tests/test_onnx/utils/model_zoo_tester.py
new file mode 100644 (file)
index 0000000..335e34e
--- /dev/null
@@ -0,0 +1,143 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+
+import glob
+import os
+import shutil
+import tarfile
+import tempfile
+from collections import defaultdict
+from typing import Dict, List, Optional, Pattern, Set, Text, Type
+
+import onnx.backend.test
+from onnx.backend.base import Backend
+from onnx.backend.test.case.test_case import TestCase as OnnxTestCase
+from onnx.backend.test.runner import TestItem
+from retrying import retry
+from six.moves.urllib.request import urlopen, urlretrieve
+
+
+class ModelZooTestRunner(onnx.backend.test.BackendTest):
+    def __init__(self, backend, zoo_models, parent_module=None):
+        # type: (Type[Backend], List[Dict[str,str]], Optional[str]) -> None
+        self.backend = backend
+        self._parent_module = parent_module
+        self._include_patterns = set()  # type: Set[Pattern[Text]]
+        self._exclude_patterns = set()  # type: Set[Pattern[Text]]
+        self._test_items = defaultdict(dict)  # type: Dict[Text, Dict[Text, TestItem]]
+
+        for zoo_model in zoo_models:
+            test_name = "test_{}".format(zoo_model["model_name"])
+
+            test_case = OnnxTestCase(
+                name=test_name,
+                url=zoo_model["url"],
+                model_name=zoo_model["model_name"],
+                model_dir=None,
+                model=None,
+                data_sets=None,
+                kind="OnnxBackendRealModelTest",
+                rtol=zoo_model.get("rtol", 0.001),
+                atol=zoo_model.get("atol", 1e-07),
+            )
+            self._add_model_test(test_case, "Zoo")
+
+    @staticmethod
+    @retry
+    def _get_etag_for_url(url):  # type: (str) -> str
+        request = urlopen(url)
+        return request.info().get("ETag")
+
+    @staticmethod
+    def _read_etag_file(model_dir):  # type: (str) -> str
+        etag_file_path = os.path.join(model_dir, "source_tar_etag")
+        if os.path.exists(etag_file_path):
+            return open(etag_file_path).read()
+
+    @staticmethod
+    def _write_etag_file(model_dir, etag_value):  # type: (str, str) -> None
+        etag_file_path = os.path.join(model_dir, "source_tar_etag")
+        open(etag_file_path, "w").write(etag_value)
+
+    @staticmethod
+    def _backup_old_version(model_dir):  # type: (str) -> None
+        if os.path.exists(model_dir):
+            backup_index = 0
+            while True:
+                dest = "{}.old.{}".format(model_dir, backup_index)
+                if os.path.exists(dest):
+                    backup_index += 1
+                    continue
+                shutil.move(model_dir, dest)
+                break
+
+    @classmethod
+    @retry
+    def prepare_model_data(cls, model_test):  # type: (OnnxTestCase) -> Text
+        onnx_home = os.path.expanduser(os.getenv("ONNX_HOME", os.path.join("~", ".onnx")))
+        models_dir = os.getenv("ONNX_MODELS", os.path.join(onnx_home, "models"))
+        model_dir = os.path.join(models_dir, model_test.model_name)  # type: Text
+        current_version_etag = ModelZooTestRunner._get_etag_for_url(model_test.url)
+
+        # If model already exists, check if it's the latest version by verifying cached Etag value
+        if os.path.exists(os.path.join(model_dir, "model.onnx")):
+            if not current_version_etag or current_version_etag == ModelZooTestRunner._read_etag_file(
+                model_dir
+            ):
+                return model_dir
+
+            # If model does exist, but is not current, backup directory
+            ModelZooTestRunner._backup_old_version(model_dir)
+
+        # Download and extract model and data
+        download_file = tempfile.NamedTemporaryFile(delete=False)
+        temp_clean_dir = tempfile.mkdtemp()
+
+        try:
+            download_file.close()
+            print("\nStart downloading model {} from {}".format(model_test.model_name, model_test.url))
+            urlretrieve(model_test.url, download_file.name)
+            print("Done")
+
+            with tempfile.TemporaryDirectory() as temp_extract_dir:
+                with tarfile.open(download_file.name) as tar_file:
+                    tar_file.extractall(temp_extract_dir)
+
+                # Move model `.onnx` file from temp_extract_dir to temp_clean_dir
+                model_files = glob.glob(temp_extract_dir + "/**/*.onnx", recursive=True)
+                assert len(model_files) > 0, "Model file not found for {}".format(model_test.name)
+                model_file = model_files[0]
+                shutil.move(model_file, temp_clean_dir + "/model.onnx")
+
+                # Move extracted test data sets to temp_clean_dir
+                test_data_sets = glob.glob(temp_extract_dir + "/**/test_data_set_*", recursive=True)
+                test_data_sets.extend(glob.glob(temp_extract_dir + "/**/test_data_*.npz", recursive=True))
+                for test_data_set in test_data_sets:
+                    shutil.move(test_data_set, temp_clean_dir)
+
+                # Save Etag value to Etag file
+                ModelZooTestRunner._write_etag_file(temp_clean_dir, current_version_etag)
+
+                # Move temp_clean_dir to ultimate destination
+                shutil.move(temp_clean_dir, model_dir)
+
+        except Exception as e:
+            print("Failed to prepare data for model {}: {}".format(model_test.model_name, e))
+            os.remove(temp_clean_dir)
+            raise
+        finally:
+            os.remove(download_file.name)
+        return model_dir
diff --git a/ngraph/python/tests/test_onnx/utils/onnx_backend.py b/ngraph/python/tests/test_onnx/utils/onnx_backend.py
new file mode 100644 (file)
index 0000000..7513fc4
--- /dev/null
@@ -0,0 +1,131 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+"""
+ONNX Backend implementation.
+
+See ONNX documentation for details:
+https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
+"""
+
+from typing import Any, Dict, List, Optional, Sequence, Text, Tuple
+
+import numpy
+import onnx
+from onnx.backend.base import Backend, BackendRep
+from onnx.helper import make_graph, make_model, make_tensor_value_info
+
+from ngraph.impl import Function
+from tests.runtime import get_runtime
+from tests.test_onnx.utils.onnx_helpers import import_onnx_model, np_dtype_to_tensor_type
+
+
+class OpenVinoOnnxBackendRep(BackendRep):
+    def __init__(self, ng_model_function, device="CPU"):  # type: (List[Function], str) -> None
+        super().__init__()
+        self.device = device
+        self.ng_model_function = ng_model_function
+        self.runtime = get_runtime()
+        self.computation = self.runtime.computation(ng_model_function)
+
+    def run(self, inputs, **kwargs):  # type: (Any, **Any) -> Tuple[Any, ...]
+        """Run computation on model."""
+        return self.computation(*inputs)
+
+
+class OpenVinoOnnxBackend(Backend):
+    @classmethod
+    def is_compatible(
+        cls,
+        model,  # type: onnx.ModelProto
+        device="CPU",  # type: Text
+        **kwargs  # type: Any
+    ):  # type: (...) -> bool
+        # Return whether the model is compatible with the backend.
+        try:
+            import_onnx_model(model)
+            return True
+        except Exception:
+            return False
+
+    @classmethod
+    def prepare(
+        cls,
+        onnx_model,  # type: onnx.ModelProto
+        device="CPU",  # type: Text
+        **kwargs  # type: Any
+    ):  # type: (...) -> OpenVinoOnnxBackendRep
+        onnx.checker.check_model(onnx_model)
+        super().prepare(onnx_model, device, **kwargs)
+        ng_model_function = import_onnx_model(onnx_model)
+        return OpenVinoOnnxBackendRep(ng_model_function, cls.backend_name)
+
+    @classmethod
+    def run_model(
+        cls,
+        model,  # type: onnx.ModelProto
+        inputs,  # type: Any
+        device="CPU",  # type: Text
+        **kwargs  # type: Any
+    ):  # type: (...) -> Tuple[Any, ...]
+        cls.prepare(model, device, **kwargs).run()
+
+    @classmethod
+    def run_node(
+        cls,
+        node,  # type: onnx.NodeProto
+        inputs,  # type: Any
+        device="CPU",  # type: Text
+        outputs_info=None,  # type: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]]
+        **kwargs  # type: Dict[Text, Any]
+    ):  # type: (...) -> Optional[Tuple[Any, ...]]
+        """Prepare and run a computation on an ONNX node."""
+        # default values for input/output tensors
+        input_tensor_types = [np_dtype_to_tensor_type(node_input.dtype) for node_input in inputs]
+        output_tensor_types = [onnx.TensorProto.FLOAT for idx in range(len(node.output))]
+        output_tensor_shapes = [()]  # type: List[Tuple[int, ...]]
+
+        if outputs_info is not None:
+            output_tensor_types = [np_dtype_to_tensor_type(dtype) for (dtype, shape) in outputs_info]
+            output_tensor_shapes = [shape for (dtype, shape) in outputs_info]
+
+        input_tensors = [
+            make_tensor_value_info(name, tensor_type, value.shape)
+            for name, value, tensor_type in zip(node.input, inputs, input_tensor_types)
+        ]
+        output_tensors = [
+            make_tensor_value_info(name, tensor_type, shape)
+            for name, shape, tensor_type in zip(node.output, output_tensor_shapes, output_tensor_types)
+        ]
+
+        graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
+        model = make_model(graph, producer_name="OpenVinoOnnxBackend")
+        if "opset_version" in kwargs:
+            model.opset_import[0].version = kwargs["opset_version"]
+        return cls.prepare(model, device).run(inputs)
+
+    @classmethod
+    def supports_device(cls, device):  # type: (Text) -> bool
+        """Check whether the backend is compiled with particular device support.
+
+        In particular it's used in the testing suite.
+        """
+        return device != "CUDA"
+
+
+prepare = OpenVinoOnnxBackend.prepare
+run_model = OpenVinoOnnxBackend.run_model
+run_node = OpenVinoOnnxBackend.run_node
+supports_device = OpenVinoOnnxBackend.supports_device
diff --git a/ngraph/python/tests/test_onnx/utils/onnx_helpers.py b/ngraph/python/tests/test_onnx/utils/onnx_helpers.py
new file mode 100644 (file)
index 0000000..de7e744
--- /dev/null
@@ -0,0 +1,42 @@
+# ******************************************************************************
+# Copyright 2018-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import onnx
+from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
+from openvino.inference_engine import IECore
+
+from ngraph.impl import Function
+
+
+def np_dtype_to_tensor_type(data_type: np.dtype) -> int:
+    """Return TensorProto type for provided numpy dtype.
+
+    :param data_type: Numpy data type object.
+    :return: TensorProto.DataType enum value for corresponding type.
+    """
+    return NP_TYPE_TO_TENSOR_TYPE[data_type]
+
+
+def import_onnx_model(model: onnx.ModelProto) -> Function:
+    onnx.checker.check_model(model)
+    model_byte_string = model.SerializeToString()
+
+    ie = IECore()
+    ie_network = ie.read_network(model=model_byte_string, weights=b"", init_from_buffer=True)
+
+    capsule = ie_network._get_function_capsule()
+    ng_function = Function.from_capsule(capsule)
+    return ng_function
index d09c07a..7b3a5e7 100644 (file)
@@ -6,7 +6,7 @@ skipdist=True
 skip_install=True
 deps =
   -rrequirements.txt
-  -rtest_requirements.txt
+  -rrequirements_test.txt
   mypy
   flake8-bugbear
 setenv =
@@ -15,15 +15,17 @@ setenv =
   LD_LIBRARY_PATH = {env:LD_LIBRARY_PATH:{homedir}/ngraph_dist/lib}
   DYLD_LIBRARY_PATH = {env:DYLD_LIBRARY_PATH:{homedir}/ngraph_dist/lib}
   PYBIND_HEADERS_PATH = {env:PYBIND_HEADERS_PATH:}
-  NGRAPH_BACKEND = {env:NGRAPH_BACKEND:"INTERPRETER"}
+  NGRAPH_BACKEND = {env:NGRAPH_BACKEND:"CPU"}
+  PYTHONPATH = {env:PYTHONPATH}
 commands=
   {envbindir}/python setup.py bdist_wheel
   {envbindir}/pip install --no-index --pre --find-links=dist/ ngraph-core
-  flake8 {posargs:src/ examples/ setup.py}
-  flake8 --ignore=D100,D101,D102,D103,D104,D105,D107,W503 test/  # ignore lack of docs in tests
-  mypy --config-file=tox.ini {posargs:src/ examples/}
-  pytest --backend={env:NGRAPH_BACKEND} {posargs:test/}
-
+  flake8 {posargs:src/ setup.py}
+  flake8 --ignore=D100,D101,D102,D103,D104,D105,D107,W503 tests/  # ignore lack of docs in tests
+  mypy --config-file=tox.ini {posargs:src/}
+  ; TODO: uncomment the line below when all test are ready (and delete the following line)
+  ; pytest --backend={env:NGRAPH_BACKEND} {posargs:tests/}
+  pytest --backend={env:NGRAPH_BACKEND} tests/test_ngraph/test_core.py tests/test_onnx/test_onnx_import.py
 
 [testenv:devenv]
 envdir = devenv
@@ -32,7 +34,7 @@ deps = -rrequirements.txt
 
 [flake8]
 inline-quotes = "
-max-line-length=100
+max-line-length=110
 max-complexity=7
 # ignore:
 # D100 - Missing docstring in public module