Tests and docs for registering custom ONNX operators (#2416)
authorMateusz Tabaka <mateusz.tabaka@intel.com>
Mon, 12 Oct 2020 04:36:19 +0000 (06:36 +0200)
committerGitHub <noreply@github.com>
Mon, 12 Oct 2020 04:36:19 +0000 (07:36 +0300)
* Add tests, examples and documentation changes for custom ONNX operators registration mechanism

* Change snippet paths

* fix CoreThreadingTests.ReadNetwork - data race in ops_bridge

* Make TemplateExtension::Operation externally visible

* changes after review

* apply code format

* use std::int64_t

* forward declare get_attribute_value specializations

* introduce unregister_operator in onnx_importer

* onnx_custom_op - lock mem first then take a buffer

* func tests - create template_extension via make_so_pointer

* fix build with NGRAPH_ONNX_IMPORT_ENABLE=OFF

* remove exports from Operation and Extension

* Move multithreaded AddExtension test to different directory to it can be excluded when NGRAPH_ONNX_IMPORT_ENABLE=OFF

* Dont include Extension tests if ENABLE_MKL_DNN=OFF

* fix excluding onnx_reader tests

* include extension tests only if mkl is enabled

* add comment on empty blob

* use register_operator conditionally in template_extension

* fix docs after review

* create static library from onnx_custom_op

* add additional test for unregister_operator

* move model example after register step

* revert changes to unit tests

* update ngraphConfig.cmake.in header

* add headers to onnx_custom_op

* changes to docs CMakeLists

* remove redundant onnx_importer dependency

* remove extension directory from func tests

* make onnx_importer a component of ngraph package

* docs fixes

* update header of ngraph/cmake/share/ngraphConfig.cmake.in with ngraph_onnx_importer_FOUND

25 files changed:
docs/CMakeLists.txt
docs/IE_DG/Extensibility_DG/Building.md
docs/IE_DG/Extensibility_DG/Custom_ONNX_Ops.md [new file with mode: 0644]
docs/IE_DG/Extensibility_DG/Extension.md
docs/doxygen/ie_docs.xml
docs/onnx_custom_op/CMakeLists.txt [new file with mode: 0644]
docs/onnx_custom_op/onnx_custom_op.cpp [new file with mode: 0644]
docs/onnx_custom_op/onnx_custom_op.hpp [new file with mode: 0644]
docs/template_extension/CMakeLists.txt
docs/template_extension/extension.cpp
docs/template_extension/extension.hpp
inference-engine/cmake/add_ie_target.cmake
inference-engine/tests/functional/inference_engine/CMakeLists.txt
inference-engine/tests/functional/inference_engine/extension.cpp [new file with mode: 0644]
inference-engine/tests/ie_test_utils/common_test_utils/CMakeLists.txt
ngraph/CMakeLists.txt
ngraph/cmake/share/ngraphConfig.cmake.in
ngraph/core/CMakeLists.txt
ngraph/frontend/onnx_import/CMakeLists.txt
ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp
ngraph/frontend/onnx_import/include/onnx_import/onnx_utils.hpp
ngraph/frontend/onnx_import/include/onnx_import/ops_bridge.hpp
ngraph/frontend/onnx_import/src/onnx_utils.cpp
ngraph/frontend/onnx_import/src/ops_bridge.cpp
ngraph/test/onnx/onnx_import.in.cpp

index 0da74ed..501564f 100644 (file)
@@ -17,6 +17,9 @@ if(NOT ENABLE_DOCKER)
         set(InferenceEngine_DIR ${CMAKE_BINARY_DIR})
     endif()
 
+    if (NGRAPH_ONNX_IMPORT_ENABLE)
+        add_subdirectory(onnx_custom_op)
+    endif()
     add_subdirectory(template_extension)
 
     set(all_docs_targets
index 8d33678..d1f62cb 100644 (file)
@@ -4,7 +4,7 @@ Inference Engine build infrastructure provides the Inference Engine Package for
 
 To build an extension library, use the following CMake script:
 
-@snippet CMakeLists.txt cmake:extension
+@snippet template_extension/CMakeLists.txt cmake:extension
 
 This CMake script finds the Inference Engine and nGraph using the `find_package` CMake command.
 
diff --git a/docs/IE_DG/Extensibility_DG/Custom_ONNX_Ops.md b/docs/IE_DG/Extensibility_DG/Custom_ONNX_Ops.md
new file mode 100644 (file)
index 0000000..2f42efc
--- /dev/null
@@ -0,0 +1,57 @@
+# Custom ONNX operators {#openvino_docs_IE_DG_Extensibility_DG_Custom_ONNX_Ops}
+
+ONNX importer provides mechanism to register custom ONNX operators based on predefined or user-defined nGraph operations.
+The function responsible for registering a new operator is called `ngraph::onnx_import::register_operator` and is defined in `onnx_import/onnx_utils.hpp`.
+
+## Registering custom ONNX operator based on predefined nGraph operations
+
+The steps below explain how to register a custom ONNX operator, for example, CustomRelu, in a domain called com.example.
+CustomRelu is defined as follows:
+```
+x >= 0 => f(x) = x * alpha
+x < 0  => f(x) = x * beta
+```
+where alpha, beta are float constants.
+
+1. Include headers:
+@snippet onnx_custom_op/main.cpp onnx_custom_op:headers
+
+2. Register the CustomRelu operator in the ONNX importer:
+@snippet onnx_custom_op/main.cpp onnx_custom_op:register_operator
+The `register_operator` function takes four arguments: op_type, opset version, domain, and a function object.
+The function object is a user-defined function that takes `ngraph::onnx_import::Node` as an input and based on that, returns a graph with nGraph operations.
+The `ngraph::onnx_import::Node` class represents a node in ONNX model. It provides functions to fetch input node(s) (`get_ng_inputs`), fetch attribute value (`get_attribute_value`) and many more (please refer to `onnx_import/core/node.hpp` for full class declaration).
+New operator registration must happen before the ONNX model is read, for example, if an ONNX model uses the 'CustomRelu' operator, `register_operator("CustomRelu", ...)` must be called before InferenceEngine::Core::ReadNetwork.
+Re-registering ONNX operators within the same process is supported. During registration of the existing operator, a warning is printed.
+
+The example below demonstrates an examplary model that requires previously created 'CustomRelu' operator:
+@snippet onnx_custom_op/main.cpp onnx_custom_op:model
+
+
+For a reference on how to create a graph with nGraph operations, visit [nGraph tutorial](../nGraphTutorial.md).
+For a complete list of predefined nGraph operators, visit [available operations sets](../../ops/opset.md).
+
+If operator is no longer needed, it can be unregistered by calling `unregister_operator`. The function takes three arguments `op_type`, `version`, and `domain`.
+@snippet onnx_custom_op/main.cpp onnx_custom_op:unregister_operator
+
+## Registering custom ONNX operator based on custom nGraph operations
+
+The same principles apply when registering custom ONNX operator based on custom nGraph operations.
+This example shows how to register custom ONNX operator based on `Operation` presented in [this tutorial](AddingNGraphOps.md), which is used in [TemplateExtension](Extension.md).
+@snippet extension.cpp extension:ctor
+
+Here, the `register_operator` function is called in Extension's constructor, which makes sure that it is called before InferenceEngine::Core::ReadNetwork (since InferenceEngine::Core::AddExtension must be called before a model with custom operator is read).
+
+The example below demonstrates how to unregister operator from Extension's destructor:
+@snippet extension.cpp extension:dtor
+Note that it is mandatory to unregister custom ONNX operator if it is defined in dynamic shared library.
+
+## Requirements for building with CMake
+
+Program that uses the `register_operator` functionality, requires (in addition to Inference Engine) `ngraph` and `onnx_importer` libraries.
+The `onnx_importer` is a component of `ngraph` package , so `find_package(ngraph REQUIRED COMPONENTS onnx_importer)` is sufficient to find both.
+The `ngraph` package exposes two variables (`${NGRAPH_LIBRARIES}` and `${ONNX_IMPORTER_LIBRARIES}`), which reference `ngraph` and `onnx_importer` libraries.
+Those variables need to be passed to the `target_link_libraries` command in the CMakeLists.txt file.
+
+See below CMakeLists.txt for reference:
+@snippet onnx_custom_op/CMakeLists.txt cmake:onnx_custom_op
index 1eb84bb..3bc96f9 100644 (file)
@@ -23,3 +23,4 @@ Implement the  InferenceEngine::IExtension::getOpSets method if the extension co
 Read the [guide about custom operations](AddingNGraphOps.md) for more information.
 
 To understand how integrate execution kernels to the extension library, read the [guide about development of custom CPU kernels](CPU_Kernel.md).
+To understand how to register custom ONNX operator to the extension library, read the [guide about custom ONNX operators](Custom_ONNX_Ops.md).
index 2d0d8d1..120b5e1 100644 (file)
                         <tab type="user" title="GPU Kernels Extensibility" url="@ref openvino_docs_IE_DG_Extensibility_DG_GPU_Kernel"/>
                         <tab type="user" title="VPU Kernels Extensibility" url="@ref openvino_docs_IE_DG_Extensibility_DG_VPU_Kernel"/>
                         <tab type="user" title="Build Extension Library Using CMake" url="@ref openvino_docs_IE_DG_Extensibility_DG_Building"/>
+                        <tab type="user" title="Custom ONNX operators" url="@ref openvino_docs_IE_DG_Extensibility_DG_Custom_ONNX_Ops"/>
                     </tab>
                     <tab type="user" title="Integrate the Inference Engine with Your Application" url="@ref openvino_docs_IE_DG_Integrate_with_customer_application_new_API"/>
                     <tab type="user" title="[DEPRECATED] Migration from Inference Engine Plugin API to Core API" url="@ref openvino_docs_IE_DG_Migration_CoreAPI"/>
diff --git a/docs/onnx_custom_op/CMakeLists.txt b/docs/onnx_custom_op/CMakeLists.txt
new file mode 100644 (file)
index 0000000..0f8decc
--- /dev/null
@@ -0,0 +1,15 @@
+# Copyright (C) 2020 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+# [cmake:onnx_custom_op]
+set(CMAKE_CXX_STANDARD 11)
+
+set(TARGET_NAME "onnx_custom_op")
+
+find_package(ngraph REQUIRED COMPONENTS onnx_importer)
+
+add_library(${TARGET_NAME} STATIC onnx_custom_op.cpp)
+
+target_link_libraries(${TARGET_NAME} PUBLIC ${NGRAPH_LIBRARIES} ${ONNX_IMPORTER_LIBRARIES})
+# [cmake:onnx_custom_op]
diff --git a/docs/onnx_custom_op/onnx_custom_op.cpp b/docs/onnx_custom_op/onnx_custom_op.cpp
new file mode 100644 (file)
index 0000000..42ebbd5
--- /dev/null
@@ -0,0 +1,118 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+//! [onnx_custom_op:headers]
+// onnx_import/onnx_utils.hpp provides ngraph::onnx_import::register_operator function, that registers operator in ONNX importer's set.
+#include <onnx_import/onnx_utils.hpp>
+// ngraph/opsets/opset5.hpp provides the declaration of predefined nGraph operator set
+#include <ngraph/opsets/opset5.hpp>
+//! [onnx_custom_op:headers]
+
+
+std::string custom_relu_model() {
+//! [onnx_custom_op:model]
+    return R"ONNX(
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "in"
+    output: "out"
+    name: "customrelu"
+    op_type: "CustomRelu"
+    domain: "com.example"
+    attribute {
+        name: "alpha"
+        type: FLOAT
+        f: 2
+    }
+    attribute {
+        name: "beta"
+        type: FLOAT
+        f: 3
+    }
+  }
+  name: "custom relu graph"
+  input {
+    name: "in"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "out"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  domain: "com.example"
+  version: 1
+}
+)ONNX";
+//! [onnx_custom_op:model]
+}
+
+
+void register_custom_relu_operator() {
+    // CustomRelu is defined as follows:
+    // x >= 0 => f(x) = x * alpha
+    // x < 0  => f(x) = x * beta
+
+//! [onnx_custom_op:register_operator]
+    ngraph::onnx_import::register_operator(
+        "CustomRelu", 1, "com.example", [](const ngraph::onnx_import::Node& onnx_node) -> ngraph::OutputVector {
+            namespace opset = ngraph::opset5;
+
+            ngraph::OutputVector ng_inputs{onnx_node.get_ng_inputs()};
+            const ngraph::Output<ngraph::Node>& data = ng_inputs.at(0);
+            // create constant node with a single element that's equal to zero
+            std::shared_ptr<ngraph::Node> zero_node = opset::Constant::create(data.get_element_type(), ngraph::Shape{}, {0});
+            // create a negative map for 'data' node, 1 for negative values , 0 for positive values or zero
+            // then convert it from boolean type to `data.get_element_type()`
+            std::shared_ptr<ngraph::Node> negative_map = std::make_shared<opset::Convert>(
+                std::make_shared<opset::Less>(data, zero_node), data.get_element_type());
+            // create a positive map for 'data' node, 0 for negative values , 1 for positive values or zero
+            // then convert it from boolean type to `data.get_element_type()`
+            std::shared_ptr<ngraph::Node> positive_map = std::make_shared<opset::Convert>(
+                std::make_shared<opset::GreaterEqual>(data, zero_node), data.get_element_type());
+
+            // fetch alpha and beta attributes from ONNX node
+            float alpha = onnx_node.get_attribute_value<float>("alpha", 1); // if 'alpha' attribute is not provided in the model, then the default value is 1
+            float beta = onnx_node.get_attribute_value<float>("beta");
+            // create constant node with a single element 'alpha' with type f32
+            std::shared_ptr<ngraph::Node> alpha_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {alpha});
+            // create constant node with a single element 'beta' with type f32
+            std::shared_ptr<ngraph::Node> beta_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {beta});
+
+            return {
+                std::make_shared<opset::Add>(
+                    std::make_shared<opset::Multiply>(alpha_node, std::make_shared<opset::Multiply>(data, positive_map)),
+                    std::make_shared<opset::Multiply>(beta_node, std::make_shared<opset::Multiply>(data, negative_map))
+                )
+            };
+    });
+//! [onnx_custom_op:register_operator]
+}
+
+void unregister_custom_relu_operator() {
+//! [onnx_custom_op:unregister_operator]
+    ngraph::onnx_import::unregister_operator("CustomRelu", 1, "com.example");
+//! [onnx_custom_op:unregister_operator]
+}
diff --git a/docs/onnx_custom_op/onnx_custom_op.hpp b/docs/onnx_custom_op/onnx_custom_op.hpp
new file mode 100644 (file)
index 0000000..a5189c6
--- /dev/null
@@ -0,0 +1,11 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <string>
+
+std::string custom_relu_model();
+void register_custom_relu_operator();
+void unregister_custom_relu_operator();
index 427ab21..4133f03 100644 (file)
@@ -3,18 +3,22 @@
 #
 
 # [cmake:extension]
+set(CMAKE_CXX_STANDARD 11)
+
 set(TARGET_NAME "template_extension")
 
-find_package(ngraph REQUIRED)
+find_package(ngraph REQUIRED OPTIONAL_COMPONENTS onnx_importer)
 find_package(InferenceEngine REQUIRED)
 
 file(GLOB_RECURSE SRC *.cpp)
 
 add_library(${TARGET_NAME} SHARED ${SRC})
 
-target_include_directories(${TARGET_NAME} PRIVATE ${InferenceEngine_INCLUDE_DIRS})
-
 target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_EXTENSION_API)
-target_link_libraries(${TARGET_NAME} PRIVATE IE::inference_engine_legacy IE::inference_engine
-                                             ${NGRAPH_LIBRARIES})
+set(LIBRARIES IE::inference_engine_legacy IE::inference_engine ${NGRAPH_LIBRARIES})
+if (ngraph_onnx_importer_FOUND)
+    list(APPEND LIBRARIES ${ONNX_IMPORTER_LIBRARIES})
+    target_compile_definitions(${TARGET_NAME} PRIVATE NGRAPH_ONNX_IMPORT_ENABLED)
+endif()
+target_link_libraries(${TARGET_NAME} PRIVATE ${LIBRARIES})
 # [cmake:extension]
index 481b7c0..3a4d373 100644 (file)
@@ -5,6 +5,9 @@
 #include "cpu_kernel.hpp"
 #include "op.hpp"
 #include <ngraph/ngraph.hpp>
+#ifdef NGRAPH_ONNX_IMPORT_ENABLED
+#include <onnx_import/onnx_utils.hpp>
+#endif
 
 #include <map>
 #include <memory>
 
 using namespace TemplateExtension;
 
+
+//! [extension:ctor]
+Extension::Extension() {
+#ifdef NGRAPH_ONNX_IMPORT_ENABLED
+    ngraph::onnx_import::register_operator(
+        Operation::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
+            ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
+            int64_t add = node.get_attribute_value<int64_t>("add");
+            return {std::make_shared<Operation>(ng_inputs.at(0), add)};
+    });
+#endif
+}
+//! [extension:ctor]
+
+//! [extension:dtor]
+Extension::~Extension() {
+#ifdef NGRAPH_ONNX_IMPORT_ENABLED
+    ngraph::onnx_import::unregister_operator(Operation::type_info.name, 1, "custom_domain");
+#endif
+}
+//! [extension:dtor]
+
 //! [extension:GetVersion]
 void Extension::GetVersion(const InferenceEngine::Version *&versionInfo) const noexcept {
     static InferenceEngine::Version ExtensionDescription = {
index e74c6c4..fa7463b 100644 (file)
@@ -17,7 +17,8 @@ namespace TemplateExtension {
 
 class Extension : public InferenceEngine::IExtension {
 public:
-    Extension() = default;
+    Extension();
+    ~Extension();
     void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
     void Unload() noexcept override {}
     void Release() noexcept override { delete this; }
index 0dfc1ce..30fc8bf 100644 (file)
@@ -13,7 +13,7 @@ addIeTarget(
    ROOT ${CMAKE_CURRENT_SOURCE_DIR}
    ADDITIONAL_SOURCE_DIRS
         /some/additional/sources
-   EXCLUDED_SOURCE_DIRS
+   EXCLUDED_SOURCE_PATHS
         ${CMAKE_CURRENT_SOURCE_DIR}/unnecessary_sources/
    INCLUDES
         ${SDL_INCLUDES}
@@ -47,7 +47,7 @@ function(addIeTarget)
         DEFINES                       # extra preprocessor definitions
         ADDITIONAL_SOURCE_DIRS        # list of directories which will be used to recursive search of source files in addition to ROOT
         OBJECT_FILES                  # list of object files to be additionally built into the target
-        EXCLUDED_SOURCE_DIRS          # list of directories excluded from the global recursive search of source files
+        EXCLUDED_SOURCE_PATHS         # list of paths excluded from the global recursive search of source files
         LINK_LIBRARIES_WHOLE_ARCHIVE  # list of static libraries to link, each object file should be used and not discarded
         LINK_FLAGS                    # list of extra commands to linker
         EXPORT_DEPENDENCIES           # list of the dependencies to be exported with the target through the developer package
@@ -76,10 +76,10 @@ function(addIeTarget)
     file(GLOB_RECURSE sources  ${sourceSearch})
 
     # remove unnecessary directories
-    if (ARG_EXCLUDED_SOURCE_DIRS)
-        list(FILTER includes EXCLUDE REGEX "${ARG_EXCLUDED_SOURCE_DIRS}/*")
-        list(FILTER sources EXCLUDE REGEX "${ARG_EXCLUDED_SOURCE_DIRS}/*")
-    endif()
+    foreach(excludedDir ${ARG_EXCLUDED_SOURCE_PATHS})
+        list(FILTER includes EXCLUDE REGEX "${excludedDir}*")
+        list(FILTER sources EXCLUDE REGEX "${excludedDir}*")
+    endforeach()
 
     source_group("include" FILES ${includes})
     source_group("src"     FILES ${sources})
index 3005af4..2322c29 100644 (file)
@@ -5,25 +5,40 @@
 
 set(TARGET_NAME ieFuncTests)
 
+set(INCLUDES ${IE_MAIN_SOURCE_DIR}/src/inference_engine)
+set(LINK_LIBRARIES
+    gmock
+    funcTestUtils
+    ngraphFunctions
+    inference_engine_transformations
+)
+set(DEPENDENCIES
+    mock_engine
+    inference_engine_ir_reader
+    inference_engine_ir_v7_reader
+    template_extension
+)
+
+if (NGRAPH_ONNX_IMPORT_ENABLE)
+    list(APPEND INCLUDES "${OpenVINO_MAIN_SOURCE_DIR}/docs/onnx_custom_op")
+    list(APPEND LINK_LIBRARIES onnx_custom_op)
+    list(APPEND DEPENDENCIES onnx_custom_op)
+else()
+    set(EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/onnx_reader")
+endif()
+
+if (NOT NGRAPH_ONNX_IMPORT_ENABLE OR NOT ENABLE_MKL_DNN)
+    set(EXCLUDED_SOURCE_PATHS ${EXCLUDED_SOURCE_PATHS} "${CMAKE_CURRENT_SOURCE_DIR}/extension.cpp")
+endif()
+
 addIeTargetTest(
         NAME ${TARGET_NAME}
         ROOT ${CMAKE_CURRENT_SOURCE_DIR}
-        INCLUDES
-            # TODO: remove after removing `cnn_network_ngraph_imp.hpp`
-            ${IE_MAIN_SOURCE_DIR}/src/inference_engine
-        EXCLUDED_SOURCE_DIRS
-            ${CMAKE_CURRENT_SOURCE_DIR}/extension_lib
-        LINK_LIBRARIES
-            gmock
-            funcTestUtils
-            ngraphFunctions
-            inference_engine_transformations
+        INCLUDES ${INCLUDES}
+        EXCLUDED_SOURCE_PATHS ${EXCLUDED_SOURCE_PATHS}
+        LINK_LIBRARIES ${LINK_LIBRARIES}
         ADD_CPPLINT
-        DEPENDENCIES
-            template_extension
-            mock_engine
-            inference_engine_ir_reader
-            inference_engine_ir_v7_reader
+        DEPENDENCIES ${DEPENDENCIES}
         LABELS
             IE
 )
diff --git a/inference-engine/tests/functional/inference_engine/extension.cpp b/inference-engine/tests/functional/inference_engine/extension.cpp
new file mode 100644 (file)
index 0000000..551fdf9
--- /dev/null
@@ -0,0 +1,422 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include <ie_core.hpp>
+#include <ngraph/ngraph.hpp>
+#include <onnx_import/onnx_utils.hpp>
+#include <file_utils.h>
+#include <common_test_utils/test_assertions.hpp>
+#include <functional_test_utils/test_model/test_model.hpp>
+#include <onnx_custom_op.hpp>
+
+
+class CustomAbsKernel : public InferenceEngine::ILayerExecImpl {
+    public:
+        explicit CustomAbsKernel(const std::shared_ptr<ngraph::Node>& node): node(node) {}
+
+        InferenceEngine::StatusCode
+        init(InferenceEngine::LayerConfig& /*config*/, InferenceEngine::ResponseDesc* /*resp*/) noexcept override {
+            return InferenceEngine::StatusCode::OK;
+        }
+
+        InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf,
+                                                               InferenceEngine::ResponseDesc* /*resp*/) noexcept override {
+            InferenceEngine::LayerConfig layerConfig;
+            layerConfig.dynBatchSupport = true;
+
+            if (node->outputs().size() != 1 && node->inputs().size() != 1)
+                return InferenceEngine::GENERAL_ERROR;
+
+            InferenceEngine::DataConfig cfg;
+            cfg.constant = false;
+            cfg.inPlace = 0;
+
+            InferenceEngine::SizeVector order;
+            auto partialShape = node->get_output_partial_shape(0);
+            if (partialShape.is_dynamic())
+                return InferenceEngine::GENERAL_ERROR;
+
+            auto shape = node->get_output_shape(0);
+            for (size_t i = 0; i < shape.size(); i++) {
+                order.push_back(i);
+            }
+            cfg.desc = InferenceEngine::TensorDesc(InferenceEngine::Precision::FP32,
+                                                   shape, {shape, order});
+            layerConfig.outConfs.push_back(cfg);
+            layerConfig.inConfs.push_back(cfg);
+            conf.push_back(layerConfig);
+            return InferenceEngine::OK;
+        }
+
+        InferenceEngine::StatusCode
+        execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, std::vector<InferenceEngine::Blob::Ptr>& outputs,
+                InferenceEngine::ResponseDesc* /*resp*/) noexcept override {
+            for (size_t i = 0; i < inputs.size(); i++) {
+                InferenceEngine::MemoryBlob::CPtr minput = InferenceEngine::as<InferenceEngine::MemoryBlob>(inputs[i]);
+                InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::as<InferenceEngine::MemoryBlob>(outputs[i]);
+                if (!moutput || !minput) {
+                    return InferenceEngine::StatusCode::PARAMETER_MISMATCH;
+                }
+                // locked memory holder should be alive all time while access to its buffer happens
+                auto minputHolder = minput->rmap();
+                auto moutputHolder = moutput->wmap();
+
+                auto inputData = minputHolder.as<const float *>();
+                auto outputData = moutputHolder.as<float  *>();
+                for (size_t j = 0; j < minput->size(); j++) {
+                    outputData[j] = inputData[j] < 0 ? (-inputData[j] * 2) : inputData[j];
+                }
+            }
+            return InferenceEngine::StatusCode::OK;
+        }
+
+
+
+    private:
+        const std::shared_ptr<ngraph::Node> node;
+};
+
+class CustomAbs : public ngraph::op::Op {
+public:
+    static constexpr ngraph::NodeTypeInfo type_info{"CustomAbs", 100500};
+    const ngraph::NodeTypeInfo& get_type_info() const override { return type_info;  }
+    CustomAbs() = default;
+    CustomAbs(const ngraph::Output<ngraph::Node>& arg): ngraph::op::Op({arg}) {
+        constructor_validate_and_infer_types();
+    }
+    void validate_and_infer_types() override {
+        set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+    }
+    std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
+        return std::make_shared<CustomAbs>(new_args.at(0));
+    }
+    bool visit_attributes(ngraph::AttributeVisitor& visitor) override {
+        return true;
+    }
+};
+
+constexpr ngraph::NodeTypeInfo CustomAbs::type_info;
+
+class CustomAbsExtension : public InferenceEngine::IExtension {
+    public:
+        CustomAbsExtension() {
+        }
+
+        void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override {}
+
+        void Release() noexcept override { delete this; }
+
+        void Unload() noexcept override {}
+
+        std::map<std::string, ngraph::OpSet> getOpSets() override {
+            std::map<std::string, ngraph::OpSet> opsets;
+            ngraph::OpSet opset;
+            opset.insert<CustomAbs>();
+            opsets["custom_opset"] = opset;
+            return opsets;
+        }
+
+        std::vector<std::string> getImplTypes(const std::shared_ptr<ngraph::Node>& node) override {
+            if (node->description() != CustomAbs::type_info.name)
+                return {};
+            return {"CPU"};
+        }
+
+        InferenceEngine::ILayerImpl::Ptr getImplementation(const std::shared_ptr<ngraph::Node>& node, const std::string& implType) override {
+            return std::make_shared<CustomAbsKernel>(node);
+        }
+};
+
+void infer_model(InferenceEngine::Core& ie, const std::string& model, const std::vector<float>& input_values, const std::vector<float>& expected) {
+    InferenceEngine::Blob::CPtr weights;
+    auto network = ie.ReadNetwork(model, weights);
+    auto function = network.getFunction();
+
+    auto network_inputs = network.getInputsInfo();
+    auto network_outputs = network.getOutputsInfo();
+    auto exe_network = ie.LoadNetwork(network, "CPU");
+    auto inference_req = exe_network.CreateInferRequest();
+    const auto& input = network_inputs.begin();
+    const auto& input_info = input->second;
+
+    auto blob = std::make_shared<InferenceEngine::TBlob<float>>(input_info->getTensorDesc());
+    blob->allocate();
+    ASSERT_EQ(input_values.size(), blob->size());
+    float* blob_buffer = blob->wmap().template as<float*>();
+    std::copy(input_values.begin(), input_values.end(), blob_buffer);
+    inference_req.SetBlob(input->first, blob);
+
+    inference_req.Infer();
+
+    auto output = network_outputs.begin();
+    InferenceEngine::MemoryBlob::CPtr computed = InferenceEngine::as<InferenceEngine::MemoryBlob>(inference_req.GetBlob(output->first));
+    const auto computed_data = computed->rmap();
+    const auto* computed_data_buffer = computed_data.template as<const float*>();
+    std::vector<float> computed_values(computed_data_buffer,
+                                   computed_data_buffer + computed->size());
+    ASSERT_EQ(expected, computed_values);
+}
+
+
+TEST(Extension, OnnxModelWithCustomAbs) {
+    std::string model = R"V0G0N(
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "A"
+    output: "Y"
+    name: "customrelu"
+    op_type: "CustomAbs"
+    domain: "custom_domain"
+  }
+  name: "test_graph"
+  input {
+    name: "A"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "Y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 1
+  domain: "custom_domain"
+}
+)V0G0N";
+
+    std::vector<float> input_values{1, -2, 3, -4, 5, -6, 7, -8, 9, -10};
+    std::vector<float> expected{1, 4, 3, 8, 5, 12, 7, 16, 9, 20};
+    InferenceEngine::Core ie;
+    ie.AddExtension(std::make_shared<CustomAbsExtension>());
+    ngraph::onnx_import::register_operator(
+        CustomAbs::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
+            ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
+            return {std::make_shared<CustomAbs>(ng_inputs.at(0))};
+    });
+
+    infer_model(ie, model, input_values, expected);
+    ngraph::onnx_import::unregister_operator(CustomAbs::type_info.name, 1, "custom_domain");
+}
+
+
+TEST(Extension, XmlModelWithCustomAbs) {
+    std::string model = R"V0G0N(
+<net name="Network" version="10">
+    <layers>
+        <layer name="in1" type="Parameter" id="0" version="opset1">
+            <data element_type="f32" shape="10"/>
+            <output>
+                <port id="0" precision="FP32">
+                    <dim>10</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="activation" id="1" type="CustomAbs" version="custom_opset">
+            <input>
+                <port id="1" precision="FP32">
+                    <dim>10</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2" precision="FP32">
+                    <dim>10</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="output" type="Result" id="2" version="opset1">
+            <input>
+                <port id="0" precision="FP32">
+                    <dim>10</dim>
+                </port>
+            </input>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+        <edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
+    </edges>
+</net>
+)V0G0N";
+
+    std::vector<float> input_values{1, -2, 3, -4, 5, -6, 7, -8, 9, -10};
+    std::vector<float> expected{1, 4, 3, 8, 5, 12, 7, 16, 9, 20};
+    InferenceEngine::Core ie;
+    ie.AddExtension(std::make_shared<CustomAbsExtension>());
+    infer_model(ie, model, input_values, expected);
+}
+
+
+static std::string get_extension_path() {
+    return FileUtils::makeSharedLibraryName<char>({},
+            std::string("template_extension") + IE_BUILD_POSTFIX);
+}
+
+
+TEST(Extension, XmlModelWithExtensionFromDSO) {
+    std::string model = R"V0G0N(
+<net name="Network" version="10">
+    <layers>
+        <layer name="in1" type="Parameter" id="0" version="opset1">
+            <data element_type="f32" shape="2,2,2,1"/>
+            <output>
+                <port id="0" precision="FP32">
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>1</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="operation" id="1" type="Template" version="custom_opset">
+            <data  add="11"/>
+            <input>
+                <port id="1" precision="FP32">
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>1</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2" precision="FP32">
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>1</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="output" type="Result" id="2" version="opset1">
+            <input>
+                <port id="0" precision="FP32">
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>2</dim>
+                    <dim>1</dim>
+                </port>
+            </input>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+        <edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
+    </edges>
+</net>
+)V0G0N";
+
+    std::vector<float> input_values{1, 2, 3, 4, 5, 6, 7, 8};
+    std::vector<float> expected{12, 13, 14, 15, 16, 17, 18, 19};
+    InferenceEngine::Core ie;
+    ie.AddExtension(InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(get_extension_path()));
+    infer_model(ie, model, input_values, expected);
+}
+
+
+TEST(Extension, OnnxModelWithExtensionFromDSO) {
+    std::string model = R"V0G0N(
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "A"
+    output: "Y"
+    name: "operation"
+    op_type: "Template"
+    domain: "custom_domain"
+    attribute {
+        name: "add"
+        type: INT
+        i: 11
+    }
+  }
+  name: "test_graph"
+  input {
+    name: "A"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "Y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 1
+  domain: "com.example"
+}
+)V0G0N";
+
+    std::vector<float> input_values{1, 2, 3, 4, 5, 6, 7, 8};
+    std::vector<float> expected{12, 13, 14, 15, 16, 17, 18, 19};
+    InferenceEngine::Core ie;
+    ie.AddExtension(InferenceEngine::make_so_pointer<InferenceEngine::IExtension>(get_extension_path()));
+    infer_model(ie, model, input_values, expected);
+}
+
+
+TEST(Extension, OnnxModelWithCustomReluDocsExample) {
+    std::vector<float> input_values{0, -1, 2, -3, 4, -5, 6, -7};
+    std::vector<float> expected{0, -3, 4, -9, 8, -15, 12, -21};
+
+    register_custom_relu_operator();
+    InferenceEngine::Core ie;
+    infer_model(ie, custom_relu_model(), input_values, expected);
+    unregister_custom_relu_operator();
+}
index 76895ef..ed514df 100644 (file)
@@ -49,7 +49,7 @@ function(add_common_utils ADD_TARGET_NAME)
             NAME ${ADD_TARGET_NAME}
             TYPE STATIC
             ROOT ${CMAKE_CURRENT_SOURCE_DIR}
-            EXCLUDED_SOURCE_DIRS
+            EXCLUDED_SOURCE_PATHS
                 ${CMAKE_CURRENT_SOURCE_DIR}/gtest
             ADD_CPPLINT
             DEVELOPER_PACKAGE
index 4eb969f..886ec53 100644 (file)
@@ -476,22 +476,10 @@ endif()
 
 add_subdirectory(core)
 
-add_subdirectory(frontend)
-
-if (NGRAPH_TEST_UTIL_ENABLE)
-    include(cmake/external_gtest.cmake)
-endif()
-
-add_subdirectory(test)
-
-if (NGRAPH_PYTHON_BUILD_ENABLE)
-    add_subdirectory(python)
-endif()
-
 if (NGRAPH_EXPORT_TARGETS_ENABLE)
     include(CMakePackageConfigHelpers)
-
-    export(TARGETS ngraph NAMESPACE ngraph:: FILE "${CMAKE_CURRENT_BINARY_DIR}/ngraphTargets.cmake")
+    set(NGRAPH_TARGETS_FILE "${CMAKE_CURRENT_BINARY_DIR}/ngraphTargets.cmake")
+    export(TARGETS ngraph NAMESPACE ngraph:: FILE "${NGRAPH_TARGETS_FILE}")
 
     install(EXPORT ngraphTargets
         FILE ngraphTargets.cmake
@@ -513,6 +501,18 @@ if (NGRAPH_EXPORT_TARGETS_ENABLE)
         COMPONENT ngraph)
 endif()
 
+add_subdirectory(frontend)
+
+if (NGRAPH_TEST_UTIL_ENABLE)
+    include(cmake/external_gtest.cmake)
+endif()
+
+add_subdirectory(test)
+
+if (NGRAPH_PYTHON_BUILD_ENABLE)
+    add_subdirectory(python)
+endif()
+
 install(DIRECTORY
     ${CMAKE_CURRENT_SOURCE_DIR}/licenses
     DESTINATION "${NGRAPH_COMPONENT_PREFIX}."
index dbe30ba..3c25413 100644 (file)
 #
 # This will define the following variables:
 #
-#   ngraph_FOUND        - True if the system has the nGraph library
-#   NGRAPH_LIBRARIES    - nGraph libraries
-#   ngraph::ngraph      - nGraph core target
+#   ngraph_FOUND               - True if the system has the nGraph library
+#   NGRAPH_LIBRARIES           - nGraph libraries
+#   ngraph::ngraph             - nGraph core target
+#   ngraph_onnx_importer_FOUND - True if the system has onnx_importer library
+#   ONNX_IMPORTER_LIBRARIES    - ONNX importer libraries
+#   ngraph::onnx_importer      - ONNX importer target
 #
 #
 
@@ -33,4 +36,8 @@ if(NOT TARGET ngraph)
 endif()
 
 set(NGRAPH_LIBRARIES ngraph::ngraph)
+set(ngraph_onnx_importer_FOUND @NGRAPH_ONNX_IMPORT_ENABLE@)
+if(ngraph_onnx_importer_FOUND)
+    set(ONNX_IMPORTER_LIBRARIES ngraph::onnx_importer)
+endif()
 check_required_components(ngraph)
index feb8184..c4ae969 100644 (file)
@@ -87,7 +87,7 @@ elseif(APPLE)
 endif()
 
 # Defines macro in C++ to load backend plugin
-target_include_directories(ngraph PUBLIC $<BUILD_INTERFACE:${NGRAPH_INCLUDE_PATH}> $<INSTALL_INTERFACE:include>)
+target_include_directories(ngraph PUBLIC $<BUILD_INTERFACE:${NGRAPH_INCLUDE_PATH}> $<INSTALL_INTERFACE:${NGRAPH_INSTALL_INCLUDE}>)
 target_include_directories(ngraph PRIVATE ${NGRAPH_INCLUDE_DIR}
                                           ${NGRAPH_INCLUDE_DIR}/op
                                           ${NGRAPH_INCLUDE_DIR}/op/util
index ac2e546..6377c77 100644 (file)
@@ -41,6 +41,7 @@ source_group("include" FILES ${PUBLIC_HEADERS})
 
 # Create shared library
 add_library(onnx_importer SHARED ${LIBRARY_SRC} ${PUBLIC_HEADERS})
+add_library(ngraph::onnx_importer ALIAS onnx_importer)
 
 if(COMMAND ie_faster_build)
     ie_faster_build(onnx_importer
@@ -56,8 +57,9 @@ set_target_properties(onnx_importer PROPERTIES
                       C_VISIBILITY_PRESET hidden
                       VISIBILITY_INLINES_HIDDEN ON
                       POSITION_INDEPENDENT_CODE ON)
+set(ONNX_INSTALL_INCLUDE "${NGRAPH_INSTALL_INCLUDE}/ngraph/frontend")
 target_include_directories(onnx_importer SYSTEM PUBLIC $<BUILD_INTERFACE:${ONNX_IMPORT_INCLUDE_DIR}>
-                                                       $<INSTALL_INTERFACE:include/ngraph/frontend/>)
+                                                       $<INSTALL_INTERFACE:${ONNX_INSTALL_INCLUDE}>)
 target_include_directories(onnx_importer SYSTEM PRIVATE ${NGRAPH_INCLUDE_PATH} ${NGRAPH_INCLUDE_PATH}/ngraph
         ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIRS})
 target_include_directories(onnx_importer PRIVATE ${ONNX_IMPORT_INCLUDE_DIR}/onnx_import/core
@@ -83,10 +85,15 @@ install(TARGETS onnx_importer EXPORT ngraphTargets
         ARCHIVE DESTINATION ${NGRAPH_INSTALL_LIB} COMPONENT ngraph
         LIBRARY DESTINATION ${NGRAPH_INSTALL_LIB} COMPONENT ngraph)
 
-    install(DIRECTORY ${ONNX_IMPORT_INCLUDE_DIR}/onnx_import
-    DESTINATION ${NGRAPH_INSTALL_INCLUDE}/ngraph/frontend/
+install(DIRECTORY ${ONNX_IMPORT_INCLUDE_DIR}/onnx_import
+    DESTINATION ${ONNX_INSTALL_INCLUDE}
     COMPONENT ngraph
     FILES_MATCHING
         PATTERN "*.hpp"
         PATTERN "*.h"
 )
+
+
+if (NGRAPH_EXPORT_TARGETS_ENABLE)
+    export(TARGETS onnx_importer NAMESPACE ngraph:: APPEND FILE "${NGRAPH_TARGETS_FILE}")
+endif()
index 81ca491..c2f140e 100644 (file)
@@ -51,6 +51,8 @@ namespace ngraph
 
         // forward declaration
         class Graph;
+        class Subgraph;
+        class Tensor;
 
         class ONNX_IMPORTER_API Node
         {
@@ -96,6 +98,114 @@ namespace ngraph
             std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
         };
 
+        template <>
+        ONNX_IMPORTER_API float Node::get_attribute_value(const std::string& name,
+                                                          float default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API double Node::get_attribute_value(const std::string& name,
+                                                           double default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::int64_t Node::get_attribute_value(const std::string& name,
+                                                                 std::int64_t default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::string Node::get_attribute_value(const std::string& name,
+                                                                std::string default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name,
+                                                           Tensor default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API Graph Node::get_attribute_value(const std::string& name,
+                                                          Graph default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<float>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<float> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<double>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<double> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::int64_t>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<std::int64_t> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::size_t>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<std::size_t> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::string>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<std::string> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<Tensor>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<Tensor> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<Graph>
+            Node::get_attribute_value(const std::string& name,
+                                      std::vector<Graph> default_value) const;
+
+        template <>
+        ONNX_IMPORTER_API float Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API double Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::int64_t Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::size_t Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::string Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API Subgraph Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<float>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<double>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::int64_t>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::size_t>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<std::string>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<Tensor>
+            Node::get_attribute_value(const std::string& name) const;
+
+        template <>
+        ONNX_IMPORTER_API std::vector<Graph>
+            Node::get_attribute_value(const std::string& name) const;
+
         inline std::ostream& operator<<(std::ostream& outs, const Node& node)
         {
             return (outs << "<Node(" << node.op_type() << "): " << node.get_description() << ">");
index 795f26c..bce6cff 100644 (file)
@@ -44,6 +44,17 @@ namespace ngraph
                                const std::string& domain,
                                Operator fn);
 
+        /// \brief      Unregisters ONNX custom operator.
+        ///             The function unregisters previously registered operator.
+        ///
+        /// \param      name      The ONNX operator name.
+        /// \param      version   The ONNX operator set version.
+        /// \param      domain    The domain the ONNX operator is registered to.
+        ONNX_IMPORTER_API
+        void unregister_operator(const std::string& name,
+                                 std::int64_t version,
+                                 const std::string& domain);
+
     } // namespace onnx_import
 
 } // namespace ngraph
index df6407f..33ac405 100644 (file)
@@ -18,6 +18,7 @@
 
 #include <cstdint>
 #include <map>
+#include <mutex>
 #include <string>
 #include <unordered_map>
 
@@ -84,6 +85,13 @@ namespace ngraph
                 instance()._register_operator(name, version, domain, std::move(fn));
             }
 
+            static void unregister_operator(const std::string& name,
+                                            std::int64_t version,
+                                            const std::string& domain)
+            {
+                instance()._unregister_operator(name, version, domain);
+            }
+
             static bool is_operator_registered(const std::string& name,
                                                std::int64_t version,
                                                const std::string& domain)
@@ -122,11 +130,16 @@ namespace ngraph
                                     std::int64_t version,
                                     const std::string& domain,
                                     Operator fn);
+            void _unregister_operator(const std::string& name,
+                                      std::int64_t version,
+                                      const std::string& domain);
             OperatorSet _get_operator_set(const std::string& domain, std::int64_t version);
 
             bool _is_operator_registered(const std::string& name,
                                          std::int64_t version,
                                          const std::string& domain);
+
+            std::mutex lock;
         };
 
         const std::string OPENVINO_ONNX_DOMAIN = "org.openvinotoolkit";
index 06258ef..059a7b2 100644 (file)
@@ -29,6 +29,13 @@ namespace ngraph
             OperatorsBridge::register_operator(name, version, domain, std::move(fn));
         }
 
+        void unregister_operator(const std::string& name,
+                                 std::int64_t version,
+                                 const std::string& domain)
+        {
+            OperatorsBridge::unregister_operator(name, version, domain);
+        }
+
     } // namespace onnx_import
 
 } // namespace ngraph
index fb5216a..5d53a8d 100644 (file)
@@ -180,6 +180,8 @@ namespace ngraph
                                                  const std::string& domain,
                                                  Operator fn)
         {
+            std::lock_guard<std::mutex> guard(lock);
+
             auto it = m_map[domain][name].find(version);
             if (it == std::end(m_map[domain][name]))
             {
@@ -194,9 +196,49 @@ namespace ngraph
             }
         }
 
+        void OperatorsBridge::_unregister_operator(const std::string& name,
+                                                   std::int64_t version,
+                                                   const std::string& domain)
+        {
+            std::lock_guard<std::mutex> guard(lock);
+
+            auto domain_it = m_map.find(domain);
+            if (domain_it == m_map.end())
+            {
+                NGRAPH_ERR << "unregister_operator: domain '" + domain +
+                                  "' was not registered before";
+                return;
+            }
+            auto name_it = domain_it->second.find(name);
+            if (name_it == domain_it->second.end())
+            {
+                NGRAPH_ERR << "unregister_operator: operator '" + name +
+                                  "' was not registered before";
+                return;
+            }
+            auto version_it = name_it->second.find(version);
+            if (version_it == name_it->second.end())
+            {
+                NGRAPH_ERR << "unregister_operator: operator '" + name + "' with version " +
+                                  std::to_string(version) + " was not registered before";
+                return;
+            }
+            m_map[domain][name].erase(version_it);
+            if (m_map[domain][name].empty())
+            {
+                m_map[domain].erase(name);
+                if (m_map[domain].empty())
+                {
+                    m_map.erase(domain);
+                }
+            }
+        }
+
         OperatorSet OperatorsBridge::_get_operator_set(const std::string& domain,
                                                        std::int64_t version)
         {
+            std::lock_guard<std::mutex> guard(lock);
+
             OperatorSet result;
 
             auto dm = m_map.find(domain);
@@ -227,6 +269,7 @@ namespace ngraph
                                                       std::int64_t version,
                                                       const std::string& domain)
         {
+            std::lock_guard<std::mutex> guard(lock);
             // search for domain
             auto dm_map = m_map.find(domain);
             if (dm_map == std::end(m_map))
index 38686bd..e254423 100644 (file)
@@ -273,6 +273,40 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_custom_op)
     test_case.run();
 }
 
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_custom_op_register_unregister)
+{
+    onnx_import::register_operator(
+        "AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> OutputVector {
+            OutputVector ng_inputs{node.get_ng_inputs()};
+            return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
+        });
+
+    auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.prototxt"));
+
+    auto test_case = test::TestCase<TestEngine>(function);
+    test_case.add_input<float>({1.f, 2.f, 3.f, 4.f});
+    test_case.add_expected_output<float>({3.f, 6.f, 9.f, 12.f});
+    test_case.run();
+
+    onnx_import::unregister_operator("AddQ", 1, "com.intel.ai");
+    try
+    {
+        auto function = onnx_import::import_onnx_model(
+            file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.prototxt"));
+        FAIL() << "Expected ngraph::ngraph_error";
+    }
+    catch (ngraph::ngraph_error const& err)
+    {
+        std::string what{err.what()};
+        EXPECT_NE(what.find("Check 'unknown_operators.empty()' failed"), std::string::npos);
+    }
+    catch (...)
+    {
+        FAIL() << "Expected ngraph::ngraph_error";
+    }
+}
+
 NGRAPH_TEST(${BACKEND_NAME}, onnx_model_custom_op_default_domain)
 {
     onnx_import::register_operator(