From ae6cfe12bb15f822d83618cea3b287dc2c6ffaa3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Tomasz=20Do=C5=82bniak?= Date: Tue, 7 Jul 2020 13:08:08 +0200 Subject: [PATCH] ONNX DequantizeLinear op (#1123) * DequantizeLinear 10 as a subgraph * Enable DequantizeLinear from opset 13 * Exclude the failing tests * Re-enable dequantize linear UTs * Validation helper --- .../src/ngraph/frontend/onnx_import/CMakeLists.txt | 6 +- .../frontend/onnx_import/op/dequantize_linear.cpp | 185 +++++++++++++++++---- .../frontend/onnx_import/op/dequantize_linear.hpp | 5 + .../src/ngraph/frontend/onnx_import/ops_bridge.cpp | 5 +- .../ngraph/frontend/onnx_import/utils/common.cpp | 20 +++ .../ngraph/frontend/onnx_import/utils/common.hpp | 10 ++ .../test/models/onnx/dequantize_linear_2.prototxt | 2 +- .../test/models/onnx/dequantize_linear_3.prototxt | 2 +- .../test/models/onnx/dequantize_linear_4.prototxt | 2 +- .../test/models/onnx/dequantize_linear_5.prototxt | 2 +- ngraph/test/onnx/onnx_import_quant.in.cpp | 8 +- ngraph/test/runtime/ie/unit_test.manifest | 14 +- ngraph/test/runtime/interpreter/unit_test.manifest | 7 - 13 files changed, 209 insertions(+), 59 deletions(-) diff --git a/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt b/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt index a1e5d40..82b6a18 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt +++ b/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. # ****************************************************************************** -set(ONNX_OPSET_VERSION 11 CACHE INTERNAL "Supported version of ONNX operator set") +set(ONNX_OPSET_VERSION 13 CACHE INTERNAL "Supported version of ONNX operator set") add_library(onnx_importer SHARED core/node.cpp @@ -86,8 +86,8 @@ add_library(onnx_importer SHARED op/cum_sum.hpp op/depth_to_space.cpp op/depth_to_space.hpp - # op/dequantize_linear.cpp - # op/dequantize_linear.hpp + op/dequantize_linear.cpp + op/dequantize_linear.hpp op/div.hpp op/dropout.hpp op/elu.cpp diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp index 93d4851..aebea5a 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp @@ -23,9 +23,9 @@ #include "ngraph/builder/make_constant.hpp" #include "ngraph/op/convert.hpp" #include "ngraph/op/dequantize.hpp" -#include "ngraph/opsets/opset0.hpp" #include "ngraph/shape.hpp" #include "ngraph/validation_util.hpp" +#include "utils/common.hpp" namespace ngraph { @@ -33,54 +33,177 @@ namespace ngraph { namespace op { - namespace set_1 + namespace { - NodeVector dequantize_linear(const Node& node) + std::shared_ptr get_zero_point(const NodeVector& inputs) { - NodeVector inputs{node.get_ng_inputs()}; - std::shared_ptr x = inputs.at(0); - std::shared_ptr x_scale = inputs.at(1); - std::shared_ptr zero_point; - if (inputs.size() == 3 && !inputs.at(2)->is_null()) + if (inputs.size() == 3 && !inputs[2]->is_null()) { - zero_point = inputs.at(2); + auto zero_point = inputs[2]; + + if (zero_point->get_element_type() != element::f32) + { + zero_point = + std::make_shared(zero_point, element::f32); + } + + return zero_point; } else { - zero_point = - ngraph::builder::make_constant(x->get_element_type(), Shape{}, 0); + return default_opset::Constant::create(element::f32, Shape{}, {0}); } + } + } + namespace set_1 + { + NodeVector dequantize_linear(const Node& node) + { + const NodeVector inputs{node.get_ng_inputs()}; - Shape y_scale_shape = x_scale->get_shape(); - Shape y_zero_point_shape = zero_point->get_shape(); + NGRAPH_CHECK( + 2 <= inputs.size() && inputs.size() <= 3, + "The DequantizeLinear op expects 2 required and one optional input. Got: ", + inputs.size()); - // get axis twice with two default values to see if it is set - int64_t axis_0{node.get_attribute_value("axis", 0)}; - int64_t axis_1{node.get_attribute_value("axis", 1)}; + const auto x = inputs[0]; + const auto scale = inputs[1]; + const auto zero_point = get_zero_point(inputs); - const auto data_rank = x->get_output_partial_shape(0).rank(); - AxisSet axes; - // if axis attribute is set - if (axis_0 == axis_1) + common::validate_scalar_input("Dequantization scale", scale, {element::f32}); + common::validate_scalar_input("Zero point", zero_point); + + const auto converted_x = + std::make_shared(x, element::f32); + + return {std::make_shared( + std::make_shared(converted_x, zero_point), scale)}; + } + } + + namespace set_13 + { + namespace + { + void validate_scale(const std::shared_ptr scale, + const std::shared_ptr x, + const int64_t axis) { - axes.insert( - ngraph::normalize_axis(node.get_description(), axis_0, data_rank)); + const auto& scale_shape = scale->get_output_partial_shape(0); + NGRAPH_CHECK(scale_shape.rank().get_length() == 0 || + scale_shape.rank().get_length() == 1, + "Dequantization scale needs to be a scalar or a vector."); + + if (scale_shape.rank().get_length() == 1) + { + const auto& scale_dim = scale_shape[0]; + const auto& x_shape = x->get_output_partial_shape(0); + const auto& x_dim_at_axis = x_shape[axis]; + + NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis), + "The number of dequantization scale elements '", + scale_dim, + "' must match the input shape dimension '", + x_dim_at_axis, + " pointed to by the axis attribute: ", + axis); + } } - if (x->get_element_type() != zero_point->get_element_type()) + void validate_zero_point(const std::shared_ptr zero_point, + const std::shared_ptr x, + const int64_t axis) { - zero_point = std::make_shared( - zero_point, x->get_element_type()); + const auto& zero_point_shape = zero_point->get_output_partial_shape(0); + NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 || + zero_point_shape.rank().get_length() == 1, + "Zero point needs to be a scalar or a vector."); + + if (zero_point_shape.rank().get_length() == 1) + { + const auto& zero_point_dim = zero_point_shape[0]; + const auto& x_shape = x->get_output_partial_shape(0); + const auto& x_dim_at_axis = x_shape[axis]; + + NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis), + "The number of zero point elements '", + zero_point_dim, + "' must match the input shape dimension '", + x_dim_at_axis, + " pointed to by the axis attribute: ", + axis); + } } - return {std::make_shared( - x, x_scale, zero_point, x_scale->get_element_type(), axes)}; + std::shared_ptr + reshape_input(const std::shared_ptr input, + const int64_t axis, + const PartialShape& x_shape) + { + std::vector target_dims; + + for (size_t i = 0; i < axis; ++i) + { + target_dims.push_back(1); + } + + // copy dimension at axis from input X + if (x_shape[axis].is_static()) + { + target_dims.push_back(x_shape[axis].get_length()); + } + else + { + target_dims.push_back(0); + } + + for (size_t i = axis + 1; i < x_shape.rank().get_length(); ++i) + { + target_dims.push_back(1); + } + + const auto target_shape = default_opset::Constant::create( + element::i64, Shape{target_dims.size()}, target_dims); + + return std::make_shared(input, target_shape, true); + } } - } // namespace set_1 + NodeVector dequantize_linear(const Node& node) + { + const NodeVector inputs{node.get_ng_inputs()}; + + NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3, + "The DequantizeLinear op expects 2 required and one optional " + "input. Got: ", + inputs.size()); + + const auto x = inputs[0]; + auto scale = inputs[1]; + auto zero_point = get_zero_point(inputs); + + const auto x_shape = x->get_output_partial_shape(0); - } // namespace op + NGRAPH_CHECK(x_shape.rank().is_static(), + "Rank of the input data tensor has to be known (static)."); - } // namespace onnx_import + int64_t axis{node.get_attribute_value("axis", 1)}; + axis = ngraph::normalize_axis(node.get_description(), axis, x_shape.rank()); -} // namespace ngraph + validate_scale(scale, x, axis); + validate_zero_point(zero_point, x, axis); + + // these reshapes make sure that dequantization happens over the specified axis + scale = reshape_input(scale, axis, x_shape); + zero_point = reshape_input(zero_point, axis, x_shape); + + const auto converted_x = + std::make_shared(x, element::f32); + + return {std::make_shared( + std::make_shared(converted_x, zero_point), scale)}; + } + } + } + } +} diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp index 41dd150..dc67093 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp @@ -31,6 +31,11 @@ namespace ngraph } // namespace set_1 + namespace set_13 + { + NodeVector dequantize_linear(const Node& node); + } + } // namespace op } // namespace onnx_import diff --git a/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp b/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp index 6c7c4c0..1fee117 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp @@ -48,7 +48,7 @@ #include "op/cosh.hpp" #include "op/cum_sum.hpp" #include "op/depth_to_space.hpp" -// #include "op/dequantize_linear.hpp" +#include "op/dequantize_linear.hpp" #include "op/div.hpp" #include "op/dropout.hpp" #include "op/elu.hpp" @@ -278,7 +278,8 @@ namespace ngraph REGISTER_OPERATOR("Cosh", 1, cosh); REGISTER_OPERATOR("CumSum", 1, cum_sum); REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space); - // REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear); + REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear); + REGISTER_OPERATOR("DequantizeLinear", 13, dequantize_linear); REGISTER_OPERATOR("Div", 1, div); REGISTER_OPERATOR("Div", 7, div); REGISTER_OPERATOR("Dropout", 1, dropout); diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp index a803f65..ffe432d 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp @@ -67,6 +67,26 @@ namespace ngraph default_opset::Constant::create(element::i64, {}, {step})); } + void validate_scalar_input(const char* input_name, + const std::shared_ptr input, + const std::set allowed_types) + { + const auto validated_input_rank = input->get_output_partial_shape(0).rank(); + + NGRAPH_CHECK( + validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar."); + + if (!allowed_types.empty()) + { + const bool data_type_ok = allowed_types.count(input->get_element_type()); + NGRAPH_CHECK(data_type_ok, + "Incorrect data type of the ", + input_name, + " input: ", + input->get_element_type()); + } + } + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp b/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp index 2aaa14e..69af5b5 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp @@ -127,6 +127,16 @@ namespace ngraph return shifted_square_identity(Shape{n, n}, type, 0); } + /// \brief Performs validation of an input that is expected to be a scalar. + /// \note This function throws an exception if any of the validation steps fails. + /// + /// \param[in] input_name A human-readable name of an input (used for logging) + /// \param[in] input An input node to be validated + /// \param[in] allowed_types An optional set of allowed element types for this input + void validate_scalar_input(const char* input_name, + const std::shared_ptr input, + const std::set allowed_types = {}); + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/test/models/onnx/dequantize_linear_2.prototxt b/ngraph/test/models/onnx/dequantize_linear_2.prototxt index d5ffdda..4629456 100644 --- a/ngraph/test/models/onnx/dequantize_linear_2.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_2.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_3.prototxt b/ngraph/test/models/onnx/dequantize_linear_3.prototxt index 1c01ba3..112312f 100644 --- a/ngraph/test/models/onnx/dequantize_linear_3.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_3.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_4.prototxt b/ngraph/test/models/onnx/dequantize_linear_4.prototxt index 2f081f5..422046d 100644 --- a/ngraph/test/models/onnx/dequantize_linear_4.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_4.prototxt @@ -87,5 +87,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_5.prototxt b/ngraph/test/models/onnx/dequantize_linear_5.prototxt index 0b5ee82..2d9466c 100644 --- a/ngraph/test/models/onnx/dequantize_linear_5.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_5.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/onnx/onnx_import_quant.in.cpp b/ngraph/test/onnx/onnx_import_quant.in.cpp index 02f2b86..3bd5087 100644 --- a/ngraph/test/onnx/onnx_import_quant.in.cpp +++ b/ngraph/test/onnx/onnx_import_quant.in.cpp @@ -156,7 +156,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_2.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f}); // scale @@ -174,7 +174,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_3.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 2, 4, 6, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f, 8.0f}); // scale @@ -192,7 +192,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8_4d) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_4.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{7, 9, 10, 10, 5, 8, 9, 1, 8, 6, 7, 9, 10, 0, 7, 10, 8, 2, 6, 0, 5, 9, 8, 1, 2, 7, 5, 3, 2, 4, 1, 3, 8, 7, @@ -216,7 +216,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8_ne auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_5.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f}); // scale diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index eb905e3..e411c91 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -17,14 +17,12 @@ onnx_model_quantize_linear_zero_point onnx_model_quantize_linear_axis_zero onnx_model_quantize_linear_axis_negative -# Not supported ONNX op: DequantizeLinear -onnx_model_dequantize_linear -onnx_model_dequantize_linear_scalar_zero_scale_uint8 -onnx_model_dequantize_linear_scalar_zero_scale_int8 -onnx_model_dequantize_linear_1d_zero_scale_uint8 -onnx_model_dequantize_linear_1d_zero_scale_int8 -onnx_model_dequantize_linear_1d_zero_scale_int8_4d -onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis +# DequantizeLinear: +# C++ exception with description "Unsupported precisions! +IE_CPU.onnx_model_dequantize_linear_scalar_zero_scale_int8 +IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8 +# C++ exception with description "Input data precision not supported. Expected float. +IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8_4d # Not supported ONNX op: QLinearConv onnx_model_quant_conv_linear diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index 643b0aa..b836584 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -91,13 +91,6 @@ INTERPRETER.onnx_model_quantize_linear INTERPRETER.onnx_model_quantize_linear_zero_point INTERPRETER.onnx_model_quantize_linear_axis_zero INTERPRETER.onnx_model_quantize_linear_axis_negative -INTERPRETER.onnx_model_dequantize_linear -INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8 -INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis INTERPRETER.onnx_model_quant_conv_linear_2d INTERPRETER.onnx_model_quant_conv_linear_3d INTERPRETER.onnx_model_conv_integer -- 2.7.4