1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/manager.hpp"
22 #include "util/all_close_f.hpp"
23 #include "util/test_tools.hpp"
25 NGRAPH_SUPPRESS_DEPRECATED_START
27 using namespace ngraph;
31 static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
34 as_type_ptr<op::Constant>(f->get_results().at(pos)->input_value(0).get_node_shared_ptr());
35 return new_const->cast_vector<T>();
38 void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
40 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
43 void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
45 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
49 typename std::enable_if<std::is_integral<T>::value>::type
50 range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
52 ASSERT_EQ(values_out, values_expected);
55 TEST(constant_folding, acosh)
57 Shape shape_in{2, 4, 1};
59 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
60 vector<float> expected;
61 for (float f : values_in)
63 expected.push_back(std::acosh(f));
65 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
66 auto acosh = make_shared<op::Acosh>(constant);
67 acosh->set_friendly_name("test");
68 auto f = make_shared<Function>(acosh, ParameterVector{});
70 pass::Manager pass_manager;
71 pass_manager.register_pass<pass::ConstantFolding>();
72 pass_manager.run_passes(f);
74 EXPECT_EQ(count_ops_of_type<op::Acosh>(f), 0);
75 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
76 ASSERT_EQ(f->get_results().size(), 1);
79 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
80 EXPECT_TRUE(new_const);
81 ASSERT_EQ(new_const->get_friendly_name(), "test");
83 auto values_out = new_const->get_vector<float>();
84 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
87 TEST(constant_folding, asinh)
89 Shape shape_in{2, 4, 1};
91 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
92 vector<float> expected;
93 for (float f : values_in)
95 expected.push_back(std::asinh(f));
97 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
98 auto asinh = make_shared<op::Asinh>(constant);
99 asinh->set_friendly_name("test");
100 auto f = make_shared<Function>(asinh, ParameterVector{});
102 pass::Manager pass_manager;
103 pass_manager.register_pass<pass::ConstantFolding>();
104 pass_manager.run_passes(f);
106 EXPECT_EQ(count_ops_of_type<op::Asinh>(f), 0);
107 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
108 ASSERT_EQ(f->get_results().size(), 1);
111 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
112 EXPECT_TRUE(new_const);
113 ASSERT_EQ(new_const->get_friendly_name(), "test");
115 auto values_out = new_const->get_vector<float>();
116 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
119 TEST(constant_folding, atanh)
121 Shape shape_in{2, 4, 1};
123 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
124 vector<float> expected;
125 for (float f : values_in)
127 expected.push_back(std::atanh(f));
129 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
130 auto atanh = make_shared<op::Atanh>(constant);
131 atanh->set_friendly_name("test");
132 auto f = make_shared<Function>(atanh, ParameterVector{});
134 pass::Manager pass_manager;
135 pass_manager.register_pass<pass::ConstantFolding>();
136 pass_manager.run_passes(f);
138 EXPECT_EQ(count_ops_of_type<op::Atanh>(f), 0);
139 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
140 ASSERT_EQ(f->get_results().size(), 1);
143 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
144 EXPECT_TRUE(new_const);
145 ASSERT_EQ(new_const->get_friendly_name(), "test");
147 auto values_out = new_const->get_vector<float>();
148 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
151 TEST(constant_folding, constant_squeeze)
153 Shape shape_in{2, 4, 1};
154 Shape shape_out{2, 4};
157 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
158 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
159 vector<int64_t> values_axes{2};
160 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
161 auto squeeze = make_shared<op::Squeeze>(constant, constant_axes);
162 squeeze->set_friendly_name("test");
163 auto f = make_shared<Function>(squeeze, ParameterVector{});
165 pass::Manager pass_manager;
166 pass_manager.register_pass<pass::ConstantFolding>();
167 pass_manager.run_passes(f);
169 ASSERT_EQ(count_ops_of_type<op::Squeeze>(f), 0);
170 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
173 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
174 ASSERT_TRUE(new_const);
175 ASSERT_EQ(new_const->get_friendly_name(), "test");
176 ASSERT_EQ(new_const->get_shape(), shape_out);
178 auto values_out = new_const->get_vector<float>();
179 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
182 TEST(constant_folding, constant_unsqueeze)
184 Shape shape_in{2, 4};
185 Shape shape_out{2, 4, 1, 1};
188 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
189 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
190 vector<int64_t> values_axes{2, 3};
191 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
192 auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
193 unsqueeze->set_friendly_name("test");
194 auto f = make_shared<Function>(unsqueeze, ParameterVector{});
196 pass::Manager pass_manager;
197 pass_manager.register_pass<pass::ConstantFolding>();
198 pass_manager.run_passes(f);
200 ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
201 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
204 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
205 ASSERT_TRUE(new_const);
206 ASSERT_EQ(new_const->get_friendly_name(), "test");
207 ASSERT_EQ(new_const->get_shape(), shape_out);
209 auto values_out = new_const->get_vector<float>();
210 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
213 TEST(constant_folding, constant_reshape)
215 Shape shape_in{2, 4};
216 Shape shape_out{2, 4, 1};
218 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
219 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
220 auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
221 reshape->set_friendly_name("test");
222 auto f = make_shared<Function>(reshape, ParameterVector{});
224 pass::Manager pass_manager;
225 pass_manager.register_pass<pass::ConstantFolding>();
226 pass_manager.run_passes(f);
228 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
229 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
232 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
233 ASSERT_TRUE(new_const);
234 ASSERT_EQ(new_const->get_friendly_name(), "test");
235 auto values_out = new_const->get_vector<float>();
237 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
240 TEST(constant_folding, DISABLED_constant_reshape_permute)
242 Shape shape_in{2, 4};
243 Shape shape_out{4, 2};
245 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
246 auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
247 auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
248 reshape->set_friendly_name("test");
249 auto f = make_shared<Function>(reshape, ParameterVector{});
251 pass::Manager pass_manager;
252 pass_manager.register_pass<pass::ConstantFolding>();
253 pass_manager.run_passes(f);
255 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
256 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
259 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
260 ASSERT_TRUE(new_const);
261 ASSERT_EQ(new_const->get_friendly_name(), "test");
262 auto values_out = new_const->get_vector<double>();
264 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
265 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
268 TEST(constant_folding, constant_broadcast_v1)
270 vector<int32_t> values_in{0, 1};
271 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
272 vector<int64_t> shape_in{2, 4};
273 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
274 vector<int64_t> axes_in{0};
275 auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
276 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape, constant_axes);
277 broadcast_v1->set_friendly_name("test");
278 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
280 pass::Manager pass_manager;
281 pass_manager.register_pass<pass::ConstantFolding>();
282 pass_manager.run_passes(f);
284 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
285 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
288 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
289 ASSERT_TRUE(new_const);
290 ASSERT_EQ(new_const->get_friendly_name(), "test");
291 auto values_out = new_const->get_vector<int32_t>();
293 vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
294 ASSERT_EQ(values_expected, values_out);
297 TEST(constant_folding, constant_broadcast_v1_with_target_shape)
299 vector<int32_t> values_in{1};
300 auto constant_in = make_shared<op::Constant>(element::i32, Shape{1, 1, 1, 1}, values_in);
301 vector<int64_t> shape_in{1, 3, 1, 1};
302 auto target_shape = make_shared<op::Constant>(element::i64, Shape{4}, shape_in);
303 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, target_shape);
304 broadcast_v1->set_friendly_name("test");
305 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
307 pass::Manager pass_manager;
308 pass_manager.register_pass<pass::ConstantFolding>();
309 pass_manager.run_passes(f);
311 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
312 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
315 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
316 ASSERT_TRUE(new_const);
317 ASSERT_EQ(new_const->get_friendly_name(), "test");
318 auto values_out = new_const->get_vector<int32_t>();
320 vector<int32_t> values_expected{1, 1, 1};
321 ASSERT_EQ(values_expected, values_out);
324 TEST(constant_folding, constant_broadcast_v1_numpy)
326 vector<int32_t> values_in{0, 1};
327 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
328 vector<int64_t> shape_in{4, 2};
329 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
330 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape);
331 broadcast_v1->set_friendly_name("test");
332 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
334 pass::Manager pass_manager;
335 pass_manager.register_pass<pass::ConstantFolding>();
336 pass_manager.run_passes(f);
338 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
339 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
342 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
343 ASSERT_TRUE(new_const);
344 ASSERT_EQ(new_const->get_friendly_name(), "test");
345 auto values_out = new_const->get_vector<int32_t>();
347 vector<int32_t> values_expected{0, 1, 0, 1, 0, 1, 0, 1};
348 ASSERT_EQ(values_expected, values_out);
351 TEST(constant_folding, constant_unary_binary)
353 vector<int> values_a{1, 2, 3, 4};
354 vector<int> values_b{1, 2, 3, 4};
355 vector<int> values_c{-1, -1, -1, -1};
356 vector<int> values_d{1, 4, 9, 16};
357 vector<int> values_e{5, 6};
358 vector<int> values_f{0, 10};
359 vector<int> values_g{1, 4};
360 vector<char> values_h{0, 0, 1, 1};
361 vector<char> values_i{0, 1};
362 auto a = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_a);
363 auto b = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_b);
364 auto c = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_c);
365 auto d = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_d);
366 auto e = make_shared<op::Constant>(element::i32, Shape{2}, values_e);
367 auto f = make_shared<op::Constant>(element::i32, Shape{2}, values_f);
368 auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
369 auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
370 auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
376 auto pow = make_shared<op::Power>(a, b);
377 auto min = make_shared<op::Minimum>(c, a);
378 auto max = make_shared<op::Maximum>(a, c);
379 auto absn = make_shared<op::Abs>(c);
380 auto neg = make_shared<op::Negative>(c);
381 auto sqrt = make_shared<op::Sqrt>(d);
382 auto add_autob_numpy = make_shared<op::Add>(a, e, op::AutoBroadcastType::NUMPY);
383 auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
384 auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
385 auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
386 auto pow_autob_numpy = make_shared<op::Power>(a, g, op::AutoBroadcastType::NUMPY);
387 auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
388 auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
389 auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
390 auto not_equal_autob_numpy = make_shared<op::NotEqual>(a, g, op::AutoBroadcastType::NUMPY);
391 auto greater_autob_numpy = make_shared<op::Greater>(a, g, op::AutoBroadcastType::NUMPY);
392 auto greater_eq_autob_numpy = make_shared<op::GreaterEq>(a, g, op::AutoBroadcastType::NUMPY);
393 auto less_autob_numpy = make_shared<op::Less>(a, g, op::AutoBroadcastType::NUMPY);
394 auto less_eq_autob_numpy = make_shared<op::LessEq>(a, g, op::AutoBroadcastType::NUMPY);
395 auto logical_or_autob_numpy = make_shared<op::Or>(h, i, op::AutoBroadcastType::NUMPY);
396 auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
398 auto neg_sqrt = make_shared<op::Sqrt>(c);
400 auto func = make_shared<Function>(NodeVector{add,
418 not_equal_autob_numpy,
420 greater_eq_autob_numpy,
423 logical_or_autob_numpy,
424 logical_xor_autob_numpy},
426 auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
428 pass::Manager pass_manager;
429 pass_manager.register_pass<pass::ConstantFolding>();
430 pass_manager.run_passes(func);
433 vector<int> add_expected{2, 4, 6, 8};
434 vector<int> sub_expected{0, 0, 0, 0};
435 vector<int> mul_expected{1, 4, 9, 16};
436 vector<int> div_expected{1, 1, 1, 1};
437 vector<int> pow_expected{1, 4, 27, 256};
438 vector<int> min_expected{-1, -1, -1, -1};
439 vector<int> max_expected{1, 2, 3, 4};
440 vector<int> abs_neg_expected{1, 1, 1, 1};
441 vector<int> sqrt_expected{1, 2, 3, 4};
442 vector<int> add_autob_numpy_expected{6, 8, 8, 10};
443 vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
444 vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
445 vector<int> div_autob_numpy_expected{1, 0, 3, 1};
446 vector<int> pow_autob_numpy_expected{1, 16, 3, 256};
447 vector<int> min_autob_numpy_expected{0, 2, 0, 4};
448 vector<int> max_autob_numpy_expected{1, 10, 3, 10};
449 vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
450 vector<char> not_equal_autob_numpy_expected{0, 1, 1, 0};
451 vector<char> greater_autob_numpy_expected{0, 0, 1, 0};
452 vector<char> greater_eq_autob_numpy_expected{1, 0, 1, 1};
453 vector<char> less_autob_numpy_expected{0, 1, 0, 0};
454 vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
455 vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
456 vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
458 ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
459 ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
460 ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
461 ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
462 ASSERT_EQ(get_result_constant<int>(func, 4), pow_expected);
463 ASSERT_EQ(get_result_constant<int>(func, 5), min_expected);
464 ASSERT_EQ(get_result_constant<int>(func, 6), max_expected);
465 ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
466 ASSERT_EQ(get_result_constant<int>(func, 8), abs_neg_expected);
467 ASSERT_EQ(get_result_constant<int>(func, 9), sqrt_expected);
468 ASSERT_EQ(get_result_constant<int>(func, 10), add_autob_numpy_expected);
469 ASSERT_EQ(get_result_constant<int>(func, 11), sub_autob_numpy_expected);
470 ASSERT_EQ(get_result_constant<int>(func, 12), mul_autob_numpy_expected);
471 ASSERT_EQ(get_result_constant<int>(func, 13), div_autob_numpy_expected);
472 ASSERT_EQ(get_result_constant<int>(func, 14), pow_autob_numpy_expected);
473 ASSERT_EQ(get_result_constant<int>(func, 15), min_autob_numpy_expected);
474 ASSERT_EQ(get_result_constant<int>(func, 16), max_autob_numpy_expected);
475 ASSERT_EQ(get_result_constant<char>(func, 17), equal_autob_numpy_expected);
476 ASSERT_EQ(get_result_constant<char>(func, 18), not_equal_autob_numpy_expected);
477 ASSERT_EQ(get_result_constant<char>(func, 19), greater_autob_numpy_expected);
478 ASSERT_EQ(get_result_constant<char>(func, 20), greater_eq_autob_numpy_expected);
479 ASSERT_EQ(get_result_constant<char>(func, 21), less_autob_numpy_expected);
480 ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
481 ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
482 ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
483 ASSERT_NO_THROW(pass_manager.run_passes(func_error));
486 TEST(constant_folding, const_quantize)
488 Shape input_shape{12};
489 Shape scale_offset_shape;
490 AxisSet quantization_axes;
492 auto quant_type = element::u8;
493 auto output_type = element::u8;
494 typedef uint8_t output_c_type;
496 vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
497 auto constant = op::Constant::create(element::f32, input_shape, values_in);
498 auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
499 auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
500 auto mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
502 make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
503 quantize->set_friendly_name("test");
504 auto f = make_shared<Function>(quantize, ParameterVector{});
506 pass::Manager pass_manager;
507 pass_manager.register_pass<pass::ConstantFolding>();
508 pass_manager.run_passes(f);
510 ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
511 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
514 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
515 ASSERT_TRUE(new_const);
516 ASSERT_EQ(new_const->get_friendly_name(), "test");
517 auto values_out = new_const->get_vector<output_c_type>();
519 vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
520 ASSERT_EQ(values_quantize, values_out);
523 TEST(constant_folding, const_convert)
525 Shape input_shape{3, 4};
527 vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
528 auto constant = op::Constant::create(element::f32, input_shape, values_in);
529 auto convert = make_shared<op::Convert>(constant, element::u64);
530 convert->set_friendly_name("test");
531 auto f = make_shared<Function>(convert, ParameterVector{});
533 pass::Manager pass_manager;
534 pass_manager.register_pass<pass::ConstantFolding>();
535 pass_manager.run_passes(f);
537 ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
538 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
541 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
542 ASSERT_TRUE(new_const);
543 ASSERT_EQ(new_const->get_friendly_name(), "test");
544 ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
545 auto values_out = new_const->get_vector<uint64_t>();
547 vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
548 ASSERT_EQ(values_expected, values_out);
551 TEST(constant_folding, shape_of_v0)
553 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
555 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
556 auto shape_of = make_shared<op::v0::ShapeOf>(param);
557 shape_of->set_friendly_name("test");
558 auto f = make_shared<Function>(shape_of, ParameterVector{param});
560 pass::Manager pass_manager;
561 pass_manager.register_pass<pass::ConstantFolding>();
562 pass_manager.run_passes(f);
564 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 0);
565 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
568 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
569 ASSERT_TRUE(new_const);
570 ASSERT_EQ(new_const->get_friendly_name(), "test");
571 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
572 auto values_out = new_const->get_vector<int64_t>();
574 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
577 TEST(constant_folding, shape_of_v3)
579 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
581 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
582 auto shape_of = make_shared<op::v3::ShapeOf>(param);
583 shape_of->set_friendly_name("test");
584 auto f = make_shared<Function>(shape_of, ParameterVector{param});
586 pass::Manager pass_manager;
587 pass_manager.register_pass<pass::ConstantFolding>();
588 pass_manager.run_passes(f);
590 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
591 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
594 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
595 ASSERT_TRUE(new_const);
596 ASSERT_EQ(new_const->get_friendly_name(), "test");
597 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
598 auto values_out = new_const->get_vector<int64_t>();
600 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
603 TEST(constant_folding, shape_of_i32_v3)
605 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
607 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
608 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
609 shape_of->set_friendly_name("test");
610 auto f = make_shared<Function>(shape_of, ParameterVector{param});
612 pass::Manager pass_manager;
613 pass_manager.register_pass<pass::ConstantFolding>();
614 pass_manager.run_passes(f);
616 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
617 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
620 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
621 ASSERT_TRUE(new_const);
622 ASSERT_EQ(new_const->get_friendly_name(), "test");
623 ASSERT_EQ(new_const->get_output_element_type(0), element::i32);
624 auto values_out = new_const->get_vector<int32_t>();
626 ASSERT_EQ((vector<int32_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
629 TEST(constant_folding, shape_of_dynamic_v0)
631 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
633 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
634 auto shape_of = make_shared<op::v0::ShapeOf>(param);
635 shape_of->set_friendly_name("test");
636 auto f = make_shared<Function>(shape_of, ParameterVector{param});
638 pass::Manager pass_manager;
639 pass_manager.register_pass<pass::ConstantFolding>();
640 pass_manager.run_passes(f);
642 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
643 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
644 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
645 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
647 auto result_as_concat =
648 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
649 ASSERT_TRUE(result_as_concat);
650 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
651 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
654 TEST(constant_folding, shape_of_dynamic_v3)
656 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
658 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
659 auto shape_of = make_shared<op::v3::ShapeOf>(param);
660 shape_of->set_friendly_name("test");
661 auto f = make_shared<Function>(shape_of, ParameterVector{param});
663 pass::Manager pass_manager;
664 pass_manager.register_pass<pass::ConstantFolding>();
665 pass_manager.run_passes(f);
667 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
668 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
669 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
670 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
672 auto result_as_concat =
673 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
674 ASSERT_TRUE(result_as_concat);
675 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
676 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
677 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
680 TEST(constant_folding, shape_of_dynamic_i32_v3)
682 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
684 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
685 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
686 shape_of->set_friendly_name("test");
687 auto f = make_shared<Function>(shape_of, ParameterVector{param});
689 pass::Manager pass_manager;
690 pass_manager.register_pass<pass::ConstantFolding>();
691 pass_manager.run_passes(f);
693 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
694 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
695 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
696 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
698 auto result_as_concat =
699 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
700 ASSERT_TRUE(result_as_concat);
701 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
702 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
703 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
706 // We need to be sure that constant folding won't be calculated endlessly.
707 TEST(constant_folding, shape_of_dynamic_double_folding_v0)
709 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
711 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
712 auto shape_of = make_shared<op::v0::ShapeOf>(param);
713 shape_of->set_friendly_name("test");
714 auto f = make_shared<Function>(shape_of, ParameterVector{param});
716 pass::Manager pass_manager;
717 pass_manager.register_pass<pass::ConstantFolding>();
718 pass_manager.run_passes(f);
719 pass_manager.run_passes(f);
721 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
722 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
723 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
724 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
726 auto result_as_concat =
727 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
728 ASSERT_TRUE(result_as_concat);
729 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
730 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
733 TEST(constant_folding, shape_of_dynamic_double_folding_v3)
735 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
737 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
738 auto shape_of = make_shared<op::v3::ShapeOf>(param);
739 shape_of->set_friendly_name("test");
740 auto f = make_shared<Function>(shape_of, ParameterVector{param});
742 pass::Manager pass_manager;
743 pass_manager.register_pass<pass::ConstantFolding>();
744 pass_manager.run_passes(f);
745 pass_manager.run_passes(f);
747 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
748 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
749 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
750 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
752 auto result_as_concat =
753 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
754 ASSERT_TRUE(result_as_concat);
755 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
756 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
759 // Constant folding will not succeed on ShapeOf if the argument rank is dynamic.
760 // We want to make sure it fails gracefully, leaving the ShapeOf op in place.
761 TEST(constant_folding, shape_of_rank_dynamic_v0)
763 PartialShape input_shape{PartialShape::dynamic()};
765 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
766 auto shape_of = make_shared<op::v0::ShapeOf>(param);
767 shape_of->set_friendly_name("test");
768 auto f = make_shared<Function>(shape_of, ParameterVector{param});
770 pass::Manager pass_manager;
771 pass_manager.register_pass<pass::ConstantFolding>();
772 pass_manager.run_passes(f);
774 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
775 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
777 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
778 ASSERT_EQ(result_shape_of, shape_of);
779 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
782 TEST(constant_folding, shape_of_rank_dynamic_v3)
784 PartialShape input_shape{PartialShape::dynamic()};
786 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
787 auto shape_of = make_shared<op::v3::ShapeOf>(param);
788 shape_of->set_friendly_name("test");
789 auto f = make_shared<Function>(shape_of, ParameterVector{param});
791 pass::Manager pass_manager;
792 pass_manager.register_pass<pass::ConstantFolding>();
793 pass_manager.run_passes(f);
795 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
796 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
798 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
799 ASSERT_EQ(result_shape_of, shape_of);
800 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
803 TEST(constant_folding, const_reverse)
805 Shape input_shape{3, 3};
807 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
808 auto constant = op::Constant::create(element::i32, input_shape, values_in);
809 auto convert = make_shared<op::Reverse>(constant, AxisSet{1});
810 convert->set_friendly_name("test");
811 auto f = make_shared<Function>(convert, ParameterVector{});
813 pass::Manager pass_manager;
814 pass_manager.register_pass<pass::ConstantFolding>();
815 pass_manager.run_passes(f);
817 ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 0);
818 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
821 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
822 ASSERT_TRUE(new_const);
823 ASSERT_EQ(new_const->get_friendly_name(), "test");
824 auto values_out = new_const->get_vector<int32_t>();
826 vector<int32_t> values_expected{3, 2, 1, 6, 5, 4, 9, 8, 7};
827 ASSERT_EQ(values_expected, values_out);
830 TEST(constant_folding, const_reduceprod)
832 Shape input_shape{3, 3};
833 Shape output_shape{3};
835 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
836 auto constant = op::Constant::create(element::i32, input_shape, values_in);
838 vector<int32_t> values_axes{1};
839 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
840 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes);
841 convert->set_friendly_name("test");
842 auto f = make_shared<Function>(convert, ParameterVector{});
844 pass::Manager pass_manager;
845 pass_manager.register_pass<pass::ConstantFolding>();
846 pass_manager.run_passes(f);
848 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
849 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
852 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
853 ASSERT_TRUE(new_const);
854 ASSERT_EQ(new_const->get_friendly_name(), "test");
855 ASSERT_EQ(new_const->get_shape(), output_shape);
857 auto values_out = new_const->get_vector<int32_t>();
859 vector<int32_t> values_expected{6, 120, 504};
861 ASSERT_EQ(values_expected, values_out);
864 TEST(constant_folding, const_reduceprod_keepdims)
866 Shape input_shape{3, 3};
867 Shape output_shape{3, 1};
869 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
870 auto constant = op::Constant::create(element::i32, input_shape, values_in);
872 vector<int32_t> values_axes{1};
873 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
874 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes, true);
875 convert->set_friendly_name("test");
876 auto f = make_shared<Function>(convert, ParameterVector{});
878 pass::Manager pass_manager;
879 pass_manager.register_pass<pass::ConstantFolding>();
880 pass_manager.run_passes(f);
882 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
883 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
886 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
887 ASSERT_TRUE(new_const);
888 ASSERT_EQ(new_const->get_friendly_name(), "test");
889 ASSERT_EQ(new_const->get_shape(), output_shape);
891 auto values_out = new_const->get_vector<int32_t>();
893 vector<int32_t> values_expected{6, 120, 504};
895 ASSERT_EQ(values_expected, values_out);
898 TEST(constant_folding, const_sum)
900 Shape input_shape{3, 3};
902 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
903 auto constant = op::Constant::create(element::i32, input_shape, values_in);
904 auto convert = make_shared<op::Sum>(constant, AxisSet{1});
905 convert->set_friendly_name("test");
906 auto f = make_shared<Function>(convert, ParameterVector{});
908 pass::Manager pass_manager;
909 pass_manager.register_pass<pass::ConstantFolding>();
910 pass_manager.run_passes(f);
912 ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
913 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
916 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
917 ASSERT_TRUE(new_const);
918 ASSERT_EQ(new_const->get_friendly_name(), "test");
919 auto values_out = new_const->get_vector<int32_t>();
921 vector<int32_t> values_expected{6, 15, 24};
923 ASSERT_EQ(values_expected, values_out);
926 TEST(constant_folding, const_reducesum)
928 Shape input_shape{3, 3};
929 Shape output_shape{3};
931 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
932 auto constant = op::Constant::create(element::i32, input_shape, values_in);
934 vector<int32_t> values_axes{1};
935 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
936 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes);
937 convert->set_friendly_name("test");
938 auto f = make_shared<Function>(convert, ParameterVector{});
940 pass::Manager pass_manager;
941 pass_manager.register_pass<pass::ConstantFolding>();
942 pass_manager.run_passes(f);
944 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
945 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
948 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
949 ASSERT_TRUE(new_const);
950 ASSERT_EQ(new_const->get_friendly_name(), "test");
951 ASSERT_EQ(new_const->get_shape(), output_shape);
953 auto values_out = new_const->get_vector<int32_t>();
955 vector<int32_t> values_expected{6, 15, 24};
957 ASSERT_EQ(values_expected, values_out);
960 TEST(constant_folding, const_reducesum_keepdims)
962 Shape input_shape{3, 3};
963 Shape output_shape{3, 1};
965 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
966 auto constant = op::Constant::create(element::i32, input_shape, values_in);
968 vector<int32_t> values_axes{1};
969 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
970 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes, true);
971 convert->set_friendly_name("test");
972 auto f = make_shared<Function>(convert, ParameterVector{});
974 pass::Manager pass_manager;
975 pass_manager.register_pass<pass::ConstantFolding>();
976 pass_manager.run_passes(f);
978 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
979 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
982 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
983 ASSERT_TRUE(new_const);
984 ASSERT_EQ(new_const->get_friendly_name(), "test");
985 ASSERT_EQ(new_const->get_shape(), output_shape);
987 auto values_out = new_const->get_vector<int32_t>();
989 vector<int32_t> values_expected{6, 15, 24};
991 ASSERT_EQ(values_expected, values_out);
994 TEST(constant_folding, const_reducemax)
996 Shape input_shape{3, 2};
997 Shape output_shape{3};
999 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1000 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1001 Shape axes_shape{1};
1002 vector<int32_t> values_axes{1};
1003 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1004 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes);
1005 convert->set_friendly_name("test");
1006 auto f = make_shared<Function>(convert, ParameterVector{});
1008 pass::Manager pass_manager;
1009 pass_manager.register_pass<pass::ConstantFolding>();
1010 pass_manager.run_passes(f);
1012 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1013 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1016 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1017 ASSERT_TRUE(new_const);
1018 ASSERT_EQ(new_const->get_friendly_name(), "test");
1019 ASSERT_EQ(new_const->get_shape(), output_shape);
1021 auto values_out = new_const->get_vector<int32_t>();
1023 vector<int32_t> values_expected{2, 4, 6};
1025 ASSERT_EQ(values_expected, values_out);
1028 TEST(constant_folding, const_reducemax_keepdims)
1030 Shape input_shape{3, 2};
1031 Shape output_shape{3, 1};
1033 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1034 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1035 Shape axes_shape{1};
1036 vector<int32_t> values_axes{1};
1037 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1038 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes, true);
1039 convert->set_friendly_name("test");
1040 auto f = make_shared<Function>(convert, ParameterVector{});
1042 pass::Manager pass_manager;
1043 pass_manager.register_pass<pass::ConstantFolding>();
1044 pass_manager.run_passes(f);
1046 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1047 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1050 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1051 ASSERT_TRUE(new_const);
1052 ASSERT_EQ(new_const->get_friendly_name(), "test");
1053 ASSERT_EQ(new_const->get_shape(), output_shape);
1055 auto values_out = new_const->get_vector<int32_t>();
1057 vector<int32_t> values_expected{2, 4, 6};
1059 ASSERT_EQ(values_expected, values_out);
1062 TEST(constant_folding, const_reducemin)
1064 Shape input_shape{3, 2};
1065 Shape output_shape{3};
1067 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1068 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1069 Shape axes_shape{1};
1070 vector<int32_t> values_axes{1};
1071 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1072 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes);
1073 convert->set_friendly_name("test");
1074 auto f = make_shared<Function>(convert, ParameterVector{});
1076 pass::Manager pass_manager;
1077 pass_manager.register_pass<pass::ConstantFolding>();
1078 pass_manager.run_passes(f);
1080 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1081 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1084 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1085 ASSERT_TRUE(new_const);
1086 ASSERT_EQ(new_const->get_friendly_name(), "test");
1087 ASSERT_EQ(new_const->get_shape(), output_shape);
1089 auto values_out = new_const->get_vector<int32_t>();
1091 vector<int32_t> values_expected{1, 3, 5};
1093 ASSERT_EQ(values_expected, values_out);
1096 TEST(constant_folding, const_reducemin_keepdims)
1098 Shape input_shape{3, 2};
1099 Shape output_shape{3, 1};
1101 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1102 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1103 Shape axes_shape{1};
1104 vector<int32_t> values_axes{1};
1105 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1106 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes, true);
1107 convert->set_friendly_name("test");
1108 auto f = make_shared<Function>(convert, ParameterVector{});
1110 pass::Manager pass_manager;
1111 pass_manager.register_pass<pass::ConstantFolding>();
1112 pass_manager.run_passes(f);
1114 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1115 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1118 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1119 ASSERT_TRUE(new_const);
1120 ASSERT_EQ(new_const->get_friendly_name(), "test");
1121 ASSERT_EQ(new_const->get_shape(), output_shape);
1123 auto values_out = new_const->get_vector<int32_t>();
1125 vector<int32_t> values_expected{1, 3, 5};
1127 ASSERT_EQ(values_expected, values_out);
1130 TEST(constant_folding, const_reducemean)
1132 Shape input_shape{3, 3};
1133 Shape output_shape{3};
1135 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1136 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1137 Shape axes_shape{1};
1138 vector<int32_t> values_axes{1};
1139 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1140 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes);
1141 convert->set_friendly_name("test");
1142 auto f = make_shared<Function>(convert, ParameterVector{});
1144 pass::Manager pass_manager;
1145 pass_manager.register_pass<pass::ConstantFolding>();
1146 pass_manager.run_passes(f);
1148 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1149 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1152 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1153 ASSERT_TRUE(new_const);
1154 ASSERT_EQ(new_const->get_friendly_name(), "test");
1155 ASSERT_EQ(new_const->get_shape(), output_shape);
1157 auto values_out = new_const->get_vector<int32_t>();
1159 vector<int32_t> values_expected{2, 5, 8};
1161 ASSERT_EQ(values_expected, values_out);
1164 TEST(constant_folding, const_reducemean_keepdims)
1166 Shape input_shape{3, 3};
1167 Shape output_shape{3, 1};
1169 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1170 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1171 Shape axes_shape{1};
1172 vector<int32_t> values_axes{1};
1173 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1174 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes, true);
1175 convert->set_friendly_name("test");
1176 auto f = make_shared<Function>(convert, ParameterVector{});
1178 pass::Manager pass_manager;
1179 pass_manager.register_pass<pass::ConstantFolding>();
1180 pass_manager.run_passes(f);
1182 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1183 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1186 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1187 ASSERT_TRUE(new_const);
1188 ASSERT_EQ(new_const->get_friendly_name(), "test");
1189 ASSERT_EQ(new_const->get_shape(), output_shape);
1191 auto values_out = new_const->get_vector<int32_t>();
1193 vector<int32_t> values_expected{2, 5, 8};
1195 ASSERT_EQ(values_expected, values_out);
1198 TEST(constant_folding, const_reduce_logical_and__no_keepdims)
1200 const Shape input_shape{3, 3};
1202 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1203 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1204 const auto axes = op::Constant::create(element::i64, {1}, {1});
1205 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, false);
1206 convert->set_friendly_name("test");
1207 auto f = make_shared<Function>(convert, ParameterVector{});
1209 pass::Manager pass_manager;
1210 pass_manager.register_pass<pass::ConstantFolding>();
1211 pass_manager.run_passes(f);
1213 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1214 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1216 const auto new_const =
1217 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1218 ASSERT_TRUE(new_const);
1219 ASSERT_EQ(new_const->get_friendly_name(), "test");
1221 const Shape expected_out_shape{3};
1222 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1224 const auto values_out = new_const->get_vector<char>();
1226 const vector<char> values_expected{0, 0, 1};
1228 ASSERT_EQ(values_expected, values_out);
1231 TEST(constant_folding, const_reduce_logical_and__keepdims)
1233 const Shape input_shape{3, 3};
1235 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1236 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1237 const auto axes = op::Constant::create(element::i64, {1}, {1});
1238 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1239 convert->set_friendly_name("test");
1240 auto f = make_shared<Function>(convert, ParameterVector{});
1242 pass::Manager pass_manager;
1243 pass_manager.register_pass<pass::ConstantFolding>();
1244 pass_manager.run_passes(f);
1246 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1247 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1249 const auto new_const =
1250 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1251 ASSERT_TRUE(new_const);
1252 ASSERT_EQ(new_const->get_friendly_name(), "test");
1254 // the output shape is expected to have 'ones' at the positions specified in the reduction axes
1255 // in case the keep_dims attribute of ReduceLogicalAnd is set to true
1256 const Shape expected_out_shape{3, 1};
1257 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1259 const auto values_out = new_const->get_vector<char>();
1261 const vector<char> values_expected{0, 0, 1};
1263 ASSERT_EQ(values_expected, values_out);
1266 TEST(constant_folding, const_reduce_logical_and__keepdims_3d)
1268 const Shape input_shape{2, 2, 2};
1270 const vector<char> values_in{1, 1, 0, 0, 1, 0, 0, 1};
1271 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1272 const auto axes = op::Constant::create(element::i64, {2}, {0, 2});
1273 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1274 convert->set_friendly_name("test");
1275 auto f = make_shared<Function>(convert, ParameterVector{});
1277 pass::Manager pass_manager;
1278 pass_manager.register_pass<pass::ConstantFolding>();
1279 pass_manager.run_passes(f);
1281 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1282 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1284 const auto new_const =
1285 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1286 ASSERT_TRUE(new_const);
1287 ASSERT_EQ(new_const->get_friendly_name(), "test");
1289 const Shape expected_out_shape{1, 2, 1};
1290 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1292 const auto values_out = new_const->get_vector<char>();
1294 const vector<char> values_expected{0, 0};
1296 ASSERT_EQ(values_expected, values_out);
1299 TEST(constant_folding, const_reduce_logical_or__no_keepdims)
1301 const Shape input_shape{3, 3};
1303 const vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
1304 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1305 const auto axes = op::Constant::create(element::i64, {1}, {1});
1306 const auto convert = make_shared<op::v1::ReduceLogicalOr>(data, axes, false);
1307 convert->set_friendly_name("test");
1308 auto f = make_shared<Function>(convert, ParameterVector{});
1310 pass::Manager pass_manager;
1311 pass_manager.register_pass<pass::ConstantFolding>();
1312 pass_manager.run_passes(f);
1314 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1315 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1317 const auto new_const =
1318 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1319 ASSERT_TRUE(new_const);
1320 ASSERT_EQ(new_const->get_friendly_name(), "test");
1322 const Shape expected_out_shape{3};
1323 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1325 const auto values_out = new_const->get_vector<char>();
1327 const vector<char> values_expected{1, 1, 0};
1329 ASSERT_EQ(values_expected, values_out);
1332 TEST(constant_folding, const_concat)
1335 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1336 auto constant1 = op::Constant::create(element::i32, Shape{2, 1}, vector<int32_t>{7, 8});
1337 auto concat = make_shared<op::Concat>(NodeVector{constant0, constant1}, 1);
1338 concat->set_friendly_name("test");
1339 auto f = make_shared<Function>(concat, ParameterVector{});
1341 pass::Manager pass_manager;
1342 pass_manager.register_pass<pass::ConstantFolding>();
1343 pass_manager.run_passes(f);
1345 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1346 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1349 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1350 ASSERT_TRUE(new_const);
1351 ASSERT_EQ(new_const->get_friendly_name(), "test");
1352 auto values_out = new_const->get_vector<int32_t>();
1354 vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
1356 ASSERT_EQ(values_expected, values_out);
1359 TEST(constant_folding, const_concat_3d_single_elem)
1361 auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
1362 auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
1363 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
1364 concat->set_friendly_name("test");
1365 auto f = make_shared<Function>(concat, ParameterVector{});
1367 pass::Manager pass_manager;
1368 pass_manager.register_pass<pass::ConstantFolding>();
1369 pass_manager.run_passes(f);
1371 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1372 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1375 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1377 ASSERT_TRUE(new_const);
1378 ASSERT_EQ(new_const->get_friendly_name(), "test");
1379 ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
1381 auto values_out = new_const->get_vector<int32_t>();
1382 vector<int32_t> values_expected{1, 2};
1383 ASSERT_EQ(values_expected, values_out);
1386 TEST(constant_folding, const_concat_axis_2)
1389 op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1390 auto constant_2 = op::Constant::create(
1391 element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
1392 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
1393 concat->set_friendly_name("test");
1394 auto f = make_shared<Function>(concat, ParameterVector{});
1396 pass::Manager pass_manager;
1397 pass_manager.register_pass<pass::ConstantFolding>();
1398 pass_manager.run_passes(f);
1400 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1401 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1404 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1406 ASSERT_TRUE(new_const);
1407 ASSERT_EQ(new_const->get_friendly_name(), "test");
1408 ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
1410 auto values_out = new_const->get_vector<int32_t>();
1411 vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
1412 ASSERT_EQ(values_expected, values_out);
1415 TEST(constant_folding, const_concat_axis_1_bool_type)
1418 op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
1419 auto constant_2 = op::Constant::create(
1420 element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
1421 auto constant_3 = op::Constant::create(
1422 element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
1423 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
1424 concat->set_friendly_name("test");
1425 auto f = make_shared<Function>(concat, ParameterVector{});
1427 pass::Manager pass_manager;
1428 pass_manager.register_pass<pass::ConstantFolding>();
1429 pass_manager.run_passes(f);
1431 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1432 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1435 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1437 ASSERT_TRUE(new_const);
1438 ASSERT_EQ(new_const->get_friendly_name(), "test");
1439 ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
1441 auto values_out = new_const->get_vector<char>();
1442 vector<char> values_expected{
1443 true, true, true, false, true, false, true, false, true, false, true, false};
1444 ASSERT_EQ(values_expected, values_out);
1447 TEST(constant_folding, const_not)
1450 op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
1451 auto logical_not = make_shared<op::Not>(constant);
1452 logical_not->set_friendly_name("test");
1453 auto f = make_shared<Function>(logical_not, ParameterVector{});
1455 pass::Manager pass_manager;
1456 pass_manager.register_pass<pass::ConstantFolding>();
1457 pass_manager.run_passes(f);
1459 ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
1460 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1463 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1464 ASSERT_TRUE(new_const);
1465 ASSERT_EQ(new_const->get_friendly_name(), "test");
1466 auto values_out = new_const->get_vector<char>();
1468 vector<char> values_expected{1, 0, 1, 1, 0, 0};
1470 ASSERT_EQ(values_expected, values_out);
1473 TEST(constant_folding, const_equal)
1476 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1478 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1479 auto eq = make_shared<op::Equal>(constant0, constant1);
1480 eq->set_friendly_name("test");
1481 auto f = make_shared<Function>(eq, ParameterVector{});
1483 pass::Manager pass_manager;
1484 pass_manager.register_pass<pass::ConstantFolding>();
1485 pass_manager.run_passes(f);
1487 ASSERT_EQ(count_ops_of_type<op::Equal>(f), 0);
1488 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1491 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1492 ASSERT_TRUE(new_const);
1493 ASSERT_EQ(new_const->get_friendly_name(), "test");
1494 auto values_out = new_const->get_vector<char>();
1496 vector<char> values_expected{1, 1, 0, 0, 1, 1};
1498 ASSERT_EQ(values_expected, values_out);
1501 TEST(constant_folding, const_not_equal)
1504 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1506 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1507 auto eq = make_shared<op::NotEqual>(constant0, constant1);
1508 eq->set_friendly_name("test");
1509 auto f = make_shared<Function>(eq, ParameterVector{});
1511 pass::Manager pass_manager;
1512 pass_manager.register_pass<pass::ConstantFolding>();
1513 pass_manager.run_passes(f);
1515 ASSERT_EQ(count_ops_of_type<op::NotEqual>(f), 0);
1516 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1519 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1520 ASSERT_TRUE(new_const);
1521 ASSERT_EQ(new_const->get_friendly_name(), "test");
1522 auto values_out = new_const->get_vector<char>();
1524 vector<char> values_expected{0, 0, 1, 1, 0, 0};
1526 ASSERT_EQ(values_expected, values_out);
1529 TEST(constant_folding, const_greater)
1532 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1534 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1535 auto eq = make_shared<op::Greater>(constant0, constant1);
1536 eq->set_friendly_name("test");
1537 auto f = make_shared<Function>(eq, ParameterVector{});
1539 pass::Manager pass_manager;
1540 pass_manager.register_pass<pass::ConstantFolding>();
1541 pass_manager.run_passes(f);
1543 ASSERT_EQ(count_ops_of_type<op::Greater>(f), 0);
1544 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1547 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1548 ASSERT_TRUE(new_const);
1549 ASSERT_EQ(new_const->get_friendly_name(), "test");
1550 auto values_out = new_const->get_vector<char>();
1552 vector<char> values_expected{0, 0, 1, 0, 0, 1};
1554 ASSERT_EQ(values_expected, values_out);
1557 TEST(constant_folding, const_greater_eq)
1560 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1562 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1563 auto eq = make_shared<op::GreaterEq>(constant0, constant1);
1564 eq->set_friendly_name("test");
1565 auto f = make_shared<Function>(eq, ParameterVector{});
1567 pass::Manager pass_manager;
1568 pass_manager.register_pass<pass::ConstantFolding>();
1569 pass_manager.run_passes(f);
1571 ASSERT_EQ(count_ops_of_type<op::GreaterEq>(f), 0);
1572 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1575 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1576 ASSERT_TRUE(new_const);
1577 ASSERT_EQ(new_const->get_friendly_name(), "test");
1578 auto values_out = new_const->get_vector<char>();
1580 vector<char> values_expected{0, 1, 1, 0, 1, 1};
1582 ASSERT_EQ(values_expected, values_out);
1585 TEST(constant_folding, const_less)
1588 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1590 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1591 auto eq = make_shared<op::Less>(constant0, constant1);
1592 eq->set_friendly_name("test");
1593 auto f = make_shared<Function>(eq, ParameterVector{});
1595 pass::Manager pass_manager;
1596 pass_manager.register_pass<pass::ConstantFolding>();
1597 pass_manager.run_passes(f);
1599 ASSERT_EQ(count_ops_of_type<op::Less>(f), 0);
1600 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1603 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1604 ASSERT_TRUE(new_const);
1605 ASSERT_EQ(new_const->get_friendly_name(), "test");
1606 auto values_out = new_const->get_vector<char>();
1608 vector<char> values_expected{1, 0, 0, 1, 0, 0};
1610 ASSERT_EQ(values_expected, values_out);
1613 TEST(constant_folding, const_less_eq)
1616 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1618 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1619 auto eq = make_shared<op::LessEq>(constant0, constant1);
1620 eq->set_friendly_name("test");
1621 auto f = make_shared<Function>(eq, ParameterVector{});
1623 pass::Manager pass_manager;
1624 pass_manager.register_pass<pass::ConstantFolding>();
1625 pass_manager.run_passes(f);
1627 ASSERT_EQ(count_ops_of_type<op::LessEq>(f), 0);
1628 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1631 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1632 ASSERT_TRUE(new_const);
1633 ASSERT_EQ(new_const->get_friendly_name(), "test");
1634 auto values_out = new_const->get_vector<char>();
1636 vector<char> values_expected{1, 1, 0, 1, 1, 0};
1638 ASSERT_EQ(values_expected, values_out);
1641 TEST(constant_folding, const_or)
1644 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1646 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1647 auto eq = make_shared<op::Or>(constant0, constant1);
1648 eq->set_friendly_name("test");
1649 auto f = make_shared<Function>(eq, ParameterVector{});
1651 pass::Manager pass_manager;
1652 pass_manager.register_pass<pass::ConstantFolding>();
1653 pass_manager.run_passes(f);
1655 ASSERT_EQ(count_ops_of_type<op::Or>(f), 0);
1656 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1659 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1660 ASSERT_TRUE(new_const);
1661 ASSERT_EQ(new_const->get_friendly_name(), "test");
1662 auto values_out = new_const->get_vector<char>();
1664 vector<char> values_expected{0, 1, 1, 1, 1, 1};
1666 ASSERT_EQ(values_expected, values_out);
1669 TEST(constant_folding, const_xor)
1672 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1674 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1675 auto eq = make_shared<op::Xor>(constant0, constant1);
1676 eq->set_friendly_name("test");
1677 auto f = make_shared<Function>(eq, ParameterVector{});
1679 pass::Manager pass_manager;
1680 pass_manager.register_pass<pass::ConstantFolding>();
1681 pass_manager.run_passes(f);
1683 ASSERT_EQ(count_ops_of_type<op::Xor>(f), 0);
1684 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1687 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1688 ASSERT_TRUE(new_const);
1689 ASSERT_EQ(new_const->get_friendly_name(), "test");
1690 auto values_out = new_const->get_vector<char>();
1692 vector<char> values_expected{0, 1, 0, 1, 1, 0};
1694 ASSERT_EQ(values_expected, values_out);
1697 TEST(constant_folding, const_ceiling)
1699 auto constant = op::Constant::create(
1700 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1701 auto ceil = make_shared<op::Ceiling>(constant);
1702 ceil->set_friendly_name("test");
1703 auto f = make_shared<Function>(ceil, ParameterVector{});
1705 pass::Manager pass_manager;
1706 pass_manager.register_pass<pass::ConstantFolding>();
1707 pass_manager.run_passes(f);
1709 ASSERT_EQ(count_ops_of_type<op::Ceiling>(f), 0);
1710 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1713 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1714 ASSERT_TRUE(new_const);
1715 ASSERT_EQ(new_const->get_friendly_name(), "test");
1716 auto values_out = new_const->get_vector<float>();
1718 vector<float> values_expected{0.0f, 1.0f, 0.0f, -2.0f, 3.0f, 3.0f};
1720 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1723 TEST(constant_folding, const_floor)
1725 auto constant = op::Constant::create(
1726 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1727 auto floor = make_shared<op::Floor>(constant);
1728 floor->set_friendly_name("test");
1729 auto f = make_shared<Function>(floor, ParameterVector{});
1731 pass::Manager pass_manager;
1732 pass_manager.register_pass<pass::ConstantFolding>();
1733 pass_manager.run_passes(f);
1735 ASSERT_EQ(count_ops_of_type<op::Floor>(f), 0);
1736 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1739 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1740 ASSERT_TRUE(new_const);
1741 ASSERT_EQ(new_const->get_friendly_name(), "test");
1742 auto values_out = new_const->get_vector<float>();
1744 vector<float> values_expected{0.0f, 0.0f, -1.0f, -3.0f, 2.0f, 3.0f};
1746 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1749 TEST(constant_folding, const_gather_v1)
1751 auto constant_data = op::Constant::create(
1754 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1755 auto constant_indices =
1756 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1757 auto constant_axis = op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
1758 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1759 gather->set_friendly_name("test");
1760 auto f = make_shared<Function>(gather, ParameterVector{});
1762 pass::Manager pass_manager;
1763 pass_manager.register_pass<pass::ConstantFolding>();
1764 pass_manager.run_passes(f);
1766 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1767 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1770 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1771 ASSERT_TRUE(new_const);
1772 ASSERT_EQ(new_const->get_friendly_name(), "test");
1773 auto values_out = new_const->get_vector<float>();
1775 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1777 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1780 TEST(constant_folding, const_gather_v1_scalar)
1782 auto constant_data = op::Constant::create(
1785 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1786 auto constant_indices =
1787 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1788 auto constant_axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{1});
1789 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1790 gather->set_friendly_name("test");
1791 auto f = make_shared<Function>(gather, ParameterVector{});
1793 pass::Manager pass_manager;
1794 pass_manager.register_pass<pass::ConstantFolding>();
1795 pass_manager.run_passes(f);
1797 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1798 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1801 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1802 ASSERT_TRUE(new_const);
1803 ASSERT_EQ(new_const->get_friendly_name(), "test");
1804 auto values_out = new_const->get_vector<float>();
1806 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1808 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1811 TEST(constant_folding, const_gather_v1_subgraph)
1813 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1814 const float b_value = 3.21f;
1815 const auto B_const = op::Constant::create(element::f32, {1}, {b_value});
1816 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1817 const int64_t axis = 0;
1818 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1820 const auto concat = make_shared<op::Concat>(NodeVector{A, B_const, C}, axis);
1822 const vector<int64_t> indices{1};
1823 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1824 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1825 gather->set_friendly_name("test");
1826 auto f = make_shared<Function>(gather, ParameterVector{A, C});
1828 pass::Manager pass_manager;
1829 pass_manager.register_pass<pass::ConstantFolding>();
1830 pass_manager.run_passes(f);
1832 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1833 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1834 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1836 const auto new_const =
1837 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1838 ASSERT_TRUE(new_const);
1839 ASSERT_EQ(new_const->get_friendly_name(), "test");
1841 const auto values_out = new_const->get_vector<float>();
1842 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1845 TEST(constant_folding, const_gather_v1_subgraph_neg_axis)
1847 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1848 const float b_value = 1.23f;
1849 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1850 const auto C_const = op::Constant::create(element::f32, {1}, {b_value});
1851 const int64_t axis = 0;
1852 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1854 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C_const}, axis);
1856 const vector<int64_t> indices{-1};
1857 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1858 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1859 gather->set_friendly_name("test");
1860 auto f = make_shared<Function>(gather, ParameterVector{A, B});
1862 pass::Manager pass_manager;
1863 pass_manager.register_pass<pass::ConstantFolding>();
1864 pass_manager.run_passes(f);
1866 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1867 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1868 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1870 const auto new_const =
1871 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1872 ASSERT_TRUE(new_const);
1873 ASSERT_EQ(new_const->get_friendly_name(), "test");
1875 const auto values_out = new_const->get_vector<float>();
1876 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1879 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input)
1881 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1882 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1883 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1884 const int64_t axis = 0;
1885 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1887 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1889 const vector<int64_t> indices{1};
1890 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1891 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1892 gather->set_friendly_name("test");
1893 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1895 pass::Manager pass_manager;
1896 pass_manager.register_pass<pass::ConstantFolding>();
1897 pass_manager.run_passes(f);
1899 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1900 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1903 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input_scalar)
1905 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1906 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1907 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1908 const int64_t axis = 0;
1909 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1911 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1913 const vector<int64_t> indices{1};
1914 const auto indices_const = op::Constant::create(element::i64, {}, indices);
1915 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1916 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1918 pass::Manager pass_manager;
1919 pass_manager.register_pass<pass::ConstantFolding>();
1920 pass_manager.run_passes(f);
1922 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1923 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1924 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(f), 1);
1927 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_zero_axis)
1929 const auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1930 const auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1931 const auto C = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1932 const int64_t axis = 1;
1933 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1935 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1937 const vector<int64_t> indices{1};
1938 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1939 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1940 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1942 pass::Manager pass_manager;
1943 pass_manager.register_pass<pass::ConstantFolding>();
1944 pass_manager.run_passes(f);
1946 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
1947 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
1950 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_single_indices)
1952 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1953 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1954 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1955 const int64_t axis = 0;
1956 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1958 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1960 const vector<int64_t> indices{0, 1};
1961 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1962 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1963 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1965 pass::Manager pass_manager;
1966 pass_manager.register_pass<pass::ConstantFolding>();
1967 pass_manager.run_passes(f);
1969 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
1970 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
1973 TEST(constant_folding, const_gather_v1_subgraph_skip_if_concat_output_shape_dynamic)
1975 const auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
1976 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1977 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1978 const int64_t axis = 0;
1979 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1981 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1983 const vector<int64_t> indices{1};
1984 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1985 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1986 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1988 pass::Manager pass_manager;
1989 pass_manager.register_pass<pass::ConstantFolding>();
1990 pass_manager.run_passes(f);
1992 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
1993 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
1996 TEST(constant_folding, const_gather_v1_subgraph_skip_if_not_single_input)
1998 const auto A = make_shared<op::Parameter>(element::f32, Shape{2});
1999 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2000 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2001 const int64_t axis = 0;
2002 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2004 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2006 const vector<int64_t> indices{1};
2007 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2008 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2009 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2011 pass::Manager pass_manager;
2012 pass_manager.register_pass<pass::ConstantFolding>();
2013 pass_manager.run_passes(f);
2015 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2016 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2019 TEST(constant_folding, const_slice)
2023 vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
2024 auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
2025 auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
2026 slice->set_friendly_name("test");
2028 auto f = make_shared<Function>(slice, ParameterVector{});
2030 pass::Manager pass_manager;
2031 pass_manager.register_pass<pass::ConstantFolding>();
2032 pass_manager.run_passes(f);
2034 ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
2035 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2038 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2039 ASSERT_TRUE(new_const);
2040 ASSERT_EQ(new_const->get_friendly_name(), "test");
2041 auto values_out = new_const->get_vector<int>();
2043 vector<int> sliced_values{3, 6, 9, 12, 15};
2044 ASSERT_EQ(sliced_values, values_out);
2047 TEST(constant_folding, constant_dyn_reshape)
2049 Shape shape_in{2, 4};
2050 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2052 Shape shape_shape{3};
2053 vector<int64_t> values_shape{2, 4, 1};
2055 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2056 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
2057 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
2058 dyn_reshape->set_friendly_name("test");
2059 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2061 pass::Manager pass_manager;
2062 pass_manager.register_pass<pass::ConstantFolding>();
2063 pass_manager.run_passes(f);
2065 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2066 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2069 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2070 ASSERT_TRUE(new_const);
2071 ASSERT_EQ(new_const->get_friendly_name(), "test");
2072 auto values_out = new_const->get_vector<float>();
2074 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2077 TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
2079 Shape shape_in{2, 4};
2080 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2082 Shape shape_shape{3};
2083 // We're going to add these two together elementwise to get {2, 4, 1}.
2084 // This means that when ConstantFolding starts, v1::Reshape will not yet
2085 // have static output shape. But by the time the Add op is folded, the
2086 // v1::Reshape's shape should be inferrable.
2087 vector<int64_t> values_shape_a{1, 3, 0};
2088 vector<int64_t> values_shape_b{1, 1, 1};
2090 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2091 auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
2092 auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
2094 make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
2095 dyn_reshape->set_friendly_name("test");
2096 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2098 ASSERT_TRUE(dyn_reshape->get_output_partial_shape(0).is_dynamic());
2100 pass::Manager pass_manager;
2101 pass_manager.register_pass<pass::ConstantFolding>();
2102 pass_manager.run_passes(f);
2104 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2105 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2108 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2109 ASSERT_TRUE(new_const);
2110 ASSERT_EQ(new_const->get_friendly_name(), "test");
2111 auto values_out = new_const->get_vector<float>();
2113 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2116 TEST(constant_folding, constant_transpose)
2118 Shape shape_in{2, 4};
2119 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2121 Shape shape_perm{2};
2122 vector<int64_t> values_perm{1, 0};
2124 auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
2125 auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
2126 auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
2127 transpose->set_friendly_name("test");
2128 auto f = make_shared<Function>(transpose, ParameterVector{});
2130 pass::Manager pass_manager;
2131 pass_manager.register_pass<pass::ConstantFolding>();
2132 pass_manager.run_passes(f);
2134 ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
2135 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2138 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2139 ASSERT_TRUE(new_const);
2140 ASSERT_EQ(new_const->get_friendly_name(), "test");
2141 auto values_out = new_const->get_vector<double>();
2143 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
2144 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
2147 template <typename T>
2148 void range_test(T start, T stop, T step, const vector<T>& values_expected)
2150 vector<T> values_start{start};
2151 vector<T> values_stop{stop};
2152 vector<T> values_step{step};
2154 auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
2155 auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
2156 auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
2157 auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
2158 range->set_friendly_name("test");
2159 auto f = make_shared<Function>(range, ParameterVector{});
2161 pass::Manager pass_manager;
2162 pass_manager.register_pass<pass::ConstantFolding>();
2163 pass_manager.run_passes(f);
2165 ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
2166 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2169 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2170 ASSERT_TRUE(new_const);
2171 ASSERT_EQ(new_const->get_friendly_name(), "test");
2173 auto values_out = new_const->template get_vector<T>();
2175 range_test_check(values_out, values_expected);
2178 TEST(constant_folding, constant_range)
2180 range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
2181 range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
2182 range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
2183 range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
2184 range_test<double>(5, 12, 2, {5, 7, 9, 11});
2185 range_test<float>(5, 12, 2, {5, 7, 9, 11});
2187 range_test<int32_t>(5, 12, -2, {});
2188 range_test<float>(12, 4, -2, {12, 10, 8, 6});
2191 TEST(constant_folding, constant_select)
2194 vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
2195 vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
2196 vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
2198 auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
2199 auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
2200 auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
2201 auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
2202 select->set_friendly_name("test");
2203 auto f = make_shared<Function>(select, ParameterVector{});
2205 pass::Manager pass_manager;
2206 pass_manager.register_pass<pass::ConstantFolding>();
2207 pass_manager.run_passes(f);
2209 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2210 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2213 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2214 ASSERT_TRUE(new_const);
2215 ASSERT_EQ(new_const->get_friendly_name(), "test");
2216 auto values_out = new_const->get_vector<int64_t>();
2218 vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
2219 ASSERT_EQ(values_expected, values_out);
2222 TEST(constant_folding, constant_v1_select)
2225 vector<char> values_selection{0, 1, 1, 0};
2226 vector<int64_t> values_t{1, 2, 3, 4};
2227 vector<int64_t> values_f{11, 12, 13, 14, 15, 16, 17, 18};
2229 auto constant_selection =
2230 make_shared<op::Constant>(element::boolean, Shape{4}, values_selection);
2231 auto constant_t = make_shared<op::Constant>(element::i64, Shape{4}, values_t);
2232 auto constant_f = make_shared<op::Constant>(element::i64, Shape{2, 4}, values_f);
2233 auto select = make_shared<op::v1::Select>(constant_selection, constant_t, constant_f);
2234 select->set_friendly_name("test");
2235 auto f = make_shared<Function>(select, ParameterVector{});
2237 pass::Manager pass_manager;
2238 pass_manager.register_pass<pass::ConstantFolding>();
2239 pass_manager.run_passes(f);
2241 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2242 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2245 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2246 ASSERT_TRUE(new_const);
2247 ASSERT_EQ(new_const->get_friendly_name(), "test");
2248 auto values_out = new_const->get_vector<int64_t>();
2250 vector<int64_t> values_expected{11, 2, 3, 14, 15, 2, 3, 18};
2251 ASSERT_EQ(values_expected, values_out);
2254 TEST(constant_folding, constant_v1_split)
2256 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2257 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2258 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2259 const auto num_splits = 3;
2261 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2262 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2264 pass::Manager pass_manager;
2265 pass_manager.register_pass<pass::ConstantFolding>();
2266 pass_manager.run_passes(f);
2268 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2269 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2272 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2274 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2276 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2281 auto res1_values = res1->get_vector<float>();
2282 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2283 auto res2_values = res2->get_vector<float>();
2284 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2285 auto res3_values = res3->get_vector<float>();
2286 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2289 TEST(constant_folding, constant_v1_split_specialized)
2291 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2292 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2293 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2294 const auto num_splits = 3;
2296 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2297 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2299 pass::Manager pass_manager;
2300 pass_manager.register_pass<pass::ConstantFolding>();
2301 pass_manager.run_passes(f);
2303 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2304 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2307 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2309 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2311 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2316 auto res1_values = res1->get_vector<float>();
2317 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2318 auto res2_values = res2->get_vector<float>();
2319 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2320 auto res3_values = res3->get_vector<float>();
2321 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2324 TEST(constant_folding, constant_v1_split_axis_1_4_splits)
2326 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2328 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2330 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2332 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2334 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2335 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2336 const auto num_splits = 4;
2338 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2339 split_v1->set_friendly_name("test");
2340 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2342 pass::Manager pass_manager;
2343 pass_manager.register_pass<pass::ConstantFolding>();
2344 pass_manager.run_passes(f);
2346 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2347 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2350 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2352 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2354 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2356 as_type_ptr<op::Constant>(f->get_results().at(3)->input_value(0).get_node_shared_ptr());
2358 ASSERT_EQ(res1->get_friendly_name(), "test.0");
2360 ASSERT_EQ(res2->get_friendly_name(), "test.1");
2362 ASSERT_EQ(res3->get_friendly_name(), "test.2");
2364 ASSERT_EQ(res4->get_friendly_name(), "test.3");
2366 auto res1_values = res1->get_vector<int64_t>();
2367 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2369 auto res2_values = res2->get_vector<int64_t>();
2370 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2372 auto res3_values = res3->get_vector<int64_t>();
2373 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59}),
2375 auto res4_values = res4->get_vector<int64_t>();
2376 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2380 TEST(constant_folding, constant_v1_split_axis_1_2_splits)
2382 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2384 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2386 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2388 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2390 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2391 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2392 const auto num_splits = 2;
2394 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2395 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2397 pass::Manager pass_manager;
2398 pass_manager.register_pass<pass::ConstantFolding>();
2399 pass_manager.run_passes(f);
2401 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2402 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2405 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2407 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2411 auto res1_values = res1->get_vector<int64_t>();
2412 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
2413 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55}),
2415 auto res2_values = res2->get_vector<int64_t>();
2416 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2417 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2421 TEST(constant_folding, constant_v1_variadic_split_axis_1_2_splits)
2423 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2425 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2427 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2429 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2431 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2432 const auto const_axis = op::Constant::create(element::i16, Shape{}, {1});
2433 vector<int64_t> values_lengths{3, 1};
2434 auto constant_lengths =
2435 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2437 auto variadic_split_v1 =
2438 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2439 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2441 pass::Manager pass_manager;
2442 pass_manager.register_pass<pass::ConstantFolding>();
2443 pass_manager.run_passes(f);
2445 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2446 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2449 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2451 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2455 auto res1_values = res1->get_vector<int64_t>();
2456 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19,
2457 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 36, 37, 38, 39,
2458 40, 41, 42, 43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}),
2460 auto res2_values = res2->get_vector<int64_t>();
2461 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2465 TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
2467 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2469 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2471 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2473 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2475 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2476 const auto const_axis = op::Constant::create(element::i32, Shape{}, {1});
2477 vector<int64_t> values_lengths{1, 1, -1};
2478 auto constant_lengths =
2479 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2481 auto variadic_split_v1 =
2482 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2483 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2485 pass::Manager pass_manager;
2486 pass_manager.register_pass<pass::ConstantFolding>();
2487 pass_manager.run_passes(f);
2489 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2490 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2493 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2495 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2497 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2502 auto res1_values = res1->get_vector<int64_t>();
2503 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2505 auto res2_values = res2->get_vector<int64_t>();
2506 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2508 auto res3_values = res3->get_vector<int64_t>();
2509 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2510 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2514 TEST(constant_folding, constant_v1_one_hot)
2516 vector<int64_t> indices{0, 1, 2};
2517 float16 on_value = 1.123f;
2518 float16 off_value = 0.321f;
2520 const auto indices_const = op::Constant::create(element::i64, Shape{3}, indices);
2521 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2522 const auto on_const = op::Constant::create(element::f16, Shape{}, {on_value});
2523 const auto off_const = op::Constant::create(element::f16, Shape{}, {off_value});
2527 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2528 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2530 pass::Manager pass_manager;
2531 pass_manager.register_pass<pass::ConstantFolding>();
2532 pass_manager.run_passes(f);
2534 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2535 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2538 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2541 ASSERT_EQ((Shape{3, 3}), res->get_output_shape(0));
2542 ASSERT_EQ(vector<float16>({on_value,
2551 res->get_vector<float16>());
2554 TEST(constant_folding, constant_v1_one_hot_negative_axes)
2556 vector<int64_t> indices{0, 2, -1, 1};
2557 int16_t on_value = 4;
2558 int16_t off_value = 1;
2560 const auto indices_const = op::Constant::create(element::i64, Shape{4}, indices);
2561 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2562 const auto on_const = op::Constant::create(element::i16, Shape{}, {on_value});
2563 const auto off_const = op::Constant::create(element::i16, Shape{}, {off_value});
2567 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2568 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2570 pass::Manager pass_manager;
2571 pass_manager.register_pass<pass::ConstantFolding>();
2572 pass_manager.run_passes(f);
2574 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2575 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2578 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2581 ASSERT_EQ((Shape{4, 3}), res->get_output_shape(0));
2582 ASSERT_EQ(vector<int16_t>({on_value,
2594 res->get_vector<int16_t>());
2597 TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
2599 vector<int64_t> indices{0, 2, 1, -1};
2600 auto on_value = true;
2601 auto off_value = false;
2603 const auto indices_const = op::Constant::create(element::i64, Shape{2, 2}, indices);
2604 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2605 const auto on_const = op::Constant::create(element::boolean, Shape{}, {on_value});
2606 const auto off_const = op::Constant::create(element::boolean, Shape{}, {off_value});
2610 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2611 one_hot_v1->set_friendly_name("test");
2612 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2614 pass::Manager pass_manager;
2615 pass_manager.register_pass<pass::ConstantFolding>();
2616 pass_manager.run_passes(f);
2618 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2619 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2622 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2624 ASSERT_EQ(res->get_friendly_name(), "test");
2626 ASSERT_EQ((Shape{2, 2, 3}), res->get_output_shape(0));
2627 ASSERT_EQ(vector<bool>({on_value,
2639 res->get_vector<bool>());
2642 TEST(constant_folding, constant_tile_1d)
2645 Shape shape_repeats{1};
2648 vector<int> values_in{0, 1};
2649 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2650 vector<int> values_repeats{2};
2651 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2652 auto tile = make_shared<op::v0::Tile>(data, repeats);
2653 tile->set_friendly_name("test");
2654 auto f = make_shared<Function>(tile, ParameterVector{});
2656 pass::Manager pass_manager;
2657 pass_manager.register_pass<pass::ConstantFolding>();
2658 pass_manager.run_passes(f);
2660 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2661 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2664 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2665 ASSERT_TRUE(new_const);
2666 ASSERT_EQ(new_const->get_friendly_name(), "test");
2667 auto values_out = new_const->get_vector<int>();
2669 vector<int> values_expected{0, 1, 0, 1};
2670 ASSERT_EQ(values_expected, values_out);
2673 TEST(constant_folding, constant_tile_3d_small_data_rank)
2676 Shape shape_repeats{3};
2677 Shape shape_out{2, 2, 4};
2679 vector<int> values_in{0, 1};
2680 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2681 vector<int> values_repeats{2, 2, 2};
2682 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2683 auto tile = make_shared<op::v0::Tile>(data, repeats);
2684 tile->set_friendly_name("test");
2685 auto f = make_shared<Function>(tile, ParameterVector{});
2687 pass::Manager pass_manager;
2688 pass_manager.register_pass<pass::ConstantFolding>();
2689 pass_manager.run_passes(f);
2691 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2692 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2695 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2696 ASSERT_TRUE(new_const);
2697 ASSERT_EQ(new_const->get_friendly_name(), "test");
2698 auto values_out = new_const->get_vector<int>();
2700 vector<int> values_expected{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
2701 ASSERT_EQ(values_expected, values_out);
2704 TEST(constant_folding, constant_tile_3d_few_repeats)
2706 Shape shape_in{2, 1, 3};
2707 Shape shape_repeats{2};
2708 Shape shape_out{2, 2, 3};
2710 vector<int> values_in{1, 2, 3, 4, 5, 6};
2711 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2712 vector<int> values_repeats{2, 1};
2713 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2714 auto tile = make_shared<op::v0::Tile>(data, repeats);
2715 tile->set_friendly_name("test");
2716 auto f = make_shared<Function>(tile, ParameterVector{});
2718 pass::Manager pass_manager;
2719 pass_manager.register_pass<pass::ConstantFolding>();
2720 pass_manager.run_passes(f);
2722 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2723 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2726 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2727 ASSERT_TRUE(new_const);
2728 ASSERT_EQ(new_const->get_friendly_name(), "test");
2729 auto values_out = new_const->get_vector<int>();
2731 vector<int> values_expected{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6};
2732 ASSERT_EQ(values_expected, values_out);
2735 TEST(constant_folding, constant_tile_1d_0_repeats)
2738 Shape shape_repeats{1};
2741 vector<int> values_in{0, 1};
2742 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2743 vector<int> values_repeats{0};
2744 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2745 auto tile = make_shared<op::v0::Tile>(data, repeats);
2746 tile->set_friendly_name("test");
2747 auto f = make_shared<Function>(tile, ParameterVector{});
2749 pass::Manager pass_manager;
2750 pass_manager.register_pass<pass::ConstantFolding>();
2751 pass_manager.run_passes(f);
2753 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2754 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2757 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2758 ASSERT_TRUE(new_const);
2759 ASSERT_EQ(new_const->get_friendly_name(), "test");
2760 auto values_out = new_const->get_vector<int>();
2762 vector<int> values_expected{};
2763 ASSERT_EQ(values_expected, values_out);
2766 TEST(constant_folding, constant_tile_0_rank_data)
2769 Shape shape_repeats{1};
2772 vector<int> values_in{1};
2773 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2774 vector<int> values_repeats{4};
2775 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2776 auto tile = make_shared<op::v0::Tile>(data, repeats);
2777 tile->set_friendly_name("test");
2778 auto f = make_shared<Function>(tile, ParameterVector{});
2780 pass::Manager pass_manager;
2781 pass_manager.register_pass<pass::ConstantFolding>();
2782 pass_manager.run_passes(f);
2784 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2785 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2788 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2789 ASSERT_TRUE(new_const);
2790 ASSERT_EQ(new_const->get_friendly_name(), "test");
2791 auto values_out = new_const->get_vector<int>();
2793 vector<int> values_expected{1, 1, 1, 1};
2794 ASSERT_EQ(values_expected, values_out);
2797 TEST(constant_folding, constant_non_zero_0D)
2799 auto data = op::Constant::create(element::i32, Shape{}, {1});
2800 auto non_zero = make_shared<op::v3::NonZero>(data);
2801 non_zero->set_friendly_name("test");
2802 auto f = make_shared<Function>(non_zero, ParameterVector{});
2804 pass::Manager pass_manager;
2805 pass_manager.register_pass<pass::ConstantFolding>();
2806 pass_manager.run_passes(f);
2808 // Fold into constant with shape of {1, 1} for scalar input with
2810 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2811 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2813 const auto new_const =
2814 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2815 ASSERT_TRUE(new_const);
2816 ASSERT_EQ(new_const->get_friendly_name(), "test");
2817 const auto values_out = new_const->get_vector<int64_t>();
2819 const vector<int64_t> values_expected{0};
2820 ASSERT_EQ(values_expected, values_out);
2821 ASSERT_EQ((Shape{1, 1}), new_const->get_shape());
2824 TEST(constant_folding, constant_non_zero_1D)
2826 vector<int> values_in{0, 1, 0, 1};
2827 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2828 auto non_zero = make_shared<op::v3::NonZero>(data);
2829 non_zero->set_friendly_name("test");
2830 auto f = make_shared<Function>(non_zero, ParameterVector{});
2832 pass::Manager pass_manager;
2833 pass_manager.register_pass<pass::ConstantFolding>();
2834 pass_manager.run_passes(f);
2836 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2837 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2839 const auto new_const =
2840 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2841 ASSERT_TRUE(new_const);
2842 ASSERT_EQ(new_const->get_friendly_name(), "test");
2843 const auto values_out = new_const->get_vector<int64_t>();
2845 const vector<int64_t> values_expected{1, 3};
2846 ASSERT_EQ(values_expected, values_out);
2847 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2850 TEST(constant_folding, constant_non_zero_int32_output_type)
2852 vector<int> values_in{0, 1, 0, 1};
2853 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2854 auto non_zero = make_shared<op::v3::NonZero>(data, element::i32);
2855 non_zero->set_friendly_name("test");
2856 auto f = make_shared<Function>(non_zero, ParameterVector{});
2858 pass::Manager pass_manager;
2859 pass_manager.register_pass<pass::ConstantFolding>();
2860 pass_manager.run_passes(f);
2862 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2863 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2865 const auto new_const =
2866 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2867 ASSERT_TRUE(new_const);
2868 ASSERT_EQ(new_const->get_friendly_name(), "test");
2869 ASSERT_EQ(element::i32, new_const->get_element_type());
2870 const auto values_out = new_const->get_vector<int32_t>();
2872 const vector<int32_t> values_expected{1, 3};
2873 ASSERT_EQ(values_expected, values_out);
2874 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2877 TEST(constant_folding, constant_non_zero_1D_all_indices)
2879 const vector<float> values_in{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
2880 const auto data = make_shared<op::Constant>(element::f32, Shape{values_in.size()}, values_in);
2881 const auto non_zero = make_shared<op::v3::NonZero>(data);
2882 non_zero->set_friendly_name("test");
2883 auto f = make_shared<Function>(non_zero, ParameterVector{});
2885 pass::Manager pass_manager;
2886 pass_manager.register_pass<pass::ConstantFolding>();
2887 pass_manager.run_passes(f);
2889 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2890 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2892 const auto new_const =
2893 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2894 ASSERT_TRUE(new_const);
2895 ASSERT_EQ(new_const->get_friendly_name(), "test");
2896 const auto values_out = new_const->get_vector<int64_t>();
2898 const vector<int64_t> values_expected{0, 1, 2, 3, 4, 5, 6, 7};
2899 ASSERT_EQ(values_expected, values_out);
2900 ASSERT_EQ((Shape{1, values_in.size()}), new_const->get_shape());
2903 TEST(constant_folding, constant_non_zero_2D)
2905 vector<int> values_in{1, 0, 0, 0, 1, 0, 1, 1, 0};
2906 auto data = make_shared<op::Constant>(element::i32, Shape{3, 3}, values_in);
2907 auto non_zero = make_shared<op::v3::NonZero>(data);
2908 non_zero->set_friendly_name("test");
2909 auto f = make_shared<Function>(non_zero, ParameterVector{});
2911 pass::Manager pass_manager;
2912 pass_manager.register_pass<pass::ConstantFolding>();
2913 pass_manager.run_passes(f);
2915 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2916 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2918 const auto new_const =
2919 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2920 ASSERT_TRUE(new_const);
2921 ASSERT_EQ(new_const->get_friendly_name(), "test");
2922 const auto values_out = new_const->get_vector<int64_t>();
2924 const vector<int64_t> values_expected{0, 1, 2, 2, 0, 1, 0, 1};
2925 ASSERT_EQ(values_expected, values_out);
2926 ASSERT_EQ((Shape{2, 4}), new_const->get_shape());
2929 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_indices)
2931 const vector<int8_t> values_in{1, 1, 1, 1, 1, 1, 1, 1, 1};
2932 const auto data = make_shared<op::Constant>(element::i8, Shape{3, 3}, values_in);
2933 const auto non_zero = make_shared<op::v3::NonZero>(data);
2934 non_zero->set_friendly_name("test");
2935 auto f = make_shared<Function>(non_zero, ParameterVector{});
2937 pass::Manager pass_manager;
2938 pass_manager.register_pass<pass::ConstantFolding>();
2939 pass_manager.run_passes(f);
2941 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2942 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2944 const auto new_const =
2945 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2946 ASSERT_TRUE(new_const);
2947 ASSERT_EQ(new_const->get_friendly_name(), "test");
2948 const auto values_out = new_const->get_vector<int64_t>();
2950 const vector<int64_t> values_expected{0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2};
2951 ASSERT_EQ(values_expected, values_out);
2952 ASSERT_EQ((Shape{2, values_in.size()}), new_const->get_shape());
2955 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_zeros)
2957 const vector<uint8_t> values_in{0, 0, 0, 0, 0, 0};
2958 const auto data = make_shared<op::Constant>(element::u8, Shape{2, 3}, values_in);
2959 const auto non_zero = make_shared<op::v3::NonZero>(data);
2960 non_zero->set_friendly_name("test");
2961 auto f = make_shared<Function>(non_zero, ParameterVector{});
2963 pass::Manager pass_manager;
2964 pass_manager.register_pass<pass::ConstantFolding>();
2965 pass_manager.run_passes(f);
2967 // fold into Constant with shape of {0}
2968 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2969 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2971 const auto new_const =
2972 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2973 ASSERT_TRUE(new_const);
2974 ASSERT_EQ(new_const->get_friendly_name(), "test");
2975 ASSERT_EQ(shape_size(new_const->get_shape()), 0);
2978 TEST(constant_folding, constant_non_zero_3D)
2980 vector<int> values_in{1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0};
2981 auto data = make_shared<op::Constant>(element::i32, Shape{2, 3, 3}, values_in);
2982 auto non_zero = make_shared<op::v3::NonZero>(data);
2983 non_zero->set_friendly_name("test");
2984 auto f = make_shared<Function>(non_zero, ParameterVector{});
2986 pass::Manager pass_manager;
2987 pass_manager.register_pass<pass::ConstantFolding>();
2988 pass_manager.run_passes(f);
2990 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2991 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2993 const auto new_const =
2994 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2995 ASSERT_TRUE(new_const);
2996 ASSERT_EQ(new_const->get_friendly_name(), "test");
2997 const auto values_out = new_const->get_vector<int64_t>();
2999 const vector<int64_t> values_expected{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 2, 2, 2,
3000 0, 0, 0, 1, 1, 2, 0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 2, 1};
3001 ASSERT_EQ(values_expected, values_out);
3002 ASSERT_EQ((Shape{3, 12}), new_const->get_shape());
3005 TEST(constant_folding, constant_scatter_elements_update_basic)
3007 const Shape data_shape{3, 3};
3008 const Shape indices_shape{2, 3};
3010 const auto data_const = op::Constant::create(
3011 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3012 const auto indices_const =
3013 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3014 const auto updates_const =
3015 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3016 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3018 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3019 data_const, indices_const, updates_const, axis_const);
3020 scatter_elem_updt->set_friendly_name("test");
3021 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3023 pass::Manager pass_manager;
3024 pass_manager.register_pass<pass::ConstantFolding>();
3025 pass_manager.run_passes(f);
3027 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3028 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3031 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3032 ASSERT_TRUE(result_node);
3033 ASSERT_EQ(result_node->get_friendly_name(), "test");
3034 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3035 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3036 range_test_check(result_node->cast_vector<float>(), expected);
3039 TEST(constant_folding, constant_scatter_elements_update_negative_axis)
3041 const Shape data_shape{3, 3};
3042 const Shape indices_shape{2, 3};
3044 const auto data_const = op::Constant::create(
3045 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3046 const auto indices_const =
3047 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3048 const auto updates_const =
3049 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3050 const auto axis_const = op::Constant::create(element::i64, Shape{}, {-1});
3052 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3053 data_const, indices_const, updates_const, axis_const);
3054 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3056 pass::Manager pass_manager;
3057 pass_manager.register_pass<pass::ConstantFolding>();
3058 pass_manager.run_passes(f);
3060 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3061 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3064 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3065 ASSERT_TRUE(result_node);
3066 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3067 std::vector<float> expected{1.1f, 1.0f, 1.2f, 2.0f, 2.2f, 2.1f, 0.0f, 0.0f, 0.0f};
3068 range_test_check(result_node->cast_vector<float>(), expected);
3071 TEST(constant_folding, constant_scatter_elements_update_1d_axis)
3073 const Shape data_shape{3, 3};
3074 const Shape indices_shape{2, 3};
3076 const auto data_const = op::Constant::create(
3077 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3078 const auto indices_const =
3079 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3080 const auto updates_const =
3081 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3082 const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
3084 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3085 data_const, indices_const, updates_const, axis_const);
3086 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3088 pass::Manager pass_manager;
3089 pass_manager.register_pass<pass::ConstantFolding>();
3090 pass_manager.run_passes(f);
3092 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3093 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3096 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3097 ASSERT_TRUE(result_node);
3098 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3099 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3100 range_test_check(result_node->cast_vector<float>(), expected);
3103 TEST(constant_folding, constant_scatter_elements_update_3d_i16)
3105 const Shape data_shape{3, 3, 3};
3106 const Shape indices_shape{2, 2, 3};
3108 const auto data_const = op::Constant::create(
3109 element::i16, data_shape, std::vector<int16_t>(shape_size(data_shape), 0));
3110 const auto indices_const =
3111 op::Constant::create(element::i16, indices_shape, {1, 0, 2, 0, 2, 1, 2, 2, 2, 0, 1, 0});
3112 const auto updates_const =
3113 op::Constant::create(element::i16, indices_shape, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3114 const auto axis_const = op::Constant::create(element::i64, Shape{}, {1});
3116 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3117 data_const, indices_const, updates_const, axis_const);
3118 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3120 pass::Manager pass_manager;
3121 pass_manager.register_pass<pass::ConstantFolding>();
3122 pass_manager.run_passes(f);
3124 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3125 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3128 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3129 ASSERT_TRUE(result_node);
3130 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3131 std::vector<int16_t> expected{4, 2, 0, 1, 0, 6, 0, 5, 3, 10, 0, 12, 0, 11,
3132 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0};
3133 range_test_check(result_node->cast_vector<int16_t>(), expected);
3136 TEST(constant_folding, constant_scatter_elements_update_one_elem)
3138 const Shape data_shape{3, 3, 3};
3139 const Shape indices_shape{1, 1, 1};
3140 const auto input_data = std::vector<int32_t>(shape_size(data_shape), 0);
3142 const auto data_const = op::Constant::create(element::i32, data_shape, input_data);
3143 const auto indices_const = op::Constant::create(element::i32, indices_shape, {1});
3144 const auto updates_const = op::Constant::create(element::i32, indices_shape, {2});
3145 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3147 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3148 data_const, indices_const, updates_const, axis_const);
3149 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3151 pass::Manager pass_manager;
3152 pass_manager.register_pass<pass::ConstantFolding>();
3153 pass_manager.run_passes(f);
3155 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3156 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3159 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3160 ASSERT_TRUE(result_node);
3161 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3162 std::vector<int32_t> expected{input_data};
3163 // we have updated coordinate (1, 0, 0)
3165 range_test_check(result_node->cast_vector<int32_t>(), expected);
3168 void test_constant_folding_reshape_v1(Shape& shape_in,
3169 vector<float>& values_in,
3171 vector<int32_t> values_shape,
3172 bool zero_flag = false)
3174 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
3175 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
3176 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, zero_flag);
3177 dyn_reshape->set_friendly_name("test");
3178 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
3180 pass::Manager pass_manager;
3181 pass_manager.register_pass<pass::ConstantFolding>();
3182 pass_manager.run_passes(f);
3184 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
3185 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3188 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3189 ASSERT_TRUE(new_const);
3190 ASSERT_EQ(new_const->get_friendly_name(), "test");
3191 auto values_out = new_const->get_vector<float>();
3193 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
3195 TEST(constant_folding, constant_dyn_reshape_v1_2d)
3197 Shape shape_in{2, 5};
3198 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3200 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 1, 10});
3201 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 2, 5});
3202 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {1, 2, 5});
3203 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {5, 2, 1});
3206 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_negative_indices)
3208 Shape shape_in{2, 2, 2, 2};
3209 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3211 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {4, -1, 2});
3212 test_constant_folding_reshape_v1(shape_in, values_in, {2}, {4, -1});
3213 test_constant_folding_reshape_v1(shape_in, values_in, {1}, {-1});
3216 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_zero_dims)
3218 Shape shape_in{2, 2, 2, 2};
3219 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3221 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {2, -1, 2, 0}, true);
3222 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {4, 1, 0, 2}, true);
3225 TEST(constant_folding, disable_constant_folding)
3227 auto input = make_shared<op::Parameter>(element::f32, Shape{1, 3});
3228 auto constant_shape = op::Constant::create(element::i64, Shape{1}, {3});
3229 auto dyn_reshape = make_shared<op::v1::Reshape>(input, constant_shape, true);
3230 auto& rt_info = dyn_reshape->get_rt_info();
3231 rt_info["DISABLED_CONSTANT_FOLDING"];
3232 auto f = make_shared<Function>(dyn_reshape, ParameterVector{input});
3234 pass::Manager pass_manager;
3235 pass_manager.register_pass<pass::ConstantFolding>();
3236 pass_manager.run_passes(f);
3238 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
3239 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);