Use GatherND-5 in the onnx_importer (#2634)
authorTomasz Dołbniak <tomasz.dolbniak@intel.com>
Fri, 16 Oct 2020 09:30:00 +0000 (11:30 +0200)
committerGitHub <noreply@github.com>
Fri, 16 Oct 2020 09:30:00 +0000 (11:30 +0200)
ngraph/frontend/onnx_import/CMakeLists.txt
ngraph/frontend/onnx_import/src/op/gather_nd.cpp
ngraph/frontend/onnx_import/src/ops_bridge.cpp
ngraph/test/runtime/ie/unit_test.manifest

index 6594408..906bfa0 100644 (file)
@@ -22,12 +22,10 @@ file(GLOB_RECURSE PUBLIC_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
 # Remove disabled ops
 list(REMOVE_ITEM LIBRARY_SRC
     ${CMAKE_CURRENT_SOURCE_DIR}/src/op/conv_integer.cpp
-    ${CMAKE_CURRENT_SOURCE_DIR}/src/op/gather_nd.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/src/op/quant_conv.cpp
     )
 list(REMOVE_ITEM PUBLIC_HEADERS
     ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/conv_integer.hpp
-    ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/gather_nd.hpp
     ${CMAKE_CURRENT_SOURCE_DIR}/include/onnx_import/op/quant_conv.hpp
     )
 
index 3fdc689..0d187f0 100644 (file)
@@ -17,7 +17,7 @@
 // Disabled in CMakeList
 // Update to higher opset required
 
-#include "ngraph/opsets/opset0.hpp"
+#include "onnx_import/default_opset.hpp"
 #include "onnx_import/utils/common.hpp"
 
 namespace ngraph
@@ -30,11 +30,12 @@ namespace ngraph
             {
                 OutputVector gather_nd(const Node& node)
                 {
-                    OutputVector ng_inputs{node.get_ng_inputs()};
-                    auto data = ng_inputs.at(0);
-                    auto indices = ng_inputs.at(1);
+                    const OutputVector ng_inputs{node.get_ng_inputs()};
+                    const auto data = ng_inputs.at(0);
+                    const auto indices = ng_inputs.at(1);
+                    const auto batch_dims = node.get_attribute_value<int64_t>("batch_dims", 0);
 
-                    return {std::make_shared<ngraph::opset0::GatherND>(data, indices)};
+                    return {std::make_shared<default_opset::GatherND>(data, indices, batch_dims)};
                 }
 
             } // namespace set_1
index 5d53a8d..2e896cd 100644 (file)
@@ -60,7 +60,7 @@
 #include "onnx_import/op/flatten.hpp"
 #include "onnx_import/op/floor.hpp"
 #include "onnx_import/op/gather.hpp"
-// #include "onnx_import/op/gather_nd.hpp"
+#include "onnx_import/op/gather_nd.hpp"
 #include "onnx_import/op/gemm.hpp"
 #include "onnx_import/op/global_average_pool.hpp"
 #include "onnx_import/op/global_max_pool.hpp"
@@ -343,7 +343,7 @@ namespace ngraph
             REGISTER_OPERATOR("Flatten", 1, flatten);
             REGISTER_OPERATOR("Floor", 1, floor);
             REGISTER_OPERATOR("Gather", 1, gather);
-            // REGISTER_OPERATOR("GatherND", 1, gather_nd);
+            REGISTER_OPERATOR("GatherND", 1, gather_nd);
             REGISTER_OPERATOR("Gemm", 1, gemm);
             REGISTER_OPERATOR("Gemm", 6, gemm);
             REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
index f2ae030..d565eab 100644 (file)
@@ -40,7 +40,7 @@ onnx_model_matmul_integer_4d_no_zero_point
 onnx_model_qlinear_matmul
 onnx_model_qlinear_matmul_3d
 
-# Not supported ONNX op: GatherND
+# The indices input type i64 is not supported by the CPU plugin
 onnx_model_gatherND_int32
 onnx_model_gatherND_float