Use MVN in GroupNorm/InstanceNorm in ONNX importer (#2711)
authorMateusz Tabaka <mateusz.tabaka@intel.com>
Wed, 21 Oct 2020 10:48:53 +0000 (12:48 +0200)
committerGitHub <noreply@github.com>
Wed, 21 Oct 2020 10:48:53 +0000 (13:48 +0300)
* Use MVN in GroupNorm/InstanceNorm in ONNX importer

* Remove mosaic_8 model from xfail list

ngraph/core/src/op/mvn.cpp
ngraph/frontend/onnx_import/src/op/instance_norm.cpp
ngraph/frontend/onnx_import/src/op/org.openvinotoolkit/group_norm.cpp
ngraph/python/tests/test_onnx/test_zoo_models.py
ngraph/test/backend/fused_op.in.cpp
ngraph/test/type_prop/mvn.cpp

index 27c5914..7c7fa2a 100644 (file)
@@ -60,7 +60,6 @@ void op::MVN::validate_and_infer_types()
     if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
     {
         AxisSet reduction_axes;
-        reduction_axes.insert(0);
         size_t start_axis = m_across_channels ? 1 : 2;
         for (size_t i = start_axis; i < input_value(0).get_partial_shape().rank().get_length(); ++i)
         {
@@ -90,11 +89,10 @@ OutputVector op::MVN::decompose_op() const
     {
         // calculate variance
         auto variance = builder::opset1::variance(data, m_reduction_axes);
-        variance = make_shared<op::Sqrt>(variance);
         // add epsilon
         auto eps_node = op::Constant::create(
             data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
-        variance = variance + eps_node;
+        variance = std::make_shared<op::Sqrt>(variance + eps_node);
         variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
 
         return OutputVector{mean_normalization / variance};
index 84e0075..9516ea5 100644 (file)
@@ -93,21 +93,7 @@ namespace ngraph
                     const auto reduction_axes =
                         common::get_monotonic_range_along_node_rank(data, 2);
 
-                    const std::shared_ptr<ngraph::Node> eps_node =
-                        std::make_shared<default_opset::Constant>(
-                            data.get_element_type(), Shape{}, epsilon);
-
-                    auto mean =
-                        std::make_shared<default_opset::ReduceMean>(data, reduction_axes, true);
-                    auto diff = std::make_shared<default_opset::Subtract>(data, mean);
-                    auto variance = std::make_shared<default_opset::ReduceMean>(
-                        std::make_shared<default_opset::Power>(
-                            diff,
-                            default_opset::Constant::create(data.get_element_type(), Shape{}, {2})),
-                        reduction_axes,
-                        true);
-                    const auto sqrt = std::make_shared<default_opset::Sqrt>(
-                        std::make_shared<default_opset::Add>(variance, eps_node));
+                    auto mvn = std::make_shared<default_opset::MVN>(data, false, true, epsilon);
 
                     std::shared_ptr<ngraph::Node> data_shape_node;
                     if (data_pshape.is_static())
@@ -132,10 +118,9 @@ namespace ngraph
                         data_shape_node,
                         std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
 
-                    // scale * (data - mean) / sqrt + bias
+                    // scale * mvn + bias
                     std::shared_ptr<ngraph::Node> result =
-                        std::make_shared<default_opset::Divide>(scale, sqrt);
-                    result = std::make_shared<default_opset::Multiply>(diff, result);
+                        std::make_shared<default_opset::Multiply>(mvn, scale);
                     result = std::make_shared<default_opset::Add>(result, bias);
 
                     return {result};
index 275bac9..bdc0294 100644 (file)
@@ -106,33 +106,19 @@ namespace ngraph
                     }
                     auto data_reshaped = std::make_shared<default_opset::Reshape>(
                         data, detail::create_group_norm_shape(data, num_groups), true);
-                    const auto reduction_axes =
-                        common::get_monotonic_range_along_node_rank(data_reshaped, 2);
-                    auto mean = std::make_shared<default_opset::ReduceMean>(
-                        data_reshaped, reduction_axes, true);
-                    auto diff = std::make_shared<default_opset::Subtract>(data_reshaped, mean);
-                    auto variance = std::make_shared<default_opset::ReduceMean>(
-                        std::make_shared<default_opset::Power>(
-                            diff, default_opset::Constant::create(element::f32, Shape{}, {2})),
-                        reduction_axes,
-                        true);
-
-                    const std::shared_ptr<ngraph::Node> eps_node =
-                        std::make_shared<default_opset::Constant>(element::f32, Shape{}, eps);
-                    const auto sqrt = std::make_shared<default_opset::Sqrt>(
-                        std::make_shared<default_opset::Add>(variance, eps_node));
+
+                    auto mvn =
+                        std::make_shared<default_opset::MVN>(data_reshaped, false, true, eps);
+                    std::shared_ptr<ngraph::Node> result =
+                        std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);
 
                     const auto& rank = data.get_partial_shape().rank();
                     NGRAPH_CHECK(rank.is_static());
                     auto data_rank_size = rank.get_length();
 
-                    std::shared_ptr<ngraph::Node> result =
-                        std::make_shared<default_opset::Divide>(diff, sqrt);
-                    result =
-                        std::make_shared<default_opset::Reshape>(result, data_shape_node, true);
                     result = std::make_shared<default_opset::Multiply>(
-                        reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size),
-                        result);
+                        result,
+                        reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size));
                     result = std::make_shared<default_opset::Add>(
                         result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
 
index b6a702e..1134f60 100644 (file)
@@ -155,7 +155,6 @@ if len(zoo_models) > 0:
             (xfail_issue_36533, "test_onnx_model_zoo_vision_classification_vgg_model_vgg19_bn_7_vgg19_bn_vgg19_bn_cpu"),
             (xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_7_tiny_yolov2_model_cpu"),
             (xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_8_tiny_yolov2_Model_cpu"),
-            (xfail_issue_36533, "test_onnx_model_zoo_vision_style_transfer_fast_neural_style_model_mosaic_8_mosaic_mosaic_cpu"),
             (xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet18_v2_7_resnet18v2_resnet18_v2_7_cpu"),
             (xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet101_v1_7_resnet101v1_resnet101_v1_7_cpu"),
             (xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet152_v1_7_resnet152v1_resnet152_v1_7_cpu"),
index d160686..47bd89a 100644 (file)
@@ -1380,6 +1380,52 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
     test_case.run();
 }
 
+NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_shared_across_channel_batch_size_2)
+{
+    Shape data_shape{2, 2, 5};
+    auto data = make_shared<op::Parameter>(element::f32, data_shape);
+
+    auto mvn_func = make_shared<op::MVN>(data, true);
+    auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
+    auto test_case = test::TestCase<TestEngine>(function);
+    // data
+    vector<float> data_vector(shape_size(data_shape));
+    iota(begin(data_vector), end(data_vector), 0);
+    test_case.add_input<float>(data_vector);
+
+    // expected result
+    test_case.add_expected_output<float>(
+        data_shape,
+        {-1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f, -0.1740777f, 0.1740777f,  0.5222329f,
+         0.8703883f,  1.2185436f,  1.5666989f,  -1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f,
+         -0.1740777f, 0.1740777f,  0.5222329f,  0.8703883f,  1.2185436f,  1.5666989f});
+
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_not_shared_across_channel_batch_size_2)
+{
+    Shape data_shape{2, 2, 5};
+    auto data = make_shared<op::Parameter>(element::f32, data_shape);
+
+    auto mvn_func = make_shared<op::MVN>(data, false);
+    auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
+    auto test_case = test::TestCase<TestEngine>(function);
+    // data
+    vector<float> data_vector(shape_size(data_shape));
+    iota(begin(data_vector), end(data_vector), 0);
+    test_case.add_input<float>(data_vector);
+
+    // expected result
+    test_case.add_expected_output<float>(
+        data_shape,
+        {-1.4142135f, -0.7071068f, 0.0000000f,  0.7071068f,  1.4142135f,  -1.4142135f, -0.7071068f,
+         0.0000000f,  0.7071068f,  1.4142135f,  -1.4142135f, -0.7071068f, 0.0000000f,  0.7071068f,
+         1.4142135f,  -1.4142135f, -0.7071068f, 0.0000000f,  0.7071068f,  1.4142135f});
+
+    test_case.run();
+}
+
 NGRAPH_TEST(${BACKEND_NAME}, grn_4d)
 {
     const Shape data_shape{1, 2, 3, 4};
index f510ec3..7b37b95 100644 (file)
@@ -34,12 +34,12 @@ TEST(type_prop, mvn_partial)
     auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
     auto mvn_func = make_shared<op::MVN>(data);
     EXPECT_EQ(mvn_func->get_element_type(), element::f32);
-    EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{0, 1, 2}));
+    EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{1, 2}));
     ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
         (PartialShape{1, Dimension::dynamic(), 6})));
 
     // across_channels = false
-    EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{0, 2}));
+    EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{2}));
 
     // rank unknown
     auto mvn_partial =