1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/manager.hpp"
22 #include "util/all_close_f.hpp"
23 #include "util/test_tools.hpp"
25 NGRAPH_SUPPRESS_DEPRECATED_START
27 using namespace ngraph;
31 static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
34 as_type_ptr<op::Constant>(f->get_results().at(pos)->input_value(0).get_node_shared_ptr());
35 return new_const->cast_vector<T>();
38 void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
40 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
43 void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
45 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
49 typename std::enable_if<std::is_integral<T>::value>::type
50 range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
52 ASSERT_EQ(values_out, values_expected);
55 TEST(constant_folding, acosh)
57 Shape shape_in{2, 4, 1};
59 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
60 vector<float> expected;
61 for (float f : values_in)
63 expected.push_back(std::acosh(f));
65 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
66 auto acosh = make_shared<op::Acosh>(constant);
67 acosh->set_friendly_name("test");
68 auto f = make_shared<Function>(acosh, ParameterVector{});
70 pass::Manager pass_manager;
71 pass_manager.register_pass<pass::ConstantFolding>();
72 pass_manager.run_passes(f);
74 EXPECT_EQ(count_ops_of_type<op::Acosh>(f), 0);
75 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
76 ASSERT_EQ(f->get_results().size(), 1);
79 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
80 EXPECT_TRUE(new_const);
81 ASSERT_EQ(new_const->get_friendly_name(), "test");
83 auto values_out = new_const->get_vector<float>();
84 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
87 TEST(constant_folding, asinh)
89 Shape shape_in{2, 4, 1};
91 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
92 vector<float> expected;
93 for (float f : values_in)
95 expected.push_back(std::asinh(f));
97 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
98 auto asinh = make_shared<op::Asinh>(constant);
99 asinh->set_friendly_name("test");
100 auto f = make_shared<Function>(asinh, ParameterVector{});
102 pass::Manager pass_manager;
103 pass_manager.register_pass<pass::ConstantFolding>();
104 pass_manager.run_passes(f);
106 EXPECT_EQ(count_ops_of_type<op::Asinh>(f), 0);
107 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
108 ASSERT_EQ(f->get_results().size(), 1);
111 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
112 EXPECT_TRUE(new_const);
113 ASSERT_EQ(new_const->get_friendly_name(), "test");
115 auto values_out = new_const->get_vector<float>();
116 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
119 TEST(constant_folding, atanh)
121 Shape shape_in{2, 4, 1};
123 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
124 vector<float> expected;
125 for (float f : values_in)
127 expected.push_back(std::atanh(f));
129 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
130 auto atanh = make_shared<op::Atanh>(constant);
131 atanh->set_friendly_name("test");
132 auto f = make_shared<Function>(atanh, ParameterVector{});
134 pass::Manager pass_manager;
135 pass_manager.register_pass<pass::ConstantFolding>();
136 pass_manager.run_passes(f);
138 EXPECT_EQ(count_ops_of_type<op::Atanh>(f), 0);
139 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
140 ASSERT_EQ(f->get_results().size(), 1);
143 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
144 EXPECT_TRUE(new_const);
145 ASSERT_EQ(new_const->get_friendly_name(), "test");
147 auto values_out = new_const->get_vector<float>();
148 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
151 TEST(constant_folding, constant_squeeze)
153 Shape shape_in{2, 4, 1};
154 Shape shape_out{2, 4};
157 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
158 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
159 vector<int64_t> values_axes{2};
160 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
161 auto squeeze = make_shared<op::Squeeze>(constant, constant_axes);
162 squeeze->set_friendly_name("test");
163 auto f = make_shared<Function>(squeeze, ParameterVector{});
165 pass::Manager pass_manager;
166 pass_manager.register_pass<pass::ConstantFolding>();
167 pass_manager.run_passes(f);
169 ASSERT_EQ(count_ops_of_type<op::Squeeze>(f), 0);
170 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
173 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
174 ASSERT_TRUE(new_const);
175 ASSERT_EQ(new_const->get_friendly_name(), "test");
176 ASSERT_EQ(new_const->get_shape(), shape_out);
178 auto values_out = new_const->get_vector<float>();
179 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
182 TEST(constant_folding, constant_unsqueeze)
184 Shape shape_in{2, 4};
185 Shape shape_out{2, 4, 1, 1};
188 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
189 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
190 vector<int64_t> values_axes{2, 3};
191 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
192 auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
193 unsqueeze->set_friendly_name("test");
194 auto f = make_shared<Function>(unsqueeze, ParameterVector{});
196 pass::Manager pass_manager;
197 pass_manager.register_pass<pass::ConstantFolding>();
198 pass_manager.run_passes(f);
200 ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
201 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
204 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
205 ASSERT_TRUE(new_const);
206 ASSERT_EQ(new_const->get_friendly_name(), "test");
207 ASSERT_EQ(new_const->get_shape(), shape_out);
209 auto values_out = new_const->get_vector<float>();
210 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
213 TEST(constant_folding, constant_reshape)
215 Shape shape_in{2, 4};
216 Shape shape_out{2, 4, 1};
218 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
219 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
220 auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
221 reshape->set_friendly_name("test");
222 auto f = make_shared<Function>(reshape, ParameterVector{});
224 pass::Manager pass_manager;
225 pass_manager.register_pass<pass::ConstantFolding>();
226 pass_manager.run_passes(f);
228 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
229 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
232 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
233 ASSERT_TRUE(new_const);
234 ASSERT_EQ(new_const->get_friendly_name(), "test");
235 auto values_out = new_const->get_vector<float>();
237 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
240 TEST(constant_folding, DISABLED_constant_reshape_permute)
242 Shape shape_in{2, 4};
243 Shape shape_out{4, 2};
245 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
246 auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
247 auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
248 reshape->set_friendly_name("test");
249 auto f = make_shared<Function>(reshape, ParameterVector{});
251 pass::Manager pass_manager;
252 pass_manager.register_pass<pass::ConstantFolding>();
253 pass_manager.run_passes(f);
255 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
256 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
259 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
260 ASSERT_TRUE(new_const);
261 ASSERT_EQ(new_const->get_friendly_name(), "test");
262 auto values_out = new_const->get_vector<double>();
264 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
265 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
268 TEST(constant_folding, constant_broadcast_v1)
270 vector<int32_t> values_in{0, 1};
271 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
272 vector<int64_t> shape_in{2, 4};
273 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
274 vector<int64_t> axes_in{0};
275 auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
276 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape, constant_axes);
277 broadcast_v1->set_friendly_name("test");
278 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
280 pass::Manager pass_manager;
281 pass_manager.register_pass<pass::ConstantFolding>();
282 pass_manager.run_passes(f);
284 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
285 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
288 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
289 ASSERT_TRUE(new_const);
290 ASSERT_EQ(new_const->get_friendly_name(), "test");
291 auto values_out = new_const->get_vector<int32_t>();
293 vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
294 ASSERT_EQ(values_expected, values_out);
297 TEST(constant_folding, constant_broadcast_v1_with_target_shape)
299 vector<int32_t> values_in{1};
300 auto constant_in = make_shared<op::Constant>(element::i32, Shape{1, 1, 1, 1}, values_in);
301 vector<int64_t> shape_in{1, 3, 1, 1};
302 auto target_shape = make_shared<op::Constant>(element::i64, Shape{4}, shape_in);
303 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, target_shape);
304 broadcast_v1->set_friendly_name("test");
305 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
307 pass::Manager pass_manager;
308 pass_manager.register_pass<pass::ConstantFolding>();
309 pass_manager.run_passes(f);
311 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
312 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
315 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
316 ASSERT_TRUE(new_const);
317 ASSERT_EQ(new_const->get_friendly_name(), "test");
318 auto values_out = new_const->get_vector<int32_t>();
320 vector<int32_t> values_expected{1, 1, 1};
321 ASSERT_EQ(values_expected, values_out);
324 TEST(constant_folding, constant_broadcast_v1_numpy)
326 vector<int32_t> values_in{0, 1};
327 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
328 vector<int64_t> shape_in{4, 2};
329 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
330 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape);
331 broadcast_v1->set_friendly_name("test");
332 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
334 pass::Manager pass_manager;
335 pass_manager.register_pass<pass::ConstantFolding>();
336 pass_manager.run_passes(f);
338 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
339 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
342 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
343 ASSERT_TRUE(new_const);
344 ASSERT_EQ(new_const->get_friendly_name(), "test");
345 auto values_out = new_const->get_vector<int32_t>();
347 vector<int32_t> values_expected{0, 1, 0, 1, 0, 1, 0, 1};
348 ASSERT_EQ(values_expected, values_out);
351 TEST(constant_folding, constant_unary_binary)
353 vector<int> values_a{1, 2, 3, 4};
354 vector<int> values_b{1, 2, 3, 4};
355 vector<int> values_c{-1, -1, -1, -1};
356 vector<int> values_d{1, 4, 9, 16};
357 vector<int> values_e{5, 6};
358 vector<int> values_f{0, 10};
359 vector<int> values_g{1, 4};
360 vector<char> values_h{0, 0, 1, 1};
361 vector<char> values_i{0, 1};
362 auto a = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_a);
363 auto b = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_b);
364 auto c = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_c);
365 auto d = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_d);
366 auto e = make_shared<op::Constant>(element::i32, Shape{2}, values_e);
367 auto f = make_shared<op::Constant>(element::i32, Shape{2}, values_f);
368 auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
369 auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
370 auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
376 auto pow = make_shared<op::Power>(a, b);
377 auto min = make_shared<op::Minimum>(c, a);
378 auto max = make_shared<op::Maximum>(a, c);
379 auto absn = make_shared<op::Abs>(c);
380 auto neg = make_shared<op::Negative>(c);
381 auto sqrt = make_shared<op::Sqrt>(d);
382 auto add_autob_numpy = make_shared<op::Add>(a, e, op::AutoBroadcastType::NUMPY);
383 auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
384 auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
385 auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
386 auto pow_autob_numpy = make_shared<op::Power>(a, g, op::AutoBroadcastType::NUMPY);
387 auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
388 auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
389 auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
390 auto not_equal_autob_numpy = make_shared<op::NotEqual>(a, g, op::AutoBroadcastType::NUMPY);
391 auto greater_autob_numpy = make_shared<op::Greater>(a, g, op::AutoBroadcastType::NUMPY);
392 auto greater_eq_autob_numpy = make_shared<op::GreaterEq>(a, g, op::AutoBroadcastType::NUMPY);
393 auto less_autob_numpy = make_shared<op::Less>(a, g, op::AutoBroadcastType::NUMPY);
394 auto less_eq_autob_numpy = make_shared<op::LessEq>(a, g, op::AutoBroadcastType::NUMPY);
395 auto logical_or_autob_numpy = make_shared<op::Or>(h, i, op::AutoBroadcastType::NUMPY);
396 auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
398 auto neg_sqrt = make_shared<op::Sqrt>(c);
400 auto func = make_shared<Function>(NodeVector{add,
418 not_equal_autob_numpy,
420 greater_eq_autob_numpy,
423 logical_or_autob_numpy,
424 logical_xor_autob_numpy},
426 auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
428 pass::Manager pass_manager;
429 pass_manager.register_pass<pass::ConstantFolding>();
430 pass_manager.run_passes(func);
433 vector<int> add_expected{2, 4, 6, 8};
434 vector<int> sub_expected{0, 0, 0, 0};
435 vector<int> mul_expected{1, 4, 9, 16};
436 vector<int> div_expected{1, 1, 1, 1};
437 vector<int> pow_expected{1, 4, 27, 256};
438 vector<int> min_expected{-1, -1, -1, -1};
439 vector<int> max_expected{1, 2, 3, 4};
440 vector<int> abs_neg_expected{1, 1, 1, 1};
441 vector<int> sqrt_expected{1, 2, 3, 4};
442 vector<int> add_autob_numpy_expected{6, 8, 8, 10};
443 vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
444 vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
445 vector<int> div_autob_numpy_expected{1, 0, 3, 1};
446 vector<int> pow_autob_numpy_expected{1, 16, 3, 256};
447 vector<int> min_autob_numpy_expected{0, 2, 0, 4};
448 vector<int> max_autob_numpy_expected{1, 10, 3, 10};
449 vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
450 vector<char> not_equal_autob_numpy_expected{0, 1, 1, 0};
451 vector<char> greater_autob_numpy_expected{0, 0, 1, 0};
452 vector<char> greater_eq_autob_numpy_expected{1, 0, 1, 1};
453 vector<char> less_autob_numpy_expected{0, 1, 0, 0};
454 vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
455 vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
456 vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
458 ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
459 ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
460 ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
461 ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
462 ASSERT_EQ(get_result_constant<int>(func, 4), pow_expected);
463 ASSERT_EQ(get_result_constant<int>(func, 5), min_expected);
464 ASSERT_EQ(get_result_constant<int>(func, 6), max_expected);
465 ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
466 ASSERT_EQ(get_result_constant<int>(func, 8), abs_neg_expected);
467 ASSERT_EQ(get_result_constant<int>(func, 9), sqrt_expected);
468 ASSERT_EQ(get_result_constant<int>(func, 10), add_autob_numpy_expected);
469 ASSERT_EQ(get_result_constant<int>(func, 11), sub_autob_numpy_expected);
470 ASSERT_EQ(get_result_constant<int>(func, 12), mul_autob_numpy_expected);
471 ASSERT_EQ(get_result_constant<int>(func, 13), div_autob_numpy_expected);
472 ASSERT_EQ(get_result_constant<int>(func, 14), pow_autob_numpy_expected);
473 ASSERT_EQ(get_result_constant<int>(func, 15), min_autob_numpy_expected);
474 ASSERT_EQ(get_result_constant<int>(func, 16), max_autob_numpy_expected);
475 ASSERT_EQ(get_result_constant<char>(func, 17), equal_autob_numpy_expected);
476 ASSERT_EQ(get_result_constant<char>(func, 18), not_equal_autob_numpy_expected);
477 ASSERT_EQ(get_result_constant<char>(func, 19), greater_autob_numpy_expected);
478 ASSERT_EQ(get_result_constant<char>(func, 20), greater_eq_autob_numpy_expected);
479 ASSERT_EQ(get_result_constant<char>(func, 21), less_autob_numpy_expected);
480 ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
481 ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
482 ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
483 ASSERT_NO_THROW(pass_manager.run_passes(func_error));
486 TEST(constant_folding, const_quantize)
488 Shape input_shape{12};
489 Shape scale_offset_shape;
490 AxisSet quantization_axes;
492 auto quant_type = element::u8;
493 auto output_type = element::u8;
494 typedef uint8_t output_c_type;
496 vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
497 auto constant = op::Constant::create(element::f32, input_shape, values_in);
498 auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
499 auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
500 auto mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
502 make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
503 quantize->set_friendly_name("test");
504 auto f = make_shared<Function>(quantize, ParameterVector{});
506 pass::Manager pass_manager;
507 pass_manager.register_pass<pass::ConstantFolding>();
508 pass_manager.run_passes(f);
510 ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
511 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
514 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
515 ASSERT_TRUE(new_const);
516 ASSERT_EQ(new_const->get_friendly_name(), "test");
517 auto values_out = new_const->get_vector<output_c_type>();
519 vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
520 ASSERT_EQ(values_quantize, values_out);
523 TEST(constant_folding, const_convert)
525 Shape input_shape{3, 4};
527 vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
528 auto constant = op::Constant::create(element::f32, input_shape, values_in);
529 auto convert = make_shared<op::Convert>(constant, element::u64);
530 convert->set_friendly_name("test");
531 auto f = make_shared<Function>(convert, ParameterVector{});
533 pass::Manager pass_manager;
534 pass_manager.register_pass<pass::ConstantFolding>();
535 pass_manager.run_passes(f);
537 ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
538 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
541 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
542 ASSERT_TRUE(new_const);
543 ASSERT_EQ(new_const->get_friendly_name(), "test");
544 ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
545 auto values_out = new_const->get_vector<uint64_t>();
547 vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
548 ASSERT_EQ(values_expected, values_out);
551 TEST(constant_folding, shape_of_v0)
553 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
555 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
556 auto shape_of = make_shared<op::v0::ShapeOf>(param);
557 shape_of->set_friendly_name("test");
558 auto f = make_shared<Function>(shape_of, ParameterVector{param});
560 pass::Manager pass_manager;
561 pass_manager.register_pass<pass::ConstantFolding>();
562 pass_manager.run_passes(f);
564 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 0);
565 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
568 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
569 ASSERT_TRUE(new_const);
570 ASSERT_EQ(new_const->get_friendly_name(), "test");
571 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
572 auto values_out = new_const->get_vector<int64_t>();
574 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
577 TEST(constant_folding, shape_of_v3)
579 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
581 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
582 auto shape_of = make_shared<op::v3::ShapeOf>(param);
583 shape_of->set_friendly_name("test");
584 auto f = make_shared<Function>(shape_of, ParameterVector{param});
586 pass::Manager pass_manager;
587 pass_manager.register_pass<pass::ConstantFolding>();
588 pass_manager.run_passes(f);
590 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
591 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
594 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
595 ASSERT_TRUE(new_const);
596 ASSERT_EQ(new_const->get_friendly_name(), "test");
597 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
598 auto values_out = new_const->get_vector<int64_t>();
600 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
603 TEST(constant_folding, shape_of_i32_v3)
605 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
607 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
608 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
609 shape_of->set_friendly_name("test");
610 auto f = make_shared<Function>(shape_of, ParameterVector{param});
612 pass::Manager pass_manager;
613 pass_manager.register_pass<pass::ConstantFolding>();
614 pass_manager.run_passes(f);
616 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
617 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
620 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
621 ASSERT_TRUE(new_const);
622 ASSERT_EQ(new_const->get_friendly_name(), "test");
623 ASSERT_EQ(new_const->get_output_element_type(0), element::i32);
624 auto values_out = new_const->get_vector<int32_t>();
626 ASSERT_EQ((vector<int32_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
629 TEST(constant_folding, shape_of_dynamic_v0)
631 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
633 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
634 auto shape_of = make_shared<op::v0::ShapeOf>(param);
635 shape_of->set_friendly_name("test");
636 auto f = make_shared<Function>(shape_of, ParameterVector{param});
638 pass::Manager pass_manager;
639 pass_manager.register_pass<pass::ConstantFolding>();
640 pass_manager.run_passes(f);
642 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
643 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
644 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
645 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
647 auto result_as_concat =
648 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
649 ASSERT_TRUE(result_as_concat);
650 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
651 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
654 TEST(constant_folding, shape_of_dynamic_v3)
656 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
658 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
659 auto shape_of = make_shared<op::v3::ShapeOf>(param);
660 shape_of->set_friendly_name("test");
661 auto f = make_shared<Function>(shape_of, ParameterVector{param});
663 pass::Manager pass_manager;
664 pass_manager.register_pass<pass::ConstantFolding>();
665 pass_manager.run_passes(f);
667 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
668 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
669 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
670 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
672 auto result_as_concat =
673 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
674 ASSERT_TRUE(result_as_concat);
675 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
676 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
677 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
680 TEST(constant_folding, shape_of_dynamic_i32_v3)
682 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
684 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
685 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
686 shape_of->set_friendly_name("test");
687 auto f = make_shared<Function>(shape_of, ParameterVector{param});
689 pass::Manager pass_manager;
690 pass_manager.register_pass<pass::ConstantFolding>();
691 pass_manager.run_passes(f);
693 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
694 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
695 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
696 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
698 auto result_as_concat =
699 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
700 ASSERT_TRUE(result_as_concat);
701 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
702 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
703 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
706 // We need to be sure that constant folding won't be calculated endlessly.
707 TEST(constant_folding, shape_of_dynamic_double_folding_v0)
709 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
711 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
712 auto shape_of = make_shared<op::v0::ShapeOf>(param);
713 shape_of->set_friendly_name("test");
714 auto f = make_shared<Function>(shape_of, ParameterVector{param});
716 pass::Manager pass_manager;
717 pass_manager.register_pass<pass::ConstantFolding>();
718 pass_manager.run_passes(f);
719 pass_manager.run_passes(f);
721 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
722 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
723 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
724 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
726 auto result_as_concat =
727 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
728 ASSERT_TRUE(result_as_concat);
729 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
730 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
733 TEST(constant_folding, shape_of_dynamic_double_folding_v3)
735 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
737 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
738 auto shape_of = make_shared<op::v3::ShapeOf>(param);
739 shape_of->set_friendly_name("test");
740 auto f = make_shared<Function>(shape_of, ParameterVector{param});
742 pass::Manager pass_manager;
743 pass_manager.register_pass<pass::ConstantFolding>();
744 pass_manager.run_passes(f);
745 pass_manager.run_passes(f);
747 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
748 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
749 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
750 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
752 auto result_as_concat =
753 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
754 ASSERT_TRUE(result_as_concat);
755 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
756 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
759 // Constant folding will not succeed on ShapeOf if the argument rank is dynamic.
760 // We want to make sure it fails gracefully, leaving the ShapeOf op in place.
761 TEST(constant_folding, shape_of_rank_dynamic_v0)
763 PartialShape input_shape{PartialShape::dynamic()};
765 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
766 auto shape_of = make_shared<op::v0::ShapeOf>(param);
767 shape_of->set_friendly_name("test");
768 auto f = make_shared<Function>(shape_of, ParameterVector{param});
770 pass::Manager pass_manager;
771 pass_manager.register_pass<pass::ConstantFolding>();
772 pass_manager.run_passes(f);
774 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
775 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
777 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
778 ASSERT_EQ(result_shape_of, shape_of);
779 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
782 TEST(constant_folding, shape_of_rank_dynamic_v3)
784 PartialShape input_shape{PartialShape::dynamic()};
786 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
787 auto shape_of = make_shared<op::v3::ShapeOf>(param);
788 shape_of->set_friendly_name("test");
789 auto f = make_shared<Function>(shape_of, ParameterVector{param});
791 pass::Manager pass_manager;
792 pass_manager.register_pass<pass::ConstantFolding>();
793 pass_manager.run_passes(f);
795 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
796 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
798 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
799 ASSERT_EQ(result_shape_of, shape_of);
800 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
803 TEST(constant_folding, const_reverse)
805 Shape input_shape{3, 3};
807 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
808 auto constant = op::Constant::create(element::i32, input_shape, values_in);
809 auto convert = make_shared<op::Reverse>(constant, AxisSet{1});
810 convert->set_friendly_name("test");
811 auto f = make_shared<Function>(convert, ParameterVector{});
813 pass::Manager pass_manager;
814 pass_manager.register_pass<pass::ConstantFolding>();
815 pass_manager.run_passes(f);
817 ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 0);
818 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
821 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
822 ASSERT_TRUE(new_const);
823 ASSERT_EQ(new_const->get_friendly_name(), "test");
824 auto values_out = new_const->get_vector<int32_t>();
826 vector<int32_t> values_expected{3, 2, 1, 6, 5, 4, 9, 8, 7};
827 ASSERT_EQ(values_expected, values_out);
830 TEST(constant_folding, const_product)
832 Shape input_shape{3, 3};
834 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
835 auto constant = op::Constant::create(element::i32, input_shape, values_in);
836 auto convert = make_shared<op::Product>(constant, AxisSet{1});
837 convert->set_friendly_name("test");
838 auto f = make_shared<Function>(convert, ParameterVector{});
840 pass::Manager pass_manager;
841 pass_manager.register_pass<pass::ConstantFolding>();
842 pass_manager.run_passes(f);
844 ASSERT_EQ(count_ops_of_type<op::Product>(f), 0);
845 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
848 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
849 ASSERT_TRUE(new_const);
850 ASSERT_EQ(new_const->get_friendly_name(), "test");
851 auto values_out = new_const->get_vector<int32_t>();
853 vector<int32_t> values_expected{6, 120, 504};
854 ASSERT_EQ(values_expected, values_out);
857 TEST(constant_folding, const_reduceprod)
859 Shape input_shape{3, 3};
860 Shape output_shape{3};
862 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
863 auto constant = op::Constant::create(element::i32, input_shape, values_in);
865 vector<int32_t> values_axes{1};
866 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
867 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes);
868 convert->set_friendly_name("test");
869 auto f = make_shared<Function>(convert, ParameterVector{});
871 pass::Manager pass_manager;
872 pass_manager.register_pass<pass::ConstantFolding>();
873 pass_manager.run_passes(f);
875 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
876 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
879 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
880 ASSERT_TRUE(new_const);
881 ASSERT_EQ(new_const->get_friendly_name(), "test");
882 ASSERT_EQ(new_const->get_shape(), output_shape);
884 auto values_out = new_const->get_vector<int32_t>();
886 vector<int32_t> values_expected{6, 120, 504};
888 ASSERT_EQ(values_expected, values_out);
891 TEST(constant_folding, const_reduceprod_keepdims)
893 Shape input_shape{3, 3};
894 Shape output_shape{3, 1};
896 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
897 auto constant = op::Constant::create(element::i32, input_shape, values_in);
899 vector<int32_t> values_axes{1};
900 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
901 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes, true);
902 convert->set_friendly_name("test");
903 auto f = make_shared<Function>(convert, ParameterVector{});
905 pass::Manager pass_manager;
906 pass_manager.register_pass<pass::ConstantFolding>();
907 pass_manager.run_passes(f);
909 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
910 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
913 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
914 ASSERT_TRUE(new_const);
915 ASSERT_EQ(new_const->get_friendly_name(), "test");
916 ASSERT_EQ(new_const->get_shape(), output_shape);
918 auto values_out = new_const->get_vector<int32_t>();
920 vector<int32_t> values_expected{6, 120, 504};
922 ASSERT_EQ(values_expected, values_out);
925 TEST(constant_folding, const_sum)
927 Shape input_shape{3, 3};
929 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
930 auto constant = op::Constant::create(element::i32, input_shape, values_in);
931 auto convert = make_shared<op::Sum>(constant, AxisSet{1});
932 convert->set_friendly_name("test");
933 auto f = make_shared<Function>(convert, ParameterVector{});
935 pass::Manager pass_manager;
936 pass_manager.register_pass<pass::ConstantFolding>();
937 pass_manager.run_passes(f);
939 ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
940 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
943 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
944 ASSERT_TRUE(new_const);
945 ASSERT_EQ(new_const->get_friendly_name(), "test");
946 auto values_out = new_const->get_vector<int32_t>();
948 vector<int32_t> values_expected{6, 15, 24};
950 ASSERT_EQ(values_expected, values_out);
953 TEST(constant_folding, const_reducesum)
955 Shape input_shape{3, 3};
956 Shape output_shape{3};
958 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
959 auto constant = op::Constant::create(element::i32, input_shape, values_in);
961 vector<int32_t> values_axes{1};
962 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
963 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes);
964 convert->set_friendly_name("test");
965 auto f = make_shared<Function>(convert, ParameterVector{});
967 pass::Manager pass_manager;
968 pass_manager.register_pass<pass::ConstantFolding>();
969 pass_manager.run_passes(f);
971 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
972 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
975 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
976 ASSERT_TRUE(new_const);
977 ASSERT_EQ(new_const->get_friendly_name(), "test");
978 ASSERT_EQ(new_const->get_shape(), output_shape);
980 auto values_out = new_const->get_vector<int32_t>();
982 vector<int32_t> values_expected{6, 15, 24};
984 ASSERT_EQ(values_expected, values_out);
987 TEST(constant_folding, const_reducesum_keepdims)
989 Shape input_shape{3, 3};
990 Shape output_shape{3, 1};
992 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
993 auto constant = op::Constant::create(element::i32, input_shape, values_in);
995 vector<int32_t> values_axes{1};
996 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
997 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes, true);
998 convert->set_friendly_name("test");
999 auto f = make_shared<Function>(convert, ParameterVector{});
1001 pass::Manager pass_manager;
1002 pass_manager.register_pass<pass::ConstantFolding>();
1003 pass_manager.run_passes(f);
1005 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
1006 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1009 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1010 ASSERT_TRUE(new_const);
1011 ASSERT_EQ(new_const->get_friendly_name(), "test");
1012 ASSERT_EQ(new_const->get_shape(), output_shape);
1014 auto values_out = new_const->get_vector<int32_t>();
1016 vector<int32_t> values_expected{6, 15, 24};
1018 ASSERT_EQ(values_expected, values_out);
1021 TEST(constant_folding, const_max)
1023 Shape input_shape{3, 3};
1025 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1026 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1027 auto convert = make_shared<op::Max>(constant, AxisSet{1});
1028 convert->set_friendly_name("test");
1029 auto f = make_shared<Function>(convert, ParameterVector{});
1031 pass::Manager pass_manager;
1032 pass_manager.register_pass<pass::ConstantFolding>();
1033 pass_manager.run_passes(f);
1035 ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
1036 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1039 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1040 ASSERT_TRUE(new_const);
1041 ASSERT_EQ(new_const->get_friendly_name(), "test");
1042 auto values_out = new_const->get_vector<int32_t>();
1044 vector<int32_t> values_expected{3, 6, 9};
1046 ASSERT_EQ(values_expected, values_out);
1049 TEST(constant_folding, const_reducemax)
1051 Shape input_shape{3, 2};
1052 Shape output_shape{3};
1054 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1055 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1056 Shape axes_shape{1};
1057 vector<int32_t> values_axes{1};
1058 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1059 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes);
1060 convert->set_friendly_name("test");
1061 auto f = make_shared<Function>(convert, ParameterVector{});
1063 pass::Manager pass_manager;
1064 pass_manager.register_pass<pass::ConstantFolding>();
1065 pass_manager.run_passes(f);
1067 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1068 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1071 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1072 ASSERT_TRUE(new_const);
1073 ASSERT_EQ(new_const->get_friendly_name(), "test");
1074 ASSERT_EQ(new_const->get_shape(), output_shape);
1076 auto values_out = new_const->get_vector<int32_t>();
1078 vector<int32_t> values_expected{2, 4, 6};
1080 ASSERT_EQ(values_expected, values_out);
1083 TEST(constant_folding, const_reducemax_keepdims)
1085 Shape input_shape{3, 2};
1086 Shape output_shape{3, 1};
1088 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1089 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1090 Shape axes_shape{1};
1091 vector<int32_t> values_axes{1};
1092 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1093 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes, true);
1094 convert->set_friendly_name("test");
1095 auto f = make_shared<Function>(convert, ParameterVector{});
1097 pass::Manager pass_manager;
1098 pass_manager.register_pass<pass::ConstantFolding>();
1099 pass_manager.run_passes(f);
1101 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1102 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1105 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1106 ASSERT_TRUE(new_const);
1107 ASSERT_EQ(new_const->get_friendly_name(), "test");
1108 ASSERT_EQ(new_const->get_shape(), output_shape);
1110 auto values_out = new_const->get_vector<int32_t>();
1112 vector<int32_t> values_expected{2, 4, 6};
1114 ASSERT_EQ(values_expected, values_out);
1117 TEST(constant_folding, const_min)
1119 Shape input_shape{3, 3};
1121 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1122 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1123 auto convert = make_shared<op::Min>(constant, AxisSet{1});
1124 convert->set_friendly_name("test");
1125 auto f = make_shared<Function>(convert, ParameterVector{});
1127 pass::Manager pass_manager;
1128 pass_manager.register_pass<pass::ConstantFolding>();
1129 pass_manager.run_passes(f);
1131 ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
1132 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1135 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1136 ASSERT_TRUE(new_const);
1137 ASSERT_EQ(new_const->get_friendly_name(), "test");
1138 auto values_out = new_const->get_vector<int32_t>();
1140 vector<int32_t> values_expected{1, 4, 7};
1142 ASSERT_EQ(values_expected, values_out);
1145 TEST(constant_folding, const_reducemin)
1147 Shape input_shape{3, 2};
1148 Shape output_shape{3};
1150 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1151 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1152 Shape axes_shape{1};
1153 vector<int32_t> values_axes{1};
1154 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1155 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes);
1156 convert->set_friendly_name("test");
1157 auto f = make_shared<Function>(convert, ParameterVector{});
1159 pass::Manager pass_manager;
1160 pass_manager.register_pass<pass::ConstantFolding>();
1161 pass_manager.run_passes(f);
1163 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1164 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1167 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1168 ASSERT_TRUE(new_const);
1169 ASSERT_EQ(new_const->get_friendly_name(), "test");
1170 ASSERT_EQ(new_const->get_shape(), output_shape);
1172 auto values_out = new_const->get_vector<int32_t>();
1174 vector<int32_t> values_expected{1, 3, 5};
1176 ASSERT_EQ(values_expected, values_out);
1179 TEST(constant_folding, const_reducemin_keepdims)
1181 Shape input_shape{3, 2};
1182 Shape output_shape{3, 1};
1184 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1185 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1186 Shape axes_shape{1};
1187 vector<int32_t> values_axes{1};
1188 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1189 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes, true);
1190 convert->set_friendly_name("test");
1191 auto f = make_shared<Function>(convert, ParameterVector{});
1193 pass::Manager pass_manager;
1194 pass_manager.register_pass<pass::ConstantFolding>();
1195 pass_manager.run_passes(f);
1197 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1198 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1201 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1202 ASSERT_TRUE(new_const);
1203 ASSERT_EQ(new_const->get_friendly_name(), "test");
1204 ASSERT_EQ(new_const->get_shape(), output_shape);
1206 auto values_out = new_const->get_vector<int32_t>();
1208 vector<int32_t> values_expected{1, 3, 5};
1210 ASSERT_EQ(values_expected, values_out);
1213 TEST(constant_folding, const_reducemean)
1215 Shape input_shape{3, 3};
1216 Shape output_shape{3};
1218 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1219 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1220 Shape axes_shape{1};
1221 vector<int32_t> values_axes{1};
1222 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1223 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes);
1224 convert->set_friendly_name("test");
1225 auto f = make_shared<Function>(convert, ParameterVector{});
1227 pass::Manager pass_manager;
1228 pass_manager.register_pass<pass::ConstantFolding>();
1229 pass_manager.run_passes(f);
1231 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1232 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1235 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1236 ASSERT_TRUE(new_const);
1237 ASSERT_EQ(new_const->get_friendly_name(), "test");
1238 ASSERT_EQ(new_const->get_shape(), output_shape);
1240 auto values_out = new_const->get_vector<int32_t>();
1242 vector<int32_t> values_expected{2, 5, 8};
1244 ASSERT_EQ(values_expected, values_out);
1247 TEST(constant_folding, const_reducemean_keepdims)
1249 Shape input_shape{3, 3};
1250 Shape output_shape{3, 1};
1252 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1253 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1254 Shape axes_shape{1};
1255 vector<int32_t> values_axes{1};
1256 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1257 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes, true);
1258 convert->set_friendly_name("test");
1259 auto f = make_shared<Function>(convert, ParameterVector{});
1261 pass::Manager pass_manager;
1262 pass_manager.register_pass<pass::ConstantFolding>();
1263 pass_manager.run_passes(f);
1265 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1266 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1269 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1270 ASSERT_TRUE(new_const);
1271 ASSERT_EQ(new_const->get_friendly_name(), "test");
1272 ASSERT_EQ(new_const->get_shape(), output_shape);
1274 auto values_out = new_const->get_vector<int32_t>();
1276 vector<int32_t> values_expected{2, 5, 8};
1278 ASSERT_EQ(values_expected, values_out);
1281 TEST(constant_folding, const_reduce_logical_and__no_keepdims)
1283 const Shape input_shape{3, 3};
1285 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1286 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1287 const auto axes = op::Constant::create(element::i64, {1}, {1});
1288 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, false);
1289 convert->set_friendly_name("test");
1290 auto f = make_shared<Function>(convert, ParameterVector{});
1292 pass::Manager pass_manager;
1293 pass_manager.register_pass<pass::ConstantFolding>();
1294 pass_manager.run_passes(f);
1296 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1297 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1299 const auto new_const =
1300 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1301 ASSERT_TRUE(new_const);
1302 ASSERT_EQ(new_const->get_friendly_name(), "test");
1304 const Shape expected_out_shape{3};
1305 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1307 const auto values_out = new_const->get_vector<char>();
1309 const vector<char> values_expected{0, 0, 1};
1311 ASSERT_EQ(values_expected, values_out);
1314 TEST(constant_folding, const_reduce_logical_and__keepdims)
1316 const Shape input_shape{3, 3};
1318 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1319 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1320 const auto axes = op::Constant::create(element::i64, {1}, {1});
1321 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1322 convert->set_friendly_name("test");
1323 auto f = make_shared<Function>(convert, ParameterVector{});
1325 pass::Manager pass_manager;
1326 pass_manager.register_pass<pass::ConstantFolding>();
1327 pass_manager.run_passes(f);
1329 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1330 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1332 const auto new_const =
1333 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1334 ASSERT_TRUE(new_const);
1335 ASSERT_EQ(new_const->get_friendly_name(), "test");
1337 // the output shape is expected to have 'ones' at the positions specified in the reduction axes
1338 // in case the keep_dims attribute of ReduceLogicalAnd is set to true
1339 const Shape expected_out_shape{3, 1};
1340 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1342 const auto values_out = new_const->get_vector<char>();
1344 const vector<char> values_expected{0, 0, 1};
1346 ASSERT_EQ(values_expected, values_out);
1349 TEST(constant_folding, const_reduce_logical_and__keepdims_3d)
1351 const Shape input_shape{2, 2, 2};
1353 const vector<char> values_in{1, 1, 0, 0, 1, 0, 0, 1};
1354 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1355 const auto axes = op::Constant::create(element::i64, {2}, {0, 2});
1356 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1357 convert->set_friendly_name("test");
1358 auto f = make_shared<Function>(convert, ParameterVector{});
1360 pass::Manager pass_manager;
1361 pass_manager.register_pass<pass::ConstantFolding>();
1362 pass_manager.run_passes(f);
1364 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1365 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1367 const auto new_const =
1368 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1369 ASSERT_TRUE(new_const);
1370 ASSERT_EQ(new_const->get_friendly_name(), "test");
1372 const Shape expected_out_shape{1, 2, 1};
1373 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1375 const auto values_out = new_const->get_vector<char>();
1377 const vector<char> values_expected{0, 0};
1379 ASSERT_EQ(values_expected, values_out);
1382 TEST(constant_folding, const_reduce_logical_or__no_keepdims)
1384 const Shape input_shape{3, 3};
1386 const vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
1387 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1388 const auto axes = op::Constant::create(element::i64, {1}, {1});
1389 const auto convert = make_shared<op::v1::ReduceLogicalOr>(data, axes, false);
1390 convert->set_friendly_name("test");
1391 auto f = make_shared<Function>(convert, ParameterVector{});
1393 pass::Manager pass_manager;
1394 pass_manager.register_pass<pass::ConstantFolding>();
1395 pass_manager.run_passes(f);
1397 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1398 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1400 const auto new_const =
1401 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1402 ASSERT_TRUE(new_const);
1403 ASSERT_EQ(new_const->get_friendly_name(), "test");
1405 const Shape expected_out_shape{3};
1406 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1408 const auto values_out = new_const->get_vector<char>();
1410 const vector<char> values_expected{1, 1, 0};
1412 ASSERT_EQ(values_expected, values_out);
1415 TEST(constant_folding, const_concat)
1418 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1419 auto constant1 = op::Constant::create(element::i32, Shape{2, 1}, vector<int32_t>{7, 8});
1420 auto concat = make_shared<op::Concat>(NodeVector{constant0, constant1}, 1);
1421 concat->set_friendly_name("test");
1422 auto f = make_shared<Function>(concat, ParameterVector{});
1424 pass::Manager pass_manager;
1425 pass_manager.register_pass<pass::ConstantFolding>();
1426 pass_manager.run_passes(f);
1428 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1429 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1432 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1433 ASSERT_TRUE(new_const);
1434 ASSERT_EQ(new_const->get_friendly_name(), "test");
1435 auto values_out = new_const->get_vector<int32_t>();
1437 vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
1439 ASSERT_EQ(values_expected, values_out);
1442 TEST(constant_folding, const_concat_3d_single_elem)
1444 auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
1445 auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
1446 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
1447 concat->set_friendly_name("test");
1448 auto f = make_shared<Function>(concat, ParameterVector{});
1450 pass::Manager pass_manager;
1451 pass_manager.register_pass<pass::ConstantFolding>();
1452 pass_manager.run_passes(f);
1454 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1455 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1458 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1460 ASSERT_TRUE(new_const);
1461 ASSERT_EQ(new_const->get_friendly_name(), "test");
1462 ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
1464 auto values_out = new_const->get_vector<int32_t>();
1465 vector<int32_t> values_expected{1, 2};
1466 ASSERT_EQ(values_expected, values_out);
1469 TEST(constant_folding, const_concat_axis_2)
1472 op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1473 auto constant_2 = op::Constant::create(
1474 element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
1475 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
1476 concat->set_friendly_name("test");
1477 auto f = make_shared<Function>(concat, ParameterVector{});
1479 pass::Manager pass_manager;
1480 pass_manager.register_pass<pass::ConstantFolding>();
1481 pass_manager.run_passes(f);
1483 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1484 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1487 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1489 ASSERT_TRUE(new_const);
1490 ASSERT_EQ(new_const->get_friendly_name(), "test");
1491 ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
1493 auto values_out = new_const->get_vector<int32_t>();
1494 vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
1495 ASSERT_EQ(values_expected, values_out);
1498 TEST(constant_folding, const_concat_axis_1_bool_type)
1501 op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
1502 auto constant_2 = op::Constant::create(
1503 element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
1504 auto constant_3 = op::Constant::create(
1505 element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
1506 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
1507 concat->set_friendly_name("test");
1508 auto f = make_shared<Function>(concat, ParameterVector{});
1510 pass::Manager pass_manager;
1511 pass_manager.register_pass<pass::ConstantFolding>();
1512 pass_manager.run_passes(f);
1514 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1515 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1518 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1520 ASSERT_TRUE(new_const);
1521 ASSERT_EQ(new_const->get_friendly_name(), "test");
1522 ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
1524 auto values_out = new_const->get_vector<char>();
1525 vector<char> values_expected{
1526 true, true, true, false, true, false, true, false, true, false, true, false};
1527 ASSERT_EQ(values_expected, values_out);
1530 TEST(constant_folding, const_not)
1533 op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
1534 auto logical_not = make_shared<op::Not>(constant);
1535 logical_not->set_friendly_name("test");
1536 auto f = make_shared<Function>(logical_not, ParameterVector{});
1538 pass::Manager pass_manager;
1539 pass_manager.register_pass<pass::ConstantFolding>();
1540 pass_manager.run_passes(f);
1542 ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
1543 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1546 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1547 ASSERT_TRUE(new_const);
1548 ASSERT_EQ(new_const->get_friendly_name(), "test");
1549 auto values_out = new_const->get_vector<char>();
1551 vector<char> values_expected{1, 0, 1, 1, 0, 0};
1553 ASSERT_EQ(values_expected, values_out);
1556 TEST(constant_folding, const_equal)
1559 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1561 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1562 auto eq = make_shared<op::Equal>(constant0, constant1);
1563 eq->set_friendly_name("test");
1564 auto f = make_shared<Function>(eq, ParameterVector{});
1566 pass::Manager pass_manager;
1567 pass_manager.register_pass<pass::ConstantFolding>();
1568 pass_manager.run_passes(f);
1570 ASSERT_EQ(count_ops_of_type<op::Equal>(f), 0);
1571 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1574 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1575 ASSERT_TRUE(new_const);
1576 ASSERT_EQ(new_const->get_friendly_name(), "test");
1577 auto values_out = new_const->get_vector<char>();
1579 vector<char> values_expected{1, 1, 0, 0, 1, 1};
1581 ASSERT_EQ(values_expected, values_out);
1584 TEST(constant_folding, const_not_equal)
1587 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1589 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1590 auto eq = make_shared<op::NotEqual>(constant0, constant1);
1591 eq->set_friendly_name("test");
1592 auto f = make_shared<Function>(eq, ParameterVector{});
1594 pass::Manager pass_manager;
1595 pass_manager.register_pass<pass::ConstantFolding>();
1596 pass_manager.run_passes(f);
1598 ASSERT_EQ(count_ops_of_type<op::NotEqual>(f), 0);
1599 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1602 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1603 ASSERT_TRUE(new_const);
1604 ASSERT_EQ(new_const->get_friendly_name(), "test");
1605 auto values_out = new_const->get_vector<char>();
1607 vector<char> values_expected{0, 0, 1, 1, 0, 0};
1609 ASSERT_EQ(values_expected, values_out);
1612 TEST(constant_folding, const_greater)
1615 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1617 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1618 auto eq = make_shared<op::Greater>(constant0, constant1);
1619 eq->set_friendly_name("test");
1620 auto f = make_shared<Function>(eq, ParameterVector{});
1622 pass::Manager pass_manager;
1623 pass_manager.register_pass<pass::ConstantFolding>();
1624 pass_manager.run_passes(f);
1626 ASSERT_EQ(count_ops_of_type<op::Greater>(f), 0);
1627 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1630 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1631 ASSERT_TRUE(new_const);
1632 ASSERT_EQ(new_const->get_friendly_name(), "test");
1633 auto values_out = new_const->get_vector<char>();
1635 vector<char> values_expected{0, 0, 1, 0, 0, 1};
1637 ASSERT_EQ(values_expected, values_out);
1640 TEST(constant_folding, const_greater_eq)
1643 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1645 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1646 auto eq = make_shared<op::GreaterEq>(constant0, constant1);
1647 eq->set_friendly_name("test");
1648 auto f = make_shared<Function>(eq, ParameterVector{});
1650 pass::Manager pass_manager;
1651 pass_manager.register_pass<pass::ConstantFolding>();
1652 pass_manager.run_passes(f);
1654 ASSERT_EQ(count_ops_of_type<op::GreaterEq>(f), 0);
1655 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1658 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1659 ASSERT_TRUE(new_const);
1660 ASSERT_EQ(new_const->get_friendly_name(), "test");
1661 auto values_out = new_const->get_vector<char>();
1663 vector<char> values_expected{0, 1, 1, 0, 1, 1};
1665 ASSERT_EQ(values_expected, values_out);
1668 TEST(constant_folding, const_less)
1671 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1673 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1674 auto eq = make_shared<op::Less>(constant0, constant1);
1675 eq->set_friendly_name("test");
1676 auto f = make_shared<Function>(eq, ParameterVector{});
1678 pass::Manager pass_manager;
1679 pass_manager.register_pass<pass::ConstantFolding>();
1680 pass_manager.run_passes(f);
1682 ASSERT_EQ(count_ops_of_type<op::Less>(f), 0);
1683 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1686 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1687 ASSERT_TRUE(new_const);
1688 ASSERT_EQ(new_const->get_friendly_name(), "test");
1689 auto values_out = new_const->get_vector<char>();
1691 vector<char> values_expected{1, 0, 0, 1, 0, 0};
1693 ASSERT_EQ(values_expected, values_out);
1696 TEST(constant_folding, const_less_eq)
1699 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1701 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1702 auto eq = make_shared<op::LessEq>(constant0, constant1);
1703 eq->set_friendly_name("test");
1704 auto f = make_shared<Function>(eq, ParameterVector{});
1706 pass::Manager pass_manager;
1707 pass_manager.register_pass<pass::ConstantFolding>();
1708 pass_manager.run_passes(f);
1710 ASSERT_EQ(count_ops_of_type<op::LessEq>(f), 0);
1711 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1714 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1715 ASSERT_TRUE(new_const);
1716 ASSERT_EQ(new_const->get_friendly_name(), "test");
1717 auto values_out = new_const->get_vector<char>();
1719 vector<char> values_expected{1, 1, 0, 1, 1, 0};
1721 ASSERT_EQ(values_expected, values_out);
1724 TEST(constant_folding, const_or)
1727 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1729 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1730 auto eq = make_shared<op::Or>(constant0, constant1);
1731 eq->set_friendly_name("test");
1732 auto f = make_shared<Function>(eq, ParameterVector{});
1734 pass::Manager pass_manager;
1735 pass_manager.register_pass<pass::ConstantFolding>();
1736 pass_manager.run_passes(f);
1738 ASSERT_EQ(count_ops_of_type<op::Or>(f), 0);
1739 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1742 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1743 ASSERT_TRUE(new_const);
1744 ASSERT_EQ(new_const->get_friendly_name(), "test");
1745 auto values_out = new_const->get_vector<char>();
1747 vector<char> values_expected{0, 1, 1, 1, 1, 1};
1749 ASSERT_EQ(values_expected, values_out);
1752 TEST(constant_folding, const_xor)
1755 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1757 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1758 auto eq = make_shared<op::Xor>(constant0, constant1);
1759 eq->set_friendly_name("test");
1760 auto f = make_shared<Function>(eq, ParameterVector{});
1762 pass::Manager pass_manager;
1763 pass_manager.register_pass<pass::ConstantFolding>();
1764 pass_manager.run_passes(f);
1766 ASSERT_EQ(count_ops_of_type<op::Xor>(f), 0);
1767 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1770 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1771 ASSERT_TRUE(new_const);
1772 ASSERT_EQ(new_const->get_friendly_name(), "test");
1773 auto values_out = new_const->get_vector<char>();
1775 vector<char> values_expected{0, 1, 0, 1, 1, 0};
1777 ASSERT_EQ(values_expected, values_out);
1780 TEST(constant_folding, const_ceiling)
1782 auto constant = op::Constant::create(
1783 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1784 auto ceil = make_shared<op::Ceiling>(constant);
1785 ceil->set_friendly_name("test");
1786 auto f = make_shared<Function>(ceil, ParameterVector{});
1788 pass::Manager pass_manager;
1789 pass_manager.register_pass<pass::ConstantFolding>();
1790 pass_manager.run_passes(f);
1792 ASSERT_EQ(count_ops_of_type<op::Ceiling>(f), 0);
1793 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1796 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1797 ASSERT_TRUE(new_const);
1798 ASSERT_EQ(new_const->get_friendly_name(), "test");
1799 auto values_out = new_const->get_vector<float>();
1801 vector<float> values_expected{0.0f, 1.0f, 0.0f, -2.0f, 3.0f, 3.0f};
1803 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1806 TEST(constant_folding, const_floor)
1808 auto constant = op::Constant::create(
1809 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1810 auto floor = make_shared<op::Floor>(constant);
1811 floor->set_friendly_name("test");
1812 auto f = make_shared<Function>(floor, ParameterVector{});
1814 pass::Manager pass_manager;
1815 pass_manager.register_pass<pass::ConstantFolding>();
1816 pass_manager.run_passes(f);
1818 ASSERT_EQ(count_ops_of_type<op::Floor>(f), 0);
1819 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1822 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1823 ASSERT_TRUE(new_const);
1824 ASSERT_EQ(new_const->get_friendly_name(), "test");
1825 auto values_out = new_const->get_vector<float>();
1827 vector<float> values_expected{0.0f, 0.0f, -1.0f, -3.0f, 2.0f, 3.0f};
1829 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1832 TEST(constant_folding, const_gather)
1834 auto constant_data = op::Constant::create(
1837 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1838 auto constant_indices =
1839 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1840 size_t gather_axis = 1;
1841 auto gather = make_shared<op::v0::Gather>(constant_data, constant_indices, gather_axis);
1842 gather->set_friendly_name("test");
1843 auto f = make_shared<Function>(gather, ParameterVector{});
1845 pass::Manager pass_manager;
1846 pass_manager.register_pass<pass::ConstantFolding>();
1847 pass_manager.run_passes(f);
1849 ASSERT_EQ(count_ops_of_type<op::v0::Gather>(f), 0);
1850 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1853 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1854 ASSERT_TRUE(new_const);
1855 ASSERT_EQ(new_const->get_friendly_name(), "test");
1856 auto values_out = new_const->get_vector<float>();
1858 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1860 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1863 TEST(constant_folding, const_gather_v1)
1865 auto constant_data = op::Constant::create(
1868 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1869 auto constant_indices =
1870 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1871 auto constant_axis = op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
1872 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1873 gather->set_friendly_name("test");
1874 auto f = make_shared<Function>(gather, ParameterVector{});
1876 pass::Manager pass_manager;
1877 pass_manager.register_pass<pass::ConstantFolding>();
1878 pass_manager.run_passes(f);
1880 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1881 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1884 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1885 ASSERT_TRUE(new_const);
1886 ASSERT_EQ(new_const->get_friendly_name(), "test");
1887 auto values_out = new_const->get_vector<float>();
1889 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1891 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1894 TEST(constant_folding, const_gather_v1_scalar)
1896 auto constant_data = op::Constant::create(
1899 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1900 auto constant_indices =
1901 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1902 auto constant_axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{1});
1903 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1904 gather->set_friendly_name("test");
1905 auto f = make_shared<Function>(gather, ParameterVector{});
1907 pass::Manager pass_manager;
1908 pass_manager.register_pass<pass::ConstantFolding>();
1909 pass_manager.run_passes(f);
1911 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1912 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1915 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1916 ASSERT_TRUE(new_const);
1917 ASSERT_EQ(new_const->get_friendly_name(), "test");
1918 auto values_out = new_const->get_vector<float>();
1920 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1922 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1925 TEST(constant_folding, const_gather_v1_subgraph)
1927 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1928 const float b_value = 3.21f;
1929 const auto B_const = op::Constant::create(element::f32, {1}, {b_value});
1930 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1931 const int64_t axis = 0;
1932 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1934 const auto concat = make_shared<op::Concat>(NodeVector{A, B_const, C}, axis);
1936 const vector<int64_t> indices{1};
1937 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1938 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1939 gather->set_friendly_name("test");
1940 auto f = make_shared<Function>(gather, ParameterVector{A, C});
1942 pass::Manager pass_manager;
1943 pass_manager.register_pass<pass::ConstantFolding>();
1944 pass_manager.run_passes(f);
1946 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1947 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1948 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1950 const auto new_const =
1951 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1952 ASSERT_TRUE(new_const);
1953 ASSERT_EQ(new_const->get_friendly_name(), "test");
1955 const auto values_out = new_const->get_vector<float>();
1956 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1959 TEST(constant_folding, const_gather_v1_subgraph_neg_axis)
1961 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1962 const float b_value = 1.23f;
1963 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1964 const auto C_const = op::Constant::create(element::f32, {1}, {b_value});
1965 const int64_t axis = 0;
1966 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1968 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C_const}, axis);
1970 const vector<int64_t> indices{-1};
1971 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1972 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1973 gather->set_friendly_name("test");
1974 auto f = make_shared<Function>(gather, ParameterVector{A, B});
1976 pass::Manager pass_manager;
1977 pass_manager.register_pass<pass::ConstantFolding>();
1978 pass_manager.run_passes(f);
1980 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1981 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1982 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1984 const auto new_const =
1985 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1986 ASSERT_TRUE(new_const);
1987 ASSERT_EQ(new_const->get_friendly_name(), "test");
1989 const auto values_out = new_const->get_vector<float>();
1990 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1993 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input)
1995 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1996 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1997 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1998 const int64_t axis = 0;
1999 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2001 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2003 const vector<int64_t> indices{1};
2004 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2005 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2006 gather->set_friendly_name("test");
2007 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2009 pass::Manager pass_manager;
2010 pass_manager.register_pass<pass::ConstantFolding>();
2011 pass_manager.run_passes(f);
2013 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
2014 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
2017 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input_scalar)
2019 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
2020 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2021 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2022 const int64_t axis = 0;
2023 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2025 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2027 const vector<int64_t> indices{1};
2028 const auto indices_const = op::Constant::create(element::i64, {}, indices);
2029 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2030 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2032 pass::Manager pass_manager;
2033 pass_manager.register_pass<pass::ConstantFolding>();
2034 pass_manager.run_passes(f);
2036 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
2037 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
2038 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(f), 1);
2041 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_zero_axis)
2043 const auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2044 const auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2045 const auto C = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2046 const int64_t axis = 1;
2047 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2049 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2051 const vector<int64_t> indices{1};
2052 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2053 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2054 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2056 pass::Manager pass_manager;
2057 pass_manager.register_pass<pass::ConstantFolding>();
2058 pass_manager.run_passes(f);
2060 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2061 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2064 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_single_indices)
2066 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
2067 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2068 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2069 const int64_t axis = 0;
2070 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2072 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2074 const vector<int64_t> indices{0, 1};
2075 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2076 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2077 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2079 pass::Manager pass_manager;
2080 pass_manager.register_pass<pass::ConstantFolding>();
2081 pass_manager.run_passes(f);
2083 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2084 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2087 TEST(constant_folding, const_gather_v1_subgraph_skip_if_concat_output_shape_dynamic)
2089 const auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
2090 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2091 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2092 const int64_t axis = 0;
2093 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2095 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2097 const vector<int64_t> indices{1};
2098 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2099 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2100 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2102 pass::Manager pass_manager;
2103 pass_manager.register_pass<pass::ConstantFolding>();
2104 pass_manager.run_passes(f);
2106 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2107 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2110 TEST(constant_folding, const_gather_v1_subgraph_skip_if_not_single_input)
2112 const auto A = make_shared<op::Parameter>(element::f32, Shape{2});
2113 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2114 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2115 const int64_t axis = 0;
2116 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2118 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2120 const vector<int64_t> indices{1};
2121 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2122 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2123 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2125 pass::Manager pass_manager;
2126 pass_manager.register_pass<pass::ConstantFolding>();
2127 pass_manager.run_passes(f);
2129 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2130 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2133 TEST(constant_folding, const_slice)
2137 vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
2138 auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
2139 auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
2140 slice->set_friendly_name("test");
2142 auto f = make_shared<Function>(slice, ParameterVector{});
2144 pass::Manager pass_manager;
2145 pass_manager.register_pass<pass::ConstantFolding>();
2146 pass_manager.run_passes(f);
2148 ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
2149 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2152 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2153 ASSERT_TRUE(new_const);
2154 ASSERT_EQ(new_const->get_friendly_name(), "test");
2155 auto values_out = new_const->get_vector<int>();
2157 vector<int> sliced_values{3, 6, 9, 12, 15};
2158 ASSERT_EQ(sliced_values, values_out);
2161 TEST(constant_folding, constant_dyn_reshape)
2163 Shape shape_in{2, 4};
2164 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2166 Shape shape_shape{3};
2167 vector<int64_t> values_shape{2, 4, 1};
2169 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2170 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
2171 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
2172 dyn_reshape->set_friendly_name("test");
2173 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2175 pass::Manager pass_manager;
2176 pass_manager.register_pass<pass::ConstantFolding>();
2177 pass_manager.run_passes(f);
2179 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2180 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2183 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2184 ASSERT_TRUE(new_const);
2185 ASSERT_EQ(new_const->get_friendly_name(), "test");
2186 auto values_out = new_const->get_vector<float>();
2188 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2191 TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
2193 Shape shape_in{2, 4};
2194 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2196 Shape shape_shape{3};
2197 // We're going to add these two together elementwise to get {2, 4, 1}.
2198 // This means that when ConstantFolding starts, v1::Reshape will not yet
2199 // have static output shape. But by the time the Add op is folded, the
2200 // v1::Reshape's shape should be inferrable.
2201 vector<int64_t> values_shape_a{1, 3, 0};
2202 vector<int64_t> values_shape_b{1, 1, 1};
2204 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2205 auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
2206 auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
2208 make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
2209 dyn_reshape->set_friendly_name("test");
2210 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2212 ASSERT_TRUE(dyn_reshape->get_output_partial_shape(0).is_dynamic());
2214 pass::Manager pass_manager;
2215 pass_manager.register_pass<pass::ConstantFolding>();
2216 pass_manager.run_passes(f);
2218 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2219 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2222 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2223 ASSERT_TRUE(new_const);
2224 ASSERT_EQ(new_const->get_friendly_name(), "test");
2225 auto values_out = new_const->get_vector<float>();
2227 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2230 TEST(constant_folding, constant_transpose)
2232 Shape shape_in{2, 4};
2233 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2235 Shape shape_perm{2};
2236 vector<int64_t> values_perm{1, 0};
2238 auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
2239 auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
2240 auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
2241 transpose->set_friendly_name("test");
2242 auto f = make_shared<Function>(transpose, ParameterVector{});
2244 pass::Manager pass_manager;
2245 pass_manager.register_pass<pass::ConstantFolding>();
2246 pass_manager.run_passes(f);
2248 ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
2249 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2252 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2253 ASSERT_TRUE(new_const);
2254 ASSERT_EQ(new_const->get_friendly_name(), "test");
2255 auto values_out = new_const->get_vector<double>();
2257 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
2258 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
2261 template <typename T>
2262 void range_test(T start, T stop, T step, const vector<T>& values_expected)
2264 vector<T> values_start{start};
2265 vector<T> values_stop{stop};
2266 vector<T> values_step{step};
2268 auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
2269 auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
2270 auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
2271 auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
2272 range->set_friendly_name("test");
2273 auto f = make_shared<Function>(range, ParameterVector{});
2275 pass::Manager pass_manager;
2276 pass_manager.register_pass<pass::ConstantFolding>();
2277 pass_manager.run_passes(f);
2279 ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
2280 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2283 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2284 ASSERT_TRUE(new_const);
2285 ASSERT_EQ(new_const->get_friendly_name(), "test");
2287 auto values_out = new_const->template get_vector<T>();
2289 range_test_check(values_out, values_expected);
2292 TEST(constant_folding, constant_range)
2294 range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
2295 range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
2296 range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
2297 range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
2298 range_test<double>(5, 12, 2, {5, 7, 9, 11});
2299 range_test<float>(5, 12, 2, {5, 7, 9, 11});
2301 range_test<int32_t>(5, 12, -2, {});
2302 range_test<float>(12, 4, -2, {12, 10, 8, 6});
2305 TEST(constant_folding, constant_select)
2308 vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
2309 vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
2310 vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
2312 auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
2313 auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
2314 auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
2315 auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
2316 select->set_friendly_name("test");
2317 auto f = make_shared<Function>(select, ParameterVector{});
2319 pass::Manager pass_manager;
2320 pass_manager.register_pass<pass::ConstantFolding>();
2321 pass_manager.run_passes(f);
2323 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2324 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2327 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2328 ASSERT_TRUE(new_const);
2329 ASSERT_EQ(new_const->get_friendly_name(), "test");
2330 auto values_out = new_const->get_vector<int64_t>();
2332 vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
2333 ASSERT_EQ(values_expected, values_out);
2336 TEST(constant_folding, constant_v1_select)
2339 vector<char> values_selection{0, 1, 1, 0};
2340 vector<int64_t> values_t{1, 2, 3, 4};
2341 vector<int64_t> values_f{11, 12, 13, 14, 15, 16, 17, 18};
2343 auto constant_selection =
2344 make_shared<op::Constant>(element::boolean, Shape{4}, values_selection);
2345 auto constant_t = make_shared<op::Constant>(element::i64, Shape{4}, values_t);
2346 auto constant_f = make_shared<op::Constant>(element::i64, Shape{2, 4}, values_f);
2347 auto select = make_shared<op::v1::Select>(constant_selection, constant_t, constant_f);
2348 select->set_friendly_name("test");
2349 auto f = make_shared<Function>(select, ParameterVector{});
2351 pass::Manager pass_manager;
2352 pass_manager.register_pass<pass::ConstantFolding>();
2353 pass_manager.run_passes(f);
2355 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2356 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2359 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2360 ASSERT_TRUE(new_const);
2361 ASSERT_EQ(new_const->get_friendly_name(), "test");
2362 auto values_out = new_const->get_vector<int64_t>();
2364 vector<int64_t> values_expected{11, 2, 3, 14, 15, 2, 3, 18};
2365 ASSERT_EQ(values_expected, values_out);
2368 TEST(constant_folding, constant_v1_split)
2370 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2371 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2372 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2373 const auto num_splits = 3;
2375 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2376 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2378 pass::Manager pass_manager;
2379 pass_manager.register_pass<pass::ConstantFolding>();
2380 pass_manager.run_passes(f);
2382 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2383 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2386 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2388 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2390 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2395 auto res1_values = res1->get_vector<float>();
2396 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2397 auto res2_values = res2->get_vector<float>();
2398 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2399 auto res3_values = res3->get_vector<float>();
2400 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2403 TEST(constant_folding, constant_v1_split_specialized)
2405 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2406 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2407 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2408 const auto num_splits = 3;
2410 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2411 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2413 pass::Manager pass_manager;
2414 pass_manager.register_pass<pass::ConstantFolding>();
2415 pass_manager.run_passes(f);
2417 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2418 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2421 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2423 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2425 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2430 auto res1_values = res1->get_vector<float>();
2431 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2432 auto res2_values = res2->get_vector<float>();
2433 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2434 auto res3_values = res3->get_vector<float>();
2435 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2438 TEST(constant_folding, constant_v1_split_axis_1_4_splits)
2440 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2442 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2444 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2446 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2448 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2449 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2450 const auto num_splits = 4;
2452 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2453 split_v1->set_friendly_name("test");
2454 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2456 pass::Manager pass_manager;
2457 pass_manager.register_pass<pass::ConstantFolding>();
2458 pass_manager.run_passes(f);
2460 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2461 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2464 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2466 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2468 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2470 as_type_ptr<op::Constant>(f->get_results().at(3)->input_value(0).get_node_shared_ptr());
2472 ASSERT_EQ(res1->get_friendly_name(), "test.0");
2474 ASSERT_EQ(res2->get_friendly_name(), "test.1");
2476 ASSERT_EQ(res3->get_friendly_name(), "test.2");
2478 ASSERT_EQ(res4->get_friendly_name(), "test.3");
2480 auto res1_values = res1->get_vector<int64_t>();
2481 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2483 auto res2_values = res2->get_vector<int64_t>();
2484 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2486 auto res3_values = res3->get_vector<int64_t>();
2487 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59}),
2489 auto res4_values = res4->get_vector<int64_t>();
2490 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2494 TEST(constant_folding, constant_v1_split_axis_1_2_splits)
2496 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2498 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2500 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2502 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2504 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2505 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2506 const auto num_splits = 2;
2508 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2509 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2511 pass::Manager pass_manager;
2512 pass_manager.register_pass<pass::ConstantFolding>();
2513 pass_manager.run_passes(f);
2515 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2516 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2519 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2521 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2525 auto res1_values = res1->get_vector<int64_t>();
2526 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
2527 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55}),
2529 auto res2_values = res2->get_vector<int64_t>();
2530 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2531 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2535 TEST(constant_folding, constant_v1_variadic_split_axis_1_2_splits)
2537 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2539 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2541 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2543 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2545 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2546 const auto const_axis = op::Constant::create(element::i16, Shape{}, {1});
2547 vector<int64_t> values_lengths{3, 1};
2548 auto constant_lengths =
2549 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2551 auto variadic_split_v1 =
2552 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2553 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2555 pass::Manager pass_manager;
2556 pass_manager.register_pass<pass::ConstantFolding>();
2557 pass_manager.run_passes(f);
2559 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2560 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2563 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2565 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2569 auto res1_values = res1->get_vector<int64_t>();
2570 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19,
2571 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 36, 37, 38, 39,
2572 40, 41, 42, 43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}),
2574 auto res2_values = res2->get_vector<int64_t>();
2575 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2579 TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
2581 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2583 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2585 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2587 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2589 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2590 const auto const_axis = op::Constant::create(element::i32, Shape{}, {1});
2591 vector<int64_t> values_lengths{1, 1, -1};
2592 auto constant_lengths =
2593 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2595 auto variadic_split_v1 =
2596 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2597 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2599 pass::Manager pass_manager;
2600 pass_manager.register_pass<pass::ConstantFolding>();
2601 pass_manager.run_passes(f);
2603 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2604 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2607 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2609 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2611 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2616 auto res1_values = res1->get_vector<int64_t>();
2617 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2619 auto res2_values = res2->get_vector<int64_t>();
2620 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2622 auto res3_values = res3->get_vector<int64_t>();
2623 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2624 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2628 TEST(constant_folding, constant_v1_one_hot)
2630 vector<int64_t> indices{0, 1, 2};
2631 float16 on_value = 1.123f;
2632 float16 off_value = 0.321f;
2634 const auto indices_const = op::Constant::create(element::i64, Shape{3}, indices);
2635 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2636 const auto on_const = op::Constant::create(element::f16, Shape{}, {on_value});
2637 const auto off_const = op::Constant::create(element::f16, Shape{}, {off_value});
2641 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2642 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2644 pass::Manager pass_manager;
2645 pass_manager.register_pass<pass::ConstantFolding>();
2646 pass_manager.run_passes(f);
2648 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2649 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2652 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2655 ASSERT_EQ((Shape{3, 3}), res->get_output_shape(0));
2656 ASSERT_EQ(vector<float16>({on_value,
2665 res->get_vector<float16>());
2668 TEST(constant_folding, constant_v1_one_hot_negative_axes)
2670 vector<int64_t> indices{0, 2, -1, 1};
2671 int16_t on_value = 4;
2672 int16_t off_value = 1;
2674 const auto indices_const = op::Constant::create(element::i64, Shape{4}, indices);
2675 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2676 const auto on_const = op::Constant::create(element::i16, Shape{}, {on_value});
2677 const auto off_const = op::Constant::create(element::i16, Shape{}, {off_value});
2681 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2682 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2684 pass::Manager pass_manager;
2685 pass_manager.register_pass<pass::ConstantFolding>();
2686 pass_manager.run_passes(f);
2688 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2689 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2692 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2695 ASSERT_EQ((Shape{4, 3}), res->get_output_shape(0));
2696 ASSERT_EQ(vector<int16_t>({on_value,
2708 res->get_vector<int16_t>());
2711 TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
2713 vector<int64_t> indices{0, 2, 1, -1};
2714 auto on_value = true;
2715 auto off_value = false;
2717 const auto indices_const = op::Constant::create(element::i64, Shape{2, 2}, indices);
2718 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2719 const auto on_const = op::Constant::create(element::boolean, Shape{}, {on_value});
2720 const auto off_const = op::Constant::create(element::boolean, Shape{}, {off_value});
2724 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2725 one_hot_v1->set_friendly_name("test");
2726 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2728 pass::Manager pass_manager;
2729 pass_manager.register_pass<pass::ConstantFolding>();
2730 pass_manager.run_passes(f);
2732 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2733 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2736 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2738 ASSERT_EQ(res->get_friendly_name(), "test");
2740 ASSERT_EQ((Shape{2, 2, 3}), res->get_output_shape(0));
2741 ASSERT_EQ(vector<bool>({on_value,
2753 res->get_vector<bool>());
2756 TEST(constant_folding, constant_tile_1d)
2759 Shape shape_repeats{1};
2762 vector<int> values_in{0, 1};
2763 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2764 vector<int> values_repeats{2};
2765 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2766 auto tile = make_shared<op::v0::Tile>(data, repeats);
2767 tile->set_friendly_name("test");
2768 auto f = make_shared<Function>(tile, ParameterVector{});
2770 pass::Manager pass_manager;
2771 pass_manager.register_pass<pass::ConstantFolding>();
2772 pass_manager.run_passes(f);
2774 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2775 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2778 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2779 ASSERT_TRUE(new_const);
2780 ASSERT_EQ(new_const->get_friendly_name(), "test");
2781 auto values_out = new_const->get_vector<int>();
2783 vector<int> values_expected{0, 1, 0, 1};
2784 ASSERT_EQ(values_expected, values_out);
2787 TEST(constant_folding, constant_tile_3d_small_data_rank)
2790 Shape shape_repeats{3};
2791 Shape shape_out{2, 2, 4};
2793 vector<int> values_in{0, 1};
2794 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2795 vector<int> values_repeats{2, 2, 2};
2796 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2797 auto tile = make_shared<op::v0::Tile>(data, repeats);
2798 tile->set_friendly_name("test");
2799 auto f = make_shared<Function>(tile, ParameterVector{});
2801 pass::Manager pass_manager;
2802 pass_manager.register_pass<pass::ConstantFolding>();
2803 pass_manager.run_passes(f);
2805 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2806 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2809 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2810 ASSERT_TRUE(new_const);
2811 ASSERT_EQ(new_const->get_friendly_name(), "test");
2812 auto values_out = new_const->get_vector<int>();
2814 vector<int> values_expected{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
2815 ASSERT_EQ(values_expected, values_out);
2818 TEST(constant_folding, constant_tile_3d_few_repeats)
2820 Shape shape_in{2, 1, 3};
2821 Shape shape_repeats{2};
2822 Shape shape_out{2, 2, 3};
2824 vector<int> values_in{1, 2, 3, 4, 5, 6};
2825 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2826 vector<int> values_repeats{2, 1};
2827 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2828 auto tile = make_shared<op::v0::Tile>(data, repeats);
2829 tile->set_friendly_name("test");
2830 auto f = make_shared<Function>(tile, ParameterVector{});
2832 pass::Manager pass_manager;
2833 pass_manager.register_pass<pass::ConstantFolding>();
2834 pass_manager.run_passes(f);
2836 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2837 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2840 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2841 ASSERT_TRUE(new_const);
2842 ASSERT_EQ(new_const->get_friendly_name(), "test");
2843 auto values_out = new_const->get_vector<int>();
2845 vector<int> values_expected{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6};
2846 ASSERT_EQ(values_expected, values_out);
2849 TEST(constant_folding, constant_tile_1d_0_repeats)
2852 Shape shape_repeats{1};
2855 vector<int> values_in{0, 1};
2856 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2857 vector<int> values_repeats{0};
2858 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2859 auto tile = make_shared<op::v0::Tile>(data, repeats);
2860 tile->set_friendly_name("test");
2861 auto f = make_shared<Function>(tile, ParameterVector{});
2863 pass::Manager pass_manager;
2864 pass_manager.register_pass<pass::ConstantFolding>();
2865 pass_manager.run_passes(f);
2867 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2868 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2871 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2872 ASSERT_TRUE(new_const);
2873 ASSERT_EQ(new_const->get_friendly_name(), "test");
2874 auto values_out = new_const->get_vector<int>();
2876 vector<int> values_expected{};
2877 ASSERT_EQ(values_expected, values_out);
2880 TEST(constant_folding, constant_tile_0_rank_data)
2883 Shape shape_repeats{1};
2886 vector<int> values_in{1};
2887 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2888 vector<int> values_repeats{4};
2889 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2890 auto tile = make_shared<op::v0::Tile>(data, repeats);
2891 tile->set_friendly_name("test");
2892 auto f = make_shared<Function>(tile, ParameterVector{});
2894 pass::Manager pass_manager;
2895 pass_manager.register_pass<pass::ConstantFolding>();
2896 pass_manager.run_passes(f);
2898 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2899 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2902 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2903 ASSERT_TRUE(new_const);
2904 ASSERT_EQ(new_const->get_friendly_name(), "test");
2905 auto values_out = new_const->get_vector<int>();
2907 vector<int> values_expected{1, 1, 1, 1};
2908 ASSERT_EQ(values_expected, values_out);
2911 TEST(constant_folding, constant_non_zero_0D)
2913 auto data = op::Constant::create(element::i32, Shape{}, {1});
2914 auto non_zero = make_shared<op::v3::NonZero>(data);
2915 non_zero->set_friendly_name("test");
2916 auto f = make_shared<Function>(non_zero, ParameterVector{});
2918 pass::Manager pass_manager;
2919 pass_manager.register_pass<pass::ConstantFolding>();
2920 pass_manager.run_passes(f);
2922 // Fold into constant with shape of {1, 1} for scalar input with
2924 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2925 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2927 const auto new_const =
2928 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2929 ASSERT_TRUE(new_const);
2930 ASSERT_EQ(new_const->get_friendly_name(), "test");
2931 const auto values_out = new_const->get_vector<int64_t>();
2933 const vector<int64_t> values_expected{0};
2934 ASSERT_EQ(values_expected, values_out);
2935 ASSERT_EQ((Shape{1, 1}), new_const->get_shape());
2938 TEST(constant_folding, constant_non_zero_1D)
2940 vector<int> values_in{0, 1, 0, 1};
2941 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2942 auto non_zero = make_shared<op::v3::NonZero>(data);
2943 non_zero->set_friendly_name("test");
2944 auto f = make_shared<Function>(non_zero, ParameterVector{});
2946 pass::Manager pass_manager;
2947 pass_manager.register_pass<pass::ConstantFolding>();
2948 pass_manager.run_passes(f);
2950 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2951 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2953 const auto new_const =
2954 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2955 ASSERT_TRUE(new_const);
2956 ASSERT_EQ(new_const->get_friendly_name(), "test");
2957 const auto values_out = new_const->get_vector<int64_t>();
2959 const vector<int64_t> values_expected{1, 3};
2960 ASSERT_EQ(values_expected, values_out);
2961 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2964 TEST(constant_folding, constant_non_zero_int32_output_type)
2966 vector<int> values_in{0, 1, 0, 1};
2967 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2968 auto non_zero = make_shared<op::v3::NonZero>(data, element::i32);
2969 non_zero->set_friendly_name("test");
2970 auto f = make_shared<Function>(non_zero, ParameterVector{});
2972 pass::Manager pass_manager;
2973 pass_manager.register_pass<pass::ConstantFolding>();
2974 pass_manager.run_passes(f);
2976 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2977 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2979 const auto new_const =
2980 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2981 ASSERT_TRUE(new_const);
2982 ASSERT_EQ(new_const->get_friendly_name(), "test");
2983 ASSERT_EQ(element::i32, new_const->get_element_type());
2984 const auto values_out = new_const->get_vector<int32_t>();
2986 const vector<int32_t> values_expected{1, 3};
2987 ASSERT_EQ(values_expected, values_out);
2988 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2991 TEST(constant_folding, constant_non_zero_1D_all_indices)
2993 const vector<float> values_in{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
2994 const auto data = make_shared<op::Constant>(element::f32, Shape{values_in.size()}, values_in);
2995 const auto non_zero = make_shared<op::v3::NonZero>(data);
2996 non_zero->set_friendly_name("test");
2997 auto f = make_shared<Function>(non_zero, ParameterVector{});
2999 pass::Manager pass_manager;
3000 pass_manager.register_pass<pass::ConstantFolding>();
3001 pass_manager.run_passes(f);
3003 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3004 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3006 const auto new_const =
3007 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3008 ASSERT_TRUE(new_const);
3009 ASSERT_EQ(new_const->get_friendly_name(), "test");
3010 const auto values_out = new_const->get_vector<int64_t>();
3012 const vector<int64_t> values_expected{0, 1, 2, 3, 4, 5, 6, 7};
3013 ASSERT_EQ(values_expected, values_out);
3014 ASSERT_EQ((Shape{1, values_in.size()}), new_const->get_shape());
3017 TEST(constant_folding, constant_non_zero_2D)
3019 vector<int> values_in{1, 0, 0, 0, 1, 0, 1, 1, 0};
3020 auto data = make_shared<op::Constant>(element::i32, Shape{3, 3}, values_in);
3021 auto non_zero = make_shared<op::v3::NonZero>(data);
3022 non_zero->set_friendly_name("test");
3023 auto f = make_shared<Function>(non_zero, ParameterVector{});
3025 pass::Manager pass_manager;
3026 pass_manager.register_pass<pass::ConstantFolding>();
3027 pass_manager.run_passes(f);
3029 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3030 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3032 const auto new_const =
3033 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3034 ASSERT_TRUE(new_const);
3035 ASSERT_EQ(new_const->get_friendly_name(), "test");
3036 const auto values_out = new_const->get_vector<int64_t>();
3038 const vector<int64_t> values_expected{0, 1, 2, 2, 0, 1, 0, 1};
3039 ASSERT_EQ(values_expected, values_out);
3040 ASSERT_EQ((Shape{2, 4}), new_const->get_shape());
3043 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_indices)
3045 const vector<int8_t> values_in{1, 1, 1, 1, 1, 1, 1, 1, 1};
3046 const auto data = make_shared<op::Constant>(element::i8, Shape{3, 3}, values_in);
3047 const auto non_zero = make_shared<op::v3::NonZero>(data);
3048 non_zero->set_friendly_name("test");
3049 auto f = make_shared<Function>(non_zero, ParameterVector{});
3051 pass::Manager pass_manager;
3052 pass_manager.register_pass<pass::ConstantFolding>();
3053 pass_manager.run_passes(f);
3055 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3056 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3058 const auto new_const =
3059 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3060 ASSERT_TRUE(new_const);
3061 ASSERT_EQ(new_const->get_friendly_name(), "test");
3062 const auto values_out = new_const->get_vector<int64_t>();
3064 const vector<int64_t> values_expected{0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2};
3065 ASSERT_EQ(values_expected, values_out);
3066 ASSERT_EQ((Shape{2, values_in.size()}), new_const->get_shape());
3069 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_zeros)
3071 const vector<uint8_t> values_in{0, 0, 0, 0, 0, 0};
3072 const auto data = make_shared<op::Constant>(element::u8, Shape{2, 3}, values_in);
3073 const auto non_zero = make_shared<op::v3::NonZero>(data);
3074 non_zero->set_friendly_name("test");
3075 auto f = make_shared<Function>(non_zero, ParameterVector{});
3077 pass::Manager pass_manager;
3078 pass_manager.register_pass<pass::ConstantFolding>();
3079 pass_manager.run_passes(f);
3081 // fold into Constant with shape of {0}
3082 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3083 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3085 const auto new_const =
3086 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3087 ASSERT_TRUE(new_const);
3088 ASSERT_EQ(new_const->get_friendly_name(), "test");
3089 ASSERT_EQ(shape_size(new_const->get_shape()), 0);
3092 TEST(constant_folding, constant_non_zero_3D)
3094 vector<int> values_in{1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0};
3095 auto data = make_shared<op::Constant>(element::i32, Shape{2, 3, 3}, values_in);
3096 auto non_zero = make_shared<op::v3::NonZero>(data);
3097 non_zero->set_friendly_name("test");
3098 auto f = make_shared<Function>(non_zero, ParameterVector{});
3100 pass::Manager pass_manager;
3101 pass_manager.register_pass<pass::ConstantFolding>();
3102 pass_manager.run_passes(f);
3104 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3105 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3107 const auto new_const =
3108 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3109 ASSERT_TRUE(new_const);
3110 ASSERT_EQ(new_const->get_friendly_name(), "test");
3111 const auto values_out = new_const->get_vector<int64_t>();
3113 const vector<int64_t> values_expected{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 2, 2, 2,
3114 0, 0, 0, 1, 1, 2, 0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 2, 1};
3115 ASSERT_EQ(values_expected, values_out);
3116 ASSERT_EQ((Shape{3, 12}), new_const->get_shape());
3119 TEST(constant_folding, constant_scatter_elements_update_basic)
3121 const Shape data_shape{3, 3};
3122 const Shape indices_shape{2, 3};
3124 const auto data_const = op::Constant::create(
3125 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3126 const auto indices_const =
3127 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3128 const auto updates_const =
3129 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3130 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3132 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3133 data_const, indices_const, updates_const, axis_const);
3134 scatter_elem_updt->set_friendly_name("test");
3135 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3137 pass::Manager pass_manager;
3138 pass_manager.register_pass<pass::ConstantFolding>();
3139 pass_manager.run_passes(f);
3141 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3142 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3145 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3146 ASSERT_TRUE(result_node);
3147 ASSERT_EQ(result_node->get_friendly_name(), "test");
3148 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3149 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3150 range_test_check(result_node->cast_vector<float>(), expected);
3153 TEST(constant_folding, constant_scatter_elements_update_negative_axis)
3155 const Shape data_shape{3, 3};
3156 const Shape indices_shape{2, 3};
3158 const auto data_const = op::Constant::create(
3159 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3160 const auto indices_const =
3161 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3162 const auto updates_const =
3163 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3164 const auto axis_const = op::Constant::create(element::i64, Shape{}, {-1});
3166 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3167 data_const, indices_const, updates_const, axis_const);
3168 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3170 pass::Manager pass_manager;
3171 pass_manager.register_pass<pass::ConstantFolding>();
3172 pass_manager.run_passes(f);
3174 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3175 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3178 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3179 ASSERT_TRUE(result_node);
3180 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3181 std::vector<float> expected{1.1f, 1.0f, 1.2f, 2.0f, 2.2f, 2.1f, 0.0f, 0.0f, 0.0f};
3182 range_test_check(result_node->cast_vector<float>(), expected);
3185 TEST(constant_folding, constant_scatter_elements_update_1d_axis)
3187 const Shape data_shape{3, 3};
3188 const Shape indices_shape{2, 3};
3190 const auto data_const = op::Constant::create(
3191 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3192 const auto indices_const =
3193 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3194 const auto updates_const =
3195 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3196 const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
3198 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3199 data_const, indices_const, updates_const, axis_const);
3200 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3202 pass::Manager pass_manager;
3203 pass_manager.register_pass<pass::ConstantFolding>();
3204 pass_manager.run_passes(f);
3206 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3207 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3210 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3211 ASSERT_TRUE(result_node);
3212 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3213 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3214 range_test_check(result_node->cast_vector<float>(), expected);
3217 TEST(constant_folding, constant_scatter_elements_update_3d_i16)
3219 const Shape data_shape{3, 3, 3};
3220 const Shape indices_shape{2, 2, 3};
3222 const auto data_const = op::Constant::create(
3223 element::i16, data_shape, std::vector<int16_t>(shape_size(data_shape), 0));
3224 const auto indices_const =
3225 op::Constant::create(element::i16, indices_shape, {1, 0, 2, 0, 2, 1, 2, 2, 2, 0, 1, 0});
3226 const auto updates_const =
3227 op::Constant::create(element::i16, indices_shape, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3228 const auto axis_const = op::Constant::create(element::i64, Shape{}, {1});
3230 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3231 data_const, indices_const, updates_const, axis_const);
3232 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3234 pass::Manager pass_manager;
3235 pass_manager.register_pass<pass::ConstantFolding>();
3236 pass_manager.run_passes(f);
3238 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3239 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3242 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3243 ASSERT_TRUE(result_node);
3244 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3245 std::vector<int16_t> expected{4, 2, 0, 1, 0, 6, 0, 5, 3, 10, 0, 12, 0, 11,
3246 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0};
3247 range_test_check(result_node->cast_vector<int16_t>(), expected);
3250 TEST(constant_folding, constant_scatter_elements_update_one_elem)
3252 const Shape data_shape{3, 3, 3};
3253 const Shape indices_shape{1, 1, 1};
3254 const auto input_data = std::vector<int32_t>(shape_size(data_shape), 0);
3256 const auto data_const = op::Constant::create(element::i32, data_shape, input_data);
3257 const auto indices_const = op::Constant::create(element::i32, indices_shape, {1});
3258 const auto updates_const = op::Constant::create(element::i32, indices_shape, {2});
3259 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3261 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3262 data_const, indices_const, updates_const, axis_const);
3263 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3265 pass::Manager pass_manager;
3266 pass_manager.register_pass<pass::ConstantFolding>();
3267 pass_manager.run_passes(f);
3269 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3270 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3273 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3274 ASSERT_TRUE(result_node);
3275 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3276 std::vector<int32_t> expected{input_data};
3277 // we have updated coordinate (1, 0, 0)
3279 range_test_check(result_node->cast_vector<int32_t>(), expected);
3282 void test_constant_folding_reshape_v1(Shape& shape_in,
3283 vector<float>& values_in,
3285 vector<int32_t> values_shape,
3286 bool zero_flag = false)
3288 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
3289 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
3290 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, zero_flag);
3291 dyn_reshape->set_friendly_name("test");
3292 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
3294 pass::Manager pass_manager;
3295 pass_manager.register_pass<pass::ConstantFolding>();
3296 pass_manager.run_passes(f);
3298 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
3299 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3302 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3303 ASSERT_TRUE(new_const);
3304 ASSERT_EQ(new_const->get_friendly_name(), "test");
3305 auto values_out = new_const->get_vector<float>();
3307 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
3309 TEST(constant_folding, constant_dyn_reshape_v1_2d)
3311 Shape shape_in{2, 5};
3312 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3314 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 1, 10});
3315 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 2, 5});
3316 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {1, 2, 5});
3317 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {5, 2, 1});
3320 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_negative_indices)
3322 Shape shape_in{2, 2, 2, 2};
3323 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3325 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {4, -1, 2});
3326 test_constant_folding_reshape_v1(shape_in, values_in, {2}, {4, -1});
3327 test_constant_folding_reshape_v1(shape_in, values_in, {1}, {-1});
3330 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_zero_dims)
3332 Shape shape_in{2, 2, 2, 2};
3333 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3335 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {2, -1, 2, 0}, true);
3336 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {4, 1, 0, 2}, true);
3339 TEST(constant_folding, disable_constant_folding)
3341 auto input = make_shared<op::Parameter>(element::f32, Shape{1, 3});
3342 auto constant_shape = op::Constant::create(element::i64, Shape{1}, {3});
3343 auto dyn_reshape = make_shared<op::v1::Reshape>(input, constant_shape, true);
3344 auto& rt_info = dyn_reshape->get_rt_info();
3345 rt_info["DISABLED_CONSTANT_FOLDING"];
3346 auto f = make_shared<Function>(dyn_reshape, ParameterVector{input});
3348 pass::Manager pass_manager;
3349 pass_manager.register_pass<pass::ConstantFolding>();
3350 pass_manager.run_passes(f);
3352 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
3353 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);