Move passes to CommonOptimizations (#2442)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Thu, 1 Oct 2020 17:08:41 +0000 (20:08 +0300)
committerGitHub <noreply@github.com>
Thu, 1 Oct 2020 17:08:41 +0000 (20:08 +0300)
* Move passes to CommonOptimizations

* Updated BN tests to use ranges for constant value generation

* Added some decomposition passes into legacy conversion

* Added WA for FQReshapeFusion pass

12 files changed:
inference-engine/src/transformations/include/transformations/convert_gelu.hpp
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp
inference-engine/src/transformations/src/transformations/convert_broadcast_to_tiles.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_gelu.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp
inference-engine/src/transformations/src/transformations/hswish_decomposition.cpp
inference-engine/src/transformations/src/transformations/reduce_l1_decomposition.cpp
inference-engine/src/transformations/src/transformations/reduce_l2_decomposition.cpp
inference-engine/tests/functional/inference_engine/transformations/fq_reshape_fusion.cpp
inference-engine/tests/ngraph_functions/src/batch_norm.cpp

index 5ee3c81790c2ac66e85fc1b0fad58abf02a585b1..0057699c99ce6df6a9d23179cacae28ec7ab8919 100644 (file)
@@ -21,13 +21,8 @@ class TRANSFORMATIONS_API ConvertGELU;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertGELU: public ngraph::pass::MatcherPass {
 public:
     NGRAPH_RTTI_DECLARATION;
-    ConvertGELU() : GraphRewrite() {
-        convert_gelu();
-    }
-
-private:
-    void convert_gelu();
+    ConvertGELU();
 };
index e45776a9c78c630aa8995248eaf23f6c4705bec2..fae8de5cfcbae4f87b48ab6763a22aa824472e96 100644 (file)
@@ -7,21 +7,40 @@
 #include "transformations/common_optimizations/algebraic_simplification.hpp"
 #include "transformations/common_optimizations/nop_elimination.hpp"
 #include "transformations/common_optimizations/common_optimizations.hpp"
+#include "transformations/common_optimizations/conv_mul_fusion.hpp"
+#include "transformations/common_optimizations/fq_mul_fusion.hpp"
+#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
 #include "transformations/depth_to_space_fusion.hpp"
 #include "transformations/optimize_strided_slice.hpp"
-#include "transformations/convert_scatter_elements_to_scatter.hpp"
-#include "transformations/convert_pad_to_group_conv.hpp"
-#include "transformations/remove_filtering_boxes_by_size.hpp"
 #include "transformations/init_node_info.hpp"
 #include "transformations/itt.hpp"
 #include "transformations/mish_fusion.hpp"
 #include "transformations/softplus_fusion.hpp"
 #include "transformations/softplus_to_mish_fusion.hpp"
 #include "transformations/swish_fusion.hpp"
-#include "transformations/hswish_fusion.hpp"
 #include "transformations/normalize_l2_fusion.hpp"
-#include "transformations/convert_quantize_dequantize.hpp"
 #include "transformations/bidirectional_sequences_decomposition.hpp"
+#include "transformations/convert_pad_to_group_conv.hpp"
+#include "transformations/convert_divide.hpp"
+#include "transformations/convert_quantize_dequantize.hpp"
+#include "transformations/convert_mod.hpp"
+#include "transformations/convert_minimum_to_power_and_max.hpp"
+#include "transformations/convert_negative.hpp"
+#include "transformations/convert_scatter_elements_to_scatter.hpp"
+#include "transformations/convert_reduce_to_pooling.hpp"
+#include "transformations/convert_subtract.hpp"
+#include "transformations/convert_depth_to_space.hpp"
+#include "transformations/convert_space_to_depth.hpp"
+#include "transformations/convert_broadcast_to_tiles.hpp"
+#include "transformations/convert_gelu.hpp"
+#include "transformations/batch_norm_decomposition.hpp"
+#include "transformations/pull_transpose_through_fq.hpp"
+#include "transformations/lin_op_sequence_fusoin.hpp"
+#include "transformations/reduce_l1_decomposition.hpp"
+#include "transformations/reduce_l2_decomposition.hpp"
+#include "transformations/remove_filtering_boxes_by_size.hpp"
+#include "transformations/hswish_decomposition.hpp"
+#include "transformations/hswish_fusion.hpp"
 
 #include <ngraph/pass/manager.hpp>
 #include <ngraph/pass/constant_folding.hpp>
@@ -55,6 +74,43 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
     manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
 
+    auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
+    decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
+    decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
+    decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
+    decomp->add_matcher<ngraph::pass::ConvertMod>();
+    decomp->add_matcher<ngraph::pass::ConvertGELU>();
+    decomp->add_matcher<ngraph::pass::ConvertMinimum>();
+    decomp->add_matcher<ngraph::pass::ConvertSubtract>();
+    decomp->add_matcher<ngraph::pass::ConvertDivide>();
+    decomp->add_matcher<ngraph::pass::ConvertNegative>();
+    decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
+    decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
+    decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
+    decomp->set_name("ngraph::pass::CommonDecompositions");
+
+    // CF is required after all decompositions
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+
+    // LinOpSequenceFusion must be executed after all decompositions
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+
+    manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
+    manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
+    manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
+    manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+
+    auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
+    fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
+    fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
+    fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
+    fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
+
     manager.set_callback(m_transformation_callback);
     manager.run_passes(f);
     return true;
index 0f7127ca2759af190b080ef3f5b34bdd8dd505e0..32ecf1c66b35872d33ebc405017ca485851c58b1 100644 (file)
@@ -20,7 +20,14 @@ ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
              ngraph::pattern::any_input()},
             pattern::consumers_count(1));
     const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
-            {fq_node_p, ngraph::pattern::any_input()});
+            {fq_node_p, ngraph::pattern::any_input()}, [](const Output<Node> & output) {
+                // WA: check that all Reshape node consumers are not GroupConvolution operations
+                const auto & target_inputs = output.get_target_inputs();
+                return std::all_of(target_inputs.begin(), target_inputs.end(),
+                        [](const Input<Node> & input){
+                            return input.get_node()->get_type_info() != opset4::GroupConvolution::type_info;
+                        });
+            });
 
     ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
         const auto &pattern_map = m.get_pattern_value_map();
index 895a0e94af0559573433b7da1cadd93968e84a0e..939648cf2aa8a7d08592c0c09f03e25adf334afb 100644 (file)
@@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcastToTiles, "ConvertBroadcastT
 ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
     auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
 
         if (!broadcast) {
@@ -83,7 +83,7 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
         }
 
         auto const_node = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {dims_count}, dims);
-        auto tile = std::make_shared<ngraph::opset1::Tile>(last_node, const_node);
+        auto tile = register_new_node<ngraph::opset1::Tile>(last_node, const_node);
         new_ops.push_back(tile);
         tile->set_friendly_name(broadcast->get_friendly_name());
 
index a69be94aa4d54c3ec4c55da91511813edbd91cf5..2a9fc54f4fab599f26ab4ca39563c3d371b6a21b 100644 (file)
@@ -12,7 +12,7 @@
 
 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0);
 
-void ngraph::pass::ConvertGELU::convert_gelu() {
+ngraph::pass::ConvertGELU::ConvertGELU() {
     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
 
@@ -24,11 +24,11 @@ void ngraph::pass::ConvertGELU::convert_gelu() {
         auto input_type = input.get_element_type();
 
         // f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) )
-        auto mul = std::make_shared<ngraph::opset1::Multiply>(input, ngraph::op::Constant::create(input_type, Shape{}, {0.5}));
-        auto sq2 = std::make_shared<ngraph::opset1::Sqrt>(ngraph::op::Constant::create(input_type, Shape{}, {2.0}));
-        auto div = std::make_shared<ngraph::opset1::Divide>(input, sq2);
+        auto mul = std::make_shared<ngraph::opset1::Multiply>(input, ngraph::opset1::Constant::create(input_type, Shape{}, {0.5}));
+        auto sq2 = std::make_shared<ngraph::opset1::Sqrt>(ngraph::opset1::Constant::create(input_type, Shape{}, {2.0}));
+        auto div = register_new_node<ngraph::opset1::Divide>(input, sq2); // can be decomposed
         auto erf = std::make_shared<ngraph::opset1::Erf>(div);
-        auto add = std::make_shared<ngraph::opset1::Add>(erf, ngraph::op::Constant::create(input_type, Shape{}, {1.0}));
+        auto add = std::make_shared<ngraph::opset1::Add>(erf, ngraph::opset1::Constant::create(input_type, Shape{}, {1.0}));
         auto res = std::make_shared<ngraph::opset1::Multiply>(mul, add);
 
         res->set_friendly_name(gelu->get_friendly_name());
@@ -38,5 +38,5 @@ void ngraph::pass::ConvertGELU::convert_gelu() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(gelu, "ConvertGELU");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
 }
index 82a2e5ce6ec423305d5d6683f736a235a9a3636b..b6784268879769c408a54a6029d8e82dd84e36a3 100644 (file)
@@ -68,48 +68,26 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet1ToLegacy");
 
     ngraph::pass::Manager manager;
-    std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
 
     manager.register_pass<ngraph::pass::ConstantFolding>();
 
-    // the following two transformations produce ReduceSum operations so they
-    // must be executed before the ConvertReduceSumToPooling transformation
-    manager.register_pass<ngraph::pass::ReduceL1Decomposition>();
-    manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
-
-    // HSwishDecomposition produce Minimum, Relu and Multiply operations
-    // so it must be executed before
-    manager.register_pass<ngraph::pass::HSwishDecomposition>();
-
-    // List if Decomposition and Conversion transformations that can be
-    // applied simultaneously in a single graph traversal
+    // Some passes before ConvertOpSet1ToLegacy can produce some of this
+    // operations. So for convenience we decompose this operations here and
+    // in CommonOptimizations.
     auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
-    decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
-    decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
-    decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
-    decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
     decomp->add_matcher<ngraph::pass::ConvertMod>();
     decomp->add_matcher<ngraph::pass::ConvertMinimum>();
     decomp->add_matcher<ngraph::pass::ConvertSubtract>();
     decomp->add_matcher<ngraph::pass::ConvertDivide>();
     decomp->add_matcher<ngraph::pass::ConvertNegative>();
-    decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
-    decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
-    decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
-    decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
-    decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
-    decomp->set_name("ngraph::pass::Decompositions");
-
-    // CF is required after all decompositions
-    manager.register_pass<ngraph::pass::ConstantFolding>();
+    decomp->set_name("ngraph::pass::LegacyDecompositions");
 
-    // LinOpSequenceFusion must be executed after all decompositions
-    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+    auto convert_matmul = manager.register_pass<ngraph::pass::GraphRewrite>();
+    convert_matmul->add_matcher<ngraph::pass::ConvertMatMulToFC>();
+    convert_matmul->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
+    convert_matmul->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
+    convert_matmul->set_name("ngraph::pass::ConvertMatMul");
 
-    manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
-    manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
-    manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
-    manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
     manager.register_pass<ngraph::pass::ConstantFolding>();
 
     // Convolution/Deconvolution/FullyConnected fusions
@@ -120,18 +98,12 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     convert_convolutions->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
     convert_convolutions->set_name("ngraph::pass::ConvertConvolutions");
 
-    auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
-    fq_fusions->add_matcher<FakeQuantizeMulFusion>();
-    fq_fusions->add_matcher<FakeQuantizeReshapeFusion>();
-    fq_fusions->add_matcher<PullTransposeThroughFQUp>();
-    fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
-
     // Convolution/Deconvolution/FullyConnected fusions
     auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
     fusion->add_matcher<ngraph::pass::ConvAddFusion>();
     fusion->add_matcher<ngraph::pass::DeconvAddFusion>();
     fusion->add_matcher<ngraph::pass::FullyConnectedBiasFusion>();
-    fusion->set_name("ngraph::pass::Fusions");
+    fusion->set_name("ngraph::pass::BiasFusions");
 
     // CF is required after fusions
     manager.register_pass<ngraph::pass::ConstantFolding>();
@@ -148,6 +120,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     anchor->add_matcher<ngraph::pass::ConvertHardSigmoidToLegacyMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertProposalToLegacyMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertProposal4ToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
     anchor->add_matcher<ngraph::pass::ConvertTileToLegacyMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertLRNToLegacyMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertPadToLegacyMatcher>();
@@ -170,7 +143,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     anchor->add_matcher<ngraph::pass::ConvertGRUSequenceMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertRNNSequenceMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertLSTMSequenceMatcher>();
-    anchor->set_name("ngraph::pass::ConvertOpSet1ToLegacy");
+    anchor->set_name("ngraph::pass::LegacyConversions");
 
     // List of final conversion transformations that must to be executed
     // after previous group of transformations
index 93588e359da9261a5c56cdf64f71ee68363afb1c..80b66b3ac0de38b1a8b7d7cc70c96c1e2e67be13 100644 (file)
@@ -21,7 +21,6 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
 
     ngraph::pass::Manager manager;
 
-    manager.register_pass<ngraph::pass::ConvertGELU>();
     manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
     manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
 
index 2da0f618c4b47553e558aa4f6e3a6c7ac4c83822..e0a52fbd72aa90557caebb3019e303a57026dc30 100644 (file)
@@ -29,7 +29,7 @@ ngraph::pass::HSwishDecomposition::HSwishDecomposition() {
         auto add = std::make_shared<ngraph::opset4::Add>(hswish_node->input_value(0), add_constant);
         auto relu = std::make_shared<ngraph::opset4::Relu>(add);
         auto min_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {6.0});
-        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto min = register_new_node<ngraph::opset4::Minimum>(relu, min_constant);
         auto mul_first = std::make_shared<ngraph::opset4::Multiply>(hswish_node->input_value(0), min);
         auto mul_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {(1.0/6.0)});  // const(1/6)
         auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
index 87f8fd0481be11fe1fd5c48f00c1fdf7fd10e9fd..6d83866a4e36c8977ffb7ed84d7dd5890f84aecc 100644 (file)
@@ -25,7 +25,7 @@ ngraph::pass::ReduceL1Decomposition::ReduceL1Decomposition() {
         }
 
         auto abs = std::make_shared<ngraph::opset4::Abs>(reduce_l1_node->input_value(0));
-        auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(abs, reduce_l1_node->input_value(1), reduce_l1_node->get_keep_dims());
+        auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(abs, reduce_l1_node->input_value(1), reduce_l1_node->get_keep_dims());
 
         reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name());
         ngraph::copy_runtime_info(reduce_l1_node,
index f36cfeef67cd04da6dd20154d10b4489a42f15a7..87c58b5c3f20929b1ae0d78e89f34ebcf10eef2a 100644 (file)
@@ -26,7 +26,7 @@ ngraph::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
 
         auto const_2 = ngraph::opset4::Constant::create(reduce_l2_node->input_value(0).get_element_type(), Shape{}, {2.0f});
         auto square = std::make_shared<ngraph::opset4::Power>(reduce_l2_node->input_value(0), const_2);
-        auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
+        auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
         auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
         reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name());
         ngraph::copy_runtime_info(reduce_l2_node,
index aea06e15ace8468673df53e588102d14480aacc2..8f166131c1a7336f7405f15e696ffec0b5552f69 100644 (file)
@@ -123,3 +123,49 @@ INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
     FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {},  {}, {}, true},
     FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {},  {}, {}, true}));
 }  // namespace
+
+TEST(nGraphFQReshapeFusionTests, FQReshapeGroupConvolution) {
+    auto get_function = [](const FQReshapeFusionTestCase & test_case) {
+        const auto & data =  std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
+        auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
+        auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
+        auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
+        auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
+
+        auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
+
+        auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
+                ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
+        auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
+
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.data_shape);
+        ngraph::Strides stride{1, 1};
+        ngraph::CoordinateDiff pad{0, 0};
+        auto group_conv = std::make_shared<ngraph::opset4::GroupConvolution>(input, reshape, stride, pad, pad, stride);
+
+        auto result = std::make_shared<ngraph::op::Result>(group_conv);
+        ngraph::ParameterVector params = {il, ih, ol, oh, input};
+        ngraph::ResultVector results = {result};
+        return std::make_shared<ngraph::Function>(results, params);
+    };
+
+    FQReshapeFusionTestCase params;
+    params.data_shape = {1, 2, 1, 3};
+    params.il_shape = {2, 1, 1};
+    params.ih_shape = {1};
+    params.ol_shape = {1, 1};
+    params.oh_shape = {1, 2, 1, 1};
+    params.reshape_pattern = {2, 3, 1, 1, 1};
+
+    auto f = get_function(params);
+
+    ngraph::pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
+    manager.run_passes(f);
+
+    ASSERT_NO_THROW(check_rt_info(f));
+
+    auto res = compare_functions(f, get_function(params));
+    ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file
index c0f960663d178862d33e1f6c33efaa4db551a614..14f4035e9e4f565333ab1d772f212d9a66ea88da 100644 (file)
@@ -15,9 +15,9 @@ std::shared_ptr<ngraph::Node> makeBatchNormInference(const ngraph::Output<Node>&
     size_t C   = data.get_shape().at(1);
     bool random = true;
     std::vector<float> values(C);
-    auto gamma = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
-    auto beta  = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
-    auto mean  = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
+    auto gamma = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
+    auto beta  = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
+    auto mean  = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
 
     // Fill the vector for variance with positive values
     std::default_random_engine gen;