[IE][VPU]: GatherND DTS transformation (#3025)
authorAndrew Bakalin <andrew.bakalin@intel.com>
Thu, 12 Nov 2020 14:30:41 +0000 (17:30 +0300)
committerGitHub <noreply@github.com>
Thu, 12 Nov 2020 14:30:41 +0000 (17:30 +0300)
* Implement GatherND DTS
* Introduce tests on DTS
* Introduce tests on DSR+GatherND

21 files changed:
inference-engine/src/vpu/common/include/vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp [new file with mode: 0644]
inference-engine/src/vpu/common/include/vpu/ngraph/utilities.hpp
inference-engine/src/vpu/common/src/ngraph/operations/static_shape_broadcast.cpp
inference-engine/src/vpu/common/src/ngraph/operations/static_shape_reshape.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_binary_elementwise.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_concat.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_gather.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp [new file with mode: 0644]
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_matmul.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_roialign.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_strided_slice.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_topk.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_variadic_split.cpp
inference-engine/src/vpu/common/src/ngraph/utilities.cpp
inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_gather.cpp
inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp [new file with mode: 0644]
inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_topk.cpp
inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_variadic_split.cpp
inference-engine/tests/functional/plugin/myriad/shared_tests_instances/skip_tests_config.cpp
inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_gather_nd.cpp [new file with mode: 0644]

diff --git a/inference-engine/src/vpu/common/include/vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp b/inference-engine/src/vpu/common/include/vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp
new file mode 100644 (file)
index 0000000..6474a76
--- /dev/null
@@ -0,0 +1,13 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include "ngraph/node.hpp"
+
+namespace vpu {
+
+void dynamicToStaticShapeGatherND(std::shared_ptr<ngraph::Node> node);
+
+}  // namespace vpu
index 520cb76..dd1fd14 100644 (file)
@@ -5,10 +5,15 @@
 #pragma once
 
 #include "ngraph/node.hpp"
+#include "ngraph/type/element_type.hpp"
+
+namespace vpu {
 
 std::vector<std::int64_t> evaluateTargetShape(const ngraph::Output<ngraph::Node>& value);
 
-namespace vpu {
+std::shared_ptr<ngraph::Node> shapeToConstant(const ngraph::element::Type& type, const ngraph::Shape& shape);
+
+std::shared_ptr<ngraph::Node> gatherShapeElements(const ngraph::Output<ngraph::Node>&, int startIndex, size_t elemCount);
 
 void printTo(std::ostream& stream, const ngraph::NodeTypeInfo& object);
 
index 8f9c553..cf9693c 100644 (file)
@@ -53,7 +53,7 @@ void StaticShapeBroadcast::validate_and_infer_types() {
         // Try to evaluate output shape. After some transformations further, we may not be able
         // to evaluate the target shape again, then we will leave the evaluated shape unchanged.
         // For example, EliminateShapeOfAfterDSR remove ShapeOf and pass the second input of DSR.
-        const auto evaluatedDimensionValues = evaluateTargetShape(input_value(1));
+        const auto evaluatedDimensionValues = ::vpu::evaluateTargetShape(input_value(1));
         NODE_VALIDATION_CHECK(this, !evaluatedDimensionValues.empty(), "StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape");
 
         const auto evaluatedTargetShape = ngraph::PartialShape(evaluatedDimensionValues);
index 8cf6b38..52b206e 100644 (file)
@@ -40,7 +40,7 @@ void StaticShapeReshape::validate_and_infer_types() {
 
     const auto& inputShape = get_input_shape(0);
 
-    auto outputDimensionsValues = evaluateTargetShape(targetShape);
+    auto outputDimensionsValues = ::vpu::evaluateTargetShape(targetShape);
     NODE_VALIDATION_CHECK(this, !outputDimensionsValues.empty(), "StaticShapeReshape (", get_friendly_name(), ") can't evaluate output shape");
 
     for (std::size_t i = 0; i < outputDimensionsValues.size(); ++i) {
index 407a1a7..eef8903 100644 (file)
@@ -8,6 +8,7 @@
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_broadcast.hpp"
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_concat.hpp"
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_gather.hpp"
+#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp"
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_matmul.hpp"
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.hpp"
 #include "vpu/ngraph/transformations/dynamic_to_static_shape_nonzero.hpp"
@@ -124,6 +125,7 @@ const Transformations& getDefaultTransformations() {
         {ngraph::opset3::Broadcast::type_info,         dynamicToStaticShapeBroadcast},
         {ngraph::opset3::MatMul::type_info,            dynamicToStaticShapeMatMul},
         {ngraph::opset5::Split::type_info,             dynamicToStaticShapeSplit},
+        {ngraph::opset5::GatherND::type_info,          dynamicToStaticShapeGatherND},
 
         // reduction
         {ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce},
index 396c933..c6d7c76 100644 (file)
@@ -21,14 +21,6 @@ void dynamicToStaticShapeBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise) {
 
     const auto copied = eltwise->copy_with_new_inputs(eltwise->input_values());
 
-    auto shapeToConstant = [&eltwise](const ngraph::Output<ngraph::Node>& output,
-                                      const ngraph::element::Type& elemType) -> std::shared_ptr<ngraph::opset3::Constant> {
-        VPU_THROW_UNLESS(output.get_partial_shape().is_static(),
-            "DynamicToStaticShape transformation for {} of type {} expects static shape on inputs without DSR",
-            eltwise->get_friendly_name(), eltwise->get_type_info());
-        return ngraph::opset3::Constant::create(elemType, {output.get_shape().size()}, output.get_shape());
-    };
-
     const auto lhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(0).get_node_shared_ptr());
     const auto rhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(1).get_node_shared_ptr());
 
@@ -42,8 +34,8 @@ void dynamicToStaticShapeBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise) {
     }
     const auto shapeElementType = lhsDSR ? lhsDSR->get_input_element_type(1) : rhsDSR->get_input_element_type(1);
 
-    auto lhsInput = lhsDSR ? lhsDSR->input_value(1) : shapeToConstant(eltwise->input_value(0), shapeElementType);
-    auto rhsInput = rhsDSR ? rhsDSR->input_value(1) : shapeToConstant(eltwise->input_value(1), shapeElementType);
+    auto lhsInput = lhsDSR ? lhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(0));
+    auto rhsInput = rhsDSR ? rhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(1));
 
     const auto diff = std::abs(lhsRank.get_length() - rhsRank.get_length());
     if (diff) {
index 4417970..ecab8ec 100644 (file)
@@ -45,11 +45,6 @@ void dynamicToStaticShapeConcat(std::shared_ptr<ngraph::Node> target) {
     const auto dataRank = firstDSRInputNode->get_output_partial_shape(0).rank().get_length();
     const auto axis = ngraph::as_type_ptr<ngraph::opset3::Concat>(target)->get_concatenation_axis();
 
-    const auto shapeToConstant = [&shapeDataType, &dataRank](const ngraph::Shape& shape) {
-        return ngraph::opset3::Constant::create(
-                shapeDataType, {static_cast<size_t>(dataRank)}, shape)->output(0);
-    };
-
     const auto getShapeFromDSR = [&target, &shapeDataType](const ngraph::Output<ngraph::Node>& dsrOutput) {
         const auto dsrNode = dsrOutput.get_node_shared_ptr();
         const auto dsrShapeInputValue = dsrNode->input_value(1);
@@ -68,12 +63,11 @@ void dynamicToStaticShapeConcat(std::shared_ptr<ngraph::Node> target) {
         return shapeAccumulatorOp->output(0);
     };
 
-    const auto divideDimsByNumOfInputsExceptAxis = [&target, &dataRank, &axis,
-                                                    &shapeDataType, &shapeToConstant](
+    const auto divideDimsByNumOfInputsExceptAxis = [&target, &dataRank, &axis, &shapeDataType](
             const ngraph::Output<ngraph::Node>& shape) {
         ngraph::Shape dividerValues(dataRank, target->get_input_size());
         dividerValues[axis] = 1;
-        const auto divider = shapeToConstant(dividerValues);
+        const auto divider = shapeToConstant(shapeDataType, dividerValues);
         const auto divide = std::make_shared<ngraph::opset3::Divide>(shape, divider);
         return divide->output(0);
     };
@@ -103,7 +97,7 @@ void dynamicToStaticShapeConcat(std::shared_ptr<ngraph::Node> target) {
     }
 
     if (!staticInputs.empty()) {
-        const auto accumulatedStaticShape = shapeToConstant(getAdditionalShapeFromStatic(staticInputs));
+        const auto accumulatedStaticShape = shapeToConstant(shapeDataType, getAdditionalShapeFromStatic(staticInputs));
         accumulatedShape = sumOfShapes(accumulatedShape, accumulatedStaticShape);
     }
 
index 0103294..8a84089 100644 (file)
@@ -25,14 +25,6 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
     VPU_THROW_UNLESS(axis != std::numeric_limits<int64_t>::max() && axis >= 0,
             "dynamicToStaticShapeGather: Unsupported Gather axis {} for node {}", axis, gather);
 
-    auto shapeToConstant = [&gather](const ngraph::Output<ngraph::Node>& output,
-                                     const ngraph::element::Type& elemType) -> std::shared_ptr<ngraph::opset3::Constant> {
-        VPU_THROW_UNLESS(output.get_partial_shape().is_static(),
-                         "DynamicToStaticShape transformation for {} of type {} expects static shape on inputs without DSR",
-                         gather->get_friendly_name(), gather->get_type_info());
-        return ngraph::opset3::Constant::create(elemType, {output.get_shape().size()}, output.get_shape());
-    };
-
     const auto dataDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(gather->input_value(0).get_node_shared_ptr());
     const auto idxDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(gather->input_value(1).get_node_shared_ptr());
 
@@ -46,8 +38,8 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
     }
     const auto shapeElementType = idxDSR ? idxDSR->get_input_element_type(1) : dataDSR->get_input_element_type(1);
 
-    const auto data_shape = dataDSR ? dataDSR->input_value(1) : shapeToConstant(gather->input_value(0), shapeElementType);
-    const auto indices_shape = idxDSR ? idxDSR->input_value(1) : shapeToConstant(gather->input_value(1), shapeElementType);
+    const auto data_shape = dataDSR ? dataDSR->input_value(1) : shapeToConstant(shapeElementType, gather->get_input_shape(0));
+    const auto indices_shape = idxDSR ? idxDSR->input_value(1) : shapeToConstant(shapeElementType, gather->get_input_shape(1));
 
     const auto copied = target->clone_with_new_inputs(target->input_values());
 
@@ -60,24 +52,12 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
     const auto indices_rank_value = indices_rank[0].get_length();
     ngraph::OutputVector output_dims;
     if (axis) {
-        std::vector<int64_t> first_data_shape_part_indices(axis);
-        std::iota(first_data_shape_part_indices.begin(), first_data_shape_part_indices.end(), 0);
-        const auto first_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(shapeElementType, {first_data_shape_part_indices.size()}, first_data_shape_part_indices),
-                ngraph::opset3::Constant::create(shapeElementType, {1}, {0}));
-        output_dims.push_back(first_data_shape_part);
+        output_dims.push_back(gatherShapeElements(data_shape, 0, axis));
     }
     if (indices_rank_value)
         output_dims.push_back(indices_shape);
     if (axis + 1 < data_rank_value) {
-        std::vector<int64_t> second_data_shape_part_indices(data_rank_value - axis - 1);
-        std::iota(second_data_shape_part_indices.begin(), second_data_shape_part_indices.end(), axis + 1);
-        const auto second_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(shapeElementType, {second_data_shape_part_indices.size()}, second_data_shape_part_indices),
-                ngraph::opset3::Constant::create(shapeElementType, {1}, {0}));
-        output_dims.push_back(second_data_shape_part);
+        output_dims.push_back(gatherShapeElements(data_shape, axis + 1, data_rank_value - axis - 1));
     }
 
     const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
diff --git a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp
new file mode 100644 (file)
index 0000000..81c6d61
--- /dev/null
@@ -0,0 +1,79 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather.hpp"
+
+#include "vpu/ngraph/operations/dynamic_shape_resolver.hpp"
+#include "vpu/ngraph/utilities.hpp"
+#include <vpu/utils/error.hpp>
+#include "ngraph/graph_util.hpp"
+
+#include "ngraph/opsets/opset5.hpp"
+
+#include <numeric>
+
+namespace vpu {
+
+void dynamicToStaticShapeGatherND(std::shared_ptr<ngraph::Node> target) {
+    const auto gatherND = ngraph::as_type_ptr<ngraph::opset5::GatherND>(target);
+    VPU_THROW_UNLESS(gatherND, "dynamicToStaticShapeGatherND transformation is not applicable for {}, it should be {} instead",
+                     target, ngraph::opset5::GatherND::type_info);
+
+    const auto dataDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(gatherND->input_value(0).get_node_shared_ptr());
+    const auto indicesDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(gatherND->input_value(1).get_node_shared_ptr());
+
+    VPU_THROW_UNLESS(dataDSR || indicesDSR, "dynamicToStaticShapeGatherND transformation for {} of type {} expects at least one DSR as input",
+                     gatherND->get_friendly_name(), gatherND->get_type_info());
+    if (dataDSR && indicesDSR) {
+        VPU_THROW_UNLESS(dataDSR->get_input_element_type(1) == indicesDSR->get_input_element_type(1),
+                         "dynamicToStaticShapeGatherND transformation for {} of type {} expects equal shapes data types, actual {} vs {}",
+                         gatherND->get_friendly_name(), gatherND->get_type_info(),
+                         dataDSR->get_input_element_type(1), indicesDSR->get_input_element_type(1));
+    }
+    const auto shapeElementType = indicesDSR ? indicesDSR->get_input_element_type(1) : dataDSR->get_input_element_type(1);
+
+    const auto dataShape = dataDSR ? dataDSR->input_value(1) : shapeToConstant(shapeElementType, gatherND->get_input_shape(0));
+    const auto indicesShape = indicesDSR ? indicesDSR->input_value(1) : shapeToConstant(shapeElementType, gatherND->get_input_shape(1));
+
+    const auto dataShapeRank = ngraph::shape_size(dataShape.get_shape());
+    const auto indicesShapeRank = ngraph::shape_size(indicesShape.get_shape());
+
+    const auto batchDims = static_cast<int64_t>(gatherND->get_batch_dims());
+    VPU_THROW_UNLESS(batchDims >= 0 && batchDims < std::min(dataShapeRank, indicesShapeRank),
+                     "dynamicToStaticShapeGatherND: node {} has unsupported batch_dims which is expected to be"
+                     " in [0; min({}, {})), but {} was provided", gatherND->get_friendly_name(), dataShapeRank, indicesShapeRank, batchDims);
+
+    std::shared_ptr<ngraph::Node> outputShape;
+
+    if (batchDims > 0) {
+        outputShape = std::make_shared<ngraph::opset5::ReduceProd>(
+            gatherShapeElements(indicesShape, 0, batchDims),
+            ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}),
+            true);
+    }
+
+    if (indicesShapeRank - batchDims - 1 > 0) {
+        const auto indicesShapePart = gatherShapeElements(indicesShape, batchDims, indicesShapeRank - batchDims - 1);
+        outputShape = outputShape ? std::make_shared<ngraph::opset5::Concat>(ngraph::NodeVector{outputShape, indicesShapePart}, 0) : indicesShapePart;
+    }
+
+    const auto lastIndicesDim = gatherND->get_input_partial_shape(1)[indicesShapeRank - 1].get_length();
+    if (batchDims + lastIndicesDim < dataShapeRank) {
+        const auto dataShapePart = gatherShapeElements(
+            dataShape,
+            lastIndicesDim + batchDims,
+            dataShapeRank - batchDims - lastIndicesDim);
+        outputShape = outputShape ? std::make_shared<ngraph::opset5::Concat>(ngraph::NodeVector{outputShape, dataShapePart}, 0) : dataShapePart;
+    }
+
+    VPU_THROW_UNLESS(outputShape, "dynamicToStaticShapeGatherND: node {} has empty output shape", gatherND->get_friendly_name());
+
+    const auto copied = target->clone_with_new_inputs(target->input_values());
+
+    auto outDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, outputShape);
+    outDSR->set_friendly_name(target->get_friendly_name());
+    ngraph::replace_node(target, std::move(outDSR));
+}
+
+}  // namespace vpu
index a04c4e7..c416062 100644 (file)
@@ -38,14 +38,6 @@ void dynamicToStaticShapeMatMul(std::shared_ptr<ngraph::Node> target) {
     VPU_THROW_UNLESS(matmul, "dynamicToStaticShapeMatMul transformation is not applicable for {}, it should be {} instead",
             target, ngraph::opset3::MatMul::type_info);
 
-    auto shapeToConstant = [&target](const ngraph::Output<ngraph::Node>& output,
-                                     const ngraph::element::Type& elementType) -> std::shared_ptr<ngraph::opset3::Constant> {
-        VPU_THROW_UNLESS(output.get_partial_shape().is_static(),
-                         "DynamicToStaticShape transformation for {} of type {} expects static shape on inputs without DSR",
-                         target->get_friendly_name(), target->get_type_info());
-        return ngraph::opset3::Constant::create(elementType, {output.get_shape().size()}, output.get_shape());
-    };
-
     const auto a_input_DSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(target->input_value(0).get_node_shared_ptr());
     const auto b_input_DSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(target->input_value(1).get_node_shared_ptr());
 
@@ -60,8 +52,10 @@ void dynamicToStaticShapeMatMul(std::shared_ptr<ngraph::Node> target) {
 
     const auto shapeElementType = a_input_DSR ? a_input_DSR->get_input_element_type(1) : b_input_DSR->get_input_element_type(1);
 
-    ngraph::Output<ngraph::Node> a_input_shape = a_input_DSR ? a_input_DSR->input_value(1) : shapeToConstant(target->input_value(0), shapeElementType);
-    ngraph::Output<ngraph::Node> b_input_shape = b_input_DSR ? b_input_DSR->input_value(1) : shapeToConstant(target->input_value(1), shapeElementType);
+    ngraph::Output<ngraph::Node> a_input_shape = a_input_DSR ? a_input_DSR->input_value(1) :
+            shapeToConstant(shapeElementType, target->get_input_shape(0));
+    ngraph::Output<ngraph::Node> b_input_shape = b_input_DSR ? b_input_DSR->input_value(1) :
+            shapeToConstant(shapeElementType, target->get_input_shape(1));
 
     const auto& a_rank = a_input_shape.get_partial_shape();
     const auto& b_rank = b_input_shape.get_partial_shape();
@@ -77,12 +71,7 @@ void dynamicToStaticShapeMatMul(std::shared_ptr<ngraph::Node> target) {
     if (max_rank_value > 2) {
         // batch broadcasting
         const auto max_shape = std::make_shared<ngraph::opset3::Maximum>(a_input_shape, b_input_shape);
-        std::vector<int64_t> indices_value(max_rank_value - 2);
-        std::iota(indices_value.begin(), indices_value.end(), 0);
-        const auto indices = ngraph::opset3::Constant::create(ngraph::element::i64, {indices_value.size()}, indices_value);
-        const auto axis = ngraph::opset3::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0});
-        const auto batch_dims = std::make_shared<ngraph::opset3::Gather>(max_shape, indices, axis);
-        output_dims.push_back(batch_dims);
+        output_dims.push_back(gatherShapeElements(max_shape, 0, max_rank_value - 2));
     }
     const auto input_channels = std::make_shared<ngraph::opset3::Gather>(
             a_input_shape,
index 9a140b4..30624e4 100644 (file)
@@ -21,21 +21,16 @@ void dynamicToStaticShapeROIAlign(std::shared_ptr<ngraph::Node> target) {
         "dynamicToStaticShapeROIAlign transformation is not applicable for {}, it should be {} instead",
         target, ngraph::opset3::ROIAlign::type_info);
 
-    auto shapeToConstant = [&roi_align](const ngraph::Output<ngraph::Node> & output) -> std::shared_ptr<ngraph::opset3::Constant> {
-        VPU_THROW_UNLESS(output.get_partial_shape().is_static(),
-                         "DynamicToStaticShape transformation for {} of type {} expects static shape on inputs without DSR",
-                         roi_align->get_friendly_name(), roi_align->get_type_info());
-        return ngraph::opset3::Constant::create(ngraph::element::i64, {output.get_shape().size()}, output.get_shape());
-    };
-
     const auto dataDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(roi_align->input_value(0).get_node_shared_ptr());
     const auto num_roisDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(roi_align->input_value(2).get_node_shared_ptr());
 
     VPU_THROW_UNLESS(dataDSR || num_roisDSR, "DynamicToStaticShape transformation for {} of type {} expects at least one DSR as input",
                      roi_align->get_friendly_name(), roi_align->get_type_info());
 
-    auto input_0_shape = dataDSR ? dataDSR->input_value(1) : shapeToConstant(roi_align->input_value(0));
-    auto num_rois = num_roisDSR ? num_roisDSR->input_value(1) : shapeToConstant(roi_align->input_value(2));
+    const auto shapeElementType = dataDSR ? dataDSR->get_input_element_type(1) : num_roisDSR->get_input_element_type(1);
+
+    auto input_0_shape = dataDSR ? dataDSR->input_value(1) : shapeToConstant(shapeElementType, roi_align->get_input_shape(0));
+    auto num_rois = num_roisDSR ? num_roisDSR->input_value(1) : shapeToConstant(shapeElementType, roi_align->get_input_shape(2));
 
     const auto c_index = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64, ngraph::Shape{1}, std::vector<int64_t>{1});
     const auto c_axis = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64, ngraph::Shape{1}, std::vector<int64_t>{0});
index 6542b8c..8852d45 100644 (file)
@@ -109,14 +109,7 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
     }
 
     if (output_dimensions.size() < inputShapeRank) {
-        std::vector<std::int64_t> indices(inputShapeRank - output_dimensions.size());
-        std::iota(indices.begin(), indices.end(), static_cast<std::int64_t>(output_dimensions.size()));
-
-        const auto tail = std::make_shared<ngraph::opset3::Gather>(
-            input_shape,
-            ngraph::opset3::Constant::create(ngraph::element::i64, {indices.size()}, indices),
-            ngraph::opset3::Constant::create(shape_type, {}, {0}));
-        output_dimensions.push_back(tail);
+        output_dimensions.push_back(gatherShapeElements(input_shape, output_dimensions.size(), inputShapeRank - output_dimensions.size()));
     }
 
     VPU_THROW_UNLESS(output_dimensions.size() == inputShapeRank,
index b97edb1..0f0be4a 100644 (file)
@@ -32,22 +32,10 @@ void dynamicToStaticShapeTopK(std::shared_ptr<ngraph::Node> target) {
     const auto data_rank_value = data_rank.get_length();
     ngraph::OutputVector first_shape_part, second_shape_part;
     if (axis) {
-        std::vector<int64_t> first_data_shape_part_indices(axis);
-        std::iota(first_data_shape_part_indices.begin(), first_data_shape_part_indices.end(), 0);
-        const auto first_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(ngraph::element::i64, {first_data_shape_part_indices.size()}, first_data_shape_part_indices),
-                ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0}));
-        first_shape_part.push_back(first_data_shape_part);
+        first_shape_part.push_back(gatherShapeElements(data_shape, 0, axis));
     }
     if (axis + 1 < data_rank_value) {
-        std::vector<int64_t> second_data_shape_part_indices(data_rank_value - axis - 1);
-        std::iota(second_data_shape_part_indices.begin(), second_data_shape_part_indices.end(), axis + 1);
-        const auto second_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(ngraph::element::i64, {second_data_shape_part_indices.size()}, second_data_shape_part_indices),
-                ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0}));
-        second_shape_part.push_back(second_data_shape_part);
+        second_shape_part.push_back(gatherShapeElements(data_shape, axis + 1, data_rank_value - axis - 1));
     }
 
     auto k_0d = target->get_input_node_shared_ptr(1);
index f93625d..339702d 100644 (file)
@@ -44,22 +44,10 @@ void dynamicToStaticShapeVariadicSplit(std::shared_ptr<ngraph::Node> target) {
     const auto data_rank_value = data_rank.get_length();
     ngraph::OutputVector first_shape_part, second_shape_part;
     if (axis) {
-        std::vector<int64_t> first_data_shape_part_indices(axis);
-        std::iota(first_data_shape_part_indices.begin(), first_data_shape_part_indices.end(), 0);
-        const auto first_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(ngraph::element::i64, {first_data_shape_part_indices.size()}, first_data_shape_part_indices),
-                ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0}));
-        first_shape_part.push_back(first_data_shape_part);
+        first_shape_part.push_back(gatherShapeElements(data_shape, 0, axis));
     }
     if (axis + 1 < data_rank_value) {
-        std::vector<int64_t> second_data_shape_part_indices(data_rank_value - axis - 1);
-        std::iota(second_data_shape_part_indices.begin(), second_data_shape_part_indices.end(), axis + 1);
-        const auto second_data_shape_part = std::make_shared<ngraph::opset3::Gather>(
-                data_shape,
-                ngraph::opset3::Constant::create(ngraph::element::i64, {second_data_shape_part_indices.size()}, second_data_shape_part_indices),
-                ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0}));
-        second_shape_part.push_back(second_data_shape_part);
+        second_shape_part.push_back(gatherShapeElements(data_shape, axis + 1, data_rank_value - axis - 1));
     }
     for (auto i = 0; i < split_lengths.size(); ++i) {
         const auto dim = ngraph::opset3::Constant::create(data_shape->get_element_type(), {1}, {split_lengths[i]});
index 3e3de8f..9bf6c21 100644 (file)
@@ -5,8 +5,14 @@
 #include "vpu/ngraph/utilities.hpp"
 
 #include "ngraph/opsets/opset3.hpp"
+#include "ngraph/opsets/opset5.hpp"
 #include "ngraph/evaluator.hpp"
 
+#include <numeric>
+
+namespace vpu {
+namespace {
+
 ngraph::HostTensorVector evaluateShapeOf(ngraph::Node* node, const ngraph::HostTensorVector&) {
     auto shapeOf = ngraph::as_type<ngraph::opset3::ShapeOf>(node);
     const auto inputValue = shapeOf->input_value(0);
@@ -39,6 +45,8 @@ ngraph::HostTensorVector evaluateOp(ngraph::Node* node, const ngraph::HostTensor
     return outputTensors;
 }
 
+} // namespace
+
 std::vector<std::int64_t> evaluateTargetShape(const ngraph::Output<ngraph::Node>& value) {
     static ngraph::Evaluator<ngraph::HostTensorPtr>::op_handler_map handlers = {
         {ngraph::opset3::ShapeOf::type_info,  evaluateShapeOf},
@@ -61,7 +69,19 @@ std::vector<std::int64_t> evaluateTargetShape(const ngraph::Output<ngraph::Node>
     return {shapeConstNode->cast_vector<std::int64_t>()};
 }
 
-namespace vpu {
+std::shared_ptr<ngraph::Node> shapeToConstant(const ngraph::element::Type& type, const ngraph::Shape& shape) {
+    return ngraph::opset5::Constant::create(type, {shape.size()}, shape);
+}
+
+std::shared_ptr<ngraph::Node> gatherShapeElements(const ngraph::Output<ngraph::Node>& shape, int startIndex, size_t elemCount) {
+    std::vector<int64_t> shapePart(elemCount);
+    std::iota(shapePart.begin(), shapePart.end(), startIndex);
+
+    return std::make_shared<ngraph::opset5::Gather>(
+        shape,
+        ngraph::opset5::Constant::create(ngraph::element::i64, {elemCount}, shapePart),
+        ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}));
+}
 
 void printTo(std::ostream& stream, const ngraph::NodeTypeInfo& object) {
     stream << object.name << " ver. " << object.version;
index cbdae00..412b99c 100644 (file)
@@ -15,6 +15,7 @@
 #include <ngraph_functions/utils/ngraph_helpers.hpp>
 #include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
 #include <vpu/utils/error.hpp>
+#include <vpu/ngraph/utilities.hpp>
 
 namespace {
 
@@ -102,24 +103,15 @@ protected:
         const auto indices_shape = ngraph::opset3::Constant::create(dims->get_element_type(), {gather_setup.index_shape.size()}, gather_setup.index_shape);
         ngraph::OutputVector output_dims;
         if (gather_setup.first_split_point) {
-            std::vector<int64_t> idxs(gather_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(dims, 0, gather_setup.first_split_point));
         }
         if (!gather_setup.index_shape.empty())
             output_dims.push_back(indices_shape);
         if (gather_setup.first_split_point + 1 < gather_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(gather_setup.data_shape.size() - gather_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), gather_setup.second_split_point);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(
+                dims,
+                gather_setup.second_split_point,
+                gather_setup.data_shape.size() - gather_setup.second_split_point));
         }
         const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
         const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, output_shape);
@@ -192,24 +184,15 @@ protected:
 
         ngraph::OutputVector output_dims;
         if (gather_setup.first_split_point) {
-            std::vector<int64_t> idxs(gather_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            data_shape,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(data_shape, 0, gather_setup.first_split_point));
         }
         if (!gather_setup.index_shape.empty())
             output_dims.push_back(dims);
         if (gather_setup.first_split_point + 1 < gather_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(gather_setup.data_shape.size() - gather_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), gather_setup.second_split_point);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            data_shape,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(
+                data_shape,
+                gather_setup.second_split_point,
+                gather_setup.data_shape.size() - gather_setup.second_split_point));
         }
         const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
         const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, output_shape);
@@ -284,24 +267,15 @@ protected:
 
         ngraph::OutputVector output_dims;
         if (gather_setup.first_split_point) {
-            std::vector<int64_t> idxs(gather_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            data_dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(data_dims, 0, gather_setup.first_split_point));
         }
         if (!gather_setup.index_shape.empty())
             output_dims.push_back(indices_dims);
         if (gather_setup.first_split_point + 1 < gather_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(gather_setup.data_shape.size() - gather_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), gather_setup.second_split_point);
-            output_dims.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            data_dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            output_dims.push_back(vpu::gatherShapeElements(
+                data_dims,
+                gather_setup.second_split_point,
+                gather_setup.data_shape.size() - gather_setup.second_split_point));
         }
         const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
         const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, output_shape);
diff --git a/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp b/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_gather_nd.cpp
new file mode 100644 (file)
index 0000000..fc9a085
--- /dev/null
@@ -0,0 +1,191 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <common_test_utils/test_common.hpp>
+
+#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
+#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
+#include <vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp>
+#include <vpu/ngraph/utilities.hpp>
+
+#include <ngraph_functions/utils/ngraph_helpers.hpp>
+
+#include <ngraph/shape.hpp>
+#include <ngraph/type/element_type.hpp>
+#include <ngraph/opsets/opset5.hpp>
+
+#include <numeric>
+
+namespace {
+
+using DataType = ngraph::element::Type_t;
+using DataDims = ngraph::Shape;
+
+using GatherNDInputsSetup = std::tuple<
+    ngraph::ParameterVector, // function parameters
+    std::shared_ptr<ngraph::Node>, // data node
+    std::shared_ptr<ngraph::Node>, // indices node
+    std::shared_ptr<ngraph::Node>, // data shape
+    std::shared_ptr<ngraph::Node>>; // indices shape
+
+enum class GatherNDTestMode {
+    DYNAMIC_DATA,
+    DYNAMIC_INDICES,
+    ALL_DYNAMIC
+};
+
+struct GatherNDTestCase {
+    int64_t batchDims;
+    ngraph::Shape dataShape, indicesShape;
+};
+
+
+const auto combinations = testing::Combine(
+        testing::Values(
+                ngraph::element::f16,
+                ngraph::element::f32,
+                ngraph::element::i32,
+                ngraph::element::i64),
+        testing::Values(
+                ngraph::element::i32,
+                ngraph::element::i64),
+        testing::Values(
+                GatherNDTestCase{0, {1000, 256, 10, 15}, {25, 125, 3}},
+                GatherNDTestCase{2, {30, 2, 100, 35}, {30, 2, 3, 1}},
+                GatherNDTestCase{0, {3, 3}, {2, 2}},
+                GatherNDTestCase{0, {3, 5}, {2, 1}},
+                GatherNDTestCase{0, {4, 3, 6}, {2, 1, 2}},
+                GatherNDTestCase{1, {2, 2, 2, 2}, {2, 1}},
+                GatherNDTestCase{2, {2, 2, 2, 2}, {2, 2, 1}},
+                GatherNDTestCase{0, {1, 22743, 4}, {0, 2}}),
+        testing::Values(GatherNDTestMode::DYNAMIC_DATA, GatherNDTestMode::DYNAMIC_INDICES, GatherNDTestMode::ALL_DYNAMIC));
+
+class DynamicToStaticShapeGatherND : public CommonTestUtils::TestsCommon,
+                                     public testing::WithParamInterface<std::tuple<DataType, DataType, GatherNDTestCase, GatherNDTestMode>> {
+public:
+    void SetUp() override {
+        const auto& parameters = GetParam();
+        const auto& dataType = std::get<0>(parameters);
+        const auto& indicesType = std::get<1>(parameters);
+        const auto& gatherNDSetup = std::get<2>(parameters);
+        const auto& gatherNDTestMode = std::get<3>(parameters);
+
+        ngraph::helpers::CompareFunctions(*transform(dataType, indicesType, gatherNDSetup, gatherNDTestMode),
+                                          *reference(dataType, indicesType, gatherNDSetup, gatherNDTestMode));
+    }
+
+protected:
+    GatherNDInputsSetup setupGatherNDInputs(
+            const ngraph::element::Type_t& dataType,
+            const ngraph::element::Type_t& indicesType,
+            const GatherNDTestCase& gatherNDSetup,
+            const GatherNDTestMode& testMode) const {
+       ngraph::ParameterVector params = {std::make_shared<ngraph::opset5::Parameter>(dataType, gatherNDSetup.dataShape),
+                                         std::make_shared<ngraph::opset5::Parameter>(indicesType, gatherNDSetup.indicesShape)};
+
+       std::shared_ptr<ngraph::Node> inputNode, indicesNode, inputShape, indicesShape;
+
+        if (testMode == GatherNDTestMode::DYNAMIC_DATA || testMode == GatherNDTestMode::ALL_DYNAMIC) {
+            params.push_back(std::make_shared<ngraph::opset5::Parameter>(ngraph::element::i64, ngraph::Shape{gatherNDSetup.dataShape.size()}));
+            inputNode = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(params[0], params.back());
+            inputShape = params.back();
+        } else {
+            inputNode = params[0];
+            inputShape = ngraph::opset5::Constant::create(ngraph::element::i64, {gatherNDSetup.dataShape.size()}, gatherNDSetup.dataShape);
+        }
+
+        if (testMode == GatherNDTestMode::DYNAMIC_INDICES || testMode == GatherNDTestMode::ALL_DYNAMIC) {
+            params.push_back(std::make_shared<ngraph::opset5::Parameter>(ngraph::element::i64, ngraph::Shape{gatherNDSetup.indicesShape.size()}));
+            indicesNode = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(params[1], params.back());
+            indicesShape = params.back();
+        } else {
+            indicesNode = params[1];
+            indicesShape = ngraph::opset5::Constant::create(ngraph::element::i64, {gatherNDSetup.indicesShape.size()}, gatherNDSetup.indicesShape);
+        }
+
+        return GatherNDInputsSetup{params, inputNode, indicesNode, inputShape, indicesShape};
+    }
+
+    std::shared_ptr<const ngraph::Function> transform(
+            const ngraph::element::Type_t& dataType,
+            const ngraph::element::Type_t& indicesType,
+            const GatherNDTestCase& gatherNDSetup,
+            const GatherNDTestMode& testMode) const {
+        const auto gatherNDInputsSetup = setupGatherNDInputs(dataType, indicesType, gatherNDSetup, testMode);
+        const auto& dataNode = std::get<1>(gatherNDInputsSetup);
+        const auto& indicesNode = std::get<2>(gatherNDInputsSetup);
+
+        const auto node = std::make_shared<ngraph::opset5::GatherND>(dataNode, indicesNode, gatherNDSetup.batchDims);
+
+        auto outputShape = node->get_output_partial_shape(0);
+        const auto function = std::make_shared<ngraph::Function>(
+                ngraph::NodeVector{node},
+                std::get<0>(gatherNDInputsSetup),
+                "Actual");
+        node->set_output_type(0, dataType, ngraph::PartialShape::dynamic(1));
+
+        const auto transformations = vpu::Transformations{{node->type_info, vpu::dynamicToStaticShapeGatherND}};
+        vpu::DynamicToStaticShape(transformations).run_on_function(function);
+        return function;
+    }
+
+    std::shared_ptr<const ngraph::Function> reference(
+            const ngraph::element::Type_t& dataType,
+            const ngraph::element::Type_t& indicesType,
+            const GatherNDTestCase& gatherNDSetup,
+            const GatherNDTestMode& testMode) const {
+        const auto gatherNDInputsSetup = setupGatherNDInputs(dataType, indicesType, gatherNDSetup, testMode);
+        const auto& dataNode = std::get<1>(gatherNDInputsSetup);
+        const auto& indicesNode = std::get<2>(gatherNDInputsSetup);
+        const auto& dataShape = std::get<3>(gatherNDInputsSetup);
+        const auto& indicesShape = std::get<4>(gatherNDInputsSetup);
+
+        const auto node = std::make_shared<ngraph::opset5::GatherND>(dataNode, indicesNode, gatherNDSetup.batchDims);
+
+        const auto dataDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(dataNode);
+        const auto indicesDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(indicesNode);
+
+        const auto dataShapeRank = gatherNDSetup.dataShape.size();
+        const auto indicesShapeRank = gatherNDSetup.indicesShape.size();
+
+        std::shared_ptr<ngraph::Node> outputShape;
+
+        if (gatherNDSetup.batchDims > 0) {
+            outputShape = std::make_shared<ngraph::opset5::ReduceProd>(
+                vpu::gatherShapeElements(indicesShape, 0, gatherNDSetup.batchDims),
+                ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}),
+                true);
+        }
+
+        if (indicesShapeRank - gatherNDSetup.batchDims - 1 > 0) {
+            const auto indicesShapePart = vpu::gatherShapeElements(
+                indicesShape,
+                gatherNDSetup.batchDims,
+                indicesShapeRank - gatherNDSetup.batchDims - 1);
+            outputShape = outputShape ? std::make_shared<ngraph::opset5::Concat>(ngraph::NodeVector{outputShape, indicesShapePart}, 0) : indicesShapePart;
+        }
+
+        const auto lastIndicesDim = node->get_input_partial_shape(1)[indicesShapeRank - 1].get_length();
+        if (gatherNDSetup.batchDims + lastIndicesDim < dataShapeRank) {
+            const auto dataShapePart = vpu::gatherShapeElements(
+                dataShape,
+                lastIndicesDim + gatherNDSetup.batchDims,
+                dataShapeRank - gatherNDSetup.batchDims - lastIndicesDim);
+            outputShape = outputShape ? std::make_shared<ngraph::opset5::Concat>(ngraph::NodeVector{outputShape, dataShapePart}, 0) : dataShapePart;
+        }
+
+        const auto outputDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, outputShape);
+        return std::make_shared<ngraph::Function>(
+                ngraph::NodeVector{outputDsr},
+                std::get<0>(gatherNDInputsSetup),
+                "Expected");
+    }
+};
+
+TEST_P(DynamicToStaticShapeGatherND, CompareFunctions) {
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeGatherND, combinations);
+
+} // namespace
index 8b1d3aa..d57c556 100644 (file)
@@ -16,6 +16,7 @@
 #include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
 #include <vpu/utils/error.hpp>
 #include <vpu/ngraph/operations/static_shape_topk.hpp>
+#include <vpu/ngraph/utilities.hpp>
 
 namespace {
 
@@ -98,22 +99,13 @@ protected:
 
         ngraph::OutputVector first_shape_part, second_shape_part;
         if (topk_setup.first_split_point) {
-            std::vector<int64_t> idxs(topk_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            first_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            first_shape_part.push_back(vpu::gatherShapeElements(dims, 0, topk_setup.first_split_point));
         }
         if (topk_setup.first_split_point + 1 < topk_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(topk_setup.data_shape.size() - topk_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), topk_setup.second_split_point);
-            second_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            second_shape_part.push_back(vpu::gatherShapeElements(
+                dims,
+                topk_setup.second_split_point,
+                topk_setup.data_shape.size() - topk_setup.second_split_point));
         }
         ngraph::OutputVector results, converted;
         ngraph::Output<ngraph::Node> k_0D = k;
@@ -208,22 +200,13 @@ protected:
 
         ngraph::OutputVector first_shape_part, second_shape_part;
         if (topk_setup.first_split_point) {
-            std::vector<int64_t> idxs(topk_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            first_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            first_shape_part.push_back(vpu::gatherShapeElements(dims, 0, topk_setup.first_split_point));
         }
         if (topk_setup.first_split_point + 1 < topk_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(topk_setup.data_shape.size() - topk_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), topk_setup.second_split_point);
-            second_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            second_shape_part.push_back(vpu::gatherShapeElements(
+                dims,
+                topk_setup.second_split_point,
+                topk_setup.data_shape.size() - topk_setup.second_split_point));
         }
         ngraph::Output<ngraph::Node> k_0D = k;
         if (node->get_input_element_type(1)!= ngraph::element::i64) {
index 79e913d..4ca9e9c 100644 (file)
@@ -15,6 +15,7 @@
 #include <ngraph_functions/utils/ngraph_helpers.hpp>
 #include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
 #include <vpu/utils/error.hpp>
+#include <vpu/ngraph/utilities.hpp>
 
 namespace {
 
@@ -102,22 +103,13 @@ protected:
 
         ngraph::OutputVector first_shape_part, second_shape_part;
         if (variadic_split_setup.first_split_point) {
-            std::vector<int64_t> idxs(variadic_split_setup.first_split_point);
-            std::iota(idxs.begin(), idxs.end(), 0);
-            first_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            first_shape_part.push_back(vpu::gatherShapeElements(dims, 0, variadic_split_setup.first_split_point));
         }
         if (variadic_split_setup.first_split_point + 1 < variadic_split_setup.data_shape.size()) {
-            std::vector<int64_t> idxs(variadic_split_setup.data_shape.size() - variadic_split_setup.second_split_point);
-            std::iota(idxs.begin(), idxs.end(), variadic_split_setup.second_split_point);
-            second_shape_part.push_back(
-                    std::make_shared<ngraph::opset3::Gather>(
-                            dims,
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {idxs.size()}, idxs),
-                            ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {0})));
+            second_shape_part.push_back(vpu::gatherShapeElements(
+                dims,
+                variadic_split_setup.second_split_point,
+                variadic_split_setup.data_shape.size() - variadic_split_setup.second_split_point));
         }
         ngraph::NodeVector results;
         for (auto i = 0; i < variadic_split_setup.split_lengths.size(); ++i) {
index f1d64f0..759ea15 100644 (file)
@@ -29,5 +29,7 @@ std::vector<std::string> disabledTestPatterns() {
         R"(.*(ConstantResultSubgraphTest).*)",
         // TODO: Issue: 42828
         R"(.*DSR_NonMaxSuppression.*NBoxes=(5|20|200).*)",
+        // TODO: Issue: 42721
+        R"(.*(DSR_GatherND).*)",
     };
 }
diff --git a/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_gather_nd.cpp b/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_gather_nd.cpp
new file mode 100644 (file)
index 0000000..4821654
--- /dev/null
@@ -0,0 +1,153 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "dsr_tests_common.hpp"
+
+#include <functional_test_utils/layer_test_utils.hpp>
+#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
+
+namespace {
+
+using namespace LayerTestsUtils::vpu;
+
+const std::vector<ngraph::element::Type> dataTypeVector = {
+    ngraph::element::f16,
+    ngraph::element::f32,
+    ngraph::element::i32,
+};
+
+const std::vector<ngraph::element::Type> idxTypeVector = {
+    ngraph::element::i32,
+};
+
+struct GatherNDTestCase {
+    DataShapeWithUpperBound dataShape;
+    DataShapeWithUpperBound indicesShape;
+    int64_t batchDims;
+};
+
+using GatherNDParameters = std::tuple<
+    DataType,                     // data type
+    DataType,                     // indices type
+    GatherNDTestCase,             // GatherND parameters
+    LayerTestsUtils::TargetDevice // device name
+>;
+
+class DSR_GatherNDBase : public testing::WithParamInterface<GatherNDParameters>,
+                         public DSR_TestsCommon {
+protected:
+    std::set<std::string> m_indicesInputNames;
+
+    InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override {
+        const auto& name = info.name();
+        if (m_indicesInputNames.count(name)) {
+            const auto& parameters = GetParam();
+            const auto& gatherSetup = std::get<2>(parameters);
+            const auto& lastIndicesDim = gatherSetup.indicesShape.shape.back();
+
+            const auto endValue = std::min_element(gatherSetup.dataShape.shape.begin() + gatherSetup.batchDims,
+                 gatherSetup.dataShape.shape.begin() + gatherSetup.batchDims + lastIndicesDim);
+
+            return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), *endValue, 0);
+        }
+        return DSR_TestsCommon::GenerateInput(info);
+    }
+
+    void SetUp() override {
+        DSR_TestsCommon::SetUp();
+        SetRefMode(LayerTestsUtils::RefMode::INTERPRETER);
+    }
+};
+
+class DSR_GatherNDDynamicDataStaticIdx : public DSR_GatherNDBase {
+protected:
+    std::shared_ptr<ngraph::Node> createTestedOp() override {
+        const auto& parameters = GetParam();
+        const auto& inDataType = std::get<0>(parameters);
+        const auto& idxType = std::get<1>(parameters);
+        const auto& gatherSetup = std::get<2>(parameters);
+        targetDevice = std::get<3>(parameters);
+
+        const auto inputDataSubgraph = createInputSubgraphWithDSR(inDataType, gatherSetup.dataShape);
+
+        const auto indicesParam = createParameter(idxType, gatherSetup.indicesShape.shape);
+        m_indicesInputNames.insert(indicesParam->get_friendly_name());
+
+        return std::make_shared<ngraph::opset5::GatherND>(inputDataSubgraph, indicesParam, gatherSetup.batchDims);
+    }
+};
+
+TEST_P(DSR_GatherNDDynamicDataStaticIdx, CompareWithReference) {
+    Run();
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_DynamicGatherData, DSR_GatherNDDynamicDataStaticIdx, testing::Combine(
+    testing::ValuesIn(dataTypeVector),
+    testing::ValuesIn(idxTypeVector),
+    testing::Values(
+          GatherNDTestCase{DataShapeWithUpperBound{{1, 1000, 4}, {1, 22734, 4}}, DataShapeWithUpperBound{{300, 2}, {}}, 0},
+          GatherNDTestCase{DataShapeWithUpperBound{{1, 500, 4}, {1, 22734, 4}}, DataShapeWithUpperBound{{300, 2}, {}}, 0}),
+    testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
+
+
+class DSR_GatherNDStaticDataDynamicIdx : public DSR_GatherNDBase {
+protected:
+    std::shared_ptr<ngraph::Node> createTestedOp() override {
+        const auto& parameters = GetParam();
+        const auto& inDataType = std::get<0>(parameters);
+        const auto& idxType = std::get<1>(parameters);
+        const auto& gatherSetup = std::get<2>(parameters);
+        targetDevice = std::get<3>(parameters);
+
+        const auto dataParam = createParameter(inDataType, gatherSetup.dataShape.shape);
+        const auto inputIdxSubgraph = createInputSubgraphWithDSR(idxType, gatherSetup.indicesShape);
+        m_indicesInputNames.insert(inputIdxSubgraph->get_input_node_shared_ptr(0)->get_friendly_name());
+
+        return std::make_shared<ngraph::opset5::GatherND>(dataParam, inputIdxSubgraph, gatherSetup.batchDims);
+    }
+};
+
+TEST_P(DSR_GatherNDStaticDataDynamicIdx, CompareWithReference) {
+    Run();
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_DynamicGatherIdx, DSR_GatherNDStaticDataDynamicIdx, testing::Combine(
+    testing::ValuesIn(dataTypeVector),
+    testing::ValuesIn(idxTypeVector),
+    testing::Values(
+        GatherNDTestCase{DataShapeWithUpperBound{{1, 22734, 4}, {}}, DataShapeWithUpperBound{{100, 2}, {300, 2}}, 0},
+        GatherNDTestCase{DataShapeWithUpperBound{{1, 22734, 4}, {}}, DataShapeWithUpperBound{{1, 2}, {300, 2}}, 0}),
+    testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
+
+
+class DSR_GatherNDDynamicDataDynamicIdx : public DSR_GatherNDBase {
+protected:
+    std::shared_ptr<ngraph::Node> createTestedOp() override {
+        const auto& parameters = GetParam();
+        const auto& inDataType = std::get<0>(parameters);
+        const auto& idxType = std::get<1>(parameters);
+        const auto& gatherSetup = std::get<2>(parameters);
+        targetDevice = std::get<3>(parameters);
+
+        const auto inputDataSubgraph = createInputSubgraphWithDSR(inDataType, gatherSetup.dataShape);
+        const auto inputIdxSubgraph = createInputSubgraphWithDSR(idxType, gatherSetup.indicesShape);
+        m_indicesInputNames.insert(inputIdxSubgraph->get_input_node_shared_ptr(0)->get_friendly_name());
+
+        return std::make_shared<ngraph::opset5::GatherND>(inputDataSubgraph, inputIdxSubgraph, gatherSetup.batchDims);
+    }
+};
+
+TEST_P(DSR_GatherNDDynamicDataDynamicIdx, CompareWithReference) {
+    Run();
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_DynamicGather, DSR_GatherNDDynamicDataDynamicIdx, testing::Combine(
+    testing::ValuesIn(dataTypeVector),
+    testing::ValuesIn(idxTypeVector),
+    testing::Values(
+            GatherNDTestCase{DataShapeWithUpperBound{{1, 1000, 4}, {1, 22734, 4}}, DataShapeWithUpperBound{{100, 2}, {300, 2}}, 0},
+            GatherNDTestCase{DataShapeWithUpperBound{{1, 500, 4}, {1, 22734, 4}}, DataShapeWithUpperBound{{1, 2}, {300, 2}}, 0}),
+    testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
+
+}  // namespace