input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
new_ops.push_back(input.get_node_shared_ptr());
} else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
+ // Fallback to real type because of potential data loss in case of integer AVG Pool
+ bool fallback_to_real = input.get_element_type().is_integral();
+
+ if (fallback_to_real) {
+ input = std::make_shared<ngraph::opset1::Convert>(input, ngraph::element::f32);
+ new_ops.push_back(input.get_node_shared_ptr());
+ }
+
input = std::make_shared<ngraph::opset1::AvgPool>(input,
- strides,
- pads_begin,
- pads_end,
- kernel,
- true,
- ngraph::op::RoundingType::FLOOR);
+ strides,
+ pads_begin,
+ pads_end,
+ kernel,
+ true,
+ ngraph::op::RoundingType::FLOOR);
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
new_ops.push_back(input.get_node_shared_ptr());
input = std::make_shared<ngraph::opset1::Multiply>(input,
- ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
+ ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
new_ops.push_back(input.get_node_shared_ptr());
+
+ if (fallback_to_real) {
+ input = std::make_shared<ngraph::opset1::Convert>(input, reduce->output(0).get_element_type());
+ new_ops.push_back(input.get_node_shared_ptr());
+ }
} else {
return false;
}