--- /dev/null
+//*****************************************************************************
+// Copyright 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "onnx_import/core/node.hpp"
+
+namespace ngraph
+{
+ namespace onnx_import
+ {
+ namespace op
+ {
+ namespace set_1
+ {
+ OutputVector group_norm(const Node& node);
+
+ } // namespace set_1
+
+ } // namespace op
+
+ } // namespace onnx_import
+
+} // namespace ngraph
+
+// namespace ngraph
///
Output<ngraph::Node> interpret_as_scalar(const Output<ngraph::Node>& node);
+ /// \brief Reshape node from shape {C} to {1, C, 1, 1,...}
+ ///
+ /// \note This function will reshape the input node
+ /// with a shape of {C} into a node with Shape{1, C, 1, 1, ..}.
+ /// The most common input to this function would be scale or bias to
+ /// BatchNorm or bias to Conv.
+ ///
+ /// \param[in] node Node to reshape.
+ /// \param[in] expected_rank Expected rank size
+ ///
+ /// \return Original node or a node representing a reshape of the original.
+ ///
+ Output<ngraph::Node>
+ reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
+ size_t expected_rank);
+
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
} // namespace error
static const std::vector<std::string> legacy_ops_to_fixup = {
- "FakeQuantize", "DetectionOutput", "Normalize", "PriorBox"};
+ "DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
// There are some models with custom OPs (list above) that has the default domain set.
// So in order to load the models, we need overwrite the OPs' domain to the one they're
#include "onnx_import/default_opset.hpp"
#include "onnx_import/exceptions.hpp"
#include "onnx_import/utils/convpool.hpp"
+#include "onnx_import/utils/reshape.hpp"
namespace ngraph
{
{
const auto rank_of_conv = ng_conv.get_partial_shape().rank().get_length();
- // reshape the bias node {M} to {1, M, 1, 1, ..., 1}
- // this is required by the addition operation that needs to be able
- // to broadcast the bias to match the shape of the convolution node
- std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
- reshape_pattern_values[1] = bias.get_shape().front();
- const auto reshape_pattern =
- default_opset::Constant::create(element::u64,
- Shape{reshape_pattern_values.size()},
- reshape_pattern_values);
-
- std::shared_ptr<ngraph::Node> reshaped_bias =
- std::make_shared<default_opset::Reshape>(bias, reshape_pattern, false);
-
- return {std::make_shared<default_opset::Add>(ng_conv, reshaped_bias)};
+ return {std::make_shared<default_opset::Add>(
+ ng_conv,
+ reshape::reshape_channel_shaped_node_to_nchw(bias, rank_of_conv))};
}
} // namespace
--- /dev/null
+//*****************************************************************************
+// Copyright 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp"
+#include "ngraph/builder/reduce_ops.hpp"
+#include "ngraph/builder/split.hpp"
+#include "ngraph/node.hpp"
+#include "onnx_import/core/node.hpp"
+#include "onnx_import/default_opset.hpp"
+#include "onnx_import/utils/common.hpp"
+#include "onnx_import/utils/reshape.hpp"
+
+namespace ngraph
+{
+ namespace onnx_import
+ {
+ namespace op
+ {
+ namespace detail
+ {
+ namespace
+ {
+ // This function creates a shape to which we need to reshape the input
+ // before normalization.
+ // If data shape is [N,C,H,W], the function returns
+ // [N, num_groups, C // num_groups, H, W]
+ std::shared_ptr<ngraph::Node>
+ create_group_norm_shape(const Output<ngraph::Node>& data, size_t num_groups)
+ {
+ const auto& pshape = data.get_partial_shape();
+ NGRAPH_CHECK(pshape.rank().is_static());
+ size_t rank_size = pshape.rank().get_length();
+ NGRAPH_CHECK(rank_size >= 3, "3-D and above tensors supported only");
+
+ if (pshape.is_static())
+ {
+ const auto& shape = pshape.to_shape();
+ std::vector<size_t> new_shape{
+ shape[0], num_groups, shape[1] / num_groups};
+ for (size_t i = 2; i < rank_size; i++)
+ {
+ new_shape.push_back(shape[i]);
+ }
+ return default_opset::Constant::create(
+ element::i64, Shape{new_shape.size()}, new_shape);
+ }
+
+ auto shape = std::make_shared<default_opset::ShapeOf>(data);
+ auto splits = builder::opset1::split(shape, rank_size);
+ auto num_groups_const =
+ default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
+ NodeVector new_shape{
+ splits[0].get_node_shared_ptr(),
+ num_groups_const,
+ std::make_shared<default_opset::Divide>(splits[1], num_groups_const)};
+ for (size_t i = 2; i < rank_size; i++)
+ {
+ new_shape.push_back(splits[i].get_node_shared_ptr());
+ }
+ return std::make_shared<default_opset::Concat>(new_shape, 0);
+ }
+ }
+ } // detail
+
+ namespace set_1
+ {
+ OutputVector group_norm(const Node& node)
+ {
+ auto inputs = node.get_ng_inputs();
+ NGRAPH_CHECK(inputs.size() == 3,
+ "Invalid number of inputs. Expected 3, actual " +
+ std::to_string(inputs.size()));
+
+ auto data = inputs[0];
+ auto scale = inputs[1];
+ auto bias = inputs[2];
+
+ size_t num_groups =
+ static_cast<size_t>(node.get_attribute_value<int64_t>("num_groups"));
+ float eps = node.get_attribute_value<float>("eps", 1e-5);
+
+ auto data_pshape = data.get_partial_shape();
+ std::shared_ptr<ngraph::Node> data_shape_node;
+ if (data_pshape.is_static())
+ {
+ auto shape = data_pshape.to_shape();
+ data_shape_node = default_opset::Constant::create(
+ element::u64, Shape{shape.size()}, shape);
+ }
+ else
+ {
+ data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
+ }
+ auto data_reshaped = std::make_shared<default_opset::Reshape>(
+ data, detail::create_group_norm_shape(data, num_groups), true);
+ const auto reduction_axes =
+ common::get_monotonic_range_along_node_rank(data_reshaped, 2);
+ auto mean = std::make_shared<default_opset::ReduceMean>(
+ data_reshaped, reduction_axes, true);
+ auto diff = std::make_shared<default_opset::Subtract>(data_reshaped, mean);
+ auto variance = std::make_shared<default_opset::ReduceMean>(
+ std::make_shared<default_opset::Power>(
+ diff, default_opset::Constant::create(element::f32, Shape{}, {2})),
+ reduction_axes,
+ true);
+
+ const std::shared_ptr<ngraph::Node> eps_node =
+ std::make_shared<default_opset::Constant>(element::f32, Shape{}, eps);
+ const auto sqrt = std::make_shared<default_opset::Sqrt>(
+ std::make_shared<default_opset::Add>(variance, eps_node));
+
+ const auto& rank = data.get_partial_shape().rank();
+ NGRAPH_CHECK(rank.is_static());
+ auto data_rank_size = rank.get_length();
+
+ std::shared_ptr<ngraph::Node> result =
+ std::make_shared<default_opset::Divide>(diff, sqrt);
+ result =
+ std::make_shared<default_opset::Reshape>(result, data_shape_node, true);
+ result = std::make_shared<default_opset::Multiply>(
+ reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size),
+ result);
+ result = std::make_shared<default_opset::Add>(
+ result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
+
+ return {result};
+ }
+
+ } // namespace set_1
+
+ } // namespace op
+
+ } // namespace onnx_import
+
+} // namespace ngraph
#include "onnx_import/op/org.openvinotoolkit/detection_output.hpp"
#include "onnx_import/op/org.openvinotoolkit/fake_quantize.hpp"
+#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp"
#include "onnx_import/op/org.openvinotoolkit/normalize.hpp"
#include "onnx_import/op/org.openvinotoolkit/prior_box.hpp"
REGISTER_OPERATOR("Xor", 1, logical_xor);
// custom OPs
- REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize);
REGISTER_OPERATOR_WITH_DOMAIN(
OPENVINO_ONNX_DOMAIN, "DetectionOutput", 1, detection_output);
- REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
+ REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize);
+ REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "GroupNorm", 1, group_norm);
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Normalize", 1, normalize);
+ REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
}
#undef REGISTER_OPERATOR
return builder::opset1::reshape(node, Shape{});
}
+ Output<ngraph::Node>
+ reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
+ size_t expected_rank)
+ {
+ const auto& rank = node.get_partial_shape().rank();
+ NGRAPH_CHECK(rank.is_static());
+ size_t node_rank = rank.get_length();
+ if (node_rank == 1)
+ {
+ // reshape the node with shape {C} to {1, C, 1, 1, ..., 1}
+ std::vector<size_t> reshape_pattern_values(expected_rank, 1U);
+ reshape_pattern_values[1] = node.get_shape().front();
+ const auto reshape_pattern = default_opset::Constant::create(
+ element::u64, Shape{reshape_pattern_values.size()}, reshape_pattern_values);
+ return std::make_shared<default_opset::Reshape>(node, reshape_pattern, false);
+ }
+ return node;
+ }
+
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
--- /dev/null
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+ node {
+ input: "data"
+ input: "gamma"
+ input: "beta"
+ output: "y"
+ op_type: "GroupNorm"
+ domain: "org.openvinotoolkit"
+ attribute {
+ name: "num_groups"
+ i: 4
+ type: INT
+ }
+ attribute {
+ name: "eps"
+ f: 1e-6
+ type: FLOAT
+ }
+ }
+ name: "group_norm_example"
+ initializer {
+ dims: 8
+ data_type: 1
+ name: "gamma"
+ raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
+ }
+ initializer {
+ dims: 8
+ data_type: 1
+ name: "beta"
+ raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
+ }
+ input {
+ name: "data"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 8
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "gamma"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "beta"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "y"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 8
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 1
+}
test_case.add_expected_output<float>(Shape{1, 3, 2, 2}, output);
test_case.run();
}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
+{
+ const auto function = onnx_import::import_onnx_model(
+ file_util::path_join(SERIALIZED_ZOO, "onnx/group_norm.prototxt"));
+ auto test_case = test::TestCase<TestEngine>(function);
+ Shape shape{2, 8, 2, 2};
+ int size = shape_size(shape);
+ std::vector<float> data(size);
+ std::iota(data.begin(), data.end(), 0);
+ std::vector<float> output = {
+ -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307, 4.1821785, 5.05505,
+ -1.5825753, -0.27326822, 1.0360391, 2.3453465, 4.8728714, 6.618614, 8.364357, 10.1101,
+ -2.6376252, -0.45544672, 1.726732, 3.9089108, 7.309307, 9.927921, 12.546536, 15.165151,
+ -3.6926756, -0.6376257, 2.4174247, 5.472475, 9.745743, 13.237228, 16.728714, 20.2202,
+ -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307, 4.1821785, 5.05505,
+ -1.5825753, -0.27326822, 1.0360391, 2.3453465, 4.8728714, 6.618614, 8.364357, 10.1101,
+ -2.6376252, -0.45544672, 1.726732, 3.9089108, 7.309307, 9.927921, 12.546536, 15.165151,
+ -3.6926756, -0.6376257, 2.4174247, 5.472475, 9.745743, 13.237228, 16.728714, 20.2202,
+ };
+
+ test_case.add_input<float>(data);
+ test_case.add_expected_output<float>(shape, output);
+ test_case.run();
+}