[NGRAPH] Fix ReduceSum decompose pass
authorAlexander Peskov <alexander.peskov@intel.com>
Mon, 31 Aug 2020 12:50:20 +0000 (15:50 +0300)
committerAlexander Peskov <alexander.peskov@intel.com>
Wed, 9 Sep 2020 09:41:31 +0000 (12:41 +0300)
Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp

index 94386d0..23dffdb 100644 (file)
@@ -231,21 +231,34 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
             input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
             new_ops.push_back(input.get_node_shared_ptr());
         } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
+            // Fallback to real type because of potential data loss in case of integer AVG Pool
+            bool fallback_to_real = input.get_element_type().is_integral();
+
+            if (fallback_to_real) {
+                input = std::make_shared<ngraph::opset1::Convert>(input, ngraph::element::f32);
+                new_ops.push_back(input.get_node_shared_ptr());
+            }
+
             input = std::make_shared<ngraph::opset1::AvgPool>(input,
-                                                              strides,
-                                                              pads_begin,
-                                                              pads_end,
-                                                              kernel,
-                                                              true,
-                                                              ngraph::op::RoundingType::FLOOR);
+                    strides,
+                    pads_begin,
+                    pads_end,
+                    kernel,
+                    true,
+                    ngraph::op::RoundingType::FLOOR);
 
             input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
             new_ops.push_back(input.get_node_shared_ptr());
 
             input = std::make_shared<ngraph::opset1::Multiply>(input,
-                    ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
+                    ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
             input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
             new_ops.push_back(input.get_node_shared_ptr());
+
+            if (fallback_to_real) {
+                input = std::make_shared<ngraph::opset1::Convert>(input, reduce->output(0).get_element_type());
+                new_ops.push_back(input.get_node_shared_ptr());
+            }
         } else {
             return false;
         }