${CMAKE_CURRENT_SOURCE_DIR}/src/op/conv_integer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/op/gather_nd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/op/quant_conv.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/src/op/scatter_nd.cpp
)
list(REMOVE_ITEM PUBLIC_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/conv_integer.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/gather_nd.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/quant_conv.hpp
- ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/scatter_nd.hpp
)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE INTERNAL "")
#include <memory>
-#include "ngraph/opsets/opset0.hpp"
-#include "scatter_nd.hpp"
+#include "onnx_import/default_opset.hpp"
+#include "onnx_import/op/scatter_nd.hpp"
namespace ngraph
{
auto indices = ng_inputs.at(1);
auto updates = ng_inputs.at(2);
- return {std::make_shared<opset0::ScatterND>(data, indices, updates)};
+ return {
+ std::make_shared<default_opset::ScatterNDUpdate>(data, indices, updates)};
}
} // namespace set_1
#include "onnx_import/op/roi_align.hpp"
#include "onnx_import/op/round.hpp"
#include "onnx_import/op/scatter_elements.hpp"
-// #include "onnx_import/op/scatter_nd.hpp"
+#include "onnx_import/op/scatter_nd.hpp"
#include "onnx_import/op/selu.hpp"
#include "onnx_import/op/shape.hpp"
#include "onnx_import/op/shrink.hpp"
REGISTER_OPERATOR("Round", 1, round);
REGISTER_OPERATOR("Scatter", 1, scatter_elements);
REGISTER_OPERATOR("ScatterElements", 1, scatter_elements);
- // REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
+ REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Shrink", 1, shrink);
--- /dev/null
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+ node {
+ input: "x"
+ input: "i"
+ input: "u"
+ output: "y"
+ op_type: "ScatterND"
+ }
+ name: "test_scatterND"
+ input {
+ name: "x"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ initializer {
+ dims: 4
+ dims: 1
+ data_type: 6
+ int32_data: 4
+ int32_data: 3
+ int32_data: 1
+ int32_data: 7
+ name: "i"
+ }
+ input {
+ name: "u"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "y"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 7
+}
}
}
}
- input {
+ input {
name: "i"
type {
tensor_type {
}
}
}
- input {
+ input {
name: "u"
type {
tensor_type {
test_case.run();
}
-NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND)
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_param_i64_indices)
{
const auto function = onnx_import::import_onnx_model(
- file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd.prototxt"));
+ file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd_param_i64_indices.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
test_case.run();
}
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_const_i32_indices)
+{
+ const auto function = onnx_import::import_onnx_model(
+ file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd_const_i32_indices.prototxt"));
+ auto test_case = test::TestCase<TestEngine>(function);
+
+ test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
+ test_case.add_input<float>({9.f, 10.f, 11.f, 12.f});
+ test_case.add_expected_output<float>(Shape{8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
+
+ test_case.run();
+}
+
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gatherND_int32)
{
const auto function = onnx_import::import_onnx_model(
onnx_model_qlinear_matmul
onnx_model_qlinear_matmul_3d
-# Not supported ONNX op: ScatterND
-onnx_model_scatterND
-
# Not supported ONNX op: GatherND
onnx_model_gatherND_int32
onnx_model_gatherND_float
onnx_dyn_shapes_slice_10_3d_input_12_axes
onnx_top_k_opset_10
onnx_model_one_hot_without_axis
+onnx_model_scatterND_param_i64_indices
# [NOT_IMPLEMENTED] Input image format U64 is not supported yet...
IE_CPU.fused_clamp_uint64
INTERPRETER.onnx_model_conv_integer_zero_point_zero
INTERPRETER.onnx_model_conv_integer_no_zero_point
INTERPRETER.onnx_model_conv_integer_pads
-INTERPRETER.onnx_model_scatterND
INTERPRETER.onnx_model_gatherND_int32
INTERPRETER.onnx_model_gatherND_float