From fd80873fcabad6e2eaf9540aa8c6676e6e09e962 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Wed, 30 Sep 2020 16:17:15 +0200 Subject: [PATCH] Add support for custom ONNX GroupNorm operator (#2267) --- .../op/org.openvinotoolkit/group_norm.hpp | 40 ++++++ .../include/onnx_import/utils/reshape.hpp | 16 +++ ngraph/frontend/onnx_import/src/onnx.cpp | 2 +- ngraph/frontend/onnx_import/src/op/conv.cpp | 18 +-- .../src/op/org.openvinotoolkit/group_norm.cpp | 148 +++++++++++++++++++++ ngraph/frontend/onnx_import/src/ops_bridge.cpp | 6 +- ngraph/frontend/onnx_import/src/utils/reshape.cpp | 19 +++ ngraph/test/models/onnx/group_norm.prototxt | 108 +++++++++++++++ ngraph/test/onnx/onnx_import.in.cpp | 25 ++++ 9 files changed, 365 insertions(+), 17 deletions(-) create mode 100644 ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp create mode 100644 ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp create mode 100644 ngraph/test/models/onnx/group_norm.prototxt diff --git a/ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp b/ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp new file mode 100644 index 0000000..60e6a8b --- /dev/null +++ b/ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp @@ -0,0 +1,40 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include "ngraph/node.hpp" +#include "onnx_import/core/node.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace op + { + namespace set_1 + { + OutputVector group_norm(const Node& node); + + } // namespace set_1 + + } // namespace op + + } // namespace onnx_import + +} // namespace ngraph + +// namespace ngraph diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/reshape.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/reshape.hpp index 242395f..3b3b4ec 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/utils/reshape.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/reshape.hpp @@ -61,6 +61,22 @@ namespace ngraph /// Output interpret_as_scalar(const Output& node); + /// \brief Reshape node from shape {C} to {1, C, 1, 1,...} + /// + /// \note This function will reshape the input node + /// with a shape of {C} into a node with Shape{1, C, 1, 1, ..}. + /// The most common input to this function would be scale or bias to + /// BatchNorm or bias to Conv. + /// + /// \param[in] node Node to reshape. + /// \param[in] expected_rank Expected rank size + /// + /// \return Original node or a node representing a reshape of the original. + /// + Output + reshape_channel_shaped_node_to_nchw(const Output& node, + size_t expected_rank); + } // namespace reshape } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/onnx.cpp b/ngraph/frontend/onnx_import/src/onnx.cpp index b3e8b7e..28333fe 100644 --- a/ngraph/frontend/onnx_import/src/onnx.cpp +++ b/ngraph/frontend/onnx_import/src/onnx.cpp @@ -75,7 +75,7 @@ namespace ngraph } // namespace error static const std::vector legacy_ops_to_fixup = { - "FakeQuantize", "DetectionOutput", "Normalize", "PriorBox"}; + "DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"}; // There are some models with custom OPs (list above) that has the default domain set. // So in order to load the models, we need overwrite the OPs' domain to the one they're diff --git a/ngraph/frontend/onnx_import/src/op/conv.cpp b/ngraph/frontend/onnx_import/src/op/conv.cpp index 4fd535c..be8937b 100644 --- a/ngraph/frontend/onnx_import/src/op/conv.cpp +++ b/ngraph/frontend/onnx_import/src/op/conv.cpp @@ -26,6 +26,7 @@ #include "onnx_import/default_opset.hpp" #include "onnx_import/exceptions.hpp" #include "onnx_import/utils/convpool.hpp" +#include "onnx_import/utils/reshape.hpp" namespace ngraph { @@ -82,20 +83,9 @@ namespace ngraph { const auto rank_of_conv = ng_conv.get_partial_shape().rank().get_length(); - // reshape the bias node {M} to {1, M, 1, 1, ..., 1} - // this is required by the addition operation that needs to be able - // to broadcast the bias to match the shape of the convolution node - std::vector reshape_pattern_values(rank_of_conv, 1U); - reshape_pattern_values[1] = bias.get_shape().front(); - const auto reshape_pattern = - default_opset::Constant::create(element::u64, - Shape{reshape_pattern_values.size()}, - reshape_pattern_values); - - std::shared_ptr reshaped_bias = - std::make_shared(bias, reshape_pattern, false); - - return {std::make_shared(ng_conv, reshaped_bias)}; + return {std::make_shared( + ng_conv, + reshape::reshape_channel_shaped_node_to_nchw(bias, rank_of_conv))}; } } // namespace diff --git a/ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp b/ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp new file mode 100644 index 0000000..275bac9 --- /dev/null +++ b/ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp @@ -0,0 +1,148 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp" +#include "ngraph/builder/reduce_ops.hpp" +#include "ngraph/builder/split.hpp" +#include "ngraph/node.hpp" +#include "onnx_import/core/node.hpp" +#include "onnx_import/default_opset.hpp" +#include "onnx_import/utils/common.hpp" +#include "onnx_import/utils/reshape.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace op + { + namespace detail + { + namespace + { + // This function creates a shape to which we need to reshape the input + // before normalization. + // If data shape is [N,C,H,W], the function returns + // [N, num_groups, C // num_groups, H, W] + std::shared_ptr + create_group_norm_shape(const Output& data, size_t num_groups) + { + const auto& pshape = data.get_partial_shape(); + NGRAPH_CHECK(pshape.rank().is_static()); + size_t rank_size = pshape.rank().get_length(); + NGRAPH_CHECK(rank_size >= 3, "3-D and above tensors supported only"); + + if (pshape.is_static()) + { + const auto& shape = pshape.to_shape(); + std::vector new_shape{ + shape[0], num_groups, shape[1] / num_groups}; + for (size_t i = 2; i < rank_size; i++) + { + new_shape.push_back(shape[i]); + } + return default_opset::Constant::create( + element::i64, Shape{new_shape.size()}, new_shape); + } + + auto shape = std::make_shared(data); + auto splits = builder::opset1::split(shape, rank_size); + auto num_groups_const = + default_opset::Constant::create(element::i64, Shape{1}, {num_groups}); + NodeVector new_shape{ + splits[0].get_node_shared_ptr(), + num_groups_const, + std::make_shared(splits[1], num_groups_const)}; + for (size_t i = 2; i < rank_size; i++) + { + new_shape.push_back(splits[i].get_node_shared_ptr()); + } + return std::make_shared(new_shape, 0); + } + } + } // detail + + namespace set_1 + { + OutputVector group_norm(const Node& node) + { + auto inputs = node.get_ng_inputs(); + NGRAPH_CHECK(inputs.size() == 3, + "Invalid number of inputs. Expected 3, actual " + + std::to_string(inputs.size())); + + auto data = inputs[0]; + auto scale = inputs[1]; + auto bias = inputs[2]; + + size_t num_groups = + static_cast(node.get_attribute_value("num_groups")); + float eps = node.get_attribute_value("eps", 1e-5); + + auto data_pshape = data.get_partial_shape(); + std::shared_ptr data_shape_node; + if (data_pshape.is_static()) + { + auto shape = data_pshape.to_shape(); + data_shape_node = default_opset::Constant::create( + element::u64, Shape{shape.size()}, shape); + } + else + { + data_shape_node = std::make_shared(data); + } + auto data_reshaped = std::make_shared( + data, detail::create_group_norm_shape(data, num_groups), true); + const auto reduction_axes = + common::get_monotonic_range_along_node_rank(data_reshaped, 2); + auto mean = std::make_shared( + data_reshaped, reduction_axes, true); + auto diff = std::make_shared(data_reshaped, mean); + auto variance = std::make_shared( + std::make_shared( + diff, default_opset::Constant::create(element::f32, Shape{}, {2})), + reduction_axes, + true); + + const std::shared_ptr eps_node = + std::make_shared(element::f32, Shape{}, eps); + const auto sqrt = std::make_shared( + std::make_shared(variance, eps_node)); + + const auto& rank = data.get_partial_shape().rank(); + NGRAPH_CHECK(rank.is_static()); + auto data_rank_size = rank.get_length(); + + std::shared_ptr result = + std::make_shared(diff, sqrt); + result = + std::make_shared(result, data_shape_node, true); + result = std::make_shared( + reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size), + result); + result = std::make_shared( + result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size)); + + return {result}; + } + + } // namespace set_1 + + } // namespace op + + } // namespace onnx_import + +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/ops_bridge.cpp b/ngraph/frontend/onnx_import/src/ops_bridge.cpp index 8d127b1..fb5216a 100644 --- a/ngraph/frontend/onnx_import/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx_import/src/ops_bridge.cpp @@ -145,6 +145,7 @@ #include "onnx_import/op/org.openvinotoolkit/detection_output.hpp" #include "onnx_import/op/org.openvinotoolkit/fake_quantize.hpp" +#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp" #include "onnx_import/op/org.openvinotoolkit/normalize.hpp" #include "onnx_import/op/org.openvinotoolkit/prior_box.hpp" @@ -406,11 +407,12 @@ namespace ngraph REGISTER_OPERATOR("Xor", 1, logical_xor); // custom OPs - REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize); REGISTER_OPERATOR_WITH_DOMAIN( OPENVINO_ONNX_DOMAIN, "DetectionOutput", 1, detection_output); - REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box); + REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize); + REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "GroupNorm", 1, group_norm); REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Normalize", 1, normalize); + REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box); } #undef REGISTER_OPERATOR diff --git a/ngraph/frontend/onnx_import/src/utils/reshape.cpp b/ngraph/frontend/onnx_import/src/utils/reshape.cpp index f8a59b8..ec8d963 100644 --- a/ngraph/frontend/onnx_import/src/utils/reshape.cpp +++ b/ngraph/frontend/onnx_import/src/utils/reshape.cpp @@ -114,6 +114,25 @@ namespace ngraph return builder::opset1::reshape(node, Shape{}); } + Output + reshape_channel_shaped_node_to_nchw(const Output& node, + size_t expected_rank) + { + const auto& rank = node.get_partial_shape().rank(); + NGRAPH_CHECK(rank.is_static()); + size_t node_rank = rank.get_length(); + if (node_rank == 1) + { + // reshape the node with shape {C} to {1, C, 1, 1, ..., 1} + std::vector reshape_pattern_values(expected_rank, 1U); + reshape_pattern_values[1] = node.get_shape().front(); + const auto reshape_pattern = default_opset::Constant::create( + element::u64, Shape{reshape_pattern_values.size()}, reshape_pattern_values); + return std::make_shared(node, reshape_pattern, false); + } + return node; + } + } // namespace reshape } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/test/models/onnx/group_norm.prototxt b/ngraph/test/models/onnx/group_norm.prototxt new file mode 100644 index 0000000..e5f43cd --- /dev/null +++ b/ngraph/test/models/onnx/group_norm.prototxt @@ -0,0 +1,108 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + input: "gamma" + input: "beta" + output: "y" + op_type: "GroupNorm" + domain: "org.openvinotoolkit" + attribute { + name: "num_groups" + i: 4 + type: INT + } + attribute { + name: "eps" + f: 1e-6 + type: FLOAT + } + } + name: "group_norm_example" + initializer { + dims: 8 + data_type: 1 + name: "gamma" + raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A" + } + initializer { + dims: 8 + data_type: 1 + name: "beta" + raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A" + } + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "gamma" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "beta" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 1 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 774c42e..38686bd 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -2618,3 +2618,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_normalize) test_case.add_expected_output(Shape{1, 3, 2, 2}, output); test_case.run(); } + +NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm) +{ + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/group_norm.prototxt")); + auto test_case = test::TestCase(function); + Shape shape{2, 8, 2, 2}; + int size = shape_size(shape); + std::vector data(size); + std::iota(data.begin(), data.end(), 0); + std::vector output = { + -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307, 4.1821785, 5.05505, + -1.5825753, -0.27326822, 1.0360391, 2.3453465, 4.8728714, 6.618614, 8.364357, 10.1101, + -2.6376252, -0.45544672, 1.726732, 3.9089108, 7.309307, 9.927921, 12.546536, 15.165151, + -3.6926756, -0.6376257, 2.4174247, 5.472475, 9.745743, 13.237228, 16.728714, 20.2202, + -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307, 4.1821785, 5.05505, + -1.5825753, -0.27326822, 1.0360391, 2.3453465, 4.8728714, 6.618614, 8.364357, 10.1101, + -2.6376252, -0.45544672, 1.726732, 3.9089108, 7.309307, 9.927921, 12.546536, 15.165151, + -3.6926756, -0.6376257, 2.4174247, 5.472475, 9.745743, 13.237228, 16.728714, 20.2202, + }; + + test_case.add_input(data); + test_case.add_expected_output(shape, output); + test_case.run(); +} -- 2.7.4