Fixed Precision Conversion for GenericIE operation type (#1917)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Wed, 26 Aug 2020 08:31:40 +0000 (11:31 +0300)
committerGitHub <noreply@github.com>
Wed, 26 Aug 2020 08:31:40 +0000 (11:31 +0300)
* Fixed Precision Conversion for GenericIE operation type

* changed logic for setting output type in GenericIE

inference-engine/src/inference_engine/generic_ie.cpp
inference-engine/src/transformations/src/transformations/convert_precision.cpp

index 9e83e41..2373779 100644 (file)
@@ -72,7 +72,17 @@ std::shared_ptr<ngraph::Node> ngraph::op::GenericIE::clone_with_new_inputs(const
 }
 
 void ngraph::op::GenericIE::validate_and_infer_types() {
-    // Try to find extension with shape inference inplementation and apply it
+    // This function returns precision based on existing precision and
+    // precision that was set in outputs vector
+    auto get_precision = [this](const size_t index) -> element::Type {
+        if (index >= get_output_size() ||
+            get_output_element_type(index) == element::dynamic ||
+            get_output_element_type(index) == element::undefined) {
+            return InferenceEngine::details::convertPrecision(outputs[index].precision);
+        }
+        return get_output_element_type(index);
+    };
+    // Try to find extension with shape inference implementation and apply it
     for (const auto& ext : extensions) {
         IE_SUPPRESS_DEPRECATED_START
         InferenceEngine::IShapeInferImpl::Ptr impl;
@@ -89,10 +99,8 @@ void ngraph::op::GenericIE::validate_and_infer_types() {
 
             if (!this_input_shape.is_static()) {
                 // Set dynamic output shapes if input shapes are not defined
-                for (size_t i = 0; i < outputs.size(); i++) {
-                    const auto& port = outputs[i];
-                    auto type = InferenceEngine::details::convertPrecision(port.precision);
-                    set_output_type(i, type, PartialShape::dynamic());
+                for (size_t output_index = 0; output_index < outputs.size(); output_index++) {
+                    set_output_type(output_index, get_precision(output_index), PartialShape::dynamic());
                 }
                 return;
             }
@@ -131,13 +139,9 @@ void ngraph::op::GenericIE::validate_and_infer_types() {
 
         if (ret != InferenceEngine::StatusCode::OK || outShapes.size() != outputs.size()) continue;
 
-        for (size_t i = 0; i < outputs.size(); i++) {
-            const auto& port = outputs[i];
-            ngraph::Shape outShape(outShapes[i]);
-            auto type = InferenceEngine::details::convertPrecision(port.precision);
-            set_output_type(i, type, PartialShape(outShape));
+        for (size_t output_index = 0; output_index < outputs.size(); output_index++) {
+            set_output_type(output_index, get_precision(output_index), Shape(outShapes[output_index]));
         }
-
         return;
     }
 
@@ -146,11 +150,8 @@ void ngraph::op::GenericIE::validate_and_infer_types() {
     if (initialized < 1) {
         if (outputs.size())
             set_output_size(outputs.size());
-        for (size_t i = 0; i < outputs.size(); i++) {
-            const auto& port = outputs[i];
-            ngraph::Shape outShape(port.dims);
-            auto type = InferenceEngine::details::convertPrecision(port.precision);
-            set_output_type(i, type, PartialShape(outShape));
+        for (size_t output_index = 0; output_index < outputs.size(); output_index++) {
+            set_output_type(output_index, get_precision(output_index), Shape(outputs[output_index].dims));
         }
         initialized++;
     } else if (reshape) {
index 830085f..477011d 100644 (file)
@@ -127,9 +127,10 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
                     break;
                 }
 
-                // If node type in map and convert can be fused into node we skip Convert creation
+                // Check that node type exists in map and we can fuse type into node
                 if (type_to_fuse.count(node->get_type_info()) &&
                     type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index())) {
+                    // We need to break if original node was replaced
                     break;
                 }
             }
@@ -261,7 +262,8 @@ bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::elemen
 
 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
     node->set_output_type(idx, to, node->output(idx).get_partial_shape());
-    return true;
+    // return false as we do not replace original node
+    return false;
 }
 
 bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {