[ONNX] Enable ScatterND operator in ONNX importer (#1987)
authorKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Tue, 1 Sep 2020 09:10:03 +0000 (11:10 +0200)
committerGitHub <noreply@github.com>
Tue, 1 Sep 2020 09:10:03 +0000 (12:10 +0300)
* Enable ONNX scatter_nd

* Tests update

ngraph/frontend/onnx_import/CMakeLists.txt
ngraph/frontend/onnx_import/src/op/scatter_nd.cpp
ngraph/frontend/onnx_import/src/ops_bridge.cpp
ngraph/test/models/onnx/scatter_nd_const_i32_indices.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/scatter_nd_param_i64_indices.prototxt [moved from ngraph/test/models/onnx/scatter_nd.prototxt with 97% similarity]
ngraph/test/onnx/onnx_import.in.cpp
ngraph/test/runtime/ie/unit_test.manifest
ngraph/test/runtime/interpreter/unit_test.manifest

index 81880e5..e64c92d 100644 (file)
@@ -24,13 +24,11 @@ list(REMOVE_ITEM LIBRARY_SRC
     ${CMAKE_CURRENT_SOURCE_DIR}/src/op/conv_integer.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/src/op/gather_nd.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/src/op/quant_conv.cpp
-    ${CMAKE_CURRENT_SOURCE_DIR}/src/op/scatter_nd.cpp
     )
 list(REMOVE_ITEM PUBLIC_HEADERS 
     ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/conv_integer.hpp
     ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/gather_nd.hpp
     ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/quant_conv.hpp
-    ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/scatter_nd.hpp
     )
 
 set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE INTERNAL "")
index 5d13564..9dd481e 100644 (file)
@@ -19,8 +19,8 @@
 
 #include <memory>
 
-#include "ngraph/opsets/opset0.hpp"
-#include "scatter_nd.hpp"
+#include "onnx_import/default_opset.hpp"
+#include "onnx_import/op/scatter_nd.hpp"
 
 namespace ngraph
 {
@@ -37,7 +37,8 @@ namespace ngraph
                     auto indices = ng_inputs.at(1);
                     auto updates = ng_inputs.at(2);
 
-                    return {std::make_shared<opset0::ScatterND>(data, indices, updates)};
+                    return {
+                        std::make_shared<default_opset::ScatterNDUpdate>(data, indices, updates)};
                 }
 
             } // namespace set_1
index 97e0102..3f4f8ad 100644 (file)
 #include "onnx_import/op/roi_align.hpp"
 #include "onnx_import/op/round.hpp"
 #include "onnx_import/op/scatter_elements.hpp"
-// #include "onnx_import/op/scatter_nd.hpp"
+#include "onnx_import/op/scatter_nd.hpp"
 #include "onnx_import/op/selu.hpp"
 #include "onnx_import/op/shape.hpp"
 #include "onnx_import/op/shrink.hpp"
@@ -363,7 +363,7 @@ namespace ngraph
             REGISTER_OPERATOR("Round", 1, round);
             REGISTER_OPERATOR("Scatter", 1, scatter_elements);
             REGISTER_OPERATOR("ScatterElements", 1, scatter_elements);
-            // REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
+            REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
             REGISTER_OPERATOR("Selu", 1, selu);
             REGISTER_OPERATOR("Shape", 1, shape);
             REGISTER_OPERATOR("Shrink", 1, shrink);
diff --git a/ngraph/test/models/onnx/scatter_nd_const_i32_indices.prototxt b/ngraph/test/models/onnx/scatter_nd_const_i32_indices.prototxt
new file mode 100644 (file)
index 0000000..c9aa297
--- /dev/null
@@ -0,0 +1,64 @@
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "x"
+    input: "i"
+    input: "u"
+    output: "y"
+    op_type: "ScatterND"
+  }
+  name: "test_scatterND"
+  input {
+    name: "x"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  initializer {
+    dims: 4
+    dims: 1
+    data_type: 6
+    int32_data: 4
+    int32_data: 3
+    int32_data: 1
+    int32_data: 7
+    name: "i"
+  }
+  input {
+    name: "u"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 7
+}
@@ -22,7 +22,7 @@ graph {
       }
     }
   }
-   input {
+  input {
     name: "i"
     type {
       tensor_type {
@@ -38,7 +38,7 @@ graph {
       }
     }
   }
-    input {
+  input {
     name: "u"
     type {
       tensor_type {
index f832f09..9607101 100644 (file)
@@ -2176,10 +2176,10 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_mod)
     test_case.run();
 }
 
-NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND)
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_param_i64_indices)
 {
     const auto function = onnx_import::import_onnx_model(
-        file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd.prototxt"));
+        file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd_param_i64_indices.prototxt"));
     auto test_case = test::TestCase<TestEngine>(function);
 
     test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
@@ -2190,6 +2190,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND)
     test_case.run();
 }
 
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_const_i32_indices)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd_const_i32_indices.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
+    test_case.add_input<float>({9.f, 10.f, 11.f, 12.f});
+    test_case.add_expected_output<float>(Shape{8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
+
+    test_case.run();
+}
+
 NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gatherND_int32)
 {
     const auto function = onnx_import::import_onnx_model(
index dc1dc6d..d3858e9 100644 (file)
@@ -40,9 +40,6 @@ onnx_model_matmul_integer_4d_no_zero_point
 onnx_model_qlinear_matmul
 onnx_model_qlinear_matmul_3d
 
-# Not supported ONNX op: ScatterND
-onnx_model_scatterND
-
 # Not supported ONNX op: GatherND
 onnx_model_gatherND_int32
 onnx_model_gatherND_float
@@ -143,6 +140,7 @@ onnx_dyn_shapes_slice_10_4d_input_23_axes_21_steps
 onnx_dyn_shapes_slice_10_3d_input_12_axes
 onnx_top_k_opset_10
 onnx_model_one_hot_without_axis
+onnx_model_scatterND_param_i64_indices
 
 # [NOT_IMPLEMENTED] Input image format U64 is not supported yet...
 IE_CPU.fused_clamp_uint64
index 03d5eb0..eb1a78b 100644 (file)
@@ -97,7 +97,6 @@ INTERPRETER.onnx_model_conv_integer
 INTERPRETER.onnx_model_conv_integer_zero_point_zero
 INTERPRETER.onnx_model_conv_integer_no_zero_point
 INTERPRETER.onnx_model_conv_integer_pads
-INTERPRETER.onnx_model_scatterND
 INTERPRETER.onnx_model_gatherND_int32
 INTERPRETER.onnx_model_gatherND_float