[IE][VPU] Fix NMS DTS (#2880)
authorAndrew Bakalin <andrew.bakalin@intel.com>
Thu, 5 Nov 2020 10:33:16 +0000 (13:33 +0300)
committerGitHub <noreply@github.com>
Thu, 5 Nov 2020 10:33:16 +0000 (13:33 +0300)
Add a new constructor to fix absent NMS-5 inputs that will be introduced after #2450 will be merged.

inference-engine/src/vpu/common/include/vpu/ngraph/operations/static_shape_non_maximum_suppression.hpp
inference-engine/src/vpu/common/src/ngraph/operations/static_shape_non_maximum_suppression.cpp
inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.cpp

index acc0b89..a31115d 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <ngraph/node.hpp>
 #include <legacy/ngraph_ops/nms_ie.hpp>
+#include <ngraph/opsets/opset5.hpp>
 
 #include <memory>
 #include <vector>
@@ -17,6 +18,8 @@ public:
     static constexpr NodeTypeInfo type_info{"StaticShapeNonMaxSuppression", 0};
     const NodeTypeInfo& get_type_info() const override { return type_info; }
 
+    explicit StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms);
+
     StaticShapeNonMaxSuppression(const Output<Node>& boxes,
                                  const Output<Node>& scores,
                                  const Output<Node>& maxOutputBoxesPerClass,
index 94e1a96..18c8b3b 100644 (file)
@@ -12,6 +12,18 @@ namespace ngraph { namespace vpu { namespace op {
 
 constexpr NodeTypeInfo StaticShapeNonMaxSuppression::type_info;
 
+StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms)
+        : StaticShapeNonMaxSuppression(
+        nms.input_value(0),
+        nms.input_value(1),
+        nms.get_input_size() > 2 ? nms.input_value(2) : ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}),
+        nms.get_input_size() > 3 ? nms.input_value(3) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
+        nms.get_input_size() > 4 ? nms.input_value(4) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
+        nms.get_input_size() > 5 ? nms.input_value(5) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
+        nms.get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
+        nms.get_sort_result_descending(),
+        nms.get_output_type()) {}
+
 StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(
         const Output<Node>& boxes,
         const Output<Node>& scores,
index 02145ca..1a36939 100644 (file)
@@ -21,16 +21,7 @@ void dynamicToStaticNonMaxSuppression(std::shared_ptr<ngraph::Node> node) {
     VPU_THROW_UNLESS(nms, "dynamicToStaticNonMaxSuppression transformation for {} of type {} expects {} as node for replacement",
                      node->get_friendly_name(), node->get_type_info(), ngraph::opset5::NonMaxSuppression::type_info);
 
-    auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(
-            nms->input_value(0),
-            nms->input_value(1),
-            nms->input_value(2),
-            nms->input_value(3),
-            nms->input_value(4),
-            nms->input_value(5),
-            nms->get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
-            nms->get_sort_result_descending(),
-            nms->get_output_type());
+    auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(*nms);
 
     auto dsrIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
             staticShapeNMS->output(0), staticShapeNMS->output(2));