1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include "common_test_utils/test_common.hpp"
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset2.hpp>
15 #include <ngraph/opsets/opset3.hpp>
16 #include <ngraph/pass/manager.hpp>
17 #include <ngraph/pass/constant_folding.hpp>
18 #include <ngraph/builder/autobroadcast.hpp>
19 #include <transformations/common_optimizations/algebraic_simplification.hpp>
20 #include <transformations/utils/utils.hpp>
21 #include <transformations/init_node_info.hpp>
23 #include "common_test_utils/ngraph_test_utils.hpp"
25 using namespace ngraph;
28 TEST(algebraic_simplification, add_negative_tests) {
30 auto type = element::f32;
31 pass::Manager pass_manager;
32 pass_manager.register_pass<pass::AlgebraicSimplification>();
34 auto a = make_shared<op::Parameter>(type, shape);
35 auto b = make_shared<op::Parameter>(type, shape);
36 auto c = make_shared<op::Parameter>(type, shape);
37 auto abs_a = make_shared<op::Abs>(a);
38 auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
39 auto add_a_0 = a + iconst2;
40 auto add_a_0_0 = add_a_0 + iconst2;
41 auto add_b_0 = b + abs_a;
42 auto add_b_0_0 = add_b_0 + abs_a;
44 auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
45 ParameterVector{a, b, c});
46 pass_manager.run_passes(f);
48 auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
49 auto results = f->get_results();
50 for (size_t i = 0; i < results.size(); i++) {
51 ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
55 TEST(algebraic_simplification, multiply_negative_tests) {
57 auto type = element::f32;
58 pass::Manager pass_manager;
59 pass_manager.register_pass<pass::AlgebraicSimplification>();
61 auto a = make_shared<op::Parameter>(type, shape);
62 auto b = make_shared<op::Parameter>(type, shape);
63 auto c = make_shared<op::Parameter>(type, shape);
64 auto abs_a = make_shared<op::Abs>(a);
65 auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
66 auto add_a_0 = a * iconst2;
67 auto add_a_0_0 = add_a_0 * iconst2;
68 auto add_b_0 = b * abs_a;
69 auto add_b_0_0 = add_b_0 * abs_a;
71 auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
72 ParameterVector{a, b, c});
73 pass_manager.run_passes(f);
75 auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
76 auto results = f->get_results();
77 for (size_t i = 0; i < results.size(); i++) {
78 ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
82 TEST(algebraic_simplification, multiply_prod_negative) {
83 auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
84 auto broadcast = builder::opset1::make_broadcast(fconst1, Shape{2, 5}, AxisSet{1});
85 auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
87 pass::Manager pass_manager;
88 pass_manager.register_pass<pass::AlgebraicSimplification>();
90 auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
91 pass_manager.run_passes(f);
92 auto f_prod = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
93 ASSERT_EQ(f_prod, prod_fconst1);
96 TEST(algebraic_simplification, multiply_sum_negative) {
97 auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
98 auto broadcast = builder::opset1::make_broadcast(fconst1, Shape{2, 5}, AxisSet{1});
99 auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
101 pass::Manager pass_manager;
102 pass_manager.register_pass<pass::AlgebraicSimplification>();
104 auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
105 pass_manager.run_passes(f);
106 auto f_sum = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
107 ASSERT_EQ(f_sum, sum_fconst1);
110 TEST(algebraic_simplification, concat_parameter_slices_reversed) {
111 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
112 auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
113 auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
114 auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
116 size_t concat_axis = 0;
117 auto concat = make_shared<op::Concat>(NodeVector{slice3, slice2, slice1}, concat_axis);
119 pass::Manager pass_manager;
120 pass_manager.register_pass<pass::AlgebraicSimplification>();
122 auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
123 pass_manager.run_passes(f);
124 ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
127 TEST(algebraic_simplification, concat_parameter_slices_element_count) {
128 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
129 // slicing 30 elements out of 96; should trigger a check that some elements are missing
130 auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{10, 100}, Strides{1, 1});
131 auto slice2 = make_shared<op::Slice>(a, Coordinate{10, 0}, Coordinate{20, 100}, Strides{1, 1});
132 auto slice3 = make_shared<op::Slice>(a, Coordinate{20, 0}, Coordinate{30, 100}, Strides{1, 1});
134 size_t concat_axis = 0;
135 auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
137 pass::Manager pass_manager;
138 pass_manager.register_pass<pass::AlgebraicSimplification>();
140 auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
141 pass_manager.run_passes(f);
142 ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
145 TEST(algebraic_simplification, concat_parameter_non_uniform_slices) {
146 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
147 auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{38, 100}, Strides{1, 1});
148 auto slice2 = make_shared<op::Slice>(a, Coordinate{38, 0}, Coordinate{64, 100}, Strides{1, 1});
149 auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
151 size_t concat_axis = 0;
152 auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
154 pass::Manager pass_manager;
155 pass_manager.register_pass<pass::AlgebraicSimplification>();
157 auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
158 pass_manager.run_passes(f);
159 ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
162 TEST(algebraic_simplification, concat_different_inputs) {
163 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
167 make_shared<op::Slice>(goe1, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
169 make_shared<op::Slice>(goe2, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
171 make_shared<op::Slice>(goe1, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
173 size_t concat_axis = 0;
174 auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
176 pass::Manager pass_manager;
177 pass_manager.register_pass<pass::AlgebraicSimplification>();
179 auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
180 pass_manager.run_passes(f);
181 ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
184 TEST(algebraic_simplification, log_no_exp) {
185 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
186 auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
187 auto abs_a = make_shared<op::Abs>(a);
188 auto div = abs_a / b;
189 auto log_div = make_shared<op::Log>(div);
191 auto neg_inner = make_shared<op::Negative>(log_div);
192 auto neg2 = make_shared<op::Negative>(neg_inner);
193 auto neg3 = make_shared<op::Negative>(neg2);
194 auto neg4 = make_shared<op::Negative>(neg3);
196 pass::Manager pass_manager;
197 pass_manager.register_pass<pass::AlgebraicSimplification>();
199 auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
200 pass_manager.run_passes(f);
201 ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_div);
204 TEST(algebraic_simplification, log_no_divide) {
205 auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
206 auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
207 auto exp_a = make_shared<op::Exp>(a);
208 auto mul = exp_a * b;
209 auto log_mul = make_shared<op::Log>(mul);
211 auto neg_inner = make_shared<op::Negative>(log_mul);
212 auto neg2 = make_shared<op::Negative>(neg_inner);
213 auto neg3 = make_shared<op::Negative>(neg2);
214 auto neg4 = make_shared<op::Negative>(neg3);
216 pass::Manager pass_manager;
217 pass_manager.register_pass<pass::AlgebraicSimplification>();
219 auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
220 pass_manager.run_passes(f);
221 ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_mul);
224 TEST(algebraic_simplification, pass_property) {
225 auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
227 ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
230 TEST(algebraic_simplification, replace_transpose_with_reshape) {
231 auto check_usecase = [](const PartialShape& shape,
232 const std::vector<int64_t>& perm_val,
236 static size_t id = 0;
237 auto casename = string("usecase #") + to_string(++id);
239 shared_ptr<Node> perm;
241 std::vector<int32_t> perm_val_i32(perm_val.begin(), perm_val.end());
243 op::Constant::create<int32_t>(element::i32, Shape{perm_val.size()}, perm_val_i32);
245 perm = op::Constant::create<int64_t>(element::i64, Shape{perm_val.size()}, perm_val);
247 auto param = make_shared<op::Parameter>(element::f32, shape);
250 auto last_dim = shape.rank().get_length() - 1;
251 A1 = make_shared<op::v0::TopK>(param, last_dim, element::i32);
253 A1 = make_shared<op::v0::Abs>(param);
255 auto transpose = make_shared<op::v1::Transpose>((multiout ? A1->output(0) : A1), perm);
256 auto transpose1 = make_shared<op::v0::Abs>(transpose);
257 auto baseline_f = make_shared<Function>(transpose1, ParameterVector{param});
258 auto optimized_f = clone_function(*baseline_f);
260 pass::Manager pass_manager;
261 pass_manager.register_pass<pass::Validate>();
262 pass_manager.register_pass<pass::AlgebraicSimplification>();
263 pass_manager.register_pass<pass::ConstantFolding>();
264 pass_manager.run_passes(optimized_f);
266 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
267 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
268 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
269 ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
271 ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(baseline_f), 1);
272 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 0);
273 ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(optimized_f), num);
274 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), (num ? 0 : 1));
277 for (auto& i32 : {true, false})
278 for (auto& multiout : {true, false}) {
279 check_usecase(Shape{1, 3}, vector<int64_t>{1, 0}, i32, multiout, 0);
280 check_usecase(Shape{2, 3, 1}, vector<int64_t>{2, 0, 1}, i32, multiout, 0);
281 check_usecase(Shape{10, 20, 1, 1}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 0);
282 check_usecase(Shape{10, 1, 1, 20}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 0);
283 check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 1, 3}, i32, multiout, 0);
284 check_usecase(Shape{10, 1, 1, 1, 20}, vector<int64_t>{0, 4, 1, 2, 3}, i32, multiout, 0);
285 check_usecase(Shape{10, 20, 1, 1, 1}, vector<int64_t>{0, 2, 3, 4, 1}, i32, multiout, 0);
286 check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{1, 4, 2, 3, 0}, i32, multiout, 0);
287 check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{4, 2, 0, 1, 3}, i32, multiout, 0);
288 check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 1);
289 check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 1);
290 check_usecase(Shape{10, 20}, vector<int64_t>{1, 0}, i32, multiout, 1);
292 check_usecase(PartialShape{Dimension::dynamic(), 20, 1, 1},
299 check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
300 vector<int64_t>{0, 1, 3, 2, 4},
304 check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
305 vector<int64_t>{0, 2, 1, 4, 3},
312 // the following gather test will be used to test when
313 // gather is Nop and will be removed during `simplify_gather`
314 // algebraic_simplification pass
316 TEST(algebraic_simplification, gather_3d_indices_constant_axis_1) {
317 auto check_usecase = [](const PartialShape& pshape,
320 const std::vector<int64_t>& indices_val,
323 static size_t id = 0;
324 auto casename = string("usecase #") + to_string(++id);
326 shared_ptr<Node> indices;
327 shared_ptr<Node> axis;
329 std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
330 indices = op::Constant::create<int32_t>(
331 element::i32, Shape{indices_val.size()}, indices_val_i32);
332 axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
335 op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
336 axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
339 auto A = make_shared<op::Parameter>(element::f32, pshape);
342 auto last_dim = pshape.rank().get_length() - 1;
343 A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
345 A1 = make_shared<op::v0::Abs>(A);
347 auto G = make_shared<op::v1::Gather>((multiout ? A1->output(0) : A1), indices, axis);
349 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(G), ParameterVector{A});
350 auto optimized_f = clone_function(*baseline_f);
352 pass::Manager pass_manager;
353 pass_manager.register_pass<pass::Validate>();
354 pass_manager.register_pass<pass::AlgebraicSimplification>();
355 pass_manager.run_passes(optimized_f);
357 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
358 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
359 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
360 ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
362 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
363 // the pass should short cut the Gather i/p with the gather users
364 // since we are fetching the whole tensor using gather op
365 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), num) << casename;
367 for (auto& i32 : {true, false})
368 for (auto& multiout : {true, false}) {
369 check_usecase(PartialShape{1, 3, 2}, i32, multiout, std::vector<int64_t>{1}, 0, 0);
370 check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{0, 1}, 1, 0);
371 check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{1}, 2, 0);
372 check_usecase(PartialShape{1, 16}, i32, multiout, std::vector<int64_t>{0, 0}, 0, 1);
376 TEST(algebraic_simplification, gather_shapeof) {
377 auto check_usecase = [](const PartialShape& pshape,
378 bool is_scalar_index,
383 const std::vector<int64_t>& indices_val,
385 static size_t id = 0;
386 auto casename = string("usecase #") + to_string(++id);
388 shared_ptr<Node> indices;
389 shared_ptr<Node> axis;
391 std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
392 indices = is_scalar_index
393 ? op::Constant::create<int32_t>(element::i32, Shape{}, indices_val_i32)
394 : op::Constant::create<int32_t>(
395 element::i32, Shape{indices_val.size()}, indices_val_i32);
396 axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
398 indices = is_scalar_index
399 ? op::Constant::create<int64_t>(element::i64, Shape{}, indices_val)
400 : op::Constant::create<int64_t>(
401 element::i64, Shape{indices_val.size()}, indices_val);
402 axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
405 auto dims_1 = std::vector<Dimension>(pshape);
406 dims_1.push_back(11);
407 dims_1.push_back(13);
408 auto pshape_1 = PartialShape(dims_1);
409 auto A = make_shared<op::Parameter>(element::f32, pshape);
410 auto AA = make_shared<op::Parameter>(element::f64, pshape_1);
413 A1 = make_shared<TestOpMultiOut>(A, AA);
415 A1 = make_shared<op::v0::Abs>(A);
417 auto B = make_shared<op::v1::Gather>(
418 (multiout ? (multiout_1 ? A1->output(1) : A1->output(0)) : A1), indices, axis);
421 B1 = make_shared<op::v0::ShapeOf>(B);
423 B1 = make_shared<op::v3::ShapeOf>(B);
425 auto baseline_f = make_shared<Function>(
426 make_shared<op::v0::Abs>(B1), (multiout ? ParameterVector{A, AA} : ParameterVector{A}));
427 auto optimized_f = clone_function(*baseline_f);
429 pass::Manager pass_manager;
430 pass_manager.register_pass<pass::Validate>();
431 pass_manager.register_pass<pass::AlgebraicSimplification>();
432 pass_manager.run_passes(optimized_f);
434 ASSERT_EQ(baseline_f->get_results().at(0)->get_element_type(),
435 optimized_f->get_results().at(0)->get_element_type());
437 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
438 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
439 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
440 EXPECT_TRUE(ps.same_scheme(ps_r)) << casename;
442 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
444 auto last_node = optimized_f->get_results()[0]->input_value(0).get_node_shared_ptr();
445 if (is_scalar_index) {
446 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(optimized_f), 1) << casename;
447 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), 1) << casename;
449 as_type_ptr<op::v1::Gather>(last_node->input_value(0).get_node_shared_ptr()))
452 ASSERT_EQ(count_ops_of_type<op::v0::Concat>(optimized_f), 1) << casename;
454 as_type_ptr<op::v0::Concat>(last_node->input_value(0).get_node_shared_ptr()))
459 for (auto& opset2 : {true, false})
460 for (auto& i32 : {true, false})
461 for (auto& multiout : {true, false})
462 for (auto& multiout_1 : {true, false}) {
463 check_usecase(PartialShape{2, 3, 2, 1},
469 std::vector<int64_t>{0},
471 check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
477 std::vector<int64_t>{0},
480 for (auto& opset2 : {true, false})
481 for (auto& i32 : {true, false})
482 for (auto& multiout : {true, false})
483 for (auto& multiout_1 : {true, false}) {
484 check_usecase(PartialShape{2, 3, 2, 1},
490 std::vector<int64_t>{0, 2},
492 check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
498 std::vector<int64_t>{0, 2},