1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/manager.hpp"
22 #include "util/all_close_f.hpp"
23 #include "util/test_tools.hpp"
25 NGRAPH_SUPPRESS_DEPRECATED_START
27 using namespace ngraph;
31 static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
34 as_type_ptr<op::Constant>(f->get_results().at(pos)->input_value(0).get_node_shared_ptr());
35 return new_const->cast_vector<T>();
38 void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
40 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
43 void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
45 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
49 typename std::enable_if<std::is_integral<T>::value>::type
50 range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
52 ASSERT_EQ(values_out, values_expected);
55 TEST(constant_folding, acosh)
57 Shape shape_in{2, 4, 1};
59 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
60 vector<float> expected;
61 for (float f : values_in)
63 expected.push_back(std::acosh(f));
65 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
66 auto acosh = make_shared<op::Acosh>(constant);
67 acosh->set_friendly_name("test");
68 auto f = make_shared<Function>(acosh, ParameterVector{});
70 pass::Manager pass_manager;
71 pass_manager.register_pass<pass::ConstantFolding>();
72 pass_manager.run_passes(f);
74 EXPECT_EQ(count_ops_of_type<op::Acosh>(f), 0);
75 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
76 ASSERT_EQ(f->get_results().size(), 1);
79 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
80 EXPECT_TRUE(new_const);
81 ASSERT_EQ(new_const->get_friendly_name(), "test");
83 auto values_out = new_const->get_vector<float>();
84 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
87 TEST(constant_folding, asinh)
89 Shape shape_in{2, 4, 1};
91 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
92 vector<float> expected;
93 for (float f : values_in)
95 expected.push_back(std::asinh(f));
97 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
98 auto asinh = make_shared<op::Asinh>(constant);
99 asinh->set_friendly_name("test");
100 auto f = make_shared<Function>(asinh, ParameterVector{});
102 pass::Manager pass_manager;
103 pass_manager.register_pass<pass::ConstantFolding>();
104 pass_manager.run_passes(f);
106 EXPECT_EQ(count_ops_of_type<op::Asinh>(f), 0);
107 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
108 ASSERT_EQ(f->get_results().size(), 1);
111 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
112 EXPECT_TRUE(new_const);
113 ASSERT_EQ(new_const->get_friendly_name(), "test");
115 auto values_out = new_const->get_vector<float>();
116 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
119 TEST(constant_folding, atanh)
121 Shape shape_in{2, 4, 1};
123 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
124 vector<float> expected;
125 for (float f : values_in)
127 expected.push_back(std::atanh(f));
129 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
130 auto atanh = make_shared<op::Atanh>(constant);
131 atanh->set_friendly_name("test");
132 auto f = make_shared<Function>(atanh, ParameterVector{});
134 pass::Manager pass_manager;
135 pass_manager.register_pass<pass::ConstantFolding>();
136 pass_manager.run_passes(f);
138 EXPECT_EQ(count_ops_of_type<op::Atanh>(f), 0);
139 EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
140 ASSERT_EQ(f->get_results().size(), 1);
143 as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
144 EXPECT_TRUE(new_const);
145 ASSERT_EQ(new_const->get_friendly_name(), "test");
147 auto values_out = new_const->get_vector<float>();
148 EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
151 TEST(constant_folding, constant_squeeze)
153 Shape shape_in{2, 4, 1};
154 Shape shape_out{2, 4};
157 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
158 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
159 vector<int64_t> values_axes{2};
160 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
161 auto squeeze = make_shared<op::Squeeze>(constant, constant_axes);
162 squeeze->set_friendly_name("test");
163 auto f = make_shared<Function>(squeeze, ParameterVector{});
165 pass::Manager pass_manager;
166 pass_manager.register_pass<pass::ConstantFolding>();
167 pass_manager.run_passes(f);
169 ASSERT_EQ(count_ops_of_type<op::Squeeze>(f), 0);
170 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
173 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
174 ASSERT_TRUE(new_const);
175 ASSERT_EQ(new_const->get_friendly_name(), "test");
176 ASSERT_EQ(new_const->get_shape(), shape_out);
178 auto values_out = new_const->get_vector<float>();
179 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
182 TEST(constant_folding, constant_unsqueeze)
184 Shape shape_in{2, 4};
185 Shape shape_out{2, 4, 1, 1};
188 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
189 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
190 vector<int64_t> values_axes{2, 3};
191 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
192 auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
193 unsqueeze->set_friendly_name("test");
194 auto f = make_shared<Function>(unsqueeze, ParameterVector{});
196 pass::Manager pass_manager;
197 pass_manager.register_pass<pass::ConstantFolding>();
198 pass_manager.run_passes(f);
200 ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
201 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
204 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
205 ASSERT_TRUE(new_const);
206 ASSERT_EQ(new_const->get_friendly_name(), "test");
207 ASSERT_EQ(new_const->get_shape(), shape_out);
209 auto values_out = new_const->get_vector<float>();
210 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
213 TEST(constant_folding, constant_reshape)
215 Shape shape_in{2, 4};
216 Shape shape_out{2, 4, 1};
218 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
219 auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
220 auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
221 reshape->set_friendly_name("test");
222 auto f = make_shared<Function>(reshape, ParameterVector{});
224 pass::Manager pass_manager;
225 pass_manager.register_pass<pass::ConstantFolding>();
226 pass_manager.run_passes(f);
228 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
229 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
232 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
233 ASSERT_TRUE(new_const);
234 ASSERT_EQ(new_const->get_friendly_name(), "test");
235 auto values_out = new_const->get_vector<float>();
237 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
240 TEST(constant_folding, DISABLED_constant_reshape_permute)
242 Shape shape_in{2, 4};
243 Shape shape_out{4, 2};
245 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
246 auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
247 auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
248 reshape->set_friendly_name("test");
249 auto f = make_shared<Function>(reshape, ParameterVector{});
251 pass::Manager pass_manager;
252 pass_manager.register_pass<pass::ConstantFolding>();
253 pass_manager.run_passes(f);
255 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
256 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
259 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
260 ASSERT_TRUE(new_const);
261 ASSERT_EQ(new_const->get_friendly_name(), "test");
262 auto values_out = new_const->get_vector<double>();
264 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
265 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
268 TEST(constant_folding, constant_broadcast_v1)
270 vector<int32_t> values_in{0, 1};
271 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
272 vector<int64_t> shape_in{2, 4};
273 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
274 vector<int64_t> axes_in{0};
275 auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
276 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape, constant_axes);
277 broadcast_v1->set_friendly_name("test");
278 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
280 pass::Manager pass_manager;
281 pass_manager.register_pass<pass::ConstantFolding>();
282 pass_manager.run_passes(f);
284 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
285 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
288 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
289 ASSERT_TRUE(new_const);
290 ASSERT_EQ(new_const->get_friendly_name(), "test");
291 auto values_out = new_const->get_vector<int32_t>();
293 vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
294 ASSERT_EQ(values_expected, values_out);
297 TEST(constant_folding, constant_broadcast_v1_with_target_shape)
299 vector<int32_t> values_in{1};
300 auto constant_in = make_shared<op::Constant>(element::i32, Shape{1, 1, 1, 1}, values_in);
301 vector<int64_t> shape_in{1, 3, 1, 1};
302 auto target_shape = make_shared<op::Constant>(element::i64, Shape{4}, shape_in);
303 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, target_shape);
304 broadcast_v1->set_friendly_name("test");
305 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
307 pass::Manager pass_manager;
308 pass_manager.register_pass<pass::ConstantFolding>();
309 pass_manager.run_passes(f);
311 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
312 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
315 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
316 ASSERT_TRUE(new_const);
317 ASSERT_EQ(new_const->get_friendly_name(), "test");
318 auto values_out = new_const->get_vector<int32_t>();
320 vector<int32_t> values_expected{1, 1, 1};
321 ASSERT_EQ(values_expected, values_out);
324 TEST(constant_folding, constant_broadcast_v1_numpy)
326 vector<int32_t> values_in{0, 1};
327 auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
328 vector<int64_t> shape_in{4, 2};
329 auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
330 auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape);
331 broadcast_v1->set_friendly_name("test");
332 auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
334 pass::Manager pass_manager;
335 pass_manager.register_pass<pass::ConstantFolding>();
336 pass_manager.run_passes(f);
338 ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
339 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
342 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
343 ASSERT_TRUE(new_const);
344 ASSERT_EQ(new_const->get_friendly_name(), "test");
345 auto values_out = new_const->get_vector<int32_t>();
347 vector<int32_t> values_expected{0, 1, 0, 1, 0, 1, 0, 1};
348 ASSERT_EQ(values_expected, values_out);
351 TEST(constant_folding, constant_unary_binary)
353 vector<int> values_a{1, 2, 3, 4};
354 vector<int> values_b{1, 2, 3, 4};
355 vector<int> values_c{-1, -1, -1, -1};
356 vector<int> values_d{1, 4, 9, 16};
357 vector<int> values_e{5, 6};
358 vector<int> values_f{0, 10};
359 vector<int> values_g{1, 4};
360 vector<char> values_h{0, 0, 1, 1};
361 vector<char> values_i{0, 1};
362 auto a = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_a);
363 auto b = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_b);
364 auto c = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_c);
365 auto d = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_d);
366 auto e = make_shared<op::Constant>(element::i32, Shape{2}, values_e);
367 auto f = make_shared<op::Constant>(element::i32, Shape{2}, values_f);
368 auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
369 auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
370 auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
376 auto pow = make_shared<op::Power>(a, b);
377 auto min = make_shared<op::Minimum>(c, a);
378 auto max = make_shared<op::Maximum>(a, c);
379 auto absn = make_shared<op::Abs>(c);
380 auto neg = make_shared<op::Negative>(c);
381 auto sqrt = make_shared<op::Sqrt>(d);
382 auto add_autob_numpy = make_shared<op::Add>(a, e, op::AutoBroadcastType::NUMPY);
383 auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
384 auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
385 auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
386 auto pow_autob_numpy = make_shared<op::Power>(a, g, op::AutoBroadcastType::NUMPY);
387 auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
388 auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
389 auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
390 auto not_equal_autob_numpy = make_shared<op::NotEqual>(a, g, op::AutoBroadcastType::NUMPY);
391 auto greater_autob_numpy = make_shared<op::Greater>(a, g, op::AutoBroadcastType::NUMPY);
392 auto greater_eq_autob_numpy = make_shared<op::GreaterEq>(a, g, op::AutoBroadcastType::NUMPY);
393 auto less_autob_numpy = make_shared<op::Less>(a, g, op::AutoBroadcastType::NUMPY);
394 auto less_eq_autob_numpy = make_shared<op::LessEq>(a, g, op::AutoBroadcastType::NUMPY);
395 auto logical_or_autob_numpy = make_shared<op::Or>(h, i, op::AutoBroadcastType::NUMPY);
396 auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
398 auto neg_sqrt = make_shared<op::Sqrt>(c);
400 auto func = make_shared<Function>(NodeVector{add,
418 not_equal_autob_numpy,
420 greater_eq_autob_numpy,
423 logical_or_autob_numpy,
424 logical_xor_autob_numpy},
426 auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
428 pass::Manager pass_manager;
429 pass_manager.register_pass<pass::ConstantFolding>();
430 pass_manager.run_passes(func);
433 vector<int> add_expected{2, 4, 6, 8};
434 vector<int> sub_expected{0, 0, 0, 0};
435 vector<int> mul_expected{1, 4, 9, 16};
436 vector<int> div_expected{1, 1, 1, 1};
437 vector<int> pow_expected{1, 4, 27, 256};
438 vector<int> min_expected{-1, -1, -1, -1};
439 vector<int> max_expected{1, 2, 3, 4};
440 vector<int> abs_neg_expected{1, 1, 1, 1};
441 vector<int> sqrt_expected{1, 2, 3, 4};
442 vector<int> add_autob_numpy_expected{6, 8, 8, 10};
443 vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
444 vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
445 vector<int> div_autob_numpy_expected{1, 0, 3, 1};
446 vector<int> pow_autob_numpy_expected{1, 16, 3, 256};
447 vector<int> min_autob_numpy_expected{0, 2, 0, 4};
448 vector<int> max_autob_numpy_expected{1, 10, 3, 10};
449 vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
450 vector<char> not_equal_autob_numpy_expected{0, 1, 1, 0};
451 vector<char> greater_autob_numpy_expected{0, 0, 1, 0};
452 vector<char> greater_eq_autob_numpy_expected{1, 0, 1, 1};
453 vector<char> less_autob_numpy_expected{0, 1, 0, 0};
454 vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
455 vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
456 vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
458 ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
459 ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
460 ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
461 ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
462 ASSERT_EQ(get_result_constant<int>(func, 4), pow_expected);
463 ASSERT_EQ(get_result_constant<int>(func, 5), min_expected);
464 ASSERT_EQ(get_result_constant<int>(func, 6), max_expected);
465 ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
466 ASSERT_EQ(get_result_constant<int>(func, 8), abs_neg_expected);
467 ASSERT_EQ(get_result_constant<int>(func, 9), sqrt_expected);
468 ASSERT_EQ(get_result_constant<int>(func, 10), add_autob_numpy_expected);
469 ASSERT_EQ(get_result_constant<int>(func, 11), sub_autob_numpy_expected);
470 ASSERT_EQ(get_result_constant<int>(func, 12), mul_autob_numpy_expected);
471 ASSERT_EQ(get_result_constant<int>(func, 13), div_autob_numpy_expected);
472 ASSERT_EQ(get_result_constant<int>(func, 14), pow_autob_numpy_expected);
473 ASSERT_EQ(get_result_constant<int>(func, 15), min_autob_numpy_expected);
474 ASSERT_EQ(get_result_constant<int>(func, 16), max_autob_numpy_expected);
475 ASSERT_EQ(get_result_constant<char>(func, 17), equal_autob_numpy_expected);
476 ASSERT_EQ(get_result_constant<char>(func, 18), not_equal_autob_numpy_expected);
477 ASSERT_EQ(get_result_constant<char>(func, 19), greater_autob_numpy_expected);
478 ASSERT_EQ(get_result_constant<char>(func, 20), greater_eq_autob_numpy_expected);
479 ASSERT_EQ(get_result_constant<char>(func, 21), less_autob_numpy_expected);
480 ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
481 ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
482 ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
483 ASSERT_NO_THROW(pass_manager.run_passes(func_error));
486 TEST(constant_folding, const_quantize)
488 Shape input_shape{12};
489 Shape scale_offset_shape;
490 AxisSet quantization_axes;
492 auto quant_type = element::u8;
493 auto output_type = element::u8;
494 typedef uint8_t output_c_type;
496 vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
497 auto constant = op::Constant::create(element::f32, input_shape, values_in);
498 auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
499 auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
500 auto mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
502 make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
503 quantize->set_friendly_name("test");
504 auto f = make_shared<Function>(quantize, ParameterVector{});
506 pass::Manager pass_manager;
507 pass_manager.register_pass<pass::ConstantFolding>();
508 pass_manager.run_passes(f);
510 ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
511 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
514 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
515 ASSERT_TRUE(new_const);
516 ASSERT_EQ(new_const->get_friendly_name(), "test");
517 auto values_out = new_const->get_vector<output_c_type>();
519 vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
520 ASSERT_EQ(values_quantize, values_out);
523 TEST(constant_folding, const_convert)
525 Shape input_shape{3, 4};
527 vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
528 auto constant = op::Constant::create(element::f32, input_shape, values_in);
529 auto convert = make_shared<op::Convert>(constant, element::u64);
530 convert->set_friendly_name("test");
531 auto f = make_shared<Function>(convert, ParameterVector{});
533 pass::Manager pass_manager;
534 pass_manager.register_pass<pass::ConstantFolding>();
535 pass_manager.run_passes(f);
537 ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
538 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
541 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
542 ASSERT_TRUE(new_const);
543 ASSERT_EQ(new_const->get_friendly_name(), "test");
544 ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
545 auto values_out = new_const->get_vector<uint64_t>();
547 vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
548 ASSERT_EQ(values_expected, values_out);
551 TEST(constant_folding, shape_of_v0)
553 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
555 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
556 auto shape_of = make_shared<op::v0::ShapeOf>(param);
557 shape_of->set_friendly_name("test");
558 auto f = make_shared<Function>(shape_of, ParameterVector{param});
560 pass::Manager pass_manager;
561 pass_manager.register_pass<pass::ConstantFolding>();
562 pass_manager.run_passes(f);
564 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 0);
565 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
568 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
569 ASSERT_TRUE(new_const);
570 ASSERT_EQ(new_const->get_friendly_name(), "test");
571 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
572 auto values_out = new_const->get_vector<int64_t>();
574 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
577 TEST(constant_folding, shape_of_v3)
579 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
581 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
582 auto shape_of = make_shared<op::v3::ShapeOf>(param);
583 shape_of->set_friendly_name("test");
584 auto f = make_shared<Function>(shape_of, ParameterVector{param});
586 pass::Manager pass_manager;
587 pass_manager.register_pass<pass::ConstantFolding>();
588 pass_manager.run_passes(f);
590 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
591 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
594 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
595 ASSERT_TRUE(new_const);
596 ASSERT_EQ(new_const->get_friendly_name(), "test");
597 ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
598 auto values_out = new_const->get_vector<int64_t>();
600 ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
603 TEST(constant_folding, shape_of_i32_v3)
605 Shape input_shape{3, 4, 0, 22, 608, 909, 3};
607 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
608 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
609 shape_of->set_friendly_name("test");
610 auto f = make_shared<Function>(shape_of, ParameterVector{param});
612 pass::Manager pass_manager;
613 pass_manager.register_pass<pass::ConstantFolding>();
614 pass_manager.run_passes(f);
616 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
617 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
620 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
621 ASSERT_TRUE(new_const);
622 ASSERT_EQ(new_const->get_friendly_name(), "test");
623 ASSERT_EQ(new_const->get_output_element_type(0), element::i32);
624 auto values_out = new_const->get_vector<int32_t>();
626 ASSERT_EQ((vector<int32_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
629 TEST(constant_folding, shape_of_dynamic_v0)
631 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
633 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
634 auto shape_of = make_shared<op::v0::ShapeOf>(param);
635 shape_of->set_friendly_name("test");
636 auto f = make_shared<Function>(shape_of, ParameterVector{param});
638 pass::Manager pass_manager;
639 pass_manager.register_pass<pass::ConstantFolding>();
640 pass_manager.run_passes(f);
642 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
643 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
644 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
645 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
647 auto result_as_concat =
648 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
649 ASSERT_TRUE(result_as_concat);
650 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
651 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
654 TEST(constant_folding, shape_of_dynamic_v3)
656 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
658 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
659 auto shape_of = make_shared<op::v3::ShapeOf>(param);
660 shape_of->set_friendly_name("test");
661 auto f = make_shared<Function>(shape_of, ParameterVector{param});
663 pass::Manager pass_manager;
664 pass_manager.register_pass<pass::ConstantFolding>();
665 pass_manager.run_passes(f);
667 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
668 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
669 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
670 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
672 auto result_as_concat =
673 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
674 ASSERT_TRUE(result_as_concat);
675 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
676 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
677 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
680 TEST(constant_folding, shape_of_dynamic_i32_v3)
682 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
684 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
685 auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
686 shape_of->set_friendly_name("test");
687 auto f = make_shared<Function>(shape_of, ParameterVector{param});
689 pass::Manager pass_manager;
690 pass_manager.register_pass<pass::ConstantFolding>();
691 pass_manager.run_passes(f);
693 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
694 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
695 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
696 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
698 auto result_as_concat =
699 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
700 ASSERT_TRUE(result_as_concat);
701 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
702 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
703 ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
706 // We need to be sure that constant folding won't be calculated endlessly.
707 TEST(constant_folding, shape_of_dynamic_double_folding_v0)
709 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
711 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
712 auto shape_of = make_shared<op::v0::ShapeOf>(param);
713 shape_of->set_friendly_name("test");
714 auto f = make_shared<Function>(shape_of, ParameterVector{param});
716 pass::Manager pass_manager;
717 pass_manager.register_pass<pass::ConstantFolding>();
718 pass_manager.run_passes(f);
719 pass_manager.run_passes(f);
721 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
722 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
723 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
724 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
726 auto result_as_concat =
727 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
728 ASSERT_TRUE(result_as_concat);
729 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
730 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
733 TEST(constant_folding, shape_of_dynamic_double_folding_v3)
735 PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
737 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
738 auto shape_of = make_shared<op::v3::ShapeOf>(param);
739 shape_of->set_friendly_name("test");
740 auto f = make_shared<Function>(shape_of, ParameterVector{param});
742 pass::Manager pass_manager;
743 pass_manager.register_pass<pass::ConstantFolding>();
744 pass_manager.run_passes(f);
745 pass_manager.run_passes(f);
747 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
748 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
749 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
750 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
752 auto result_as_concat =
753 as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
754 ASSERT_TRUE(result_as_concat);
755 ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
756 ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
759 // Constant folding will not succeed on ShapeOf if the argument rank is dynamic.
760 // We want to make sure it fails gracefully, leaving the ShapeOf op in place.
761 TEST(constant_folding, shape_of_rank_dynamic_v0)
763 PartialShape input_shape{PartialShape::dynamic()};
765 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
766 auto shape_of = make_shared<op::v0::ShapeOf>(param);
767 shape_of->set_friendly_name("test");
768 auto f = make_shared<Function>(shape_of, ParameterVector{param});
770 pass::Manager pass_manager;
771 pass_manager.register_pass<pass::ConstantFolding>();
772 pass_manager.run_passes(f);
774 ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
775 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
777 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
778 ASSERT_EQ(result_shape_of, shape_of);
779 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
782 TEST(constant_folding, shape_of_rank_dynamic_v3)
784 PartialShape input_shape{PartialShape::dynamic()};
786 auto param = make_shared<op::Parameter>(element::boolean, input_shape);
787 auto shape_of = make_shared<op::v3::ShapeOf>(param);
788 shape_of->set_friendly_name("test");
789 auto f = make_shared<Function>(shape_of, ParameterVector{param});
791 pass::Manager pass_manager;
792 pass_manager.register_pass<pass::ConstantFolding>();
793 pass_manager.run_passes(f);
795 ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
796 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
798 auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
799 ASSERT_EQ(result_shape_of, shape_of);
800 ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
803 TEST(constant_folding, const_reverse)
805 Shape input_shape{3, 3};
807 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
808 auto constant = op::Constant::create(element::i32, input_shape, values_in);
809 auto convert = make_shared<op::Reverse>(constant, AxisSet{1});
810 convert->set_friendly_name("test");
811 auto f = make_shared<Function>(convert, ParameterVector{});
813 pass::Manager pass_manager;
814 pass_manager.register_pass<pass::ConstantFolding>();
815 pass_manager.run_passes(f);
817 ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 0);
818 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
821 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
822 ASSERT_TRUE(new_const);
823 ASSERT_EQ(new_const->get_friendly_name(), "test");
824 auto values_out = new_const->get_vector<int32_t>();
826 vector<int32_t> values_expected{3, 2, 1, 6, 5, 4, 9, 8, 7};
827 ASSERT_EQ(values_expected, values_out);
830 TEST(constant_folding, const_reduceprod)
832 Shape input_shape{3, 3};
833 Shape output_shape{3};
835 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
836 auto constant = op::Constant::create(element::i32, input_shape, values_in);
838 vector<int32_t> values_axes{1};
839 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
840 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes);
841 convert->set_friendly_name("test");
842 auto f = make_shared<Function>(convert, ParameterVector{});
844 pass::Manager pass_manager;
845 pass_manager.register_pass<pass::ConstantFolding>();
846 pass_manager.run_passes(f);
848 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
849 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
852 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
853 ASSERT_TRUE(new_const);
854 ASSERT_EQ(new_const->get_friendly_name(), "test");
855 ASSERT_EQ(new_const->get_shape(), output_shape);
857 auto values_out = new_const->get_vector<int32_t>();
859 vector<int32_t> values_expected{6, 120, 504};
861 ASSERT_EQ(values_expected, values_out);
864 TEST(constant_folding, const_reduceprod_keepdims)
866 Shape input_shape{3, 3};
867 Shape output_shape{3, 1};
869 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
870 auto constant = op::Constant::create(element::i32, input_shape, values_in);
872 vector<int32_t> values_axes{1};
873 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
874 auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes, true);
875 convert->set_friendly_name("test");
876 auto f = make_shared<Function>(convert, ParameterVector{});
878 pass::Manager pass_manager;
879 pass_manager.register_pass<pass::ConstantFolding>();
880 pass_manager.run_passes(f);
882 ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
883 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
886 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
887 ASSERT_TRUE(new_const);
888 ASSERT_EQ(new_const->get_friendly_name(), "test");
889 ASSERT_EQ(new_const->get_shape(), output_shape);
891 auto values_out = new_const->get_vector<int32_t>();
893 vector<int32_t> values_expected{6, 120, 504};
895 ASSERT_EQ(values_expected, values_out);
898 TEST(constant_folding, const_sum)
900 Shape input_shape{3, 3};
902 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
903 auto constant = op::Constant::create(element::i32, input_shape, values_in);
904 auto convert = make_shared<op::Sum>(constant, AxisSet{1});
905 convert->set_friendly_name("test");
906 auto f = make_shared<Function>(convert, ParameterVector{});
908 pass::Manager pass_manager;
909 pass_manager.register_pass<pass::ConstantFolding>();
910 pass_manager.run_passes(f);
912 ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
913 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
916 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
917 ASSERT_TRUE(new_const);
918 ASSERT_EQ(new_const->get_friendly_name(), "test");
919 auto values_out = new_const->get_vector<int32_t>();
921 vector<int32_t> values_expected{6, 15, 24};
923 ASSERT_EQ(values_expected, values_out);
926 TEST(constant_folding, const_reducesum)
928 Shape input_shape{3, 3};
929 Shape output_shape{3};
931 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
932 auto constant = op::Constant::create(element::i32, input_shape, values_in);
934 vector<int32_t> values_axes{1};
935 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
936 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes);
937 convert->set_friendly_name("test");
938 auto f = make_shared<Function>(convert, ParameterVector{});
940 pass::Manager pass_manager;
941 pass_manager.register_pass<pass::ConstantFolding>();
942 pass_manager.run_passes(f);
944 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
945 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
948 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
949 ASSERT_TRUE(new_const);
950 ASSERT_EQ(new_const->get_friendly_name(), "test");
951 ASSERT_EQ(new_const->get_shape(), output_shape);
953 auto values_out = new_const->get_vector<int32_t>();
955 vector<int32_t> values_expected{6, 15, 24};
957 ASSERT_EQ(values_expected, values_out);
960 TEST(constant_folding, const_reducesum_keepdims)
962 Shape input_shape{3, 3};
963 Shape output_shape{3, 1};
965 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
966 auto constant = op::Constant::create(element::i32, input_shape, values_in);
968 vector<int32_t> values_axes{1};
969 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
970 auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes, true);
971 convert->set_friendly_name("test");
972 auto f = make_shared<Function>(convert, ParameterVector{});
974 pass::Manager pass_manager;
975 pass_manager.register_pass<pass::ConstantFolding>();
976 pass_manager.run_passes(f);
978 ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
979 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
982 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
983 ASSERT_TRUE(new_const);
984 ASSERT_EQ(new_const->get_friendly_name(), "test");
985 ASSERT_EQ(new_const->get_shape(), output_shape);
987 auto values_out = new_const->get_vector<int32_t>();
989 vector<int32_t> values_expected{6, 15, 24};
991 ASSERT_EQ(values_expected, values_out);
994 TEST(constant_folding, const_max)
996 Shape input_shape{3, 3};
998 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
999 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1000 auto convert = make_shared<op::Max>(constant, AxisSet{1});
1001 convert->set_friendly_name("test");
1002 auto f = make_shared<Function>(convert, ParameterVector{});
1004 pass::Manager pass_manager;
1005 pass_manager.register_pass<pass::ConstantFolding>();
1006 pass_manager.run_passes(f);
1008 ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
1009 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1012 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1013 ASSERT_TRUE(new_const);
1014 ASSERT_EQ(new_const->get_friendly_name(), "test");
1015 auto values_out = new_const->get_vector<int32_t>();
1017 vector<int32_t> values_expected{3, 6, 9};
1019 ASSERT_EQ(values_expected, values_out);
1022 TEST(constant_folding, const_reducemax)
1024 Shape input_shape{3, 2};
1025 Shape output_shape{3};
1027 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1028 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1029 Shape axes_shape{1};
1030 vector<int32_t> values_axes{1};
1031 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1032 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes);
1033 convert->set_friendly_name("test");
1034 auto f = make_shared<Function>(convert, ParameterVector{});
1036 pass::Manager pass_manager;
1037 pass_manager.register_pass<pass::ConstantFolding>();
1038 pass_manager.run_passes(f);
1040 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1041 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1044 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1045 ASSERT_TRUE(new_const);
1046 ASSERT_EQ(new_const->get_friendly_name(), "test");
1047 ASSERT_EQ(new_const->get_shape(), output_shape);
1049 auto values_out = new_const->get_vector<int32_t>();
1051 vector<int32_t> values_expected{2, 4, 6};
1053 ASSERT_EQ(values_expected, values_out);
1056 TEST(constant_folding, const_reducemax_keepdims)
1058 Shape input_shape{3, 2};
1059 Shape output_shape{3, 1};
1061 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1062 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1063 Shape axes_shape{1};
1064 vector<int32_t> values_axes{1};
1065 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1066 auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes, true);
1067 convert->set_friendly_name("test");
1068 auto f = make_shared<Function>(convert, ParameterVector{});
1070 pass::Manager pass_manager;
1071 pass_manager.register_pass<pass::ConstantFolding>();
1072 pass_manager.run_passes(f);
1074 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1075 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1078 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1079 ASSERT_TRUE(new_const);
1080 ASSERT_EQ(new_const->get_friendly_name(), "test");
1081 ASSERT_EQ(new_const->get_shape(), output_shape);
1083 auto values_out = new_const->get_vector<int32_t>();
1085 vector<int32_t> values_expected{2, 4, 6};
1087 ASSERT_EQ(values_expected, values_out);
1090 TEST(constant_folding, const_min)
1092 Shape input_shape{3, 3};
1094 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1095 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1096 auto convert = make_shared<op::Min>(constant, AxisSet{1});
1097 convert->set_friendly_name("test");
1098 auto f = make_shared<Function>(convert, ParameterVector{});
1100 pass::Manager pass_manager;
1101 pass_manager.register_pass<pass::ConstantFolding>();
1102 pass_manager.run_passes(f);
1104 ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
1105 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1108 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1109 ASSERT_TRUE(new_const);
1110 ASSERT_EQ(new_const->get_friendly_name(), "test");
1111 auto values_out = new_const->get_vector<int32_t>();
1113 vector<int32_t> values_expected{1, 4, 7};
1115 ASSERT_EQ(values_expected, values_out);
1118 TEST(constant_folding, const_reducemin)
1120 Shape input_shape{3, 2};
1121 Shape output_shape{3};
1123 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1124 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1125 Shape axes_shape{1};
1126 vector<int32_t> values_axes{1};
1127 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1128 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes);
1129 convert->set_friendly_name("test");
1130 auto f = make_shared<Function>(convert, ParameterVector{});
1132 pass::Manager pass_manager;
1133 pass_manager.register_pass<pass::ConstantFolding>();
1134 pass_manager.run_passes(f);
1136 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1137 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1140 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1141 ASSERT_TRUE(new_const);
1142 ASSERT_EQ(new_const->get_friendly_name(), "test");
1143 ASSERT_EQ(new_const->get_shape(), output_shape);
1145 auto values_out = new_const->get_vector<int32_t>();
1147 vector<int32_t> values_expected{1, 3, 5};
1149 ASSERT_EQ(values_expected, values_out);
1152 TEST(constant_folding, const_reducemin_keepdims)
1154 Shape input_shape{3, 2};
1155 Shape output_shape{3, 1};
1157 vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1158 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1159 Shape axes_shape{1};
1160 vector<int32_t> values_axes{1};
1161 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1162 auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes, true);
1163 convert->set_friendly_name("test");
1164 auto f = make_shared<Function>(convert, ParameterVector{});
1166 pass::Manager pass_manager;
1167 pass_manager.register_pass<pass::ConstantFolding>();
1168 pass_manager.run_passes(f);
1170 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1171 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1174 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1175 ASSERT_TRUE(new_const);
1176 ASSERT_EQ(new_const->get_friendly_name(), "test");
1177 ASSERT_EQ(new_const->get_shape(), output_shape);
1179 auto values_out = new_const->get_vector<int32_t>();
1181 vector<int32_t> values_expected{1, 3, 5};
1183 ASSERT_EQ(values_expected, values_out);
1186 TEST(constant_folding, const_reducemean)
1188 Shape input_shape{3, 3};
1189 Shape output_shape{3};
1191 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1192 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1193 Shape axes_shape{1};
1194 vector<int32_t> values_axes{1};
1195 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1196 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes);
1197 convert->set_friendly_name("test");
1198 auto f = make_shared<Function>(convert, ParameterVector{});
1200 pass::Manager pass_manager;
1201 pass_manager.register_pass<pass::ConstantFolding>();
1202 pass_manager.run_passes(f);
1204 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1205 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1208 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1209 ASSERT_TRUE(new_const);
1210 ASSERT_EQ(new_const->get_friendly_name(), "test");
1211 ASSERT_EQ(new_const->get_shape(), output_shape);
1213 auto values_out = new_const->get_vector<int32_t>();
1215 vector<int32_t> values_expected{2, 5, 8};
1217 ASSERT_EQ(values_expected, values_out);
1220 TEST(constant_folding, const_reducemean_keepdims)
1222 Shape input_shape{3, 3};
1223 Shape output_shape{3, 1};
1225 vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1226 auto constant = op::Constant::create(element::i32, input_shape, values_in);
1227 Shape axes_shape{1};
1228 vector<int32_t> values_axes{1};
1229 auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1230 auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes, true);
1231 convert->set_friendly_name("test");
1232 auto f = make_shared<Function>(convert, ParameterVector{});
1234 pass::Manager pass_manager;
1235 pass_manager.register_pass<pass::ConstantFolding>();
1236 pass_manager.run_passes(f);
1238 ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1239 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1242 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1243 ASSERT_TRUE(new_const);
1244 ASSERT_EQ(new_const->get_friendly_name(), "test");
1245 ASSERT_EQ(new_const->get_shape(), output_shape);
1247 auto values_out = new_const->get_vector<int32_t>();
1249 vector<int32_t> values_expected{2, 5, 8};
1251 ASSERT_EQ(values_expected, values_out);
1254 TEST(constant_folding, const_reduce_logical_and__no_keepdims)
1256 const Shape input_shape{3, 3};
1258 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1259 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1260 const auto axes = op::Constant::create(element::i64, {1}, {1});
1261 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, false);
1262 convert->set_friendly_name("test");
1263 auto f = make_shared<Function>(convert, ParameterVector{});
1265 pass::Manager pass_manager;
1266 pass_manager.register_pass<pass::ConstantFolding>();
1267 pass_manager.run_passes(f);
1269 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1270 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1272 const auto new_const =
1273 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1274 ASSERT_TRUE(new_const);
1275 ASSERT_EQ(new_const->get_friendly_name(), "test");
1277 const Shape expected_out_shape{3};
1278 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1280 const auto values_out = new_const->get_vector<char>();
1282 const vector<char> values_expected{0, 0, 1};
1284 ASSERT_EQ(values_expected, values_out);
1287 TEST(constant_folding, const_reduce_logical_and__keepdims)
1289 const Shape input_shape{3, 3};
1291 const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1292 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1293 const auto axes = op::Constant::create(element::i64, {1}, {1});
1294 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1295 convert->set_friendly_name("test");
1296 auto f = make_shared<Function>(convert, ParameterVector{});
1298 pass::Manager pass_manager;
1299 pass_manager.register_pass<pass::ConstantFolding>();
1300 pass_manager.run_passes(f);
1302 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1303 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1305 const auto new_const =
1306 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1307 ASSERT_TRUE(new_const);
1308 ASSERT_EQ(new_const->get_friendly_name(), "test");
1310 // the output shape is expected to have 'ones' at the positions specified in the reduction axes
1311 // in case the keep_dims attribute of ReduceLogicalAnd is set to true
1312 const Shape expected_out_shape{3, 1};
1313 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1315 const auto values_out = new_const->get_vector<char>();
1317 const vector<char> values_expected{0, 0, 1};
1319 ASSERT_EQ(values_expected, values_out);
1322 TEST(constant_folding, const_reduce_logical_and__keepdims_3d)
1324 const Shape input_shape{2, 2, 2};
1326 const vector<char> values_in{1, 1, 0, 0, 1, 0, 0, 1};
1327 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1328 const auto axes = op::Constant::create(element::i64, {2}, {0, 2});
1329 const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1330 convert->set_friendly_name("test");
1331 auto f = make_shared<Function>(convert, ParameterVector{});
1333 pass::Manager pass_manager;
1334 pass_manager.register_pass<pass::ConstantFolding>();
1335 pass_manager.run_passes(f);
1337 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1338 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1340 const auto new_const =
1341 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1342 ASSERT_TRUE(new_const);
1343 ASSERT_EQ(new_const->get_friendly_name(), "test");
1345 const Shape expected_out_shape{1, 2, 1};
1346 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1348 const auto values_out = new_const->get_vector<char>();
1350 const vector<char> values_expected{0, 0};
1352 ASSERT_EQ(values_expected, values_out);
1355 TEST(constant_folding, const_reduce_logical_or__no_keepdims)
1357 const Shape input_shape{3, 3};
1359 const vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
1360 const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1361 const auto axes = op::Constant::create(element::i64, {1}, {1});
1362 const auto convert = make_shared<op::v1::ReduceLogicalOr>(data, axes, false);
1363 convert->set_friendly_name("test");
1364 auto f = make_shared<Function>(convert, ParameterVector{});
1366 pass::Manager pass_manager;
1367 pass_manager.register_pass<pass::ConstantFolding>();
1368 pass_manager.run_passes(f);
1370 ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1371 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1373 const auto new_const =
1374 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1375 ASSERT_TRUE(new_const);
1376 ASSERT_EQ(new_const->get_friendly_name(), "test");
1378 const Shape expected_out_shape{3};
1379 ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1381 const auto values_out = new_const->get_vector<char>();
1383 const vector<char> values_expected{1, 1, 0};
1385 ASSERT_EQ(values_expected, values_out);
1388 TEST(constant_folding, const_concat)
1391 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1392 auto constant1 = op::Constant::create(element::i32, Shape{2, 1}, vector<int32_t>{7, 8});
1393 auto concat = make_shared<op::Concat>(NodeVector{constant0, constant1}, 1);
1394 concat->set_friendly_name("test");
1395 auto f = make_shared<Function>(concat, ParameterVector{});
1397 pass::Manager pass_manager;
1398 pass_manager.register_pass<pass::ConstantFolding>();
1399 pass_manager.run_passes(f);
1401 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1402 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1405 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1406 ASSERT_TRUE(new_const);
1407 ASSERT_EQ(new_const->get_friendly_name(), "test");
1408 auto values_out = new_const->get_vector<int32_t>();
1410 vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
1412 ASSERT_EQ(values_expected, values_out);
1415 TEST(constant_folding, const_concat_3d_single_elem)
1417 auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
1418 auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
1419 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
1420 concat->set_friendly_name("test");
1421 auto f = make_shared<Function>(concat, ParameterVector{});
1423 pass::Manager pass_manager;
1424 pass_manager.register_pass<pass::ConstantFolding>();
1425 pass_manager.run_passes(f);
1427 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1428 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1431 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1433 ASSERT_TRUE(new_const);
1434 ASSERT_EQ(new_const->get_friendly_name(), "test");
1435 ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
1437 auto values_out = new_const->get_vector<int32_t>();
1438 vector<int32_t> values_expected{1, 2};
1439 ASSERT_EQ(values_expected, values_out);
1442 TEST(constant_folding, const_concat_axis_2)
1445 op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1446 auto constant_2 = op::Constant::create(
1447 element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
1448 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
1449 concat->set_friendly_name("test");
1450 auto f = make_shared<Function>(concat, ParameterVector{});
1452 pass::Manager pass_manager;
1453 pass_manager.register_pass<pass::ConstantFolding>();
1454 pass_manager.run_passes(f);
1456 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1457 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1460 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1462 ASSERT_TRUE(new_const);
1463 ASSERT_EQ(new_const->get_friendly_name(), "test");
1464 ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
1466 auto values_out = new_const->get_vector<int32_t>();
1467 vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
1468 ASSERT_EQ(values_expected, values_out);
1471 TEST(constant_folding, const_concat_axis_1_bool_type)
1474 op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
1475 auto constant_2 = op::Constant::create(
1476 element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
1477 auto constant_3 = op::Constant::create(
1478 element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
1479 auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
1480 concat->set_friendly_name("test");
1481 auto f = make_shared<Function>(concat, ParameterVector{});
1483 pass::Manager pass_manager;
1484 pass_manager.register_pass<pass::ConstantFolding>();
1485 pass_manager.run_passes(f);
1487 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1488 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1491 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1493 ASSERT_TRUE(new_const);
1494 ASSERT_EQ(new_const->get_friendly_name(), "test");
1495 ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
1497 auto values_out = new_const->get_vector<char>();
1498 vector<char> values_expected{
1499 true, true, true, false, true, false, true, false, true, false, true, false};
1500 ASSERT_EQ(values_expected, values_out);
1503 TEST(constant_folding, const_not)
1506 op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
1507 auto logical_not = make_shared<op::Not>(constant);
1508 logical_not->set_friendly_name("test");
1509 auto f = make_shared<Function>(logical_not, ParameterVector{});
1511 pass::Manager pass_manager;
1512 pass_manager.register_pass<pass::ConstantFolding>();
1513 pass_manager.run_passes(f);
1515 ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
1516 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1519 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1520 ASSERT_TRUE(new_const);
1521 ASSERT_EQ(new_const->get_friendly_name(), "test");
1522 auto values_out = new_const->get_vector<char>();
1524 vector<char> values_expected{1, 0, 1, 1, 0, 0};
1526 ASSERT_EQ(values_expected, values_out);
1529 TEST(constant_folding, const_equal)
1532 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1534 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1535 auto eq = make_shared<op::Equal>(constant0, constant1);
1536 eq->set_friendly_name("test");
1537 auto f = make_shared<Function>(eq, ParameterVector{});
1539 pass::Manager pass_manager;
1540 pass_manager.register_pass<pass::ConstantFolding>();
1541 pass_manager.run_passes(f);
1543 ASSERT_EQ(count_ops_of_type<op::Equal>(f), 0);
1544 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1547 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1548 ASSERT_TRUE(new_const);
1549 ASSERT_EQ(new_const->get_friendly_name(), "test");
1550 auto values_out = new_const->get_vector<char>();
1552 vector<char> values_expected{1, 1, 0, 0, 1, 1};
1554 ASSERT_EQ(values_expected, values_out);
1557 TEST(constant_folding, const_not_equal)
1560 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1562 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1563 auto eq = make_shared<op::NotEqual>(constant0, constant1);
1564 eq->set_friendly_name("test");
1565 auto f = make_shared<Function>(eq, ParameterVector{});
1567 pass::Manager pass_manager;
1568 pass_manager.register_pass<pass::ConstantFolding>();
1569 pass_manager.run_passes(f);
1571 ASSERT_EQ(count_ops_of_type<op::NotEqual>(f), 0);
1572 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1575 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1576 ASSERT_TRUE(new_const);
1577 ASSERT_EQ(new_const->get_friendly_name(), "test");
1578 auto values_out = new_const->get_vector<char>();
1580 vector<char> values_expected{0, 0, 1, 1, 0, 0};
1582 ASSERT_EQ(values_expected, values_out);
1585 TEST(constant_folding, const_greater)
1588 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1590 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1591 auto eq = make_shared<op::Greater>(constant0, constant1);
1592 eq->set_friendly_name("test");
1593 auto f = make_shared<Function>(eq, ParameterVector{});
1595 pass::Manager pass_manager;
1596 pass_manager.register_pass<pass::ConstantFolding>();
1597 pass_manager.run_passes(f);
1599 ASSERT_EQ(count_ops_of_type<op::Greater>(f), 0);
1600 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1603 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1604 ASSERT_TRUE(new_const);
1605 ASSERT_EQ(new_const->get_friendly_name(), "test");
1606 auto values_out = new_const->get_vector<char>();
1608 vector<char> values_expected{0, 0, 1, 0, 0, 1};
1610 ASSERT_EQ(values_expected, values_out);
1613 TEST(constant_folding, const_greater_eq)
1616 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1618 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1619 auto eq = make_shared<op::GreaterEq>(constant0, constant1);
1620 eq->set_friendly_name("test");
1621 auto f = make_shared<Function>(eq, ParameterVector{});
1623 pass::Manager pass_manager;
1624 pass_manager.register_pass<pass::ConstantFolding>();
1625 pass_manager.run_passes(f);
1627 ASSERT_EQ(count_ops_of_type<op::GreaterEq>(f), 0);
1628 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1631 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1632 ASSERT_TRUE(new_const);
1633 ASSERT_EQ(new_const->get_friendly_name(), "test");
1634 auto values_out = new_const->get_vector<char>();
1636 vector<char> values_expected{0, 1, 1, 0, 1, 1};
1638 ASSERT_EQ(values_expected, values_out);
1641 TEST(constant_folding, const_less)
1644 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1646 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1647 auto eq = make_shared<op::Less>(constant0, constant1);
1648 eq->set_friendly_name("test");
1649 auto f = make_shared<Function>(eq, ParameterVector{});
1651 pass::Manager pass_manager;
1652 pass_manager.register_pass<pass::ConstantFolding>();
1653 pass_manager.run_passes(f);
1655 ASSERT_EQ(count_ops_of_type<op::Less>(f), 0);
1656 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1659 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1660 ASSERT_TRUE(new_const);
1661 ASSERT_EQ(new_const->get_friendly_name(), "test");
1662 auto values_out = new_const->get_vector<char>();
1664 vector<char> values_expected{1, 0, 0, 1, 0, 0};
1666 ASSERT_EQ(values_expected, values_out);
1669 TEST(constant_folding, const_less_eq)
1672 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1674 op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1675 auto eq = make_shared<op::LessEq>(constant0, constant1);
1676 eq->set_friendly_name("test");
1677 auto f = make_shared<Function>(eq, ParameterVector{});
1679 pass::Manager pass_manager;
1680 pass_manager.register_pass<pass::ConstantFolding>();
1681 pass_manager.run_passes(f);
1683 ASSERT_EQ(count_ops_of_type<op::LessEq>(f), 0);
1684 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1687 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1688 ASSERT_TRUE(new_const);
1689 ASSERT_EQ(new_const->get_friendly_name(), "test");
1690 auto values_out = new_const->get_vector<char>();
1692 vector<char> values_expected{1, 1, 0, 1, 1, 0};
1694 ASSERT_EQ(values_expected, values_out);
1697 TEST(constant_folding, const_or)
1700 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1702 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1703 auto eq = make_shared<op::Or>(constant0, constant1);
1704 eq->set_friendly_name("test");
1705 auto f = make_shared<Function>(eq, ParameterVector{});
1707 pass::Manager pass_manager;
1708 pass_manager.register_pass<pass::ConstantFolding>();
1709 pass_manager.run_passes(f);
1711 ASSERT_EQ(count_ops_of_type<op::Or>(f), 0);
1712 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1715 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1716 ASSERT_TRUE(new_const);
1717 ASSERT_EQ(new_const->get_friendly_name(), "test");
1718 auto values_out = new_const->get_vector<char>();
1720 vector<char> values_expected{0, 1, 1, 1, 1, 1};
1722 ASSERT_EQ(values_expected, values_out);
1725 TEST(constant_folding, const_xor)
1728 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1730 op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1731 auto eq = make_shared<op::Xor>(constant0, constant1);
1732 eq->set_friendly_name("test");
1733 auto f = make_shared<Function>(eq, ParameterVector{});
1735 pass::Manager pass_manager;
1736 pass_manager.register_pass<pass::ConstantFolding>();
1737 pass_manager.run_passes(f);
1739 ASSERT_EQ(count_ops_of_type<op::Xor>(f), 0);
1740 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1743 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1744 ASSERT_TRUE(new_const);
1745 ASSERT_EQ(new_const->get_friendly_name(), "test");
1746 auto values_out = new_const->get_vector<char>();
1748 vector<char> values_expected{0, 1, 0, 1, 1, 0};
1750 ASSERT_EQ(values_expected, values_out);
1753 TEST(constant_folding, const_ceiling)
1755 auto constant = op::Constant::create(
1756 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1757 auto ceil = make_shared<op::Ceiling>(constant);
1758 ceil->set_friendly_name("test");
1759 auto f = make_shared<Function>(ceil, ParameterVector{});
1761 pass::Manager pass_manager;
1762 pass_manager.register_pass<pass::ConstantFolding>();
1763 pass_manager.run_passes(f);
1765 ASSERT_EQ(count_ops_of_type<op::Ceiling>(f), 0);
1766 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1769 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1770 ASSERT_TRUE(new_const);
1771 ASSERT_EQ(new_const->get_friendly_name(), "test");
1772 auto values_out = new_const->get_vector<float>();
1774 vector<float> values_expected{0.0f, 1.0f, 0.0f, -2.0f, 3.0f, 3.0f};
1776 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1779 TEST(constant_folding, const_floor)
1781 auto constant = op::Constant::create(
1782 element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1783 auto floor = make_shared<op::Floor>(constant);
1784 floor->set_friendly_name("test");
1785 auto f = make_shared<Function>(floor, ParameterVector{});
1787 pass::Manager pass_manager;
1788 pass_manager.register_pass<pass::ConstantFolding>();
1789 pass_manager.run_passes(f);
1791 ASSERT_EQ(count_ops_of_type<op::Floor>(f), 0);
1792 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1795 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1796 ASSERT_TRUE(new_const);
1797 ASSERT_EQ(new_const->get_friendly_name(), "test");
1798 auto values_out = new_const->get_vector<float>();
1800 vector<float> values_expected{0.0f, 0.0f, -1.0f, -3.0f, 2.0f, 3.0f};
1802 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1805 TEST(constant_folding, const_gather_v1)
1807 auto constant_data = op::Constant::create(
1810 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1811 auto constant_indices =
1812 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1813 auto constant_axis = op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
1814 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1815 gather->set_friendly_name("test");
1816 auto f = make_shared<Function>(gather, ParameterVector{});
1818 pass::Manager pass_manager;
1819 pass_manager.register_pass<pass::ConstantFolding>();
1820 pass_manager.run_passes(f);
1822 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1823 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1826 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1827 ASSERT_TRUE(new_const);
1828 ASSERT_EQ(new_const->get_friendly_name(), "test");
1829 auto values_out = new_const->get_vector<float>();
1831 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1833 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1836 TEST(constant_folding, const_gather_v1_scalar)
1838 auto constant_data = op::Constant::create(
1841 vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1842 auto constant_indices =
1843 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1844 auto constant_axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{1});
1845 auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1846 gather->set_friendly_name("test");
1847 auto f = make_shared<Function>(gather, ParameterVector{});
1849 pass::Manager pass_manager;
1850 pass_manager.register_pass<pass::ConstantFolding>();
1851 pass_manager.run_passes(f);
1853 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1854 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1857 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1858 ASSERT_TRUE(new_const);
1859 ASSERT_EQ(new_const->get_friendly_name(), "test");
1860 auto values_out = new_const->get_vector<float>();
1862 vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1864 ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1867 TEST(constant_folding, const_gather_v1_subgraph)
1869 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1870 const float b_value = 3.21f;
1871 const auto B_const = op::Constant::create(element::f32, {1}, {b_value});
1872 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1873 const int64_t axis = 0;
1874 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1876 const auto concat = make_shared<op::Concat>(NodeVector{A, B_const, C}, axis);
1878 const vector<int64_t> indices{1};
1879 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1880 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1881 gather->set_friendly_name("test");
1882 auto f = make_shared<Function>(gather, ParameterVector{A, C});
1884 pass::Manager pass_manager;
1885 pass_manager.register_pass<pass::ConstantFolding>();
1886 pass_manager.run_passes(f);
1888 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1889 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1890 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1892 const auto new_const =
1893 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1894 ASSERT_TRUE(new_const);
1895 ASSERT_EQ(new_const->get_friendly_name(), "test");
1897 const auto values_out = new_const->get_vector<float>();
1898 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1901 TEST(constant_folding, const_gather_v1_subgraph_neg_axis)
1903 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1904 const float b_value = 1.23f;
1905 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1906 const auto C_const = op::Constant::create(element::f32, {1}, {b_value});
1907 const int64_t axis = 0;
1908 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1910 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C_const}, axis);
1912 const vector<int64_t> indices{-1};
1913 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1914 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1915 gather->set_friendly_name("test");
1916 auto f = make_shared<Function>(gather, ParameterVector{A, B});
1918 pass::Manager pass_manager;
1919 pass_manager.register_pass<pass::ConstantFolding>();
1920 pass_manager.run_passes(f);
1922 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1923 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1924 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1926 const auto new_const =
1927 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1928 ASSERT_TRUE(new_const);
1929 ASSERT_EQ(new_const->get_friendly_name(), "test");
1931 const auto values_out = new_const->get_vector<float>();
1932 ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1935 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input)
1937 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1938 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1939 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1940 const int64_t axis = 0;
1941 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1943 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1945 const vector<int64_t> indices{1};
1946 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1947 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1948 gather->set_friendly_name("test");
1949 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1951 pass::Manager pass_manager;
1952 pass_manager.register_pass<pass::ConstantFolding>();
1953 pass_manager.run_passes(f);
1955 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1956 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1959 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input_scalar)
1961 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1962 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1963 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1964 const int64_t axis = 0;
1965 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1967 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1969 const vector<int64_t> indices{1};
1970 const auto indices_const = op::Constant::create(element::i64, {}, indices);
1971 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1972 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1974 pass::Manager pass_manager;
1975 pass_manager.register_pass<pass::ConstantFolding>();
1976 pass_manager.run_passes(f);
1978 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1979 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1980 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(f), 1);
1983 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_zero_axis)
1985 const auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1986 const auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1987 const auto C = make_shared<op::Parameter>(element::f32, Shape{2, 2});
1988 const int64_t axis = 1;
1989 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1991 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1993 const vector<int64_t> indices{1};
1994 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1995 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1996 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1998 pass::Manager pass_manager;
1999 pass_manager.register_pass<pass::ConstantFolding>();
2000 pass_manager.run_passes(f);
2002 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2003 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2006 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_single_indices)
2008 const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
2009 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2010 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2011 const int64_t axis = 0;
2012 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2014 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2016 const vector<int64_t> indices{0, 1};
2017 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2018 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2019 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2021 pass::Manager pass_manager;
2022 pass_manager.register_pass<pass::ConstantFolding>();
2023 pass_manager.run_passes(f);
2025 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2026 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2029 TEST(constant_folding, const_gather_v1_subgraph_skip_if_concat_output_shape_dynamic)
2031 const auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
2032 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2033 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2034 const int64_t axis = 0;
2035 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2037 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2039 const vector<int64_t> indices{1};
2040 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2041 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2042 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2044 pass::Manager pass_manager;
2045 pass_manager.register_pass<pass::ConstantFolding>();
2046 pass_manager.run_passes(f);
2048 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2049 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2052 TEST(constant_folding, const_gather_v1_subgraph_skip_if_not_single_input)
2054 const auto A = make_shared<op::Parameter>(element::f32, Shape{2});
2055 const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2056 const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2057 const int64_t axis = 0;
2058 const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2060 const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2062 const vector<int64_t> indices{1};
2063 const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2064 const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2065 auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2067 pass::Manager pass_manager;
2068 pass_manager.register_pass<pass::ConstantFolding>();
2069 pass_manager.run_passes(f);
2071 ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2072 ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2075 TEST(constant_folding, const_slice)
2079 vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
2080 auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
2081 auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
2082 slice->set_friendly_name("test");
2084 auto f = make_shared<Function>(slice, ParameterVector{});
2086 pass::Manager pass_manager;
2087 pass_manager.register_pass<pass::ConstantFolding>();
2088 pass_manager.run_passes(f);
2090 ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
2091 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2094 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2095 ASSERT_TRUE(new_const);
2096 ASSERT_EQ(new_const->get_friendly_name(), "test");
2097 auto values_out = new_const->get_vector<int>();
2099 vector<int> sliced_values{3, 6, 9, 12, 15};
2100 ASSERT_EQ(sliced_values, values_out);
2103 TEST(constant_folding, constant_dyn_reshape)
2105 Shape shape_in{2, 4};
2106 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2108 Shape shape_shape{3};
2109 vector<int64_t> values_shape{2, 4, 1};
2111 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2112 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
2113 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
2114 dyn_reshape->set_friendly_name("test");
2115 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2117 pass::Manager pass_manager;
2118 pass_manager.register_pass<pass::ConstantFolding>();
2119 pass_manager.run_passes(f);
2121 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2122 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2125 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2126 ASSERT_TRUE(new_const);
2127 ASSERT_EQ(new_const->get_friendly_name(), "test");
2128 auto values_out = new_const->get_vector<float>();
2130 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2133 TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
2135 Shape shape_in{2, 4};
2136 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2138 Shape shape_shape{3};
2139 // We're going to add these two together elementwise to get {2, 4, 1}.
2140 // This means that when ConstantFolding starts, v1::Reshape will not yet
2141 // have static output shape. But by the time the Add op is folded, the
2142 // v1::Reshape's shape should be inferrable.
2143 vector<int64_t> values_shape_a{1, 3, 0};
2144 vector<int64_t> values_shape_b{1, 1, 1};
2146 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2147 auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
2148 auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
2150 make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
2151 dyn_reshape->set_friendly_name("test");
2152 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2154 ASSERT_TRUE(dyn_reshape->get_output_partial_shape(0).is_dynamic());
2156 pass::Manager pass_manager;
2157 pass_manager.register_pass<pass::ConstantFolding>();
2158 pass_manager.run_passes(f);
2160 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2161 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2164 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2165 ASSERT_TRUE(new_const);
2166 ASSERT_EQ(new_const->get_friendly_name(), "test");
2167 auto values_out = new_const->get_vector<float>();
2169 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2172 TEST(constant_folding, constant_transpose)
2174 Shape shape_in{2, 4};
2175 vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2177 Shape shape_perm{2};
2178 vector<int64_t> values_perm{1, 0};
2180 auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
2181 auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
2182 auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
2183 transpose->set_friendly_name("test");
2184 auto f = make_shared<Function>(transpose, ParameterVector{});
2186 pass::Manager pass_manager;
2187 pass_manager.register_pass<pass::ConstantFolding>();
2188 pass_manager.run_passes(f);
2190 ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
2191 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2194 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2195 ASSERT_TRUE(new_const);
2196 ASSERT_EQ(new_const->get_friendly_name(), "test");
2197 auto values_out = new_const->get_vector<double>();
2199 vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
2200 ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
2203 template <typename T>
2204 void range_test(T start, T stop, T step, const vector<T>& values_expected)
2206 vector<T> values_start{start};
2207 vector<T> values_stop{stop};
2208 vector<T> values_step{step};
2210 auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
2211 auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
2212 auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
2213 auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
2214 range->set_friendly_name("test");
2215 auto f = make_shared<Function>(range, ParameterVector{});
2217 pass::Manager pass_manager;
2218 pass_manager.register_pass<pass::ConstantFolding>();
2219 pass_manager.run_passes(f);
2221 ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
2222 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2225 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2226 ASSERT_TRUE(new_const);
2227 ASSERT_EQ(new_const->get_friendly_name(), "test");
2229 auto values_out = new_const->template get_vector<T>();
2231 range_test_check(values_out, values_expected);
2234 TEST(constant_folding, constant_range)
2236 range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
2237 range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
2238 range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
2239 range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
2240 range_test<double>(5, 12, 2, {5, 7, 9, 11});
2241 range_test<float>(5, 12, 2, {5, 7, 9, 11});
2243 range_test<int32_t>(5, 12, -2, {});
2244 range_test<float>(12, 4, -2, {12, 10, 8, 6});
2247 TEST(constant_folding, constant_select)
2250 vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
2251 vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
2252 vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
2254 auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
2255 auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
2256 auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
2257 auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
2258 select->set_friendly_name("test");
2259 auto f = make_shared<Function>(select, ParameterVector{});
2261 pass::Manager pass_manager;
2262 pass_manager.register_pass<pass::ConstantFolding>();
2263 pass_manager.run_passes(f);
2265 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2266 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2269 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2270 ASSERT_TRUE(new_const);
2271 ASSERT_EQ(new_const->get_friendly_name(), "test");
2272 auto values_out = new_const->get_vector<int64_t>();
2274 vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
2275 ASSERT_EQ(values_expected, values_out);
2278 TEST(constant_folding, constant_v1_select)
2281 vector<char> values_selection{0, 1, 1, 0};
2282 vector<int64_t> values_t{1, 2, 3, 4};
2283 vector<int64_t> values_f{11, 12, 13, 14, 15, 16, 17, 18};
2285 auto constant_selection =
2286 make_shared<op::Constant>(element::boolean, Shape{4}, values_selection);
2287 auto constant_t = make_shared<op::Constant>(element::i64, Shape{4}, values_t);
2288 auto constant_f = make_shared<op::Constant>(element::i64, Shape{2, 4}, values_f);
2289 auto select = make_shared<op::v1::Select>(constant_selection, constant_t, constant_f);
2290 select->set_friendly_name("test");
2291 auto f = make_shared<Function>(select, ParameterVector{});
2293 pass::Manager pass_manager;
2294 pass_manager.register_pass<pass::ConstantFolding>();
2295 pass_manager.run_passes(f);
2297 ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2298 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2301 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2302 ASSERT_TRUE(new_const);
2303 ASSERT_EQ(new_const->get_friendly_name(), "test");
2304 auto values_out = new_const->get_vector<int64_t>();
2306 vector<int64_t> values_expected{11, 2, 3, 14, 15, 2, 3, 18};
2307 ASSERT_EQ(values_expected, values_out);
2310 TEST(constant_folding, constant_v1_split)
2312 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2313 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2314 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2315 const auto num_splits = 3;
2317 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2318 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2320 pass::Manager pass_manager;
2321 pass_manager.register_pass<pass::ConstantFolding>();
2322 pass_manager.run_passes(f);
2324 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2325 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2328 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2330 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2332 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2337 auto res1_values = res1->get_vector<float>();
2338 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2339 auto res2_values = res2->get_vector<float>();
2340 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2341 auto res3_values = res3->get_vector<float>();
2342 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2345 TEST(constant_folding, constant_v1_split_specialized)
2347 vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2348 const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2349 const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2350 const auto num_splits = 3;
2352 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2353 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2355 pass::Manager pass_manager;
2356 pass_manager.register_pass<pass::ConstantFolding>();
2357 pass_manager.run_passes(f);
2359 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2360 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2363 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2365 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2367 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2372 auto res1_values = res1->get_vector<float>();
2373 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2374 auto res2_values = res2->get_vector<float>();
2375 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2376 auto res3_values = res3->get_vector<float>();
2377 ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2380 TEST(constant_folding, constant_v1_split_axis_1_4_splits)
2382 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2384 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2386 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2388 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2390 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2391 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2392 const auto num_splits = 4;
2394 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2395 split_v1->set_friendly_name("test");
2396 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2398 pass::Manager pass_manager;
2399 pass_manager.register_pass<pass::ConstantFolding>();
2400 pass_manager.run_passes(f);
2402 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2403 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2406 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2408 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2410 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2412 as_type_ptr<op::Constant>(f->get_results().at(3)->input_value(0).get_node_shared_ptr());
2414 ASSERT_EQ(res1->get_friendly_name(), "test.0");
2416 ASSERT_EQ(res2->get_friendly_name(), "test.1");
2418 ASSERT_EQ(res3->get_friendly_name(), "test.2");
2420 ASSERT_EQ(res4->get_friendly_name(), "test.3");
2422 auto res1_values = res1->get_vector<int64_t>();
2423 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2425 auto res2_values = res2->get_vector<int64_t>();
2426 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2428 auto res3_values = res3->get_vector<int64_t>();
2429 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59}),
2431 auto res4_values = res4->get_vector<int64_t>();
2432 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2436 TEST(constant_folding, constant_v1_split_axis_1_2_splits)
2438 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2440 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2442 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2444 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2446 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2447 const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2448 const auto num_splits = 2;
2450 auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2451 auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2453 pass::Manager pass_manager;
2454 pass_manager.register_pass<pass::ConstantFolding>();
2455 pass_manager.run_passes(f);
2457 ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2458 ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2461 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2463 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2467 auto res1_values = res1->get_vector<int64_t>();
2468 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
2469 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55}),
2471 auto res2_values = res2->get_vector<int64_t>();
2472 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2473 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2477 TEST(constant_folding, constant_v1_variadic_split_axis_1_2_splits)
2479 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2481 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2483 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2485 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2487 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2488 const auto const_axis = op::Constant::create(element::i16, Shape{}, {1});
2489 vector<int64_t> values_lengths{3, 1};
2490 auto constant_lengths =
2491 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2493 auto variadic_split_v1 =
2494 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2495 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2497 pass::Manager pass_manager;
2498 pass_manager.register_pass<pass::ConstantFolding>();
2499 pass_manager.run_passes(f);
2501 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2502 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2505 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2507 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2511 auto res1_values = res1->get_vector<int64_t>();
2512 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19,
2513 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 36, 37, 38, 39,
2514 40, 41, 42, 43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}),
2516 auto res2_values = res2->get_vector<int64_t>();
2517 ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2521 TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
2523 vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
2525 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2527 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2529 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2531 const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2532 const auto const_axis = op::Constant::create(element::i32, Shape{}, {1});
2533 vector<int64_t> values_lengths{1, 1, -1};
2534 auto constant_lengths =
2535 make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2537 auto variadic_split_v1 =
2538 make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2539 auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2541 pass::Manager pass_manager;
2542 pass_manager.register_pass<pass::ConstantFolding>();
2543 pass_manager.run_passes(f);
2545 ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2546 ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2549 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2551 as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2553 as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2558 auto res1_values = res1->get_vector<int64_t>();
2559 ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2561 auto res2_values = res2->get_vector<int64_t>();
2562 ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2564 auto res3_values = res3->get_vector<int64_t>();
2565 ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2566 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2570 TEST(constant_folding, constant_v1_one_hot)
2572 vector<int64_t> indices{0, 1, 2};
2573 float16 on_value = 1.123f;
2574 float16 off_value = 0.321f;
2576 const auto indices_const = op::Constant::create(element::i64, Shape{3}, indices);
2577 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2578 const auto on_const = op::Constant::create(element::f16, Shape{}, {on_value});
2579 const auto off_const = op::Constant::create(element::f16, Shape{}, {off_value});
2583 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2584 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2586 pass::Manager pass_manager;
2587 pass_manager.register_pass<pass::ConstantFolding>();
2588 pass_manager.run_passes(f);
2590 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2591 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2594 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2597 ASSERT_EQ((Shape{3, 3}), res->get_output_shape(0));
2598 ASSERT_EQ(vector<float16>({on_value,
2607 res->get_vector<float16>());
2610 TEST(constant_folding, constant_v1_one_hot_negative_axes)
2612 vector<int64_t> indices{0, 2, -1, 1};
2613 int16_t on_value = 4;
2614 int16_t off_value = 1;
2616 const auto indices_const = op::Constant::create(element::i64, Shape{4}, indices);
2617 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2618 const auto on_const = op::Constant::create(element::i16, Shape{}, {on_value});
2619 const auto off_const = op::Constant::create(element::i16, Shape{}, {off_value});
2623 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2624 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2626 pass::Manager pass_manager;
2627 pass_manager.register_pass<pass::ConstantFolding>();
2628 pass_manager.run_passes(f);
2630 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2631 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2634 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2637 ASSERT_EQ((Shape{4, 3}), res->get_output_shape(0));
2638 ASSERT_EQ(vector<int16_t>({on_value,
2650 res->get_vector<int16_t>());
2653 TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
2655 vector<int64_t> indices{0, 2, 1, -1};
2656 auto on_value = true;
2657 auto off_value = false;
2659 const auto indices_const = op::Constant::create(element::i64, Shape{2, 2}, indices);
2660 const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2661 const auto on_const = op::Constant::create(element::boolean, Shape{}, {on_value});
2662 const auto off_const = op::Constant::create(element::boolean, Shape{}, {off_value});
2666 make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2667 one_hot_v1->set_friendly_name("test");
2668 auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2670 pass::Manager pass_manager;
2671 pass_manager.register_pass<pass::ConstantFolding>();
2672 pass_manager.run_passes(f);
2674 ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2675 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2678 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2680 ASSERT_EQ(res->get_friendly_name(), "test");
2682 ASSERT_EQ((Shape{2, 2, 3}), res->get_output_shape(0));
2683 ASSERT_EQ(vector<bool>({on_value,
2695 res->get_vector<bool>());
2698 TEST(constant_folding, constant_tile_1d)
2701 Shape shape_repeats{1};
2704 vector<int> values_in{0, 1};
2705 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2706 vector<int> values_repeats{2};
2707 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2708 auto tile = make_shared<op::v0::Tile>(data, repeats);
2709 tile->set_friendly_name("test");
2710 auto f = make_shared<Function>(tile, ParameterVector{});
2712 pass::Manager pass_manager;
2713 pass_manager.register_pass<pass::ConstantFolding>();
2714 pass_manager.run_passes(f);
2716 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2717 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2720 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2721 ASSERT_TRUE(new_const);
2722 ASSERT_EQ(new_const->get_friendly_name(), "test");
2723 auto values_out = new_const->get_vector<int>();
2725 vector<int> values_expected{0, 1, 0, 1};
2726 ASSERT_EQ(values_expected, values_out);
2729 TEST(constant_folding, constant_tile_3d_small_data_rank)
2732 Shape shape_repeats{3};
2733 Shape shape_out{2, 2, 4};
2735 vector<int> values_in{0, 1};
2736 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2737 vector<int> values_repeats{2, 2, 2};
2738 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2739 auto tile = make_shared<op::v0::Tile>(data, repeats);
2740 tile->set_friendly_name("test");
2741 auto f = make_shared<Function>(tile, ParameterVector{});
2743 pass::Manager pass_manager;
2744 pass_manager.register_pass<pass::ConstantFolding>();
2745 pass_manager.run_passes(f);
2747 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2748 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2751 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2752 ASSERT_TRUE(new_const);
2753 ASSERT_EQ(new_const->get_friendly_name(), "test");
2754 auto values_out = new_const->get_vector<int>();
2756 vector<int> values_expected{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
2757 ASSERT_EQ(values_expected, values_out);
2760 TEST(constant_folding, constant_tile_3d_few_repeats)
2762 Shape shape_in{2, 1, 3};
2763 Shape shape_repeats{2};
2764 Shape shape_out{2, 2, 3};
2766 vector<int> values_in{1, 2, 3, 4, 5, 6};
2767 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2768 vector<int> values_repeats{2, 1};
2769 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2770 auto tile = make_shared<op::v0::Tile>(data, repeats);
2771 tile->set_friendly_name("test");
2772 auto f = make_shared<Function>(tile, ParameterVector{});
2774 pass::Manager pass_manager;
2775 pass_manager.register_pass<pass::ConstantFolding>();
2776 pass_manager.run_passes(f);
2778 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2779 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2782 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2783 ASSERT_TRUE(new_const);
2784 ASSERT_EQ(new_const->get_friendly_name(), "test");
2785 auto values_out = new_const->get_vector<int>();
2787 vector<int> values_expected{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6};
2788 ASSERT_EQ(values_expected, values_out);
2791 TEST(constant_folding, constant_tile_1d_0_repeats)
2794 Shape shape_repeats{1};
2797 vector<int> values_in{0, 1};
2798 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2799 vector<int> values_repeats{0};
2800 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2801 auto tile = make_shared<op::v0::Tile>(data, repeats);
2802 tile->set_friendly_name("test");
2803 auto f = make_shared<Function>(tile, ParameterVector{});
2805 pass::Manager pass_manager;
2806 pass_manager.register_pass<pass::ConstantFolding>();
2807 pass_manager.run_passes(f);
2809 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2810 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2813 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2814 ASSERT_TRUE(new_const);
2815 ASSERT_EQ(new_const->get_friendly_name(), "test");
2816 auto values_out = new_const->get_vector<int>();
2818 vector<int> values_expected{};
2819 ASSERT_EQ(values_expected, values_out);
2822 TEST(constant_folding, constant_tile_0_rank_data)
2825 Shape shape_repeats{1};
2828 vector<int> values_in{1};
2829 auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2830 vector<int> values_repeats{4};
2831 auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2832 auto tile = make_shared<op::v0::Tile>(data, repeats);
2833 tile->set_friendly_name("test");
2834 auto f = make_shared<Function>(tile, ParameterVector{});
2836 pass::Manager pass_manager;
2837 pass_manager.register_pass<pass::ConstantFolding>();
2838 pass_manager.run_passes(f);
2840 ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2841 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2844 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2845 ASSERT_TRUE(new_const);
2846 ASSERT_EQ(new_const->get_friendly_name(), "test");
2847 auto values_out = new_const->get_vector<int>();
2849 vector<int> values_expected{1, 1, 1, 1};
2850 ASSERT_EQ(values_expected, values_out);
2853 TEST(constant_folding, constant_non_zero_0D)
2855 auto data = op::Constant::create(element::i32, Shape{}, {1});
2856 auto non_zero = make_shared<op::v3::NonZero>(data);
2857 non_zero->set_friendly_name("test");
2858 auto f = make_shared<Function>(non_zero, ParameterVector{});
2860 pass::Manager pass_manager;
2861 pass_manager.register_pass<pass::ConstantFolding>();
2862 pass_manager.run_passes(f);
2864 // Fold into constant with shape of {1, 1} for scalar input with
2866 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2867 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2869 const auto new_const =
2870 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2871 ASSERT_TRUE(new_const);
2872 ASSERT_EQ(new_const->get_friendly_name(), "test");
2873 const auto values_out = new_const->get_vector<int64_t>();
2875 const vector<int64_t> values_expected{0};
2876 ASSERT_EQ(values_expected, values_out);
2877 ASSERT_EQ((Shape{1, 1}), new_const->get_shape());
2880 TEST(constant_folding, constant_non_zero_1D)
2882 vector<int> values_in{0, 1, 0, 1};
2883 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2884 auto non_zero = make_shared<op::v3::NonZero>(data);
2885 non_zero->set_friendly_name("test");
2886 auto f = make_shared<Function>(non_zero, ParameterVector{});
2888 pass::Manager pass_manager;
2889 pass_manager.register_pass<pass::ConstantFolding>();
2890 pass_manager.run_passes(f);
2892 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2893 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2895 const auto new_const =
2896 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2897 ASSERT_TRUE(new_const);
2898 ASSERT_EQ(new_const->get_friendly_name(), "test");
2899 const auto values_out = new_const->get_vector<int64_t>();
2901 const vector<int64_t> values_expected{1, 3};
2902 ASSERT_EQ(values_expected, values_out);
2903 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2906 TEST(constant_folding, constant_non_zero_int32_output_type)
2908 vector<int> values_in{0, 1, 0, 1};
2909 auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2910 auto non_zero = make_shared<op::v3::NonZero>(data, element::i32);
2911 non_zero->set_friendly_name("test");
2912 auto f = make_shared<Function>(non_zero, ParameterVector{});
2914 pass::Manager pass_manager;
2915 pass_manager.register_pass<pass::ConstantFolding>();
2916 pass_manager.run_passes(f);
2918 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2919 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2921 const auto new_const =
2922 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2923 ASSERT_TRUE(new_const);
2924 ASSERT_EQ(new_const->get_friendly_name(), "test");
2925 ASSERT_EQ(element::i32, new_const->get_element_type());
2926 const auto values_out = new_const->get_vector<int32_t>();
2928 const vector<int32_t> values_expected{1, 3};
2929 ASSERT_EQ(values_expected, values_out);
2930 ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2933 TEST(constant_folding, constant_non_zero_1D_all_indices)
2935 const vector<float> values_in{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
2936 const auto data = make_shared<op::Constant>(element::f32, Shape{values_in.size()}, values_in);
2937 const auto non_zero = make_shared<op::v3::NonZero>(data);
2938 non_zero->set_friendly_name("test");
2939 auto f = make_shared<Function>(non_zero, ParameterVector{});
2941 pass::Manager pass_manager;
2942 pass_manager.register_pass<pass::ConstantFolding>();
2943 pass_manager.run_passes(f);
2945 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2946 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2948 const auto new_const =
2949 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2950 ASSERT_TRUE(new_const);
2951 ASSERT_EQ(new_const->get_friendly_name(), "test");
2952 const auto values_out = new_const->get_vector<int64_t>();
2954 const vector<int64_t> values_expected{0, 1, 2, 3, 4, 5, 6, 7};
2955 ASSERT_EQ(values_expected, values_out);
2956 ASSERT_EQ((Shape{1, values_in.size()}), new_const->get_shape());
2959 TEST(constant_folding, constant_non_zero_2D)
2961 vector<int> values_in{1, 0, 0, 0, 1, 0, 1, 1, 0};
2962 auto data = make_shared<op::Constant>(element::i32, Shape{3, 3}, values_in);
2963 auto non_zero = make_shared<op::v3::NonZero>(data);
2964 non_zero->set_friendly_name("test");
2965 auto f = make_shared<Function>(non_zero, ParameterVector{});
2967 pass::Manager pass_manager;
2968 pass_manager.register_pass<pass::ConstantFolding>();
2969 pass_manager.run_passes(f);
2971 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2972 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2974 const auto new_const =
2975 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2976 ASSERT_TRUE(new_const);
2977 ASSERT_EQ(new_const->get_friendly_name(), "test");
2978 const auto values_out = new_const->get_vector<int64_t>();
2980 const vector<int64_t> values_expected{0, 1, 2, 2, 0, 1, 0, 1};
2981 ASSERT_EQ(values_expected, values_out);
2982 ASSERT_EQ((Shape{2, 4}), new_const->get_shape());
2985 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_indices)
2987 const vector<int8_t> values_in{1, 1, 1, 1, 1, 1, 1, 1, 1};
2988 const auto data = make_shared<op::Constant>(element::i8, Shape{3, 3}, values_in);
2989 const auto non_zero = make_shared<op::v3::NonZero>(data);
2990 non_zero->set_friendly_name("test");
2991 auto f = make_shared<Function>(non_zero, ParameterVector{});
2993 pass::Manager pass_manager;
2994 pass_manager.register_pass<pass::ConstantFolding>();
2995 pass_manager.run_passes(f);
2997 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2998 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3000 const auto new_const =
3001 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3002 ASSERT_TRUE(new_const);
3003 ASSERT_EQ(new_const->get_friendly_name(), "test");
3004 const auto values_out = new_const->get_vector<int64_t>();
3006 const vector<int64_t> values_expected{0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2};
3007 ASSERT_EQ(values_expected, values_out);
3008 ASSERT_EQ((Shape{2, values_in.size()}), new_const->get_shape());
3011 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_zeros)
3013 const vector<uint8_t> values_in{0, 0, 0, 0, 0, 0};
3014 const auto data = make_shared<op::Constant>(element::u8, Shape{2, 3}, values_in);
3015 const auto non_zero = make_shared<op::v3::NonZero>(data);
3016 non_zero->set_friendly_name("test");
3017 auto f = make_shared<Function>(non_zero, ParameterVector{});
3019 pass::Manager pass_manager;
3020 pass_manager.register_pass<pass::ConstantFolding>();
3021 pass_manager.run_passes(f);
3023 // fold into Constant with shape of {0}
3024 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3025 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3027 const auto new_const =
3028 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3029 ASSERT_TRUE(new_const);
3030 ASSERT_EQ(new_const->get_friendly_name(), "test");
3031 ASSERT_EQ(shape_size(new_const->get_shape()), 0);
3034 TEST(constant_folding, constant_non_zero_3D)
3036 vector<int> values_in{1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0};
3037 auto data = make_shared<op::Constant>(element::i32, Shape{2, 3, 3}, values_in);
3038 auto non_zero = make_shared<op::v3::NonZero>(data);
3039 non_zero->set_friendly_name("test");
3040 auto f = make_shared<Function>(non_zero, ParameterVector{});
3042 pass::Manager pass_manager;
3043 pass_manager.register_pass<pass::ConstantFolding>();
3044 pass_manager.run_passes(f);
3046 ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3047 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3049 const auto new_const =
3050 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3051 ASSERT_TRUE(new_const);
3052 ASSERT_EQ(new_const->get_friendly_name(), "test");
3053 const auto values_out = new_const->get_vector<int64_t>();
3055 const vector<int64_t> values_expected{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 2, 2, 2,
3056 0, 0, 0, 1, 1, 2, 0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 2, 1};
3057 ASSERT_EQ(values_expected, values_out);
3058 ASSERT_EQ((Shape{3, 12}), new_const->get_shape());
3061 TEST(constant_folding, constant_scatter_elements_update_basic)
3063 const Shape data_shape{3, 3};
3064 const Shape indices_shape{2, 3};
3066 const auto data_const = op::Constant::create(
3067 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3068 const auto indices_const =
3069 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3070 const auto updates_const =
3071 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3072 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3074 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3075 data_const, indices_const, updates_const, axis_const);
3076 scatter_elem_updt->set_friendly_name("test");
3077 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3079 pass::Manager pass_manager;
3080 pass_manager.register_pass<pass::ConstantFolding>();
3081 pass_manager.run_passes(f);
3083 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3084 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3087 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3088 ASSERT_TRUE(result_node);
3089 ASSERT_EQ(result_node->get_friendly_name(), "test");
3090 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3091 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3092 range_test_check(result_node->cast_vector<float>(), expected);
3095 TEST(constant_folding, constant_scatter_elements_update_negative_axis)
3097 const Shape data_shape{3, 3};
3098 const Shape indices_shape{2, 3};
3100 const auto data_const = op::Constant::create(
3101 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3102 const auto indices_const =
3103 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3104 const auto updates_const =
3105 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3106 const auto axis_const = op::Constant::create(element::i64, Shape{}, {-1});
3108 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3109 data_const, indices_const, updates_const, axis_const);
3110 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3112 pass::Manager pass_manager;
3113 pass_manager.register_pass<pass::ConstantFolding>();
3114 pass_manager.run_passes(f);
3116 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3117 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3120 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3121 ASSERT_TRUE(result_node);
3122 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3123 std::vector<float> expected{1.1f, 1.0f, 1.2f, 2.0f, 2.2f, 2.1f, 0.0f, 0.0f, 0.0f};
3124 range_test_check(result_node->cast_vector<float>(), expected);
3127 TEST(constant_folding, constant_scatter_elements_update_1d_axis)
3129 const Shape data_shape{3, 3};
3130 const Shape indices_shape{2, 3};
3132 const auto data_const = op::Constant::create(
3133 element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3134 const auto indices_const =
3135 op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3136 const auto updates_const =
3137 op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3138 const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
3140 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3141 data_const, indices_const, updates_const, axis_const);
3142 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3144 pass::Manager pass_manager;
3145 pass_manager.register_pass<pass::ConstantFolding>();
3146 pass_manager.run_passes(f);
3148 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3149 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3152 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3153 ASSERT_TRUE(result_node);
3154 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3155 std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3156 range_test_check(result_node->cast_vector<float>(), expected);
3159 TEST(constant_folding, constant_scatter_elements_update_3d_i16)
3161 const Shape data_shape{3, 3, 3};
3162 const Shape indices_shape{2, 2, 3};
3164 const auto data_const = op::Constant::create(
3165 element::i16, data_shape, std::vector<int16_t>(shape_size(data_shape), 0));
3166 const auto indices_const =
3167 op::Constant::create(element::i16, indices_shape, {1, 0, 2, 0, 2, 1, 2, 2, 2, 0, 1, 0});
3168 const auto updates_const =
3169 op::Constant::create(element::i16, indices_shape, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3170 const auto axis_const = op::Constant::create(element::i64, Shape{}, {1});
3172 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3173 data_const, indices_const, updates_const, axis_const);
3174 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3176 pass::Manager pass_manager;
3177 pass_manager.register_pass<pass::ConstantFolding>();
3178 pass_manager.run_passes(f);
3180 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3181 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3184 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3185 ASSERT_TRUE(result_node);
3186 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3187 std::vector<int16_t> expected{4, 2, 0, 1, 0, 6, 0, 5, 3, 10, 0, 12, 0, 11,
3188 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0};
3189 range_test_check(result_node->cast_vector<int16_t>(), expected);
3192 TEST(constant_folding, constant_scatter_elements_update_one_elem)
3194 const Shape data_shape{3, 3, 3};
3195 const Shape indices_shape{1, 1, 1};
3196 const auto input_data = std::vector<int32_t>(shape_size(data_shape), 0);
3198 const auto data_const = op::Constant::create(element::i32, data_shape, input_data);
3199 const auto indices_const = op::Constant::create(element::i32, indices_shape, {1});
3200 const auto updates_const = op::Constant::create(element::i32, indices_shape, {2});
3201 const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3203 auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3204 data_const, indices_const, updates_const, axis_const);
3205 auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3207 pass::Manager pass_manager;
3208 pass_manager.register_pass<pass::ConstantFolding>();
3209 pass_manager.run_passes(f);
3211 ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3212 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3215 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3216 ASSERT_TRUE(result_node);
3217 ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3218 std::vector<int32_t> expected{input_data};
3219 // we have updated coordinate (1, 0, 0)
3221 range_test_check(result_node->cast_vector<int32_t>(), expected);
3224 void test_constant_folding_reshape_v1(Shape& shape_in,
3225 vector<float>& values_in,
3227 vector<int32_t> values_shape,
3228 bool zero_flag = false)
3230 auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
3231 auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
3232 auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, zero_flag);
3233 dyn_reshape->set_friendly_name("test");
3234 auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
3236 pass::Manager pass_manager;
3237 pass_manager.register_pass<pass::ConstantFolding>();
3238 pass_manager.run_passes(f);
3240 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
3241 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3244 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3245 ASSERT_TRUE(new_const);
3246 ASSERT_EQ(new_const->get_friendly_name(), "test");
3247 auto values_out = new_const->get_vector<float>();
3249 ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
3251 TEST(constant_folding, constant_dyn_reshape_v1_2d)
3253 Shape shape_in{2, 5};
3254 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3256 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 1, 10});
3257 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 2, 5});
3258 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {1, 2, 5});
3259 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {5, 2, 1});
3262 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_negative_indices)
3264 Shape shape_in{2, 2, 2, 2};
3265 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3267 test_constant_folding_reshape_v1(shape_in, values_in, {3}, {4, -1, 2});
3268 test_constant_folding_reshape_v1(shape_in, values_in, {2}, {4, -1});
3269 test_constant_folding_reshape_v1(shape_in, values_in, {1}, {-1});
3272 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_zero_dims)
3274 Shape shape_in{2, 2, 2, 2};
3275 vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3277 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {2, -1, 2, 0}, true);
3278 test_constant_folding_reshape_v1(shape_in, values_in, {4}, {4, 1, 0, 2}, true);
3281 TEST(constant_folding, disable_constant_folding)
3283 auto input = make_shared<op::Parameter>(element::f32, Shape{1, 3});
3284 auto constant_shape = op::Constant::create(element::i64, Shape{1}, {3});
3285 auto dyn_reshape = make_shared<op::v1::Reshape>(input, constant_shape, true);
3286 auto& rt_info = dyn_reshape->get_rt_info();
3287 rt_info["DISABLED_CONSTANT_FOLDING"];
3288 auto f = make_shared<Function>(dyn_reshape, ParameterVector{input});
3290 pass::Manager pass_manager;
3291 pass_manager.register_pass<pass::ConstantFolding>();
3292 pass_manager.run_passes(f);
3294 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
3295 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);