Fix NopElimination (#891)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Mon, 15 Jun 2020 07:39:55 +0000 (10:39 +0300)
committerGitHub <noreply@github.com>
Mon, 15 Jun 2020 07:39:55 +0000 (10:39 +0300)
ngraph/src/ngraph/pass/nop_elimination.cpp
ngraph/test/nop_elimination.cpp

index 3846fde..fc196dd 100644 (file)
@@ -331,8 +331,6 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
 
     auto squeeze = as_type_ptr<opset3::Squeeze>(node);
     auto input = squeeze->input_value(0).get_node_shared_ptr();
-    auto data_shape = input->input_value(0).get_partial_shape();
-    auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input);
     auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
         auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
         auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
@@ -343,8 +341,9 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
         return false;
     };
     // eliminate redundant unsqueeze->squeeze
-    if (unsqueeze)
+    if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
     {
+        auto data_shape = input->input(0).get_partial_shape();
         if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
                                       squeeze->input_value(1).get_node_shared_ptr()))
         {
@@ -383,9 +382,9 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
         return false;
     }
     // eliminate redundant squeeze->squeeze
-    auto squeeze_i = as_type_ptr<opset3::Squeeze>(input);
-    if (squeeze_i)
+    if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
     {
+        auto data_shape = input->input(0).get_partial_shape();
         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
         {
             return false;
index 8b4aa57..b95c6a1 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "gtest/gtest.h"
 #include "ngraph/ngraph.hpp"
+#include "ngraph/opsets/opset1.hpp"
 #include "ngraph/pass/manager.hpp"
 #include "ngraph/pass/nop_elimination.hpp"
 #include "util/all_close.hpp"
@@ -728,6 +729,23 @@ TEST(nop_elimination, unsqueeze_reshape_elimination)
     check_usecase(Shape{1, 6, 2}, std::vector<int64_t>{3});
 }
 
+TEST(nop_elimination, squeeze_unsqueeze_elimination_negative)
+{
+    auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
+        auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val);
+        auto input = make_shared<op::Parameter>(element::f32, shape);
+        auto squeeze = make_shared<ngraph::opset1::Squeeze>(input, indices);
+        auto baseline_f = make_shared<Function>(squeeze, ParameterVector{input});
+        auto optimized_f = clone_function(*baseline_f);
+        EXPECT_TRUE((compare_pass_int<pass::NopElimination, float>(baseline_f, optimized_f)));
+
+        ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(baseline_f), 1);
+        ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(optimized_f), 1);
+    };
+
+    check_usecase(Shape{1, 1, 1}, std::vector<int64_t>{0, 1, 2});
+}
+
 TEST(nop_elimination, topk_convert_elimination)
 {
     auto check_usecase = []() {