From 36be9e403174e2f5d492bc0fd207fd32b2246665 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Mon, 15 Jun 2020 10:39:55 +0300 Subject: [PATCH] Fix NopElimination (#891) --- ngraph/src/ngraph/pass/nop_elimination.cpp | 9 ++++----- ngraph/test/nop_elimination.cpp | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/ngraph/src/ngraph/pass/nop_elimination.cpp b/ngraph/src/ngraph/pass/nop_elimination.cpp index 3846fde..fc196dd 100644 --- a/ngraph/src/ngraph/pass/nop_elimination.cpp +++ b/ngraph/src/ngraph/pass/nop_elimination.cpp @@ -331,8 +331,6 @@ static bool eliminate_squeeze(const std::shared_ptr& node) auto squeeze = as_type_ptr(node); auto input = squeeze->input_value(0).get_node_shared_ptr(); - auto data_shape = input->input_value(0).get_partial_shape(); - auto unsqueeze = as_type_ptr(input); auto replace_squeeze_only = [&](const vector& axes) { auto axes_const = opset3::Constant::create(element::i64, Shape{axes.size()}, axes); auto new_sq = make_shared(input->input_value(0), axes_const); @@ -343,8 +341,9 @@ static bool eliminate_squeeze(const std::shared_ptr& node) return false; }; // eliminate redundant unsqueeze->squeeze - if (unsqueeze) + if (auto unsqueeze = as_type_ptr(input)) { + auto data_shape = input->input(0).get_partial_shape(); if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(), squeeze->input_value(1).get_node_shared_ptr())) { @@ -383,9 +382,9 @@ static bool eliminate_squeeze(const std::shared_ptr& node) return false; } // eliminate redundant squeeze->squeeze - auto squeeze_i = as_type_ptr(input); - if (squeeze_i) + if (auto squeeze_i = as_type_ptr(input)) { + auto data_shape = input->input(0).get_partial_shape(); if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) { return false; diff --git a/ngraph/test/nop_elimination.cpp b/ngraph/test/nop_elimination.cpp index 8b4aa57..b95c6a1 100644 --- a/ngraph/test/nop_elimination.cpp +++ b/ngraph/test/nop_elimination.cpp @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset1.hpp" #include "ngraph/pass/manager.hpp" #include "ngraph/pass/nop_elimination.hpp" #include "util/all_close.hpp" @@ -728,6 +729,23 @@ TEST(nop_elimination, unsqueeze_reshape_elimination) check_usecase(Shape{1, 6, 2}, std::vector{3}); } +TEST(nop_elimination, squeeze_unsqueeze_elimination_negative) +{ + auto check_usecase = [](const Shape& shape, const std::vector& indices_val) { + auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val); + auto input = make_shared(element::f32, shape); + auto squeeze = make_shared(input, indices); + auto baseline_f = make_shared(squeeze, ParameterVector{input}); + auto optimized_f = clone_function(*baseline_f); + EXPECT_TRUE((compare_pass_int(baseline_f, optimized_f))); + + ASSERT_EQ(count_ops_of_type(baseline_f), 1); + ASSERT_EQ(count_ops_of_type(optimized_f), 1); + }; + + check_usecase(Shape{1, 1, 1}, std::vector{0, 1, 2}); +} + TEST(nop_elimination, topk_convert_elimination) { auto check_usecase = []() { -- 2.7.4