From e6a36123db78eea265bafbe0276ad6abf09218c1 Mon Sep 17 00:00:00 2001 From: Vladimir Gavrilov Date: Thu, 3 Sep 2020 19:34:35 +0300 Subject: [PATCH] Reverted conversion of Resize-10 to Interpolate-4 in ONNX Importer. (#2048) * Reverted conversion of Resize-10 to Interpolate-4. Now Resize with opset version < 11 generates Interpolate-1 again. * Corrected tests. --- ngraph/frontend/onnx_import/src/op/resize.cpp | 53 ++++++++++++++++++---- ngraph/test/onnx/onnx_import.in.cpp | 4 +- ngraph/test/runtime/interpreter/unit_test.manifest | 3 ++ 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/ngraph/frontend/onnx_import/src/op/resize.cpp b/ngraph/frontend/onnx_import/src/op/resize.cpp index 9fad34b..601f507 100644 --- a/ngraph/frontend/onnx_import/src/op/resize.cpp +++ b/ngraph/frontend/onnx_import/src/op/resize.cpp @@ -224,6 +224,44 @@ namespace ngraph return scales; } + + OutputVector build_resize(const Node& node, + const std::shared_ptr& output_shape, + const AxisSet& axes) + { + const auto mode = node.get_attribute_value("mode", "nearest"); + + std::unordered_set supported_modes = {"nearest", "linear"}; + bool is_mode_supported = + (std::find(supported_modes.begin(), supported_modes.end(), mode) != + supported_modes.end()); + + if (!is_mode_supported) + { + std::string supported_modes_str = ""; + for (const auto& mode_name : supported_modes) + { + supported_modes_str += (mode_name + ", "); + } + CHECK_VALID_NODE(node, + is_mode_supported, + mode, + " - this type of interpolation mode is not supported." + " Choose one of the following modes: ", + supported_modes_str); + } + + auto attrs = ngraph::op::v0::InterpolateAttrs(); + attrs.axes = axes; + attrs.mode = mode; + attrs.align_corners = false; + + const auto inputs = node.get_ng_inputs(); + const auto& data = inputs.at(0); + + return { + std::make_shared(data, output_shape, attrs)}; + } } namespace set_11 @@ -279,20 +317,17 @@ namespace ngraph const auto& data_shape = data.get_partial_shape(); const auto& scales_shape = scales.get_partial_shape(); - auto attrs = get_resize_attrs(node); - if (attrs.mode == InterpolateMode::linear_onnx) - { - attrs.coordinate_transformation_mode = Transform_mode::asymmetric; - } - CHECK_VALID_NODE( node, (scales_shape.is_static() || data_shape.rank().is_static()), - " Data rank or shape of Scales input is required to be static."); + " Data rank or shape of scales input is required to be static."); + + size_t axes_size = scales_shape.is_static() ? scales_shape[0].get_length() + : data_shape.rank().get_length(); const auto output_shape = calculate_output_shape_based_on_scales(data, scales); - return {std::make_shared( - data, output_shape, scales, attrs)}; + return build_resize( + node, output_shape, AxisSet(common::get_monotonic_range(axes_size))); } } // namespace set_1 diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 9607101..860e894 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -1077,8 +1077,8 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_resize10_import_only) Shape expected_output_shape{4, 6, 6, 4}; EXPECT_EQ(resize_fn->get_output_size(), 1); EXPECT_EQ(resize_fn->get_output_shape(0), expected_output_shape); - EXPECT_EQ(count_ops_of_type(resize_fn), 1); - EXPECT_EQ(count_ops_of_type(resize_fn), 2); + EXPECT_EQ(count_ops_of_type(resize_fn), 1); + EXPECT_EQ(count_ops_of_type(resize_fn), 1); } NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_nearest) diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index eb1a78b..cb84aca 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -14,6 +14,9 @@ reduce_sum_keep_large_1d_to_scalar INTERPRETER.onnx_resize11_scales_nearest_asymmetric_floor_dynamic_sizes INTERPRETER.onnx_resize11_scales_down_linear INTERPRETER.interpolate_down_scales_const_linear +INTERPRETER.onnx_resize10_up_scales_const_nearest +INTERPRETER.onnx_resize10_up_scales_const_linear +INTERPRETER.onnx_resize10_down_scales_const_nearest # Failed in MacOS: INTERPRETER.onnx_resize11_sizes_nearest_asymmetric_floor -- 2.7.4