auto squeeze = as_type_ptr<opset3::Squeeze>(node);
auto input = squeeze->input_value(0).get_node_shared_ptr();
- auto data_shape = input->input_value(0).get_partial_shape();
- auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input);
auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
return false;
};
// eliminate redundant unsqueeze->squeeze
- if (unsqueeze)
+ if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
{
+ auto data_shape = input->input(0).get_partial_shape();
if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
squeeze->input_value(1).get_node_shared_ptr()))
{
return false;
}
// eliminate redundant squeeze->squeeze
- auto squeeze_i = as_type_ptr<opset3::Squeeze>(input);
- if (squeeze_i)
+ if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
{
+ auto data_shape = input->input(0).get_partial_shape();
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
+#include "ngraph/opsets/opset1.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "util/all_close.hpp"
check_usecase(Shape{1, 6, 2}, std::vector<int64_t>{3});
}
+TEST(nop_elimination, squeeze_unsqueeze_elimination_negative)
+{
+ auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
+ auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val);
+ auto input = make_shared<op::Parameter>(element::f32, shape);
+ auto squeeze = make_shared<ngraph::opset1::Squeeze>(input, indices);
+ auto baseline_f = make_shared<Function>(squeeze, ParameterVector{input});
+ auto optimized_f = clone_function(*baseline_f);
+ EXPECT_TRUE((compare_pass_int<pass::NopElimination, float>(baseline_f, optimized_f)));
+
+ ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(baseline_f), 1);
+ ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(optimized_f), 1);
+ };
+
+ check_usecase(Shape{1, 1, 1}, std::vector<int64_t>{0, 1, 2});
+}
+
TEST(nop_elimination, topk_convert_elimination)
{
auto check_usecase = []() {