Fixes for Mask-RCNN conversion (#654)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Thu, 28 May 2020 11:31:42 +0000 (14:31 +0300)
committerGitHub <noreply@github.com>
Thu, 28 May 2020 11:31:42 +0000 (14:31 +0300)
* Fixed ONNX Mask-RCNN conversion

* Fixed validate_and_infet_types for NMS ops: added check for number of connected inputs

* Updated NMS ops to properly handle optional input with index 2

* Fixed typo in the implementation

model-optimizer/extensions/front/onnx/mask_rcnn_conversion.py
ngraph/src/ngraph/op/non_max_suppression.cpp

index 4b1702d..ee26bad 100644 (file)
@@ -16,6 +16,8 @@
 
 import numpy as np
 
+from extensions.front.onnx.softmaxONNX_to_softmax import SoftmaxONNXFrontReplacer
+from extensions.ops.Cast import Cast
 from extensions.ops.detectionoutput_onnx import ExperimentalDetectronDetectionOutput
 from extensions.ops.parameter import Parameter
 from extensions.ops.roifeatureextractor_onnx import ExperimentalDetectronROIFeatureExtractor
@@ -29,7 +31,7 @@ from mo.ops.reshape import Reshape
 input_fpn_heads = ('486', '454', '422', '390')
 
 
-class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
+class ONNXMaskRCNNTransformation(FrontReplacementFromConfigFileGeneral):
     """
     This transformation performs 3 actions:
     1. Replaces a sub-graph calculating ROIAlign over FPN heads with a single ExperimentalDetectronROIFeatureExtractor
@@ -42,6 +44,11 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
     """
     replacement_id = 'ONNXMaskRCNNReplacement'
 
+    def run_before(self):
+        # the node "2774" which is used in this transformation is of op SoftMaxONNX. But operations of op SoftMaxONNX
+        # will be replaced with a transformation SoftmaxONNXFrontReplacer
+        return [SoftmaxONNXFrontReplacer]
+
     def transform_graph(self, graph: Graph, replacement_descriptions: dict):
         insert_ExperimentalDetectronROIFeatureExtractor2(graph)
         insert_do(graph, replacement_descriptions)
@@ -80,6 +87,9 @@ def insert_do(graph: Graph, replacement_descriptions):
     old_do_output_nodes = [Node(graph, node_id) for node_id in do_outputs]
     for old_node, new_port in zip(old_do_output_nodes, do_output_ports):
         old_node.out_port(0).get_connection().set_source(new_port)
+    # the consumer of the second output port of the ExperimentalDetectronDetectionOutput is the Mul node which second
+    # input is of type int64 so it is necessary to insert Cast to have data types match
+    do_node.out_port(1).get_connection().insert_node(Cast(graph, {'dst_type': np.int64}).create_node())
 
 
 def insert_ExperimentalDetectronROIFeatureExtractor1(graph: Graph):
index 073d718..17c9cb7 100644 (file)
@@ -61,7 +61,7 @@ shared_ptr<Node>
 {
     check_new_args_count(this, new_args);
     NODE_VALIDATION_CHECK(
-        this, new_args.size() >= 3 && new_args.size() <= 5, "Number of inputs must be 3, 4 or 5");
+        this, new_args.size() >= 2 && new_args.size() <= 5, "Number of inputs must be 2, 3, 4 or 5");
     if (new_args.size() == 5)
     {
         return make_shared<op::v1::NonMaxSuppression>(new_args.at(0),
@@ -83,7 +83,7 @@ shared_ptr<Node>
             m_box_encoding,
             m_sort_result_descending);
     }
-    else
+    else if (new_args.size() == 3)
     {
         return make_shared<op::v1::NonMaxSuppression>(
             new_args.at(0),
@@ -94,6 +94,17 @@ shared_ptr<Node>
             m_box_encoding,
             m_sort_result_descending);
     }
+    else
+    {
+        return make_shared<op::v1::NonMaxSuppression>(
+            new_args.at(0),
+            new_args.at(1),
+            op::Constant::create(element::i32, Shape{}, {0}),
+            op::Constant::create(element::f32, Shape{}, {.0f}),
+            op::Constant::create(element::f32, Shape{}, {.0f}),
+            m_box_encoding,
+            m_sort_result_descending);
+    }
 }
 
 bool ngraph::op::v1::NonMaxSuppression::visit_attributes(AttributeVisitor& visitor)
@@ -133,24 +144,30 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
                           "Expected a 3D tensor for the 'scores' input. Got: ",
                           scores_ps);
 
-    const auto max_boxes_ps = get_input_partial_shape(2);
-    NODE_VALIDATION_CHECK(this,
-                          max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
-                          "Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
-                          max_boxes_ps);
+    if (get_inputs().size() >= 3) {
+        const auto max_boxes_ps = get_input_partial_shape(2);
+        NODE_VALIDATION_CHECK(this,
+                              max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
+                              "Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
+                              max_boxes_ps);
+    }
 
-    const auto iou_threshold_ps = get_input_partial_shape(3);
-    NODE_VALIDATION_CHECK(this,
-                          iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
-                          "Expected a scalar for the 'iou_threshold' input. Got: ",
-                          iou_threshold_ps);
+    if (get_inputs().size() >= 4) {
+        const auto iou_threshold_ps = get_input_partial_shape(3);
+        NODE_VALIDATION_CHECK(this,
+                              iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
+                              "Expected a scalar for the 'iou_threshold' input. Got: ",
+                              iou_threshold_ps);
+    }
 
-    const auto score_threshold_ps = get_input_partial_shape(4);
-    NODE_VALIDATION_CHECK(this,
-                          score_threshold_ps.is_dynamic() ||
-                              is_scalar(score_threshold_ps.to_shape()),
-                          "Expected a scalar for the 'score_threshold' input. Got: ",
-                          score_threshold_ps);
+    if (get_inputs().size() >= 5) {
+        const auto score_threshold_ps = get_input_partial_shape(4);
+        NODE_VALIDATION_CHECK(this,
+                              score_threshold_ps.is_dynamic() ||
+                                  is_scalar(score_threshold_ps.to_shape()),
+                              "Expected a scalar for the 'score_threshold' input. Got: ",
+                              score_threshold_ps);
+    }
 
     const auto num_batches_boxes = boxes_ps[0];
     const auto num_batches_scores = scores_ps[0];
@@ -268,7 +285,7 @@ shared_ptr<Node>
 {
     check_new_args_count(this, new_args);
     NODE_VALIDATION_CHECK(
-        this, new_args.size() >= 3 && new_args.size() <= 5, "Number of inputs must be 3, 4 or 5");
+        this, new_args.size() >= 2 && new_args.size() <= 5, "Number of inputs must be 2, 3, 4 or 5");
     if (new_args.size() == 5)
     {
         return make_shared<op::v3::NonMaxSuppression>(new_args.at(0),
@@ -292,7 +309,7 @@ shared_ptr<Node>
             m_sort_result_descending,
             m_output_type);
     }
-    else
+    else if (new_args.size() == 3)
     {
         return make_shared<op::v3::NonMaxSuppression>(
             new_args.at(0),
@@ -301,6 +318,17 @@ shared_ptr<Node>
             op::Constant::create(element::f32, Shape{}, {.0f}),
             op::Constant::create(element::f32, Shape{}, {.0f}),
             m_box_encoding,
+            m_sort_result_descending);
+    }
+    else
+    {
+        return make_shared<op::v3::NonMaxSuppression>(
+            new_args.at(0),
+            new_args.at(1),
+            op::Constant::create(element::i32, Shape{}, {0}),
+            op::Constant::create(element::f32, Shape{}, {.0f}),
+            op::Constant::create(element::f32, Shape{}, {.0f}),
+            m_box_encoding,
             m_sort_result_descending,
             m_output_type);
     }
@@ -343,24 +371,30 @@ void op::v3::NonMaxSuppression::validate_and_infer_types()
                           "Expected a 3D tensor for the 'scores' input. Got: ",
                           scores_ps);
 
-    const auto max_boxes_ps = get_input_partial_shape(2);
-    NODE_VALIDATION_CHECK(this,
-                          max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
-                          "Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
-                          max_boxes_ps);
+    if (get_inputs().size() >= 3) {
+        const auto max_boxes_ps = get_input_partial_shape(2);
+        NODE_VALIDATION_CHECK(this,
+                              max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
+                              "Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
+                              max_boxes_ps);
+    }
 
-    const auto iou_threshold_ps = get_input_partial_shape(3);
-    NODE_VALIDATION_CHECK(this,
-                          iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
-                          "Expected a scalar for the 'iou_threshold' input. Got: ",
-                          iou_threshold_ps);
+    if (get_inputs().size() >= 4) {
+        const auto iou_threshold_ps = get_input_partial_shape(3);
+        NODE_VALIDATION_CHECK(this,
+                              iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
+                              "Expected a scalar for the 'iou_threshold' input. Got: ",
+                              iou_threshold_ps);
+    }
 
-    const auto score_threshold_ps = get_input_partial_shape(4);
-    NODE_VALIDATION_CHECK(this,
-                          score_threshold_ps.is_dynamic() ||
-                              is_scalar(score_threshold_ps.to_shape()),
-                          "Expected a scalar for the 'score_threshold' input. Got: ",
-                          score_threshold_ps);
+    if (get_inputs().size() >= 5) {
+        const auto score_threshold_ps = get_input_partial_shape(4);
+        NODE_VALIDATION_CHECK(this,
+                              score_threshold_ps.is_dynamic() ||
+                                  is_scalar(score_threshold_ps.to_shape()),
+                              "Expected a scalar for the 'score_threshold' input. Got: ",
+                              score_threshold_ps);
+    }
 
     const auto num_batches_boxes = boxes_ps[0];
     const auto num_batches_scores = scores_ps[0];