Add support for custom ONNX GroupNorm operator (#2267)
authorMateusz Tabaka <mateusz.tabaka@intel.com>
Wed, 30 Sep 2020 14:17:15 +0000 (16:17 +0200)
committerGitHub <noreply@github.com>
Wed, 30 Sep 2020 14:17:15 +0000 (16:17 +0200)
ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp [new file with mode: 0644]
ngraph/frontend/onnx_import/include/onnx_import/utils/reshape.hpp
ngraph/frontend/onnx_import/src/onnx.cpp
ngraph/frontend/onnx_import/src/op/conv.cpp
ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp [new file with mode: 0644]
ngraph/frontend/onnx_import/src/ops_bridge.cpp
ngraph/frontend/onnx_import/src/utils/reshape.cpp
ngraph/test/models/onnx/group_norm.prototxt [new file with mode: 0644]
ngraph/test/onnx/onnx_import.in.cpp

diff --git a/ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp b/ngraph/frontend/onnx_import/include/onnx_import/op/org.openvinotoolkit/group_norm.hpp
new file mode 100644 (file)
index 0000000..60e6a8b
--- /dev/null
@@ -0,0 +1,40 @@
+//*****************************************************************************
+// Copyright 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "onnx_import/core/node.hpp"
+
+namespace ngraph
+{
+    namespace onnx_import
+    {
+        namespace op
+        {
+            namespace set_1
+            {
+                OutputVector group_norm(const Node& node);
+
+            } // namespace set_1
+
+        } // namespace op
+
+    } // namespace onnx_import
+
+} // namespace ngraph
+
+// namespace ngraph
index 242395f..3b3b4ec 100644 (file)
@@ -61,6 +61,22 @@ namespace ngraph
             ///
             Output<ngraph::Node> interpret_as_scalar(const Output<ngraph::Node>& node);
 
+            /// \brief      Reshape node from shape {C} to {1, C, 1, 1,...}
+            ///
+            /// \note       This function will reshape the input node
+            ///             with a shape of {C} into a node with Shape{1, C, 1, 1, ..}.
+            ///             The most common input to this function would be scale or bias to
+            ///             BatchNorm or bias to Conv.
+            ///
+            /// \param[in]  node            Node to reshape.
+            /// \param[in]  expected_rank   Expected rank size
+            ///
+            /// \return     Original node or a node representing a reshape of the original.
+            ///
+            Output<ngraph::Node>
+                reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
+                                                    size_t expected_rank);
+
         } // namespace  reshape
     }     // namespace onnx_import
 } // namespace ngraph
index b3e8b7e..28333fe 100644 (file)
@@ -75,7 +75,7 @@ namespace ngraph
             } // namespace error
 
             static const std::vector<std::string> legacy_ops_to_fixup = {
-                "FakeQuantize", "DetectionOutput", "Normalize", "PriorBox"};
+                "DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
 
             // There are some models with custom OPs (list above) that has the default domain set.
             // So in order to load the models, we need overwrite the OPs' domain to the one they're
index 4fd535c..be8937b 100644 (file)
@@ -26,6 +26,7 @@
 #include "onnx_import/default_opset.hpp"
 #include "onnx_import/exceptions.hpp"
 #include "onnx_import/utils/convpool.hpp"
+#include "onnx_import/utils/reshape.hpp"
 
 namespace ngraph
 {
@@ -82,20 +83,9 @@ namespace ngraph
                     {
                         const auto rank_of_conv = ng_conv.get_partial_shape().rank().get_length();
 
-                        // reshape the bias node {M} to {1, M, 1, 1, ..., 1}
-                        // this is required by the addition operation that needs to be able
-                        // to broadcast the bias to match the shape of the convolution node
-                        std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
-                        reshape_pattern_values[1] = bias.get_shape().front();
-                        const auto reshape_pattern =
-                            default_opset::Constant::create(element::u64,
-                                                            Shape{reshape_pattern_values.size()},
-                                                            reshape_pattern_values);
-
-                        std::shared_ptr<ngraph::Node> reshaped_bias =
-                            std::make_shared<default_opset::Reshape>(bias, reshape_pattern, false);
-
-                        return {std::make_shared<default_opset::Add>(ng_conv, reshaped_bias)};
+                        return {std::make_shared<default_opset::Add>(
+                            ng_conv,
+                            reshape::reshape_channel_shaped_node_to_nchw(bias, rank_of_conv))};
                     }
                 } // namespace
 
diff --git a/ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp b/ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp
new file mode 100644 (file)
index 0000000..275bac9
--- /dev/null
@@ -0,0 +1,148 @@
+//*****************************************************************************
+// Copyright 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp"
+#include "ngraph/builder/reduce_ops.hpp"
+#include "ngraph/builder/split.hpp"
+#include "ngraph/node.hpp"
+#include "onnx_import/core/node.hpp"
+#include "onnx_import/default_opset.hpp"
+#include "onnx_import/utils/common.hpp"
+#include "onnx_import/utils/reshape.hpp"
+
+namespace ngraph
+{
+    namespace onnx_import
+    {
+        namespace op
+        {
+            namespace detail
+            {
+                namespace
+                {
+                    // This function creates a shape to which we need to reshape the input
+                    // before normalization.
+                    // If data shape is [N,C,H,W], the function returns
+                    // [N, num_groups, C // num_groups, H, W]
+                    std::shared_ptr<ngraph::Node>
+                        create_group_norm_shape(const Output<ngraph::Node>& data, size_t num_groups)
+                    {
+                        const auto& pshape = data.get_partial_shape();
+                        NGRAPH_CHECK(pshape.rank().is_static());
+                        size_t rank_size = pshape.rank().get_length();
+                        NGRAPH_CHECK(rank_size >= 3, "3-D and above tensors supported only");
+
+                        if (pshape.is_static())
+                        {
+                            const auto& shape = pshape.to_shape();
+                            std::vector<size_t> new_shape{
+                                shape[0], num_groups, shape[1] / num_groups};
+                            for (size_t i = 2; i < rank_size; i++)
+                            {
+                                new_shape.push_back(shape[i]);
+                            }
+                            return default_opset::Constant::create(
+                                element::i64, Shape{new_shape.size()}, new_shape);
+                        }
+
+                        auto shape = std::make_shared<default_opset::ShapeOf>(data);
+                        auto splits = builder::opset1::split(shape, rank_size);
+                        auto num_groups_const =
+                            default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
+                        NodeVector new_shape{
+                            splits[0].get_node_shared_ptr(),
+                            num_groups_const,
+                            std::make_shared<default_opset::Divide>(splits[1], num_groups_const)};
+                        for (size_t i = 2; i < rank_size; i++)
+                        {
+                            new_shape.push_back(splits[i].get_node_shared_ptr());
+                        }
+                        return std::make_shared<default_opset::Concat>(new_shape, 0);
+                    }
+                }
+            } // detail
+
+            namespace set_1
+            {
+                OutputVector group_norm(const Node& node)
+                {
+                    auto inputs = node.get_ng_inputs();
+                    NGRAPH_CHECK(inputs.size() == 3,
+                                 "Invalid number of inputs. Expected 3, actual " +
+                                     std::to_string(inputs.size()));
+
+                    auto data = inputs[0];
+                    auto scale = inputs[1];
+                    auto bias = inputs[2];
+
+                    size_t num_groups =
+                        static_cast<size_t>(node.get_attribute_value<int64_t>("num_groups"));
+                    float eps = node.get_attribute_value<float>("eps", 1e-5);
+
+                    auto data_pshape = data.get_partial_shape();
+                    std::shared_ptr<ngraph::Node> data_shape_node;
+                    if (data_pshape.is_static())
+                    {
+                        auto shape = data_pshape.to_shape();
+                        data_shape_node = default_opset::Constant::create(
+                            element::u64, Shape{shape.size()}, shape);
+                    }
+                    else
+                    {
+                        data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
+                    }
+                    auto data_reshaped = std::make_shared<default_opset::Reshape>(
+                        data, detail::create_group_norm_shape(data, num_groups), true);
+                    const auto reduction_axes =
+                        common::get_monotonic_range_along_node_rank(data_reshaped, 2);
+                    auto mean = std::make_shared<default_opset::ReduceMean>(
+                        data_reshaped, reduction_axes, true);
+                    auto diff = std::make_shared<default_opset::Subtract>(data_reshaped, mean);
+                    auto variance = std::make_shared<default_opset::ReduceMean>(
+                        std::make_shared<default_opset::Power>(
+                            diff, default_opset::Constant::create(element::f32, Shape{}, {2})),
+                        reduction_axes,
+                        true);
+
+                    const std::shared_ptr<ngraph::Node> eps_node =
+                        std::make_shared<default_opset::Constant>(element::f32, Shape{}, eps);
+                    const auto sqrt = std::make_shared<default_opset::Sqrt>(
+                        std::make_shared<default_opset::Add>(variance, eps_node));
+
+                    const auto& rank = data.get_partial_shape().rank();
+                    NGRAPH_CHECK(rank.is_static());
+                    auto data_rank_size = rank.get_length();
+
+                    std::shared_ptr<ngraph::Node> result =
+                        std::make_shared<default_opset::Divide>(diff, sqrt);
+                    result =
+                        std::make_shared<default_opset::Reshape>(result, data_shape_node, true);
+                    result = std::make_shared<default_opset::Multiply>(
+                        reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size),
+                        result);
+                    result = std::make_shared<default_opset::Add>(
+                        result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
+
+                    return {result};
+                }
+
+            } // namespace set_1
+
+        } // namespace op
+
+    } // namespace onnx_import
+
+} // namespace ngraph
index 8d127b1..fb5216a 100644 (file)
 
 #include "onnx_import/op/org.openvinotoolkit/detection_output.hpp"
 #include "onnx_import/op/org.openvinotoolkit/fake_quantize.hpp"
+#include "onnx_import/op/org.openvinotoolkit/group_norm.hpp"
 #include "onnx_import/op/org.openvinotoolkit/normalize.hpp"
 #include "onnx_import/op/org.openvinotoolkit/prior_box.hpp"
 
@@ -406,11 +407,12 @@ namespace ngraph
             REGISTER_OPERATOR("Xor", 1, logical_xor);
 
             // custom OPs
-            REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize);
             REGISTER_OPERATOR_WITH_DOMAIN(
                 OPENVINO_ONNX_DOMAIN, "DetectionOutput", 1, detection_output);
-            REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
+            REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize);
+            REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "GroupNorm", 1, group_norm);
             REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Normalize", 1, normalize);
+            REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
         }
 
 #undef REGISTER_OPERATOR
index f8a59b8..ec8d963 100644 (file)
@@ -114,6 +114,25 @@ namespace ngraph
                 return builder::opset1::reshape(node, Shape{});
             }
 
+            Output<ngraph::Node>
+                reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
+                                                    size_t expected_rank)
+            {
+                const auto& rank = node.get_partial_shape().rank();
+                NGRAPH_CHECK(rank.is_static());
+                size_t node_rank = rank.get_length();
+                if (node_rank == 1)
+                {
+                    // reshape the node with shape {C} to {1, C, 1, 1, ..., 1}
+                    std::vector<size_t> reshape_pattern_values(expected_rank, 1U);
+                    reshape_pattern_values[1] = node.get_shape().front();
+                    const auto reshape_pattern = default_opset::Constant::create(
+                        element::u64, Shape{reshape_pattern_values.size()}, reshape_pattern_values);
+                    return std::make_shared<default_opset::Reshape>(node, reshape_pattern, false);
+                }
+                return node;
+            }
+
         } // namespace  reshape
     }     // namespace onnx_import
 } // namespace ngraph
diff --git a/ngraph/test/models/onnx/group_norm.prototxt b/ngraph/test/models/onnx/group_norm.prototxt
new file mode 100644 (file)
index 0000000..e5f43cd
--- /dev/null
@@ -0,0 +1,108 @@
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "data"
+    input: "gamma"
+    input: "beta"
+    output: "y"
+    op_type: "GroupNorm"
+    domain: "org.openvinotoolkit"
+    attribute {
+        name: "num_groups"
+        i: 4
+        type: INT
+    }
+    attribute {
+        name: "eps"
+        f: 1e-6
+        type: FLOAT
+    }
+  }
+  name: "group_norm_example"
+  initializer {
+    dims: 8
+    data_type: 1
+    name: "gamma"
+    raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
+  }
+  initializer {
+    dims: 8
+    data_type: 1
+    name: "beta"
+    raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
+  }
+  input {
+    name: "data"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 8
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "gamma"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "beta"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 8
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 1
+}
index 774c42e..38686bd 100644 (file)
@@ -2618,3 +2618,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_normalize)
     test_case.add_expected_output<float>(Shape{1, 3, 2, 2}, output);
     test_case.run();
 }
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/group_norm.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+    Shape shape{2, 8, 2, 2};
+    int size = shape_size(shape);
+    std::vector<float> data(size);
+    std::iota(data.begin(), data.end(), 0);
+    std::vector<float> output = {
+        -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307,  4.1821785, 5.05505,
+        -1.5825753,  -0.27326822, 1.0360391, 2.3453465,  4.8728714, 6.618614,  8.364357,  10.1101,
+        -2.6376252,  -0.45544672, 1.726732,  3.9089108,  7.309307,  9.927921,  12.546536, 15.165151,
+        -3.6926756,  -0.6376257,  2.4174247, 5.472475,   9.745743,  13.237228, 16.728714, 20.2202,
+        -0.52752507, -0.09108937, 0.3453464, 0.78178215, 2.4364357, 3.309307,  4.1821785, 5.05505,
+        -1.5825753,  -0.27326822, 1.0360391, 2.3453465,  4.8728714, 6.618614,  8.364357,  10.1101,
+        -2.6376252,  -0.45544672, 1.726732,  3.9089108,  7.309307,  9.927921,  12.546536, 15.165151,
+        -3.6926756,  -0.6376257,  2.4174247, 5.472475,   9.745743,  13.237228, 16.728714, 20.2202,
+    };
+
+    test_case.add_input<float>(data);
+    test_case.add_expected_output<float>(shape, output);
+    test_case.run();
+}