Add dynamic shape checks to nGraph transformations (#2735)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Fri, 23 Oct 2020 12:39:47 +0000 (15:39 +0300)
committerGitHub <noreply@github.com>
Fri, 23 Oct 2020 12:39:47 +0000 (15:39 +0300)
* Added dynamic shape checks for BatchNormDecompositoin pass

* Added dynamic shapes checks for FQTranspose fusion pass

* Added patter::has_static_rank predicate

* Added dynamic shapes checks for BroadcastToTiles pass

* Fixed BN inputs order

* Add dynamic shape checks for DepthToSpace/SpaceToDepth passes

* Added dynamic check for ReduceToPooling pass

* Updated BN transformation

* Fix PR comments

* size_t to int64_t

* Updated reduce to pooling pattern

13 files changed:
inference-engine/src/transformations/include/transformations/op_conversions/convert_reduce_to_pooling.hpp
inference-engine/src/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp
inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp
inference-engine/src/transformations/src/transformations/op_conversions/convert_broadcast_to_tiles.cpp
inference-engine/src/transformations/src/transformations/op_conversions/convert_depth_to_space.cpp
inference-engine/src/transformations/src/transformations/op_conversions/convert_space_to_depth.cpp
inference-engine/tests/functional/inference_engine/transformations/batch_norm_decompositoin.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/convert_broadcast_to_tiles_test.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/convert_reduce_to_pooling_test.cpp
inference-engine/tests/functional/inference_engine/transformations/ngraph_depth_to_space_transform_test.cpp
inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp
ngraph/core/include/ngraph/pattern/op/pattern.hpp
ngraph/core/src/pattern/op/pattern.cpp

index 23dffdb..3e142cb 100644 (file)
@@ -47,7 +47,10 @@ public:
 class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
 public:
     ConvertReduceMeanToPooling() {
-        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMean>(), "ConvertReduceMean");
+        auto m = std::make_shared<ngraph::pattern::Matcher>(
+                ngraph::pattern::wrap_type<opset1::ReduceMean>({pattern::any_input(pattern::has_static_shape()),
+                                                                pattern::wrap_type<opset1::Constant>()},
+                                                                pattern::has_static_shape()), "ConvertReduceMean");
         register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
     }
 };
@@ -55,7 +58,10 @@ public:
 class ngraph::pass::ConvertReduceMaxToPooling: public ConvertReduceBase {
 public:
     ConvertReduceMaxToPooling() {
-        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMax>(), "ConvertReduceMax");
+        auto m = std::make_shared<ngraph::pattern::Matcher>(
+                ngraph::pattern::wrap_type<opset1::ReduceMax>({pattern::any_input(pattern::has_static_shape()),
+                                                               pattern::wrap_type<opset1::Constant>()},
+                                                               pattern::has_static_shape()), "ConvertReduceMax");
         register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
     }
 };
@@ -63,7 +69,10 @@ public:
 class ngraph::pass::ConvertReduceSumToPooling: public ConvertReduceBase {
 public:
     ConvertReduceSumToPooling() {
-        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceSum>(), "ConvertReduceSum");
+        auto m = std::make_shared<ngraph::pattern::Matcher>(
+                ngraph::pattern::wrap_type<opset1::ReduceSum>({pattern::any_input(pattern::has_static_shape()),
+                                                               pattern::wrap_type<opset1::Constant>()},
+                                                               pattern::has_static_shape()), "ConvertReduceSum");
         register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
     }
 };
@@ -79,12 +88,12 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
 
         auto input = reduce->input_value(0);
 
-        auto axes_node = reduce->input_value(1).get_node_shared_ptr();
-        if (!ngraph::op::is_constant(axes_node)) {
+        auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(reduce->input_value(1).get_node_shared_ptr());
+        if (!axes_node) {
             return false;
         }
 
-        auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
+        auto axes_vector = axes_node->template cast_vector<int64_t>();
         const auto input_rank = input.get_partial_shape().rank().get_length();
         // Transform negative axes into non-negative ones
         for (size_t i = 0; i < axes_vector.size(); ++i) {
@@ -99,10 +108,6 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
             return replace_output_update_name(reduce->output(0), input);
         }
 
-        // As this transformation requires static input shape we should guaranty it
-        if (input.get_partial_shape().is_dynamic()) {
-            return false;
-        }
         auto input_shape = input.get_shape();
 
         // If Reduce op reduces only 1 dims we replace it with Reshape
index 047fd19..aea96b9 100644 (file)
@@ -9,56 +9,42 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThroughFQUp", 0);
 
 ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
-    auto data1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto data2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto data3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto data4 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto data5 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(data1, data2, data3, data4, data5, 1);
-    auto transpose_order = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
-    auto transpose = std::make_shared<ngraph::opset1::Transpose>(fq, transpose_order);
+    auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()),
+                                                          pattern::any_input(pattern::has_static_rank()),
+                                                          pattern::any_input(pattern::has_static_rank()),
+                                                          pattern::any_input(pattern::has_static_rank()),
+                                                          pattern::any_input(pattern::has_static_rank())},
+                                                          pattern::consumers_count(1));
+    auto m_transpose = pattern::wrap_type<opset1::Transpose>({m_fq, pattern::wrap_type<opset1::Constant>()});
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
-        auto transpose = ngraph::as_type_ptr<ngraph::opset1::Transpose>(m.get_match_root());
-        if (!transpose) {
-            return false;
-        }
-
-        auto const_node = transpose->input(1).get_source_output().get_node_shared_ptr();
-        auto const_order = ngraph::as_type_ptr<ngraph::opset1::Constant>(const_node);
-        if (!const_order) {
-            return false;
-        }
-
-        auto fq_node = transpose->input(0).get_source_output().get_node_shared_ptr();
-        auto fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(fq_node);
-        if (!fq || fq->output(0).get_target_inputs().size() != 1) {
-            return false;
-        }
+    ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
+        auto & pattern_map = m.get_pattern_value_map();
+        auto transpose = pattern_map[m_transpose].get_node_shared_ptr();
+        auto fq = pattern_map[m_fq].get_node_shared_ptr();
 
-        auto input_shape = fq->input(0).get_source_output().get_shape();
+        auto input_rank = fq->input(0).get_partial_shape().rank().get_length();
 
         ngraph::NodeVector new_ops;
         ngraph::OutputVector fq_inputs;
         for (size_t i = 0; i < fq->inputs().size(); ++i) {
-            std::shared_ptr<ngraph::Node> fq_input;
-            fq_input = fq->input(i).get_source_output().get_node_shared_ptr();
-            auto fq_input_shape = fq_input->get_shape();
+            auto fq_input = fq->input_value(i);
+            auto fq_input_rank = fq_input.get_partial_shape().rank().get_length();
             std::vector<int64_t> unsqueeze_axes;
-            for (size_t j = 0; j < input_shape.size() - fq_input_shape.size(); ++j) {
+            for (size_t j = 0; j < input_rank - fq_input_rank; ++j) {
                 unsqueeze_axes.push_back(j);
             }
             if (!unsqueeze_axes.empty()) {
                 fq_input = std::make_shared<ngraph::opset1::Unsqueeze>(fq_input,
                                                                        opset1::Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes));
-                new_ops.push_back(fq_input);
+                new_ops.push_back(fq_input.get_node_shared_ptr());
             }
-            fq_input = transpose->copy_with_new_inputs({fq_input, const_order});
-            ngraph::copy_runtime_info(transpose, fq_input);
+            fq_input = transpose->copy_with_new_inputs({fq_input, transpose->input_value(1)});
+            ngraph::copy_runtime_info(transpose, fq_input.get_node_shared_ptr());
             fq_inputs.push_back(fq_input);
         }
 
@@ -71,6 +57,6 @@ ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "PullTransposeThroughFQUp");
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_transpose, "PullTransposeThroughFQUp");
     this->register_matcher(m, callback);
 }
index e76e302..384e916 100644 (file)
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/opsets/opset5.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 using namespace ngraph;
 
 NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0);
 
 ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
-    Shape shape{2, 2, 1, 1};
-    auto input = make_shared<pattern::op::Label>(element::f32, shape);
-    auto mean_shape = Shape{2};
-    auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
-    auto var_shape = Shape{2};
-    auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
-    auto gamma_shape = Shape{2};
-    auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
-    auto beta_shape = Shape{2};
-    auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
-    auto bn = make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
-
-    ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
-        auto pattern_map = m.get_pattern_map();
-
-        auto m_input = pattern_map[input];
-        auto m_gamma = pattern_map[gamma];
-        auto m_beta = pattern_map[beta];
-        auto m_mean = pattern_map[mean];
-        auto m_var = pattern_map[var];
-
-        // TODO: check that all input shapes are static
-
+    auto bn = pattern::wrap_type<opset1::BatchNormInference>({
+        pattern::any_input(pattern::has_static_rank()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape())
+    });
+
+    ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
         auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
         if (!m_bn) {
             return false;
         }
 
-        const auto& input_type = m_input->get_element_type();
+        auto m_gamma = m_bn->input_value(0);
+        auto m_beta = m_bn->input_value(1);
+        auto m_input = m_bn->input_value(2);
+        auto m_mean = m_bn->input_value(3);
+        auto m_var = m_bn->input_value(4);
+
+        const auto& input_type = m_input.get_element_type();
         // scale_add = variance + eps
         auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
         // scale = sqrt(variance + eps)
@@ -52,8 +45,10 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
         // Divide `gamma` by `sqrt(variance + eps)`
         auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
 
-        size_t dims_to_add = m_input->get_shape().size() - 2;
-        Shape input_aligned_shape = m_gamma->get_shape();
+        int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
+
+        // TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
+        Shape input_aligned_shape = m_gamma.get_shape();
         for (size_t i = 0; i < dims_to_add; ++i)
             input_aligned_shape.push_back(1);
         auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
@@ -84,36 +79,29 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
 
 NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
 
+// TODO: this pass will be unified with BatchNormDecomposition pass
 ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
-    Shape shape{2, 2, 1, 1};
-    auto input = make_shared<pattern::op::Label>(element::f32, shape);
-    auto mean_shape = Shape{2};
-    auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
-    auto var_shape = Shape{2};
-    auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
-    auto gamma_shape = Shape{2};
-    auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
-    auto beta_shape = Shape{2};
-    auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
-    auto bn = make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
-
-    ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
-        auto pattern_map = m.get_pattern_map();
-
-        auto m_input = pattern_map[input];
-        auto m_gamma = pattern_map[gamma];
-        auto m_beta = pattern_map[beta];
-        auto m_mean = pattern_map[mean];
-        auto m_var = pattern_map[var];
-
-        // TODO: check that all input shapes are static
-
+    auto bn = pattern::wrap_type<opset5::BatchNormInference>({
+        pattern::any_input(pattern::has_static_rank()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape()),
+        pattern::any_input(pattern::has_static_shape())
+    });
+
+    ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
         auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
         if (!m_bn) {
             return false;
         }
 
-        const auto& input_type = m_input->get_element_type();
+        auto m_input = m_bn->input_value(0);
+        auto m_gamma = m_bn->input_value(1);
+        auto m_beta = m_bn->input_value(2);
+        auto m_mean = m_bn->input_value(3);
+        auto m_var = m_bn->input_value(4);
+
+        const auto& input_type = m_input.get_element_type();
         // scale_add = variance + eps
         auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
         // scale = sqrt(variance + eps)
@@ -121,8 +109,10 @@ ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
         // Divide `gamma` by `sqrt(variance + eps)`
         auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
 
-        size_t dims_to_add = m_input->get_shape().size() - 2;
-        Shape input_aligned_shape = m_gamma->get_shape();
+        int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
+
+        // TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
+        Shape input_aligned_shape = m_gamma.get_shape();
         for (size_t i = 0; i < dims_to_add; ++i)
             input_aligned_shape.push_back(1);
         auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
index 540dd3d..b736106 100644 (file)
@@ -16,24 +16,28 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcastToTiles, "ConvertBroadcastT
 ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
     auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
 
-    ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
         auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
 
         if (!broadcast) {
             return false;
         }
 
-        auto data_node = broadcast->input_value(0).get_node_shared_ptr();
+        auto data_node = broadcast->input_value(0);
+        if (data_node.get_partial_shape().is_dynamic()) {
+            return false;
+        }
+
         auto shape_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(1).get_node_shared_ptr());
         auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(2).get_node_shared_ptr());
-        if (!data_node || !shape_node || !axes_node) return false;
+        if (!shape_node || !axes_node) return false;
 
         auto output_shape = shape_node->cast_vector<int64_t>();
-        auto input_shape = data_node->get_shape();
+        auto input_shape = data_node.get_shape();
         int64_t cur_dim_id = output_shape.size() - 1;
         size_t dims_count = output_shape.size();
 
-        auto last_node = std::dynamic_pointer_cast<ngraph::Node>(data_node);
+        auto last_node = data_node;
 
         NodeVector new_ops;
 
@@ -61,7 +65,7 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
             auto shape_const = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {shape.size()}, shape);
             auto reshape = std::make_shared<ngraph::opset1::Reshape>(data_node, shape_const, true);
             new_ops.push_back(reshape);
-            last_node = std::dynamic_pointer_cast<ngraph::Node>(reshape);
+            last_node = reshape;
             input_shape = shape;
         }
 
@@ -87,9 +91,8 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
         new_ops.push_back(tile);
         tile->set_friendly_name(broadcast->get_friendly_name());
 
-        last_node = std::dynamic_pointer_cast<ngraph::Node>(tile);
         ngraph::copy_runtime_info(broadcast, new_ops);
-        ngraph::replace_node(broadcast, last_node);
+        ngraph::replace_node(broadcast, tile);
         return true;
     };
 
index 07f580c..3cda6a5 100644 (file)
@@ -14,7 +14,7 @@
 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace", 0);
 
 ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
-    auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>();
+    auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>({pattern::any_input(pattern::has_static_shape())});
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
@@ -22,7 +22,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
             return false;
         }
 
-        auto input = dts_node->input(0).get_source_output().get_node_shared_ptr();
+        auto input = dts_node->input_value(0);
 
         /*
          * In this transformation we decompose DepthToSpace operation to the next sequence of ops:
index d00af46..35c66e8 100644 (file)
@@ -14,7 +14,7 @@
 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth", 0);
 
 ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
-    auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>();
+    auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>({pattern::any_input(pattern::has_static_shape())});
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
@@ -22,7 +22,7 @@ ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
             return false;
         }
 
-        auto input = std_node->input(0).get_source_output().get_node_shared_ptr();
+        auto input = std_node->input_value(0);
 
         /*
          * In this transformation we decompose SpaceToDepth operation to the next sequence of ops:
diff --git a/inference-engine/tests/functional/inference_engine/transformations/batch_norm_decompositoin.cpp b/inference-engine/tests/functional/inference_engine/transformations/batch_norm_decompositoin.cpp
new file mode 100644 (file)
index 0000000..4dc8c3a
--- /dev/null
@@ -0,0 +1,40 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <transformations/op_conversions/batch_norm_decomposition.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, BatchNormDecompositionDynamic) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
+        auto gamma = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
+        auto beta = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
+        auto mean = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
+        auto var = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
+        auto broadcast = std::make_shared<ngraph::opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
+        broadcast->set_friendly_name("broadcast");
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::BatchNormDecomposition>();
+        ASSERT_NO_THROW(manager.run_passes(f));
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+}
\ No newline at end of file
diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_broadcast_to_tiles_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_broadcast_to_tiles_test.cpp
new file mode 100644 (file)
index 0000000..b00606e
--- /dev/null
@@ -0,0 +1,40 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <transformations/op_conversions/convert_broadcast_to_tiles.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, ConvertBroadcastToTilesDynamic) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
+        auto target_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input1, target_shape);
+        broadcast->set_friendly_name("broadcast");
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertBroadcastToTiles>();
+        ASSERT_NO_THROW(manager.run_passes(f));
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+}
+
index c63b7bd..b1a9506 100644 (file)
@@ -54,8 +54,7 @@ public:
         f_ref = get_reference_function(input_shape, reduce_type, reference_params);
     }
 
-private:
-    std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
+    static std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
                                                            const std::vector<int64_t> & axes,
                                                            const ReduceType & reduce_type,
                                                            const bool keep_dims) {
@@ -72,7 +71,7 @@ private:
         return std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
     }
 
-    std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
+    static std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
                                                              const ReduceType & reduce,
                                                              const ReduceToPoolParams & params) {
         auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
@@ -137,6 +136,10 @@ INSTANTIATE_TEST_CASE_P(ReduceToReshapePoolReshape, ConvertReduceToPoolingTests,
                         std::make_tuple(MAX, InputShape{2, 9},       ReduceAxes{-1},      KeepDims{true},  ReduceToPoolParams({1, 1, 9, 1}, {9, 1}, {1, 1})),
                         std::make_tuple(MAX, InputShape{2, 3, 4, 1}, ReduceAxes{1, 3, 2}, KeepDims{false}, ReduceToPoolParams({1, 1, 12, 1}, {12, 1}, {1}))));
 
-#undef MAX
-
+TEST(ConvertReduceToPooling, Negative) {
+    auto f = ConvertReduceToPoolingTests::get_initial_function(
+            ngraph::PartialShape::dynamic(), {3}, MAX, true);
+    ASSERT_NO_THROW(ngraph::pass::ConvertReduceToPooling().run_on_function(f));
+}
 
+#undef MAX
index 692e784..d823b1d 100644 (file)
@@ -181,3 +181,29 @@ TEST(TransformationTests, TestSpaceToDepthTransformDepthFirst) {
     std::vector<int64_t> shape_end_value_ref{1, 12 * 4, 1080 / 2, 1616 / 2};
     ASSERT_EQ(shape_end_value, shape_end_value_ref);
 }
+
+TEST(TransformationTests, TestSpaceToDepthDynamic) {
+    auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
+    std::shared_ptr<ngraph::Function> f(nullptr);
+
+    {
+        auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, 2);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
+        ASSERT_NO_THROW(m.run_passes(f));
+    }
+}
+
+TEST(TransformationTests, TestDepthToSpaceDynamic) {
+    auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
+    std::shared_ptr<ngraph::Function> f(nullptr);
+
+    {
+        auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::ConvertDepthToSpace>();
+        ASSERT_NO_THROW(m.run_passes(f));
+    }
+}
index 7af719b..e6ce532 100644 (file)
@@ -55,3 +55,29 @@ TEST(TransformationTests, FQTransposeTest1) {
         }
     }
 }
+
+TEST(TransformationTests, FQTransposeDynamic) {
+    auto data1 = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
+    auto data2 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {1, 2, 3});
+    auto data3 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
+    auto data4 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
+    auto data5 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
+    auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+
+    std::shared_ptr<ngraph::Function> f(nullptr);
+    {
+        auto fq = std::make_shared<ngraph::op::FakeQuantize>(data1, data2, data3, data4, data5, 1);
+        auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data1});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
+        manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+            check_rt_info(f);
+        });
+        manager.register_pass<ngraph::pass::ConstantFolding>();
+        ASSERT_NO_THROW(manager.run_passes(f));
+    }
+}
index 71c9769..0a3db96 100644 (file)
@@ -62,6 +62,9 @@ namespace ngraph
         std::function<bool(Output<Node>)> has_static_shape();
 
         NGRAPH_API
+        std::function<bool(Output<Node>)> has_static_rank();
+
+        NGRAPH_API
         std::function<bool(Output<Node>)> type_matches(const element::Type& type);
 
         NGRAPH_API
index 6a98897..fc47fe7 100644 (file)
@@ -95,6 +95,13 @@ namespace ngraph
                 [=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
         }
 
+        std::function<bool(Output<Node>)> has_static_rank()
+        {
+            return [=](Output<Node> output) -> bool {
+                return output.get_partial_shape().rank().is_static();
+            };
+        }
+
         std::function<bool(Output<Node>)> type_matches(const element::Type& type)
         {
             return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };