Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / constant_folding.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/manager.hpp"
22 #include "util/all_close_f.hpp"
23 #include "util/test_tools.hpp"
24
25 NGRAPH_SUPPRESS_DEPRECATED_START
26
27 using namespace ngraph;
28 using namespace std;
29
30 template <typename T>
31 static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
32 {
33     auto new_const =
34         as_type_ptr<op::Constant>(f->get_results().at(pos)->input_value(0).get_node_shared_ptr());
35     return new_const->cast_vector<T>();
36 }
37
38 void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
39 {
40     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
41 }
42
43 void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
44 {
45     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
46 }
47
48 template <typename T>
49 typename std::enable_if<std::is_integral<T>::value>::type
50     range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
51 {
52     ASSERT_EQ(values_out, values_expected);
53 }
54
55 TEST(constant_folding, acosh)
56 {
57     Shape shape_in{2, 4, 1};
58
59     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
60     vector<float> expected;
61     for (float f : values_in)
62     {
63         expected.push_back(std::acosh(f));
64     }
65     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
66     auto acosh = make_shared<op::Acosh>(constant);
67     auto f = make_shared<Function>(acosh, ParameterVector{});
68
69     pass::Manager pass_manager;
70     pass_manager.register_pass<pass::ConstantFolding>();
71     pass_manager.run_passes(f);
72
73     EXPECT_EQ(count_ops_of_type<op::Acosh>(f), 0);
74     EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
75     ASSERT_EQ(f->get_results().size(), 1);
76
77     auto new_const =
78         as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
79     EXPECT_TRUE(new_const);
80
81     auto values_out = new_const->get_vector<float>();
82     EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
83 }
84
85 TEST(constant_folding, asinh)
86 {
87     Shape shape_in{2, 4, 1};
88
89     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
90     vector<float> expected;
91     for (float f : values_in)
92     {
93         expected.push_back(std::asinh(f));
94     }
95     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
96     auto asinh = make_shared<op::Asinh>(constant);
97     auto f = make_shared<Function>(asinh, ParameterVector{});
98
99     pass::Manager pass_manager;
100     pass_manager.register_pass<pass::ConstantFolding>();
101     pass_manager.run_passes(f);
102
103     EXPECT_EQ(count_ops_of_type<op::Asinh>(f), 0);
104     EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
105     ASSERT_EQ(f->get_results().size(), 1);
106
107     auto new_const =
108         as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
109     EXPECT_TRUE(new_const);
110
111     auto values_out = new_const->get_vector<float>();
112     EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
113 }
114
115 TEST(constant_folding, atanh)
116 {
117     Shape shape_in{2, 4, 1};
118
119     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
120     vector<float> expected;
121     for (float f : values_in)
122     {
123         expected.push_back(std::atanh(f));
124     }
125     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
126     auto atanh = make_shared<op::Atanh>(constant);
127     auto f = make_shared<Function>(atanh, ParameterVector{});
128
129     pass::Manager pass_manager;
130     pass_manager.register_pass<pass::ConstantFolding>();
131     pass_manager.run_passes(f);
132
133     EXPECT_EQ(count_ops_of_type<op::Atanh>(f), 0);
134     EXPECT_EQ(count_ops_of_type<op::Constant>(f), 1);
135     ASSERT_EQ(f->get_results().size(), 1);
136
137     auto new_const =
138         as_type_ptr<op::Constant>(f->get_results()[0]->input_value(0).get_node_shared_ptr());
139     EXPECT_TRUE(new_const);
140
141     auto values_out = new_const->get_vector<float>();
142     EXPECT_TRUE(test::all_close_f(expected, values_out, MIN_FLOAT_TOLERANCE_BITS));
143 }
144
145 TEST(constant_folding, constant_squeeze)
146 {
147     Shape shape_in{2, 4, 1};
148     Shape shape_out{2, 4};
149     Shape axes_shape{1};
150
151     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
152     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
153     vector<int64_t> values_axes{2};
154     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
155     auto squeeze = make_shared<op::Squeeze>(constant, constant_axes);
156     auto f = make_shared<Function>(squeeze, ParameterVector{});
157
158     pass::Manager pass_manager;
159     pass_manager.register_pass<pass::ConstantFolding>();
160     pass_manager.run_passes(f);
161
162     ASSERT_EQ(count_ops_of_type<op::Squeeze>(f), 0);
163     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
164
165     auto new_const =
166         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
167     ASSERT_TRUE(new_const);
168     ASSERT_EQ(new_const->get_shape(), shape_out);
169
170     auto values_out = new_const->get_vector<float>();
171     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
172 }
173
174 TEST(constant_folding, constant_unsqueeze)
175 {
176     Shape shape_in{2, 4};
177     Shape shape_out{2, 4, 1, 1};
178     Shape axes_shape{2};
179
180     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
181     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
182     vector<int64_t> values_axes{2, 3};
183     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
184     auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
185     auto f = make_shared<Function>(unsqueeze, ParameterVector{});
186
187     pass::Manager pass_manager;
188     pass_manager.register_pass<pass::ConstantFolding>();
189     pass_manager.run_passes(f);
190
191     ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
192     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
193
194     auto new_const =
195         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
196     ASSERT_TRUE(new_const);
197     ASSERT_EQ(new_const->get_shape(), shape_out);
198
199     auto values_out = new_const->get_vector<float>();
200     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
201 }
202
203 TEST(constant_folding, constant_reshape)
204 {
205     Shape shape_in{2, 4};
206     Shape shape_out{2, 4, 1};
207
208     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
209     auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
210     auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
211     auto f = make_shared<Function>(reshape, ParameterVector{});
212
213     pass::Manager pass_manager;
214     pass_manager.register_pass<pass::ConstantFolding>();
215     pass_manager.run_passes(f);
216
217     ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
218     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
219
220     auto new_const =
221         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
222     ASSERT_TRUE(new_const);
223     auto values_out = new_const->get_vector<float>();
224
225     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
226 }
227
228 TEST(constant_folding, DISABLED_constant_reshape_permute)
229 {
230     Shape shape_in{2, 4};
231     Shape shape_out{4, 2};
232
233     vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
234     auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
235     auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
236     auto f = make_shared<Function>(reshape, ParameterVector{});
237
238     pass::Manager pass_manager;
239     pass_manager.register_pass<pass::ConstantFolding>();
240     pass_manager.run_passes(f);
241
242     ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
243     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
244
245     auto new_const =
246         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
247     ASSERT_TRUE(new_const);
248     auto values_out = new_const->get_vector<double>();
249
250     vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
251     ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
252 }
253
254 TEST(constant_folding, constant_broadcast)
255 {
256     Shape shape_in{2};
257     Shape shape_out{2, 4};
258
259     vector<int> values_in{0, 1};
260     auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
261     auto broadcast = make_shared<op::Broadcast>(constant, shape_out, AxisSet{1});
262     auto f = make_shared<Function>(broadcast, ParameterVector{});
263
264     pass::Manager pass_manager;
265     pass_manager.register_pass<pass::ConstantFolding>();
266     pass_manager.run_passes(f);
267
268     ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
269     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
270
271     auto new_const =
272         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
273     ASSERT_TRUE(new_const);
274     auto values_out = new_const->get_vector<int>();
275
276     vector<int> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
277     ASSERT_EQ(values_expected, values_out);
278 }
279
280 TEST(constant_folding, constant_broadcast_v1)
281 {
282     vector<int32_t> values_in{0, 1};
283     auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
284     vector<int64_t> shape_in{2, 4};
285     auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
286     vector<int64_t> axes_in{0};
287     auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
288     auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape, constant_axes);
289     auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
290
291     pass::Manager pass_manager;
292     pass_manager.register_pass<pass::ConstantFolding>();
293     pass_manager.run_passes(f);
294
295     ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
296     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
297
298     auto new_const =
299         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
300     ASSERT_TRUE(new_const);
301     auto values_out = new_const->get_vector<int32_t>();
302
303     vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
304     ASSERT_EQ(values_expected, values_out);
305 }
306
307 TEST(constant_folding, constant_broadcast_v1_with_target_shape)
308 {
309     vector<int32_t> values_in{1};
310     auto constant_in = make_shared<op::Constant>(element::i32, Shape{1, 1, 1, 1}, values_in);
311     vector<int64_t> shape_in{1, 3, 1, 1};
312     auto target_shape = make_shared<op::Constant>(element::i64, Shape{4}, shape_in);
313     auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, target_shape);
314     auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
315
316     pass::Manager pass_manager;
317     pass_manager.register_pass<pass::ConstantFolding>();
318     pass_manager.run_passes(f);
319
320     ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
321     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
322
323     auto new_const =
324         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
325     ASSERT_TRUE(new_const);
326     auto values_out = new_const->get_vector<int32_t>();
327
328     vector<int32_t> values_expected{1, 1, 1};
329     ASSERT_EQ(values_expected, values_out);
330 }
331
332 TEST(constant_folding, constant_broadcast_v1_numpy)
333 {
334     vector<int32_t> values_in{0, 1};
335     auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
336     vector<int64_t> shape_in{4, 2};
337     auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
338     auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape);
339     auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
340
341     pass::Manager pass_manager;
342     pass_manager.register_pass<pass::ConstantFolding>();
343     pass_manager.run_passes(f);
344
345     ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
346     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
347
348     auto new_const =
349         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
350     ASSERT_TRUE(new_const);
351     auto values_out = new_const->get_vector<int32_t>();
352
353     vector<int32_t> values_expected{0, 1, 0, 1, 0, 1, 0, 1};
354     ASSERT_EQ(values_expected, values_out);
355 }
356
357 TEST(constant_folding, constant_unary_binary)
358 {
359     vector<int> values_a{1, 2, 3, 4};
360     vector<int> values_b{1, 2, 3, 4};
361     vector<int> values_c{-1, -1, -1, -1};
362     vector<int> values_d{1, 4, 9, 16};
363     vector<int> values_e{5, 6};
364     vector<int> values_f{0, 10};
365     vector<int> values_g{1, 4};
366     vector<char> values_h{0, 0, 1, 1};
367     vector<char> values_i{0, 1};
368     auto a = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_a);
369     auto b = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_b);
370     auto c = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_c);
371     auto d = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_d);
372     auto e = make_shared<op::Constant>(element::i32, Shape{2}, values_e);
373     auto f = make_shared<op::Constant>(element::i32, Shape{2}, values_f);
374     auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
375     auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
376     auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
377
378     auto add = a + b;
379     auto sub = a - b;
380     auto mul = a * b;
381     auto divn = a / b;
382     auto pow = make_shared<op::Power>(a, b);
383     auto min = make_shared<op::Minimum>(c, a);
384     auto max = make_shared<op::Maximum>(a, c);
385     auto absn = make_shared<op::Abs>(c);
386     auto neg = make_shared<op::Negative>(c);
387     auto sqrt = make_shared<op::Sqrt>(d);
388     auto add_autob_numpy = make_shared<op::Add>(a, e, op::AutoBroadcastType::NUMPY);
389     auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
390     auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
391     auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
392     auto pow_autob_numpy = make_shared<op::Power>(a, g, op::AutoBroadcastType::NUMPY);
393     auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
394     auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
395     auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
396     auto not_equal_autob_numpy = make_shared<op::NotEqual>(a, g, op::AutoBroadcastType::NUMPY);
397     auto greater_autob_numpy = make_shared<op::Greater>(a, g, op::AutoBroadcastType::NUMPY);
398     auto greater_eq_autob_numpy = make_shared<op::GreaterEq>(a, g, op::AutoBroadcastType::NUMPY);
399     auto less_autob_numpy = make_shared<op::Less>(a, g, op::AutoBroadcastType::NUMPY);
400     auto less_eq_autob_numpy = make_shared<op::LessEq>(a, g, op::AutoBroadcastType::NUMPY);
401     auto logical_or_autob_numpy = make_shared<op::Or>(h, i, op::AutoBroadcastType::NUMPY);
402     auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
403
404     auto neg_sqrt = make_shared<op::Sqrt>(c);
405
406     auto func = make_shared<Function>(NodeVector{add,
407                                                  sub,
408                                                  mul,
409                                                  divn,
410                                                  pow,
411                                                  min,
412                                                  max,
413                                                  absn,
414                                                  neg,
415                                                  sqrt,
416                                                  add_autob_numpy,
417                                                  sub_autob_numpy,
418                                                  mul_autob_numpy,
419                                                  div_autob_numpy,
420                                                  pow_autob_numpy,
421                                                  min_autob_numpy,
422                                                  max_autob_numpy,
423                                                  equal_autob_numpy,
424                                                  not_equal_autob_numpy,
425                                                  greater_autob_numpy,
426                                                  greater_eq_autob_numpy,
427                                                  less_autob_numpy,
428                                                  less_eq_autob_numpy,
429                                                  logical_or_autob_numpy,
430                                                  logical_xor_autob_numpy},
431                                       ParameterVector{});
432     auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
433
434     pass::Manager pass_manager;
435     pass_manager.register_pass<pass::ConstantFolding>();
436     pass_manager.run_passes(func);
437
438     // expected values
439     vector<int> add_expected{2, 4, 6, 8};
440     vector<int> sub_expected{0, 0, 0, 0};
441     vector<int> mul_expected{1, 4, 9, 16};
442     vector<int> div_expected{1, 1, 1, 1};
443     vector<int> pow_expected{1, 4, 27, 256};
444     vector<int> min_expected{-1, -1, -1, -1};
445     vector<int> max_expected{1, 2, 3, 4};
446     vector<int> abs_neg_expected{1, 1, 1, 1};
447     vector<int> sqrt_expected{1, 2, 3, 4};
448     vector<int> add_autob_numpy_expected{6, 8, 8, 10};
449     vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
450     vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
451     vector<int> div_autob_numpy_expected{1, 0, 3, 1};
452     vector<int> pow_autob_numpy_expected{1, 16, 3, 256};
453     vector<int> min_autob_numpy_expected{0, 2, 0, 4};
454     vector<int> max_autob_numpy_expected{1, 10, 3, 10};
455     vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
456     vector<char> not_equal_autob_numpy_expected{0, 1, 1, 0};
457     vector<char> greater_autob_numpy_expected{0, 0, 1, 0};
458     vector<char> greater_eq_autob_numpy_expected{1, 0, 1, 1};
459     vector<char> less_autob_numpy_expected{0, 1, 0, 0};
460     vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
461     vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
462     vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
463
464     ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
465     ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
466     ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
467     ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
468     ASSERT_EQ(get_result_constant<int>(func, 4), pow_expected);
469     ASSERT_EQ(get_result_constant<int>(func, 5), min_expected);
470     ASSERT_EQ(get_result_constant<int>(func, 6), max_expected);
471     ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
472     ASSERT_EQ(get_result_constant<int>(func, 8), abs_neg_expected);
473     ASSERT_EQ(get_result_constant<int>(func, 9), sqrt_expected);
474     ASSERT_EQ(get_result_constant<int>(func, 10), add_autob_numpy_expected);
475     ASSERT_EQ(get_result_constant<int>(func, 11), sub_autob_numpy_expected);
476     ASSERT_EQ(get_result_constant<int>(func, 12), mul_autob_numpy_expected);
477     ASSERT_EQ(get_result_constant<int>(func, 13), div_autob_numpy_expected);
478     ASSERT_EQ(get_result_constant<int>(func, 14), pow_autob_numpy_expected);
479     ASSERT_EQ(get_result_constant<int>(func, 15), min_autob_numpy_expected);
480     ASSERT_EQ(get_result_constant<int>(func, 16), max_autob_numpy_expected);
481     ASSERT_EQ(get_result_constant<char>(func, 17), equal_autob_numpy_expected);
482     ASSERT_EQ(get_result_constant<char>(func, 18), not_equal_autob_numpy_expected);
483     ASSERT_EQ(get_result_constant<char>(func, 19), greater_autob_numpy_expected);
484     ASSERT_EQ(get_result_constant<char>(func, 20), greater_eq_autob_numpy_expected);
485     ASSERT_EQ(get_result_constant<char>(func, 21), less_autob_numpy_expected);
486     ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
487     ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
488     ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
489     ASSERT_NO_THROW(pass_manager.run_passes(func_error));
490 }
491
492 TEST(constant_folding, const_dequantize)
493 {
494     Shape input_shape{12};
495     Shape scale_offset_shape;
496     AxisSet quantization_axes;
497
498     auto quant_type = element::u8;
499     auto output_type = element::f32;
500     typedef float output_c_type;
501
502     vector<uint8_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
503     auto constant = op::Constant::create(quant_type, input_shape, values_in);
504     auto scale = op::Constant::create(output_type, scale_offset_shape, {2});
505     auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
506     auto dequantize =
507         make_shared<op::Dequantize>(constant, scale, offset, output_type, quantization_axes);
508     auto f = make_shared<Function>(dequantize, ParameterVector{});
509
510     pass::Manager pass_manager;
511     pass_manager.register_pass<pass::ConstantFolding>();
512     pass_manager.run_passes(f);
513
514     ASSERT_EQ(count_ops_of_type<op::Dequantize>(f), 0);
515     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
516
517     auto new_const =
518         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
519     ASSERT_TRUE(new_const);
520     auto values_out = new_const->get_vector<output_c_type>();
521
522     vector<output_c_type> values_dequantize{0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12};
523     ASSERT_EQ(values_dequantize, values_out);
524 }
525
526 TEST(constant_folding, const_quantize)
527 {
528     Shape input_shape{12};
529     Shape scale_offset_shape;
530     AxisSet quantization_axes;
531
532     auto quant_type = element::u8;
533     auto output_type = element::u8;
534     typedef uint8_t output_c_type;
535
536     vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
537     auto constant = op::Constant::create(element::f32, input_shape, values_in);
538     auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
539     auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
540     auto mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
541     auto quantize =
542         make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
543     auto f = make_shared<Function>(quantize, ParameterVector{});
544
545     pass::Manager pass_manager;
546     pass_manager.register_pass<pass::ConstantFolding>();
547     pass_manager.run_passes(f);
548
549     ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
550     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
551
552     auto new_const =
553         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
554     ASSERT_TRUE(new_const);
555     auto values_out = new_const->get_vector<output_c_type>();
556
557     vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
558     ASSERT_EQ(values_quantize, values_out);
559 }
560
561 TEST(constant_folding, const_convert)
562 {
563     Shape input_shape{3, 4};
564
565     vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
566     auto constant = op::Constant::create(element::f32, input_shape, values_in);
567     auto convert = make_shared<op::Convert>(constant, element::u64);
568     auto f = make_shared<Function>(convert, ParameterVector{});
569
570     pass::Manager pass_manager;
571     pass_manager.register_pass<pass::ConstantFolding>();
572     pass_manager.run_passes(f);
573
574     ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
575     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
576
577     auto new_const =
578         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
579     ASSERT_TRUE(new_const);
580     ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
581     auto values_out = new_const->get_vector<uint64_t>();
582
583     vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
584     ASSERT_EQ(values_expected, values_out);
585 }
586
587 TEST(constant_folding, shape_of_v0)
588 {
589     Shape input_shape{3, 4, 0, 22, 608, 909, 3};
590
591     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
592     auto shape_of = make_shared<op::v0::ShapeOf>(param);
593     auto f = make_shared<Function>(shape_of, ParameterVector{param});
594
595     pass::Manager pass_manager;
596     pass_manager.register_pass<pass::ConstantFolding>();
597     pass_manager.run_passes(f);
598
599     ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 0);
600     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
601
602     auto new_const =
603         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
604     ASSERT_TRUE(new_const);
605     ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
606     auto values_out = new_const->get_vector<int64_t>();
607
608     ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
609 }
610
611 TEST(constant_folding, shape_of_v3)
612 {
613     Shape input_shape{3, 4, 0, 22, 608, 909, 3};
614
615     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
616     auto shape_of = make_shared<op::v3::ShapeOf>(param);
617     auto f = make_shared<Function>(shape_of, ParameterVector{param});
618
619     pass::Manager pass_manager;
620     pass_manager.register_pass<pass::ConstantFolding>();
621     pass_manager.run_passes(f);
622
623     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
624     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
625
626     auto new_const =
627         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
628     ASSERT_TRUE(new_const);
629     ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
630     auto values_out = new_const->get_vector<int64_t>();
631
632     ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
633 }
634
635 TEST(constant_folding, shape_of_i32_v3)
636 {
637     Shape input_shape{3, 4, 0, 22, 608, 909, 3};
638
639     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
640     auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
641     auto f = make_shared<Function>(shape_of, ParameterVector{param});
642
643     pass::Manager pass_manager;
644     pass_manager.register_pass<pass::ConstantFolding>();
645     pass_manager.run_passes(f);
646
647     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 0);
648     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
649
650     auto new_const =
651         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
652     ASSERT_TRUE(new_const);
653     ASSERT_EQ(new_const->get_output_element_type(0), element::i32);
654     auto values_out = new_const->get_vector<int32_t>();
655
656     ASSERT_EQ((vector<int32_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
657 }
658
659 TEST(constant_folding, shape_of_dynamic_v0)
660 {
661     PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
662
663     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
664     auto shape_of = make_shared<op::v0::ShapeOf>(param);
665     auto f = make_shared<Function>(shape_of, ParameterVector{param});
666
667     pass::Manager pass_manager;
668     pass_manager.register_pass<pass::ConstantFolding>();
669     pass_manager.run_passes(f);
670
671     ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
672     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
673     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
674     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
675
676     auto result_as_concat =
677         as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
678     ASSERT_TRUE(result_as_concat);
679     ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
680 }
681
682 TEST(constant_folding, shape_of_dynamic_v3)
683 {
684     PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
685
686     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
687     auto shape_of = make_shared<op::v3::ShapeOf>(param);
688     auto f = make_shared<Function>(shape_of, ParameterVector{param});
689
690     pass::Manager pass_manager;
691     pass_manager.register_pass<pass::ConstantFolding>();
692     pass_manager.run_passes(f);
693
694     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
695     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
696     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
697     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
698
699     auto result_as_concat =
700         as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
701     ASSERT_TRUE(result_as_concat);
702     ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
703     ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
704 }
705
706 TEST(constant_folding, shape_of_dynamic_i32_v3)
707 {
708     PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
709
710     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
711     auto shape_of = make_shared<op::v3::ShapeOf>(param, element::i32);
712     auto f = make_shared<Function>(shape_of, ParameterVector{param});
713
714     pass::Manager pass_manager;
715     pass_manager.register_pass<pass::ConstantFolding>();
716     pass_manager.run_passes(f);
717
718     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
719     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
720     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
721     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
722
723     auto result_as_concat =
724         as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
725     ASSERT_TRUE(result_as_concat);
726     ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
727     ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
728 }
729
730 // We need to be sure that constant folding won't be calculated endlessly.
731 TEST(constant_folding, shape_of_dynamic_double_folding_v0)
732 {
733     PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
734
735     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
736     auto shape_of = make_shared<op::v0::ShapeOf>(param);
737     auto f = make_shared<Function>(shape_of, ParameterVector{param});
738
739     pass::Manager pass_manager;
740     pass_manager.register_pass<pass::ConstantFolding>();
741     pass_manager.run_passes(f);
742     pass_manager.run_passes(f);
743
744     ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
745     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
746     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
747     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
748
749     auto result_as_concat =
750         as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
751     ASSERT_TRUE(result_as_concat);
752     ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
753 }
754
755 TEST(constant_folding, shape_of_dynamic_double_folding_v3)
756 {
757     PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
758
759     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
760     auto shape_of = make_shared<op::v3::ShapeOf>(param);
761     auto f = make_shared<Function>(shape_of, ParameterVector{param});
762
763     pass::Manager pass_manager;
764     pass_manager.register_pass<pass::ConstantFolding>();
765     pass_manager.run_passes(f);
766     pass_manager.run_passes(f);
767
768     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
769     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
770     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
771     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
772
773     auto result_as_concat =
774         as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
775     ASSERT_TRUE(result_as_concat);
776     ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
777 }
778
779 // Constant folding will not succeed on ShapeOf if the argument rank is dynamic.
780 // We want to make sure it fails gracefully, leaving the ShapeOf op in place.
781 TEST(constant_folding, shape_of_rank_dynamic_v0)
782 {
783     PartialShape input_shape{PartialShape::dynamic()};
784
785     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
786     auto shape_of = make_shared<op::v0::ShapeOf>(param);
787     auto f = make_shared<Function>(shape_of, ParameterVector{param});
788
789     pass::Manager pass_manager;
790     pass_manager.register_pass<pass::ConstantFolding>();
791     pass_manager.run_passes(f);
792
793     ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
794     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
795
796     auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
797     ASSERT_EQ(result_shape_of, shape_of);
798 }
799
800 TEST(constant_folding, shape_of_rank_dynamic_v3)
801 {
802     PartialShape input_shape{PartialShape::dynamic()};
803
804     auto param = make_shared<op::Parameter>(element::boolean, input_shape);
805     auto shape_of = make_shared<op::v3::ShapeOf>(param);
806     auto f = make_shared<Function>(shape_of, ParameterVector{param});
807
808     pass::Manager pass_manager;
809     pass_manager.register_pass<pass::ConstantFolding>();
810     pass_manager.run_passes(f);
811
812     ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
813     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
814
815     auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
816     ASSERT_EQ(result_shape_of, shape_of);
817 }
818
819 TEST(constant_folding, const_reverse)
820 {
821     Shape input_shape{3, 3};
822
823     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
824     auto constant = op::Constant::create(element::i32, input_shape, values_in);
825     auto convert = make_shared<op::Reverse>(constant, AxisSet{1});
826     auto f = make_shared<Function>(convert, ParameterVector{});
827
828     pass::Manager pass_manager;
829     pass_manager.register_pass<pass::ConstantFolding>();
830     pass_manager.run_passes(f);
831
832     ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 0);
833     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
834
835     auto new_const =
836         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
837     ASSERT_TRUE(new_const);
838     auto values_out = new_const->get_vector<int32_t>();
839
840     vector<int32_t> values_expected{3, 2, 1, 6, 5, 4, 9, 8, 7};
841     ASSERT_EQ(values_expected, values_out);
842 }
843
844 TEST(constant_folding, const_product)
845 {
846     Shape input_shape{3, 3};
847
848     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
849     auto constant = op::Constant::create(element::i32, input_shape, values_in);
850     auto convert = make_shared<op::Product>(constant, AxisSet{1});
851     auto f = make_shared<Function>(convert, ParameterVector{});
852
853     pass::Manager pass_manager;
854     pass_manager.register_pass<pass::ConstantFolding>();
855     pass_manager.run_passes(f);
856
857     ASSERT_EQ(count_ops_of_type<op::Product>(f), 0);
858     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
859
860     auto new_const =
861         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
862     ASSERT_TRUE(new_const);
863     auto values_out = new_const->get_vector<int32_t>();
864
865     vector<int32_t> values_expected{6, 120, 504};
866     ASSERT_EQ(values_expected, values_out);
867 }
868
869 TEST(constant_folding, const_reduceprod)
870 {
871     Shape input_shape{3, 3};
872     Shape output_shape{3};
873
874     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
875     auto constant = op::Constant::create(element::i32, input_shape, values_in);
876     Shape axes_shape{1};
877     vector<int32_t> values_axes{1};
878     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
879     auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes);
880     auto f = make_shared<Function>(convert, ParameterVector{});
881
882     pass::Manager pass_manager;
883     pass_manager.register_pass<pass::ConstantFolding>();
884     pass_manager.run_passes(f);
885
886     ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
887     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
888
889     auto new_const =
890         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
891     ASSERT_TRUE(new_const);
892     ASSERT_EQ(new_const->get_shape(), output_shape);
893
894     auto values_out = new_const->get_vector<int32_t>();
895
896     vector<int32_t> values_expected{6, 120, 504};
897
898     ASSERT_EQ(values_expected, values_out);
899 }
900
901 TEST(constant_folding, const_reduceprod_keepdims)
902 {
903     Shape input_shape{3, 3};
904     Shape output_shape{3, 1};
905
906     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
907     auto constant = op::Constant::create(element::i32, input_shape, values_in);
908     Shape axes_shape{1};
909     vector<int32_t> values_axes{1};
910     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
911     auto convert = make_shared<op::v1::ReduceProd>(constant, constant_axes, true);
912     auto f = make_shared<Function>(convert, ParameterVector{});
913
914     pass::Manager pass_manager;
915     pass_manager.register_pass<pass::ConstantFolding>();
916     pass_manager.run_passes(f);
917
918     ASSERT_EQ(count_ops_of_type<op::v1::ReduceProd>(f), 0);
919     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
920
921     auto new_const =
922         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
923     ASSERT_TRUE(new_const);
924     ASSERT_EQ(new_const->get_shape(), output_shape);
925
926     auto values_out = new_const->get_vector<int32_t>();
927
928     vector<int32_t> values_expected{6, 120, 504};
929
930     ASSERT_EQ(values_expected, values_out);
931 }
932
933 TEST(constant_folding, const_sum)
934 {
935     Shape input_shape{3, 3};
936
937     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
938     auto constant = op::Constant::create(element::i32, input_shape, values_in);
939     auto convert = make_shared<op::Sum>(constant, AxisSet{1});
940     auto f = make_shared<Function>(convert, ParameterVector{});
941
942     pass::Manager pass_manager;
943     pass_manager.register_pass<pass::ConstantFolding>();
944     pass_manager.run_passes(f);
945
946     ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
947     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
948
949     auto new_const =
950         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
951     ASSERT_TRUE(new_const);
952     auto values_out = new_const->get_vector<int32_t>();
953
954     vector<int32_t> values_expected{6, 15, 24};
955
956     ASSERT_EQ(values_expected, values_out);
957 }
958
959 TEST(constant_folding, const_reducesum)
960 {
961     Shape input_shape{3, 3};
962     Shape output_shape{3};
963
964     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
965     auto constant = op::Constant::create(element::i32, input_shape, values_in);
966     Shape axes_shape{1};
967     vector<int32_t> values_axes{1};
968     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
969     auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes);
970     auto f = make_shared<Function>(convert, ParameterVector{});
971
972     pass::Manager pass_manager;
973     pass_manager.register_pass<pass::ConstantFolding>();
974     pass_manager.run_passes(f);
975
976     ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
977     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
978
979     auto new_const =
980         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
981     ASSERT_TRUE(new_const);
982     ASSERT_EQ(new_const->get_shape(), output_shape);
983
984     auto values_out = new_const->get_vector<int32_t>();
985
986     vector<int32_t> values_expected{6, 15, 24};
987
988     ASSERT_EQ(values_expected, values_out);
989 }
990
991 TEST(constant_folding, const_reducesum_keepdims)
992 {
993     Shape input_shape{3, 3};
994     Shape output_shape{3, 1};
995
996     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
997     auto constant = op::Constant::create(element::i32, input_shape, values_in);
998     Shape axes_shape{1};
999     vector<int32_t> values_axes{1};
1000     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1001     auto convert = make_shared<op::v1::ReduceSum>(constant, constant_axes, true);
1002     auto f = make_shared<Function>(convert, ParameterVector{});
1003
1004     pass::Manager pass_manager;
1005     pass_manager.register_pass<pass::ConstantFolding>();
1006     pass_manager.run_passes(f);
1007
1008     ASSERT_EQ(count_ops_of_type<op::v1::ReduceSum>(f), 0);
1009     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1010
1011     auto new_const =
1012         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1013     ASSERT_TRUE(new_const);
1014     ASSERT_EQ(new_const->get_shape(), output_shape);
1015
1016     auto values_out = new_const->get_vector<int32_t>();
1017
1018     vector<int32_t> values_expected{6, 15, 24};
1019
1020     ASSERT_EQ(values_expected, values_out);
1021 }
1022
1023 TEST(constant_folding, const_max)
1024 {
1025     Shape input_shape{3, 3};
1026
1027     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1028     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1029     auto convert = make_shared<op::Max>(constant, AxisSet{1});
1030     auto f = make_shared<Function>(convert, ParameterVector{});
1031
1032     pass::Manager pass_manager;
1033     pass_manager.register_pass<pass::ConstantFolding>();
1034     pass_manager.run_passes(f);
1035
1036     ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
1037     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1038
1039     auto new_const =
1040         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1041     ASSERT_TRUE(new_const);
1042     auto values_out = new_const->get_vector<int32_t>();
1043
1044     vector<int32_t> values_expected{3, 6, 9};
1045
1046     ASSERT_EQ(values_expected, values_out);
1047 }
1048
1049 TEST(constant_folding, const_reducemax)
1050 {
1051     Shape input_shape{3, 2};
1052     Shape output_shape{3};
1053
1054     vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1055     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1056     Shape axes_shape{1};
1057     vector<int32_t> values_axes{1};
1058     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1059     auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes);
1060     auto f = make_shared<Function>(convert, ParameterVector{});
1061
1062     pass::Manager pass_manager;
1063     pass_manager.register_pass<pass::ConstantFolding>();
1064     pass_manager.run_passes(f);
1065
1066     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1067     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1068
1069     auto new_const =
1070         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1071     ASSERT_TRUE(new_const);
1072     ASSERT_EQ(new_const->get_shape(), output_shape);
1073
1074     auto values_out = new_const->get_vector<int32_t>();
1075
1076     vector<int32_t> values_expected{2, 4, 6};
1077
1078     ASSERT_EQ(values_expected, values_out);
1079 }
1080
1081 TEST(constant_folding, const_reducemax_keepdims)
1082 {
1083     Shape input_shape{3, 2};
1084     Shape output_shape{3, 1};
1085
1086     vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1087     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1088     Shape axes_shape{1};
1089     vector<int32_t> values_axes{1};
1090     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1091     auto convert = make_shared<op::v1::ReduceMax>(constant, constant_axes, true);
1092     auto f = make_shared<Function>(convert, ParameterVector{});
1093
1094     pass::Manager pass_manager;
1095     pass_manager.register_pass<pass::ConstantFolding>();
1096     pass_manager.run_passes(f);
1097
1098     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMax>(f), 0);
1099     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1100
1101     auto new_const =
1102         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1103     ASSERT_TRUE(new_const);
1104     ASSERT_EQ(new_const->get_shape(), output_shape);
1105
1106     auto values_out = new_const->get_vector<int32_t>();
1107
1108     vector<int32_t> values_expected{2, 4, 6};
1109
1110     ASSERT_EQ(values_expected, values_out);
1111 }
1112
1113 TEST(constant_folding, const_min)
1114 {
1115     Shape input_shape{3, 3};
1116
1117     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1118     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1119     auto convert = make_shared<op::Min>(constant, AxisSet{1});
1120     auto f = make_shared<Function>(convert, ParameterVector{});
1121
1122     pass::Manager pass_manager;
1123     pass_manager.register_pass<pass::ConstantFolding>();
1124     pass_manager.run_passes(f);
1125
1126     ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
1127     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1128
1129     auto new_const =
1130         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1131     ASSERT_TRUE(new_const);
1132     auto values_out = new_const->get_vector<int32_t>();
1133
1134     vector<int32_t> values_expected{1, 4, 7};
1135
1136     ASSERT_EQ(values_expected, values_out);
1137 }
1138
1139 TEST(constant_folding, const_reducemin)
1140 {
1141     Shape input_shape{3, 2};
1142     Shape output_shape{3};
1143
1144     vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1145     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1146     Shape axes_shape{1};
1147     vector<int32_t> values_axes{1};
1148     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1149     auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes);
1150     auto f = make_shared<Function>(convert, ParameterVector{});
1151
1152     pass::Manager pass_manager;
1153     pass_manager.register_pass<pass::ConstantFolding>();
1154     pass_manager.run_passes(f);
1155
1156     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1157     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1158
1159     auto new_const =
1160         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1161     ASSERT_TRUE(new_const);
1162     ASSERT_EQ(new_const->get_shape(), output_shape);
1163
1164     auto values_out = new_const->get_vector<int32_t>();
1165
1166     vector<int32_t> values_expected{1, 3, 5};
1167
1168     ASSERT_EQ(values_expected, values_out);
1169 }
1170
1171 TEST(constant_folding, const_reducemin_keepdims)
1172 {
1173     Shape input_shape{3, 2};
1174     Shape output_shape{3, 1};
1175
1176     vector<int32_t> values_in{1, 2, 3, 4, 5, 6};
1177     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1178     Shape axes_shape{1};
1179     vector<int32_t> values_axes{1};
1180     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1181     auto convert = make_shared<op::v1::ReduceMin>(constant, constant_axes, true);
1182     auto f = make_shared<Function>(convert, ParameterVector{});
1183
1184     pass::Manager pass_manager;
1185     pass_manager.register_pass<pass::ConstantFolding>();
1186     pass_manager.run_passes(f);
1187
1188     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMin>(f), 0);
1189     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1190
1191     auto new_const =
1192         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1193     ASSERT_TRUE(new_const);
1194     ASSERT_EQ(new_const->get_shape(), output_shape);
1195
1196     auto values_out = new_const->get_vector<int32_t>();
1197
1198     vector<int32_t> values_expected{1, 3, 5};
1199
1200     ASSERT_EQ(values_expected, values_out);
1201 }
1202
1203 TEST(constant_folding, const_reducemean)
1204 {
1205     Shape input_shape{3, 3};
1206     Shape output_shape{3};
1207
1208     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1209     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1210     Shape axes_shape{1};
1211     vector<int32_t> values_axes{1};
1212     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1213     auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes);
1214     auto f = make_shared<Function>(convert, ParameterVector{});
1215
1216     pass::Manager pass_manager;
1217     pass_manager.register_pass<pass::ConstantFolding>();
1218     pass_manager.run_passes(f);
1219
1220     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1221     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1222
1223     auto new_const =
1224         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1225     ASSERT_TRUE(new_const);
1226     ASSERT_EQ(new_const->get_shape(), output_shape);
1227
1228     auto values_out = new_const->get_vector<int32_t>();
1229
1230     vector<int32_t> values_expected{2, 5, 8};
1231
1232     ASSERT_EQ(values_expected, values_out);
1233 }
1234
1235 TEST(constant_folding, const_reducemean_keepdims)
1236 {
1237     Shape input_shape{3, 3};
1238     Shape output_shape{3, 1};
1239
1240     vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
1241     auto constant = op::Constant::create(element::i32, input_shape, values_in);
1242     Shape axes_shape{1};
1243     vector<int32_t> values_axes{1};
1244     auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
1245     auto convert = make_shared<op::v1::ReduceMean>(constant, constant_axes, true);
1246     auto f = make_shared<Function>(convert, ParameterVector{});
1247
1248     pass::Manager pass_manager;
1249     pass_manager.register_pass<pass::ConstantFolding>();
1250     pass_manager.run_passes(f);
1251
1252     ASSERT_EQ(count_ops_of_type<op::v1::ReduceMean>(f), 0);
1253     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1254
1255     auto new_const =
1256         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1257     ASSERT_TRUE(new_const);
1258     ASSERT_EQ(new_const->get_shape(), output_shape);
1259
1260     auto values_out = new_const->get_vector<int32_t>();
1261
1262     vector<int32_t> values_expected{2, 5, 8};
1263
1264     ASSERT_EQ(values_expected, values_out);
1265 }
1266
1267 TEST(constant_folding, const_reduce_logical_and__no_keepdims)
1268 {
1269     const Shape input_shape{3, 3};
1270
1271     const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1272     const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1273     const auto axes = op::Constant::create(element::i64, {1}, {1});
1274     const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, false);
1275     auto f = make_shared<Function>(convert, ParameterVector{});
1276
1277     pass::Manager pass_manager;
1278     pass_manager.register_pass<pass::ConstantFolding>();
1279     pass_manager.run_passes(f);
1280
1281     ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1282     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1283
1284     const auto new_const =
1285         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1286     ASSERT_TRUE(new_const);
1287
1288     const Shape expected_out_shape{3};
1289     ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1290
1291     const auto values_out = new_const->get_vector<char>();
1292
1293     const vector<char> values_expected{0, 0, 1};
1294
1295     ASSERT_EQ(values_expected, values_out);
1296 }
1297
1298 TEST(constant_folding, const_reduce_logical_and__keepdims)
1299 {
1300     const Shape input_shape{3, 3};
1301
1302     const vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
1303     const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1304     const auto axes = op::Constant::create(element::i64, {1}, {1});
1305     const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1306     auto f = make_shared<Function>(convert, ParameterVector{});
1307
1308     pass::Manager pass_manager;
1309     pass_manager.register_pass<pass::ConstantFolding>();
1310     pass_manager.run_passes(f);
1311
1312     ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1313     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1314
1315     const auto new_const =
1316         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1317     ASSERT_TRUE(new_const);
1318
1319     // the output shape is expected to have 'ones' at the positions specified in the reduction axes
1320     // in case the keep_dims attribute of ReduceLogicalAnd is set to true
1321     const Shape expected_out_shape{3, 1};
1322     ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1323
1324     const auto values_out = new_const->get_vector<char>();
1325
1326     const vector<char> values_expected{0, 0, 1};
1327
1328     ASSERT_EQ(values_expected, values_out);
1329 }
1330
1331 TEST(constant_folding, const_reduce_logical_and__keepdims_3d)
1332 {
1333     const Shape input_shape{2, 2, 2};
1334
1335     const vector<char> values_in{1, 1, 0, 0, 1, 0, 0, 1};
1336     const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1337     const auto axes = op::Constant::create(element::i64, {2}, {0, 2});
1338     const auto convert = make_shared<op::v1::ReduceLogicalAnd>(data, axes, true);
1339     auto f = make_shared<Function>(convert, ParameterVector{});
1340
1341     pass::Manager pass_manager;
1342     pass_manager.register_pass<pass::ConstantFolding>();
1343     pass_manager.run_passes(f);
1344
1345     ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1346     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1347
1348     const auto new_const =
1349         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1350     ASSERT_TRUE(new_const);
1351
1352     const Shape expected_out_shape{1, 2, 1};
1353     ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1354
1355     const auto values_out = new_const->get_vector<char>();
1356
1357     const vector<char> values_expected{0, 0};
1358
1359     ASSERT_EQ(values_expected, values_out);
1360 }
1361
1362 TEST(constant_folding, const_any)
1363 {
1364     Shape input_shape{3, 3};
1365
1366     vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
1367     auto constant = op::Constant::create(element::boolean, input_shape, values_in);
1368     auto convert = make_shared<op::Any>(constant, AxisSet{1});
1369     auto f = make_shared<Function>(convert, ParameterVector{});
1370
1371     pass::Manager pass_manager;
1372     pass_manager.register_pass<pass::ConstantFolding>();
1373     pass_manager.run_passes(f);
1374
1375     ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
1376     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1377
1378     auto new_const =
1379         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1380     ASSERT_TRUE(new_const);
1381     auto values_out = new_const->get_vector<char>();
1382
1383     vector<char> values_expected{1, 1, 0};
1384
1385     ASSERT_EQ(values_expected, values_out);
1386 }
1387
1388 TEST(constant_folding, const_reduce_logical_or__no_keepdims)
1389 {
1390     const Shape input_shape{3, 3};
1391
1392     const vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
1393     const auto data = op::Constant::create(element::boolean, input_shape, values_in);
1394     const auto axes = op::Constant::create(element::i64, {1}, {1});
1395     const auto convert = make_shared<op::v1::ReduceLogicalOr>(data, axes, false);
1396     auto f = make_shared<Function>(convert, ParameterVector{});
1397
1398     pass::Manager pass_manager;
1399     pass_manager.register_pass<pass::ConstantFolding>();
1400     pass_manager.run_passes(f);
1401
1402     ASSERT_EQ(count_ops_of_type<op::v1::ReduceLogicalAnd>(f), 0);
1403     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1404
1405     const auto new_const =
1406         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1407     ASSERT_TRUE(new_const);
1408
1409     const Shape expected_out_shape{3};
1410     ASSERT_EQ(new_const->get_shape(), expected_out_shape);
1411
1412     const auto values_out = new_const->get_vector<char>();
1413
1414     const vector<char> values_expected{1, 1, 0};
1415
1416     ASSERT_EQ(values_expected, values_out);
1417 }
1418
1419 TEST(constant_folding, const_concat)
1420 {
1421     auto constant0 =
1422         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1423     auto constant1 = op::Constant::create(element::i32, Shape{2, 1}, vector<int32_t>{7, 8});
1424     auto concat = make_shared<op::Concat>(NodeVector{constant0, constant1}, 1);
1425     auto f = make_shared<Function>(concat, ParameterVector{});
1426
1427     pass::Manager pass_manager;
1428     pass_manager.register_pass<pass::ConstantFolding>();
1429     pass_manager.run_passes(f);
1430
1431     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1432     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1433
1434     auto new_const =
1435         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1436     ASSERT_TRUE(new_const);
1437     auto values_out = new_const->get_vector<int32_t>();
1438
1439     vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
1440
1441     ASSERT_EQ(values_expected, values_out);
1442 }
1443
1444 TEST(constant_folding, const_concat_3d_single_elem)
1445 {
1446     auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
1447     auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
1448     auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
1449     auto f = make_shared<Function>(concat, ParameterVector{});
1450
1451     pass::Manager pass_manager;
1452     pass_manager.register_pass<pass::ConstantFolding>();
1453     pass_manager.run_passes(f);
1454
1455     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1456     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1457
1458     auto new_const =
1459         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1460
1461     ASSERT_TRUE(new_const);
1462     ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
1463
1464     auto values_out = new_const->get_vector<int32_t>();
1465     vector<int32_t> values_expected{1, 2};
1466     ASSERT_EQ(values_expected, values_out);
1467 }
1468
1469 TEST(constant_folding, const_concat_axis_2)
1470 {
1471     auto constant_1 =
1472         op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1473     auto constant_2 = op::Constant::create(
1474         element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
1475     auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
1476     auto f = make_shared<Function>(concat, ParameterVector{});
1477
1478     pass::Manager pass_manager;
1479     pass_manager.register_pass<pass::ConstantFolding>();
1480     pass_manager.run_passes(f);
1481
1482     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1483     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1484
1485     auto new_const =
1486         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1487
1488     ASSERT_TRUE(new_const);
1489     ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
1490
1491     auto values_out = new_const->get_vector<int32_t>();
1492     vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
1493     ASSERT_EQ(values_expected, values_out);
1494 }
1495
1496 TEST(constant_folding, const_concat_axis_1_bool_type)
1497 {
1498     auto constant_1 =
1499         op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
1500     auto constant_2 = op::Constant::create(
1501         element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
1502     auto constant_3 = op::Constant::create(
1503         element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
1504     auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
1505     auto f = make_shared<Function>(concat, ParameterVector{});
1506
1507     pass::Manager pass_manager;
1508     pass_manager.register_pass<pass::ConstantFolding>();
1509     pass_manager.run_passes(f);
1510
1511     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1512     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1513
1514     auto new_const =
1515         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1516
1517     ASSERT_TRUE(new_const);
1518     ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
1519
1520     auto values_out = new_const->get_vector<char>();
1521     vector<char> values_expected{
1522         true, true, true, false, true, false, true, false, true, false, true, false};
1523     ASSERT_EQ(values_expected, values_out);
1524 }
1525
1526 TEST(constant_folding, const_not)
1527 {
1528     auto constant =
1529         op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
1530     auto logical_not = make_shared<op::Not>(constant);
1531     auto f = make_shared<Function>(logical_not, ParameterVector{});
1532
1533     pass::Manager pass_manager;
1534     pass_manager.register_pass<pass::ConstantFolding>();
1535     pass_manager.run_passes(f);
1536
1537     ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
1538     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1539
1540     auto new_const =
1541         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1542     ASSERT_TRUE(new_const);
1543     auto values_out = new_const->get_vector<char>();
1544
1545     vector<char> values_expected{1, 0, 1, 1, 0, 0};
1546
1547     ASSERT_EQ(values_expected, values_out);
1548 }
1549
1550 TEST(constant_folding, const_equal)
1551 {
1552     auto constant0 =
1553         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1554     auto constant1 =
1555         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1556     auto eq = make_shared<op::Equal>(constant0, constant1);
1557     auto f = make_shared<Function>(eq, ParameterVector{});
1558
1559     pass::Manager pass_manager;
1560     pass_manager.register_pass<pass::ConstantFolding>();
1561     pass_manager.run_passes(f);
1562
1563     ASSERT_EQ(count_ops_of_type<op::Equal>(f), 0);
1564     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1565
1566     auto new_const =
1567         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1568     ASSERT_TRUE(new_const);
1569     auto values_out = new_const->get_vector<char>();
1570
1571     vector<char> values_expected{1, 1, 0, 0, 1, 1};
1572
1573     ASSERT_EQ(values_expected, values_out);
1574 }
1575
1576 TEST(constant_folding, const_not_equal)
1577 {
1578     auto constant0 =
1579         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1580     auto constant1 =
1581         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
1582     auto eq = make_shared<op::NotEqual>(constant0, constant1);
1583     auto f = make_shared<Function>(eq, ParameterVector{});
1584
1585     pass::Manager pass_manager;
1586     pass_manager.register_pass<pass::ConstantFolding>();
1587     pass_manager.run_passes(f);
1588
1589     ASSERT_EQ(count_ops_of_type<op::NotEqual>(f), 0);
1590     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1591
1592     auto new_const =
1593         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1594     ASSERT_TRUE(new_const);
1595     auto values_out = new_const->get_vector<char>();
1596
1597     vector<char> values_expected{0, 0, 1, 1, 0, 0};
1598
1599     ASSERT_EQ(values_expected, values_out);
1600 }
1601
1602 TEST(constant_folding, const_greater)
1603 {
1604     auto constant0 =
1605         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1606     auto constant1 =
1607         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1608     auto eq = make_shared<op::Greater>(constant0, constant1);
1609     auto f = make_shared<Function>(eq, ParameterVector{});
1610
1611     pass::Manager pass_manager;
1612     pass_manager.register_pass<pass::ConstantFolding>();
1613     pass_manager.run_passes(f);
1614
1615     ASSERT_EQ(count_ops_of_type<op::Greater>(f), 0);
1616     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1617
1618     auto new_const =
1619         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1620     ASSERT_TRUE(new_const);
1621     auto values_out = new_const->get_vector<char>();
1622
1623     vector<char> values_expected{0, 0, 1, 0, 0, 1};
1624
1625     ASSERT_EQ(values_expected, values_out);
1626 }
1627
1628 TEST(constant_folding, const_greater_eq)
1629 {
1630     auto constant0 =
1631         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1632     auto constant1 =
1633         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1634     auto eq = make_shared<op::GreaterEq>(constant0, constant1);
1635     auto f = make_shared<Function>(eq, ParameterVector{});
1636
1637     pass::Manager pass_manager;
1638     pass_manager.register_pass<pass::ConstantFolding>();
1639     pass_manager.run_passes(f);
1640
1641     ASSERT_EQ(count_ops_of_type<op::GreaterEq>(f), 0);
1642     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1643
1644     auto new_const =
1645         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1646     ASSERT_TRUE(new_const);
1647     auto values_out = new_const->get_vector<char>();
1648
1649     vector<char> values_expected{0, 1, 1, 0, 1, 1};
1650
1651     ASSERT_EQ(values_expected, values_out);
1652 }
1653
1654 TEST(constant_folding, const_less)
1655 {
1656     auto constant0 =
1657         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1658     auto constant1 =
1659         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1660     auto eq = make_shared<op::Less>(constant0, constant1);
1661     auto f = make_shared<Function>(eq, ParameterVector{});
1662
1663     pass::Manager pass_manager;
1664     pass_manager.register_pass<pass::ConstantFolding>();
1665     pass_manager.run_passes(f);
1666
1667     ASSERT_EQ(count_ops_of_type<op::Less>(f), 0);
1668     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1669
1670     auto new_const =
1671         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1672     ASSERT_TRUE(new_const);
1673     auto values_out = new_const->get_vector<char>();
1674
1675     vector<char> values_expected{1, 0, 0, 1, 0, 0};
1676
1677     ASSERT_EQ(values_expected, values_out);
1678 }
1679
1680 TEST(constant_folding, const_less_eq)
1681 {
1682     auto constant0 =
1683         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
1684     auto constant1 =
1685         op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
1686     auto eq = make_shared<op::LessEq>(constant0, constant1);
1687     auto f = make_shared<Function>(eq, ParameterVector{});
1688
1689     pass::Manager pass_manager;
1690     pass_manager.register_pass<pass::ConstantFolding>();
1691     pass_manager.run_passes(f);
1692
1693     ASSERT_EQ(count_ops_of_type<op::LessEq>(f), 0);
1694     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1695
1696     auto new_const =
1697         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1698     ASSERT_TRUE(new_const);
1699     auto values_out = new_const->get_vector<char>();
1700
1701     vector<char> values_expected{1, 1, 0, 1, 1, 0};
1702
1703     ASSERT_EQ(values_expected, values_out);
1704 }
1705
1706 TEST(constant_folding, const_or)
1707 {
1708     auto constant0 =
1709         op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1710     auto constant1 =
1711         op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1712     auto eq = make_shared<op::Or>(constant0, constant1);
1713     auto f = make_shared<Function>(eq, ParameterVector{});
1714
1715     pass::Manager pass_manager;
1716     pass_manager.register_pass<pass::ConstantFolding>();
1717     pass_manager.run_passes(f);
1718
1719     ASSERT_EQ(count_ops_of_type<op::Or>(f), 0);
1720     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1721
1722     auto new_const =
1723         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1724     ASSERT_TRUE(new_const);
1725     auto values_out = new_const->get_vector<char>();
1726
1727     vector<char> values_expected{0, 1, 1, 1, 1, 1};
1728
1729     ASSERT_EQ(values_expected, values_out);
1730 }
1731
1732 TEST(constant_folding, const_xor)
1733 {
1734     auto constant0 =
1735         op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
1736     auto constant1 =
1737         op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
1738     auto eq = make_shared<op::Xor>(constant0, constant1);
1739     auto f = make_shared<Function>(eq, ParameterVector{});
1740
1741     pass::Manager pass_manager;
1742     pass_manager.register_pass<pass::ConstantFolding>();
1743     pass_manager.run_passes(f);
1744
1745     ASSERT_EQ(count_ops_of_type<op::Xor>(f), 0);
1746     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1747
1748     auto new_const =
1749         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1750     ASSERT_TRUE(new_const);
1751     auto values_out = new_const->get_vector<char>();
1752
1753     vector<char> values_expected{0, 1, 0, 1, 1, 0};
1754
1755     ASSERT_EQ(values_expected, values_out);
1756 }
1757
1758 TEST(constant_folding, const_ceiling)
1759 {
1760     auto constant = op::Constant::create(
1761         element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1762     auto ceil = make_shared<op::Ceiling>(constant);
1763     auto f = make_shared<Function>(ceil, ParameterVector{});
1764
1765     pass::Manager pass_manager;
1766     pass_manager.register_pass<pass::ConstantFolding>();
1767     pass_manager.run_passes(f);
1768
1769     ASSERT_EQ(count_ops_of_type<op::Ceiling>(f), 0);
1770     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1771
1772     auto new_const =
1773         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1774     ASSERT_TRUE(new_const);
1775     auto values_out = new_const->get_vector<float>();
1776
1777     vector<float> values_expected{0.0f, 1.0f, 0.0f, -2.0f, 3.0f, 3.0f};
1778
1779     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1780 }
1781
1782 TEST(constant_folding, const_floor)
1783 {
1784     auto constant = op::Constant::create(
1785         element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
1786     auto floor = make_shared<op::Floor>(constant);
1787     auto f = make_shared<Function>(floor, ParameterVector{});
1788
1789     pass::Manager pass_manager;
1790     pass_manager.register_pass<pass::ConstantFolding>();
1791     pass_manager.run_passes(f);
1792
1793     ASSERT_EQ(count_ops_of_type<op::Floor>(f), 0);
1794     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1795
1796     auto new_const =
1797         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1798     ASSERT_TRUE(new_const);
1799     auto values_out = new_const->get_vector<float>();
1800
1801     vector<float> values_expected{0.0f, 0.0f, -1.0f, -3.0f, 2.0f, 3.0f};
1802
1803     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1804 }
1805
1806 TEST(constant_folding, const_gather)
1807 {
1808     auto constant_data = op::Constant::create(
1809         element::f32,
1810         Shape{2, 5},
1811         vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1812     auto constant_indices =
1813         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1814     size_t gather_axis = 1;
1815     auto gather = make_shared<op::v0::Gather>(constant_data, constant_indices, gather_axis);
1816     auto f = make_shared<Function>(gather, ParameterVector{});
1817
1818     pass::Manager pass_manager;
1819     pass_manager.register_pass<pass::ConstantFolding>();
1820     pass_manager.run_passes(f);
1821
1822     ASSERT_EQ(count_ops_of_type<op::v0::Gather>(f), 0);
1823     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1824
1825     auto new_const =
1826         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1827     ASSERT_TRUE(new_const);
1828     auto values_out = new_const->get_vector<float>();
1829
1830     vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1831
1832     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1833 }
1834
1835 TEST(constant_folding, const_gather_v1)
1836 {
1837     auto constant_data = op::Constant::create(
1838         element::f32,
1839         Shape{2, 5},
1840         vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1841     auto constant_indices =
1842         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1843     auto constant_axis = op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
1844     auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1845     auto f = make_shared<Function>(gather, ParameterVector{});
1846
1847     pass::Manager pass_manager;
1848     pass_manager.register_pass<pass::ConstantFolding>();
1849     pass_manager.run_passes(f);
1850
1851     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1852     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1853
1854     auto new_const =
1855         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1856     ASSERT_TRUE(new_const);
1857     auto values_out = new_const->get_vector<float>();
1858
1859     vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1860
1861     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1862 }
1863
1864 TEST(constant_folding, const_gather_v1_scalar)
1865 {
1866     auto constant_data = op::Constant::create(
1867         element::f32,
1868         Shape{2, 5},
1869         vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
1870     auto constant_indices =
1871         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
1872     auto constant_axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{1});
1873     auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
1874     auto f = make_shared<Function>(gather, ParameterVector{});
1875
1876     pass::Manager pass_manager;
1877     pass_manager.register_pass<pass::ConstantFolding>();
1878     pass_manager.run_passes(f);
1879
1880     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1881     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1882
1883     auto new_const =
1884         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1885     ASSERT_TRUE(new_const);
1886     auto values_out = new_const->get_vector<float>();
1887
1888     vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
1889
1890     ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
1891 }
1892
1893 TEST(constant_folding, const_gather_v1_subgraph)
1894 {
1895     const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1896     const float b_value = 3.21f;
1897     const auto B_const = op::Constant::create(element::f32, {1}, {b_value});
1898     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1899     const int64_t axis = 0;
1900     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1901
1902     const auto concat = make_shared<op::Concat>(NodeVector{A, B_const, C}, axis);
1903
1904     const vector<int64_t> indices{1};
1905     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1906     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1907     auto f = make_shared<Function>(gather, ParameterVector{A, C});
1908
1909     pass::Manager pass_manager;
1910     pass_manager.register_pass<pass::ConstantFolding>();
1911     pass_manager.run_passes(f);
1912
1913     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1914     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1915     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1916
1917     const auto new_const =
1918         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1919     ASSERT_TRUE(new_const);
1920
1921     const auto values_out = new_const->get_vector<float>();
1922     ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1923 }
1924
1925 TEST(constant_folding, const_gather_v1_subgraph_neg_axis)
1926 {
1927     const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1928     const float b_value = 1.23f;
1929     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1930     const auto C_const = op::Constant::create(element::f32, {1}, {b_value});
1931     const int64_t axis = 0;
1932     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1933
1934     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C_const}, axis);
1935
1936     const vector<int64_t> indices{-1};
1937     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1938     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1939     auto f = make_shared<Function>(gather, ParameterVector{A, B});
1940
1941     pass::Manager pass_manager;
1942     pass_manager.register_pass<pass::ConstantFolding>();
1943     pass_manager.run_passes(f);
1944
1945     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1946     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1947     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
1948
1949     const auto new_const =
1950         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
1951     ASSERT_TRUE(new_const);
1952
1953     const auto values_out = new_const->get_vector<float>();
1954     ASSERT_TRUE(test::all_close_f(values_out, {b_value}, MIN_FLOAT_TOLERANCE_BITS));
1955 }
1956
1957 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input)
1958 {
1959     const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1960     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1961     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1962     const int64_t axis = 0;
1963     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1964
1965     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1966
1967     const vector<int64_t> indices{1};
1968     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
1969     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1970     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1971
1972     pass::Manager pass_manager;
1973     pass_manager.register_pass<pass::ConstantFolding>();
1974     pass_manager.run_passes(f);
1975
1976     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
1977     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
1978 }
1979
1980 TEST(constant_folding, const_gather_v1_subgraph_no_constant_input_scalar)
1981 {
1982     const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
1983     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
1984     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
1985     const int64_t axis = 0;
1986     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
1987
1988     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
1989
1990     const vector<int64_t> indices{1};
1991     const auto indices_const = op::Constant::create(element::i64, {}, indices);
1992     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
1993     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
1994
1995     pass::Manager pass_manager;
1996     pass_manager.register_pass<pass::ConstantFolding>();
1997     pass_manager.run_passes(f);
1998
1999     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
2000     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 0);
2001     ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(f), 1);
2002 }
2003
2004 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_zero_axis)
2005 {
2006     const auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2007     const auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2008     const auto C = make_shared<op::Parameter>(element::f32, Shape{2, 2});
2009     const int64_t axis = 1;
2010     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2011
2012     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2013
2014     const vector<int64_t> indices{1};
2015     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2016     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2017     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2018
2019     pass::Manager pass_manager;
2020     pass_manager.register_pass<pass::ConstantFolding>();
2021     pass_manager.run_passes(f);
2022
2023     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2024     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2025 }
2026
2027 TEST(constant_folding, const_gather_v1_subgraph_skip_if_non_single_indices)
2028 {
2029     const auto A = make_shared<op::Parameter>(element::f32, Shape{1});
2030     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2031     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2032     const int64_t axis = 0;
2033     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2034
2035     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2036
2037     const vector<int64_t> indices{0, 1};
2038     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2039     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2040     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2041
2042     pass::Manager pass_manager;
2043     pass_manager.register_pass<pass::ConstantFolding>();
2044     pass_manager.run_passes(f);
2045
2046     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2047     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2048 }
2049
2050 TEST(constant_folding, const_gather_v1_subgraph_skip_if_concat_output_shape_dynamic)
2051 {
2052     const auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
2053     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2054     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2055     const int64_t axis = 0;
2056     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2057
2058     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2059
2060     const vector<int64_t> indices{1};
2061     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2062     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2063     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2064
2065     pass::Manager pass_manager;
2066     pass_manager.register_pass<pass::ConstantFolding>();
2067     pass_manager.run_passes(f);
2068
2069     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2070     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2071 }
2072
2073 TEST(constant_folding, const_gather_v1_subgraph_skip_if_not_single_input)
2074 {
2075     const auto A = make_shared<op::Parameter>(element::f32, Shape{2});
2076     const auto B = make_shared<op::Parameter>(element::f32, Shape{1});
2077     const auto C = make_shared<op::Parameter>(element::f32, Shape{1});
2078     const int64_t axis = 0;
2079     const auto axis_const = op::Constant::create(element::i64, {}, {axis});
2080
2081     const auto concat = make_shared<op::Concat>(NodeVector{A, B, C}, axis);
2082
2083     const vector<int64_t> indices{1};
2084     const auto indices_const = op::Constant::create(element::i64, {indices.size()}, indices);
2085     const auto gather = make_shared<op::v1::Gather>(concat, indices_const, axis_const);
2086     auto f = make_shared<Function>(gather, ParameterVector{A, B, C});
2087
2088     pass::Manager pass_manager;
2089     pass_manager.register_pass<pass::ConstantFolding>();
2090     pass_manager.run_passes(f);
2091
2092     ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
2093     ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
2094 }
2095
2096 TEST(constant_folding, const_slice)
2097 {
2098     Shape shape_in{16};
2099
2100     vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
2101     auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
2102     auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
2103
2104     auto f = make_shared<Function>(slice, ParameterVector{});
2105
2106     pass::Manager pass_manager;
2107     pass_manager.register_pass<pass::ConstantFolding>();
2108     pass_manager.run_passes(f);
2109
2110     ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
2111     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2112
2113     auto new_const =
2114         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2115     ASSERT_TRUE(new_const);
2116     auto values_out = new_const->get_vector<int>();
2117
2118     vector<int> sliced_values{3, 6, 9, 12, 15};
2119     ASSERT_EQ(sliced_values, values_out);
2120 }
2121
2122 TEST(constant_folding, constant_dyn_reshape)
2123 {
2124     Shape shape_in{2, 4};
2125     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2126
2127     Shape shape_shape{3};
2128     vector<int64_t> values_shape{2, 4, 1};
2129
2130     auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2131     auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
2132     auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
2133     auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2134
2135     pass::Manager pass_manager;
2136     pass_manager.register_pass<pass::ConstantFolding>();
2137     pass_manager.run_passes(f);
2138
2139     ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2140     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2141
2142     auto new_const =
2143         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2144     ASSERT_TRUE(new_const);
2145     auto values_out = new_const->get_vector<float>();
2146
2147     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2148 }
2149
2150 TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
2151 {
2152     Shape shape_in{2, 4};
2153     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2154
2155     Shape shape_shape{3};
2156     // We're going to add these two together elementwise to get {2, 4, 1}.
2157     // This means that when ConstantFolding starts, v1::Reshape will not yet
2158     // have static output shape. But by the time the Add op is folded, the
2159     // v1::Reshape's shape should be inferrable.
2160     vector<int64_t> values_shape_a{1, 3, 0};
2161     vector<int64_t> values_shape_b{1, 1, 1};
2162
2163     auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
2164     auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
2165     auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
2166     auto dyn_reshape =
2167         make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
2168     auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
2169
2170     ASSERT_TRUE(dyn_reshape->get_output_partial_shape(0).is_dynamic());
2171
2172     pass::Manager pass_manager;
2173     pass_manager.register_pass<pass::ConstantFolding>();
2174     pass_manager.run_passes(f);
2175
2176     ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
2177     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2178
2179     auto new_const =
2180         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2181     ASSERT_TRUE(new_const);
2182     auto values_out = new_const->get_vector<float>();
2183
2184     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
2185 }
2186
2187 TEST(constant_folding, constant_transpose)
2188 {
2189     Shape shape_in{2, 4};
2190     vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
2191
2192     Shape shape_perm{2};
2193     vector<int64_t> values_perm{1, 0};
2194
2195     auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
2196     auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
2197     auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
2198     auto f = make_shared<Function>(transpose, ParameterVector{});
2199
2200     pass::Manager pass_manager;
2201     pass_manager.register_pass<pass::ConstantFolding>();
2202     pass_manager.run_passes(f);
2203
2204     ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
2205     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2206
2207     auto new_const =
2208         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2209     ASSERT_TRUE(new_const);
2210     auto values_out = new_const->get_vector<double>();
2211
2212     vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
2213     ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
2214 }
2215
2216 template <typename T>
2217 void range_test(T start, T stop, T step, const vector<T>& values_expected)
2218 {
2219     vector<T> values_start{start};
2220     vector<T> values_stop{stop};
2221     vector<T> values_step{step};
2222
2223     auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
2224     auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
2225     auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
2226     auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
2227     auto f = make_shared<Function>(range, ParameterVector{});
2228
2229     pass::Manager pass_manager;
2230     pass_manager.register_pass<pass::ConstantFolding>();
2231     pass_manager.run_passes(f);
2232
2233     ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
2234     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2235
2236     auto new_const =
2237         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2238     ASSERT_TRUE(new_const);
2239
2240     auto values_out = new_const->template get_vector<T>();
2241
2242     range_test_check(values_out, values_expected);
2243 }
2244
2245 TEST(constant_folding, constant_range)
2246 {
2247     range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
2248     range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
2249     range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
2250     range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
2251     range_test<double>(5, 12, 2, {5, 7, 9, 11});
2252     range_test<float>(5, 12, 2, {5, 7, 9, 11});
2253
2254     range_test<int32_t>(5, 12, -2, {});
2255     range_test<float>(12, 4, -2, {12, 10, 8, 6});
2256 }
2257
2258 TEST(constant_folding, constant_select)
2259 {
2260     Shape shape{2, 4};
2261     vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
2262     vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
2263     vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
2264
2265     auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
2266     auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
2267     auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
2268     auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
2269     auto f = make_shared<Function>(select, ParameterVector{});
2270
2271     pass::Manager pass_manager;
2272     pass_manager.register_pass<pass::ConstantFolding>();
2273     pass_manager.run_passes(f);
2274
2275     ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2276     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2277
2278     auto new_const =
2279         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2280     ASSERT_TRUE(new_const);
2281     auto values_out = new_const->get_vector<int64_t>();
2282
2283     vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
2284     ASSERT_EQ(values_expected, values_out);
2285 }
2286
2287 TEST(constant_folding, constant_v1_select)
2288 {
2289     Shape shape{2, 4};
2290     vector<char> values_selection{0, 1, 1, 0};
2291     vector<int64_t> values_t{1, 2, 3, 4};
2292     vector<int64_t> values_f{11, 12, 13, 14, 15, 16, 17, 18};
2293
2294     auto constant_selection =
2295         make_shared<op::Constant>(element::boolean, Shape{4}, values_selection);
2296     auto constant_t = make_shared<op::Constant>(element::i64, Shape{4}, values_t);
2297     auto constant_f = make_shared<op::Constant>(element::i64, Shape{2, 4}, values_f);
2298     auto select = make_shared<op::v1::Select>(constant_selection, constant_t, constant_f);
2299     auto f = make_shared<Function>(select, ParameterVector{});
2300
2301     pass::Manager pass_manager;
2302     pass_manager.register_pass<pass::ConstantFolding>();
2303     pass_manager.run_passes(f);
2304
2305     ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
2306     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2307
2308     auto new_const =
2309         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2310     ASSERT_TRUE(new_const);
2311     auto values_out = new_const->get_vector<int64_t>();
2312
2313     vector<int64_t> values_expected{11, 2, 3, 14, 15, 2, 3, 18};
2314     ASSERT_EQ(values_expected, values_out);
2315 }
2316
2317 TEST(constant_folding, constant_v1_split)
2318 {
2319     vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2320     const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2321     const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2322     const auto num_splits = 3;
2323
2324     auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2325     auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2326
2327     pass::Manager pass_manager;
2328     pass_manager.register_pass<pass::ConstantFolding>();
2329     pass_manager.run_passes(f);
2330
2331     ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2332     ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2333
2334     auto res1 =
2335         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2336     auto res2 =
2337         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2338     auto res3 =
2339         as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2340     ASSERT_TRUE(res1);
2341     ASSERT_TRUE(res2);
2342     ASSERT_TRUE(res3);
2343
2344     auto res1_values = res1->get_vector<float>();
2345     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2346     auto res2_values = res2->get_vector<float>();
2347     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2348     auto res3_values = res3->get_vector<float>();
2349     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2350 }
2351
2352 TEST(constant_folding, constant_v1_split_specialized)
2353 {
2354     vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
2355     const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
2356     const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
2357     const auto num_splits = 3;
2358
2359     auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2360     auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2361
2362     pass::Manager pass_manager;
2363     pass_manager.register_pass<pass::ConstantFolding>();
2364     pass_manager.run_passes(f);
2365
2366     ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2367     ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2368
2369     auto res1 =
2370         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2371     auto res2 =
2372         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2373     auto res3 =
2374         as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2375     ASSERT_TRUE(res1);
2376     ASSERT_TRUE(res2);
2377     ASSERT_TRUE(res3);
2378
2379     auto res1_values = res1->get_vector<float>();
2380     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
2381     auto res2_values = res2->get_vector<float>();
2382     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
2383     auto res3_values = res3->get_vector<float>();
2384     ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
2385 }
2386
2387 TEST(constant_folding, constant_v1_split_axis_1_4_splits)
2388 {
2389     vector<int64_t> data{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
2390
2391                          16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2392
2393                          32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2394
2395                          48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2396
2397     const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2398     const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2399     const auto num_splits = 4;
2400
2401     auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2402     auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2403
2404     pass::Manager pass_manager;
2405     pass_manager.register_pass<pass::ConstantFolding>();
2406     pass_manager.run_passes(f);
2407
2408     ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2409     ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2410
2411     auto res1 =
2412         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2413     auto res2 =
2414         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2415     auto res3 =
2416         as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2417     auto res4 =
2418         as_type_ptr<op::Constant>(f->get_results().at(3)->input_value(0).get_node_shared_ptr());
2419     ASSERT_TRUE(res1);
2420     ASSERT_TRUE(res2);
2421     ASSERT_TRUE(res3);
2422     ASSERT_TRUE(res4);
2423
2424     auto res1_values = res1->get_vector<int64_t>();
2425     ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2426               res1_values);
2427     auto res2_values = res2->get_vector<int64_t>();
2428     ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2429               res2_values);
2430     auto res3_values = res3->get_vector<int64_t>();
2431     ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59}),
2432               res3_values);
2433     auto res4_values = res4->get_vector<int64_t>();
2434     ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2435               res4_values);
2436 }
2437
2438 TEST(constant_folding, constant_v1_split_axis_1_2_splits)
2439 {
2440     vector<int64_t> data{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
2441
2442                          16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2443
2444                          32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2445
2446                          48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2447
2448     const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2449     const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
2450     const auto num_splits = 2;
2451
2452     auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
2453     auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
2454
2455     pass::Manager pass_manager;
2456     pass_manager.register_pass<pass::ConstantFolding>();
2457     pass_manager.run_passes(f);
2458
2459     ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
2460     ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
2461
2462     auto res1 =
2463         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2464     auto res2 =
2465         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2466     ASSERT_TRUE(res1);
2467     ASSERT_TRUE(res2);
2468
2469     auto res1_values = res1->get_vector<int64_t>();
2470     ASSERT_EQ(vector<int64_t>({0,  1,  2,  3,  4,  5,  6,  7,  16, 17, 18, 19, 20, 21, 22, 23,
2471                                32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55}),
2472               res1_values);
2473     auto res2_values = res2->get_vector<int64_t>();
2474     ASSERT_EQ(vector<int64_t>({8,  9,  10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2475                                40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2476               res2_values);
2477 }
2478
2479 TEST(constant_folding, constant_v1_variadic_split_axis_1_2_splits)
2480 {
2481     vector<int64_t> data{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
2482
2483                          16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2484
2485                          32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2486
2487                          48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2488
2489     const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2490     const auto const_axis = op::Constant::create(element::i16, Shape{}, {1});
2491     vector<int64_t> values_lengths{3, 1};
2492     auto constant_lengths =
2493         make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2494
2495     auto variadic_split_v1 =
2496         make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2497     auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2498
2499     pass::Manager pass_manager;
2500     pass_manager.register_pass<pass::ConstantFolding>();
2501     pass_manager.run_passes(f);
2502
2503     ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2504     ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2505
2506     auto res1 =
2507         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2508     auto res2 =
2509         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2510     ASSERT_TRUE(res1);
2511     ASSERT_TRUE(res2);
2512
2513     auto res1_values = res1->get_vector<int64_t>();
2514     ASSERT_EQ(vector<int64_t>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 16, 17, 18, 19,
2515                                20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 36, 37, 38, 39,
2516                                40, 41, 42, 43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}),
2517               res1_values);
2518     auto res2_values = res2->get_vector<int64_t>();
2519     ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
2520               res2_values);
2521 }
2522
2523 TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
2524 {
2525     vector<int64_t> data{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
2526
2527                          16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
2528
2529                          32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
2530
2531                          48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
2532
2533     const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
2534     const auto const_axis = op::Constant::create(element::i32, Shape{}, {1});
2535     vector<int64_t> values_lengths{1, 1, -1};
2536     auto constant_lengths =
2537         make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
2538
2539     auto variadic_split_v1 =
2540         make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
2541     auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
2542
2543     pass::Manager pass_manager;
2544     pass_manager.register_pass<pass::ConstantFolding>();
2545     pass_manager.run_passes(f);
2546
2547     ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
2548     ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
2549
2550     auto res1 =
2551         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2552     auto res2 =
2553         as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
2554     auto res3 =
2555         as_type_ptr<op::Constant>(f->get_results().at(2)->input_value(0).get_node_shared_ptr());
2556     ASSERT_TRUE(res1);
2557     ASSERT_TRUE(res2);
2558     ASSERT_TRUE(res3);
2559
2560     auto res1_values = res1->get_vector<int64_t>();
2561     ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
2562               res1_values);
2563     auto res2_values = res2->get_vector<int64_t>();
2564     ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
2565               res2_values);
2566     auto res3_values = res3->get_vector<int64_t>();
2567     ASSERT_EQ(vector<int64_t>({8,  9,  10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
2568                                40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
2569               res3_values);
2570 }
2571
2572 TEST(constant_folding, constant_v1_one_hot)
2573 {
2574     vector<int64_t> indices{0, 1, 2};
2575     float16 on_value = 1.123f;
2576     float16 off_value = 0.321f;
2577
2578     const auto indices_const = op::Constant::create(element::i64, Shape{3}, indices);
2579     const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2580     const auto on_const = op::Constant::create(element::f16, Shape{}, {on_value});
2581     const auto off_const = op::Constant::create(element::f16, Shape{}, {off_value});
2582     int64_t axis = 1;
2583
2584     auto one_hot_v1 =
2585         make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2586     auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2587
2588     pass::Manager pass_manager;
2589     pass_manager.register_pass<pass::ConstantFolding>();
2590     pass_manager.run_passes(f);
2591
2592     ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2593     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2594
2595     auto res =
2596         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2597     ASSERT_TRUE(res);
2598
2599     ASSERT_EQ((Shape{3, 3}), res->get_output_shape(0));
2600     ASSERT_EQ(vector<float16>({on_value,
2601                                off_value,
2602                                off_value,
2603                                off_value,
2604                                on_value,
2605                                off_value,
2606                                off_value,
2607                                off_value,
2608                                on_value}),
2609               res->get_vector<float16>());
2610 }
2611
2612 TEST(constant_folding, constant_v1_one_hot_negative_axes)
2613 {
2614     vector<int64_t> indices{0, 2, -1, 1};
2615     int16_t on_value = 4;
2616     int16_t off_value = 1;
2617
2618     const auto indices_const = op::Constant::create(element::i64, Shape{4}, indices);
2619     const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2620     const auto on_const = op::Constant::create(element::i16, Shape{}, {on_value});
2621     const auto off_const = op::Constant::create(element::i16, Shape{}, {off_value});
2622     int64_t axis = -1;
2623
2624     auto one_hot_v1 =
2625         make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2626     auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2627
2628     pass::Manager pass_manager;
2629     pass_manager.register_pass<pass::ConstantFolding>();
2630     pass_manager.run_passes(f);
2631
2632     ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2633     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2634
2635     auto res =
2636         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2637     ASSERT_TRUE(res);
2638
2639     ASSERT_EQ((Shape{4, 3}), res->get_output_shape(0));
2640     ASSERT_EQ(vector<int16_t>({on_value,
2641                                off_value,
2642                                off_value,
2643                                off_value,
2644                                off_value,
2645                                on_value,
2646                                off_value,
2647                                off_value,
2648                                off_value,
2649                                off_value,
2650                                on_value,
2651                                off_value}),
2652               res->get_vector<int16_t>());
2653 }
2654
2655 TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
2656 {
2657     vector<int64_t> indices{0, 2, 1, -1};
2658     auto on_value = true;
2659     auto off_value = false;
2660
2661     const auto indices_const = op::Constant::create(element::i64, Shape{2, 2}, indices);
2662     const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
2663     const auto on_const = op::Constant::create(element::boolean, Shape{}, {on_value});
2664     const auto off_const = op::Constant::create(element::boolean, Shape{}, {off_value});
2665     int64_t axis = -1;
2666
2667     auto one_hot_v1 =
2668         make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
2669     auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
2670
2671     pass::Manager pass_manager;
2672     pass_manager.register_pass<pass::ConstantFolding>();
2673     pass_manager.run_passes(f);
2674
2675     ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
2676     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2677
2678     auto res =
2679         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2680     ASSERT_TRUE(res);
2681
2682     ASSERT_EQ((Shape{2, 2, 3}), res->get_output_shape(0));
2683     ASSERT_EQ(vector<bool>({on_value,
2684                             off_value,
2685                             off_value,
2686                             off_value,
2687                             off_value,
2688                             on_value,
2689                             off_value,
2690                             on_value,
2691                             off_value,
2692                             off_value,
2693                             off_value,
2694                             off_value}),
2695               res->get_vector<bool>());
2696 }
2697
2698 TEST(constant_folding, constant_tile_1d)
2699 {
2700     Shape shape_in{2};
2701     Shape shape_repeats{1};
2702     Shape shape_out{4};
2703
2704     vector<int> values_in{0, 1};
2705     auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2706     vector<int> values_repeats{2};
2707     auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2708     auto tile = make_shared<op::v0::Tile>(data, repeats);
2709     auto f = make_shared<Function>(tile, ParameterVector{});
2710
2711     pass::Manager pass_manager;
2712     pass_manager.register_pass<pass::ConstantFolding>();
2713     pass_manager.run_passes(f);
2714
2715     ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2716     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2717
2718     auto new_const =
2719         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2720     ASSERT_TRUE(new_const);
2721     auto values_out = new_const->get_vector<int>();
2722
2723     vector<int> values_expected{0, 1, 0, 1};
2724     ASSERT_EQ(values_expected, values_out);
2725 }
2726
2727 TEST(constant_folding, constant_tile_3d_small_data_rank)
2728 {
2729     Shape shape_in{2};
2730     Shape shape_repeats{3};
2731     Shape shape_out{2, 2, 4};
2732
2733     vector<int> values_in{0, 1};
2734     auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2735     vector<int> values_repeats{2, 2, 2};
2736     auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2737     auto tile = make_shared<op::v0::Tile>(data, repeats);
2738     auto f = make_shared<Function>(tile, ParameterVector{});
2739
2740     pass::Manager pass_manager;
2741     pass_manager.register_pass<pass::ConstantFolding>();
2742     pass_manager.run_passes(f);
2743
2744     ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2745     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2746
2747     auto new_const =
2748         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2749     ASSERT_TRUE(new_const);
2750     auto values_out = new_const->get_vector<int>();
2751
2752     vector<int> values_expected{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
2753     ASSERT_EQ(values_expected, values_out);
2754 }
2755
2756 TEST(constant_folding, constant_tile_3d_few_repeats)
2757 {
2758     Shape shape_in{2, 1, 3};
2759     Shape shape_repeats{2};
2760     Shape shape_out{2, 2, 3};
2761
2762     vector<int> values_in{1, 2, 3, 4, 5, 6};
2763     auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2764     vector<int> values_repeats{2, 1};
2765     auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2766     auto tile = make_shared<op::v0::Tile>(data, repeats);
2767     auto f = make_shared<Function>(tile, ParameterVector{});
2768
2769     pass::Manager pass_manager;
2770     pass_manager.register_pass<pass::ConstantFolding>();
2771     pass_manager.run_passes(f);
2772
2773     ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2774     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2775
2776     auto new_const =
2777         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2778     ASSERT_TRUE(new_const);
2779     auto values_out = new_const->get_vector<int>();
2780
2781     vector<int> values_expected{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6};
2782     ASSERT_EQ(values_expected, values_out);
2783 }
2784
2785 TEST(constant_folding, constant_tile_1d_0_repeats)
2786 {
2787     Shape shape_in{2};
2788     Shape shape_repeats{1};
2789     Shape shape_out{};
2790
2791     vector<int> values_in{0, 1};
2792     auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2793     vector<int> values_repeats{0};
2794     auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2795     auto tile = make_shared<op::v0::Tile>(data, repeats);
2796     auto f = make_shared<Function>(tile, ParameterVector{});
2797
2798     pass::Manager pass_manager;
2799     pass_manager.register_pass<pass::ConstantFolding>();
2800     pass_manager.run_passes(f);
2801
2802     ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2803     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2804
2805     auto new_const =
2806         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2807     ASSERT_TRUE(new_const);
2808     auto values_out = new_const->get_vector<int>();
2809
2810     vector<int> values_expected{};
2811     ASSERT_EQ(values_expected, values_out);
2812 }
2813
2814 TEST(constant_folding, constant_tile_0_rank_data)
2815 {
2816     Shape shape_in{};
2817     Shape shape_repeats{1};
2818     Shape shape_out{4};
2819
2820     vector<int> values_in{1};
2821     auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
2822     vector<int> values_repeats{4};
2823     auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
2824     auto tile = make_shared<op::v0::Tile>(data, repeats);
2825     auto f = make_shared<Function>(tile, ParameterVector{});
2826
2827     pass::Manager pass_manager;
2828     pass_manager.register_pass<pass::ConstantFolding>();
2829     pass_manager.run_passes(f);
2830
2831     ASSERT_EQ(count_ops_of_type<op::v0::Tile>(f), 0);
2832     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2833
2834     auto new_const =
2835         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2836     ASSERT_TRUE(new_const);
2837     auto values_out = new_const->get_vector<int>();
2838
2839     vector<int> values_expected{1, 1, 1, 1};
2840     ASSERT_EQ(values_expected, values_out);
2841 }
2842
2843 TEST(constant_folding, constant_non_zero_0D)
2844 {
2845     auto data = op::Constant::create(element::i32, Shape{}, {1});
2846     auto non_zero = make_shared<op::v3::NonZero>(data);
2847     auto f = make_shared<Function>(non_zero, ParameterVector{});
2848
2849     pass::Manager pass_manager;
2850     pass_manager.register_pass<pass::ConstantFolding>();
2851     pass_manager.run_passes(f);
2852
2853     // Fold into constant with shape of {1, 1} for scalar input with
2854     // non-zero value
2855     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2856     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2857
2858     const auto new_const =
2859         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2860     ASSERT_TRUE(new_const);
2861     const auto values_out = new_const->get_vector<int64_t>();
2862
2863     const vector<int64_t> values_expected{0};
2864     ASSERT_EQ(values_expected, values_out);
2865     ASSERT_EQ((Shape{1, 1}), new_const->get_shape());
2866 }
2867
2868 TEST(constant_folding, constant_non_zero_1D)
2869 {
2870     vector<int> values_in{0, 1, 0, 1};
2871     auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2872     auto non_zero = make_shared<op::v3::NonZero>(data);
2873     auto f = make_shared<Function>(non_zero, ParameterVector{});
2874
2875     pass::Manager pass_manager;
2876     pass_manager.register_pass<pass::ConstantFolding>();
2877     pass_manager.run_passes(f);
2878
2879     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2880     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2881
2882     const auto new_const =
2883         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2884     ASSERT_TRUE(new_const);
2885     const auto values_out = new_const->get_vector<int64_t>();
2886
2887     const vector<int64_t> values_expected{1, 3};
2888     ASSERT_EQ(values_expected, values_out);
2889     ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2890 }
2891
2892 TEST(constant_folding, constant_non_zero_int32_output_type)
2893 {
2894     vector<int> values_in{0, 1, 0, 1};
2895     auto data = make_shared<op::Constant>(element::i32, Shape{4}, values_in);
2896     auto non_zero = make_shared<op::v3::NonZero>(data, element::i32);
2897     auto f = make_shared<Function>(non_zero, ParameterVector{});
2898
2899     pass::Manager pass_manager;
2900     pass_manager.register_pass<pass::ConstantFolding>();
2901     pass_manager.run_passes(f);
2902
2903     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2904     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2905
2906     const auto new_const =
2907         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2908     ASSERT_TRUE(new_const);
2909     ASSERT_EQ(element::i32, new_const->get_element_type());
2910     const auto values_out = new_const->get_vector<int32_t>();
2911
2912     const vector<int32_t> values_expected{1, 3};
2913     ASSERT_EQ(values_expected, values_out);
2914     ASSERT_EQ((Shape{1, 2}), new_const->get_shape());
2915 }
2916
2917 TEST(constant_folding, constant_non_zero_1D_all_indices)
2918 {
2919     const vector<float> values_in{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
2920     const auto data = make_shared<op::Constant>(element::f32, Shape{values_in.size()}, values_in);
2921     const auto non_zero = make_shared<op::v3::NonZero>(data);
2922     auto f = make_shared<Function>(non_zero, ParameterVector{});
2923
2924     pass::Manager pass_manager;
2925     pass_manager.register_pass<pass::ConstantFolding>();
2926     pass_manager.run_passes(f);
2927
2928     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2929     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2930
2931     const auto new_const =
2932         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2933     ASSERT_TRUE(new_const);
2934     const auto values_out = new_const->get_vector<int64_t>();
2935
2936     const vector<int64_t> values_expected{0, 1, 2, 3, 4, 5, 6, 7};
2937     ASSERT_EQ(values_expected, values_out);
2938     ASSERT_EQ((Shape{1, values_in.size()}), new_const->get_shape());
2939 }
2940
2941 TEST(constant_folding, constant_non_zero_2D)
2942 {
2943     vector<int> values_in{1, 0, 0, 0, 1, 0, 1, 1, 0};
2944     auto data = make_shared<op::Constant>(element::i32, Shape{3, 3}, values_in);
2945     auto non_zero = make_shared<op::v3::NonZero>(data);
2946     auto f = make_shared<Function>(non_zero, ParameterVector{});
2947
2948     pass::Manager pass_manager;
2949     pass_manager.register_pass<pass::ConstantFolding>();
2950     pass_manager.run_passes(f);
2951
2952     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2953     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2954
2955     const auto new_const =
2956         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2957     ASSERT_TRUE(new_const);
2958     const auto values_out = new_const->get_vector<int64_t>();
2959
2960     const vector<int64_t> values_expected{0, 1, 2, 2, 0, 1, 0, 1};
2961     ASSERT_EQ(values_expected, values_out);
2962     ASSERT_EQ((Shape{2, 4}), new_const->get_shape());
2963 }
2964
2965 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_indices)
2966 {
2967     const vector<int8_t> values_in{1, 1, 1, 1, 1, 1, 1, 1, 1};
2968     const auto data = make_shared<op::Constant>(element::i8, Shape{3, 3}, values_in);
2969     const auto non_zero = make_shared<op::v3::NonZero>(data);
2970     auto f = make_shared<Function>(non_zero, ParameterVector{});
2971
2972     pass::Manager pass_manager;
2973     pass_manager.register_pass<pass::ConstantFolding>();
2974     pass_manager.run_passes(f);
2975
2976     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
2977     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
2978
2979     const auto new_const =
2980         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
2981     ASSERT_TRUE(new_const);
2982     const auto values_out = new_const->get_vector<int64_t>();
2983
2984     const vector<int64_t> values_expected{0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2};
2985     ASSERT_EQ(values_expected, values_out);
2986     ASSERT_EQ((Shape{2, values_in.size()}), new_const->get_shape());
2987 }
2988
2989 TEST(constant_folding, DISABLED_constant_non_zero_2D_all_zeros)
2990 {
2991     const vector<uint8_t> values_in{0, 0, 0, 0, 0, 0};
2992     const auto data = make_shared<op::Constant>(element::u8, Shape{2, 3}, values_in);
2993     const auto non_zero = make_shared<op::v3::NonZero>(data);
2994     auto f = make_shared<Function>(non_zero, ParameterVector{});
2995
2996     pass::Manager pass_manager;
2997     pass_manager.register_pass<pass::ConstantFolding>();
2998     pass_manager.run_passes(f);
2999
3000     // fold into Constant with shape of {0}
3001     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3002     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3003
3004     const auto new_const =
3005         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3006     ASSERT_TRUE(new_const);
3007     ASSERT_EQ(shape_size(new_const->get_shape()), 0);
3008 }
3009
3010 TEST(constant_folding, constant_non_zero_3D)
3011 {
3012     vector<int> values_in{1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0};
3013     auto data = make_shared<op::Constant>(element::i32, Shape{2, 3, 3}, values_in);
3014     auto non_zero = make_shared<op::v3::NonZero>(data);
3015     auto f = make_shared<Function>(non_zero, ParameterVector{});
3016
3017     pass::Manager pass_manager;
3018     pass_manager.register_pass<pass::ConstantFolding>();
3019     pass_manager.run_passes(f);
3020
3021     ASSERT_EQ(count_ops_of_type<op::v3::NonZero>(f), 0);
3022     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3023
3024     const auto new_const =
3025         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3026     ASSERT_TRUE(new_const);
3027     const auto values_out = new_const->get_vector<int64_t>();
3028
3029     const vector<int64_t> values_expected{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 2, 2, 2,
3030                                           0, 0, 0, 1, 1, 2, 0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 2, 1};
3031     ASSERT_EQ(values_expected, values_out);
3032     ASSERT_EQ((Shape{3, 12}), new_const->get_shape());
3033 }
3034
3035 TEST(constant_folding, constant_scatter_elements_update_basic)
3036 {
3037     const Shape data_shape{3, 3};
3038     const Shape indices_shape{2, 3};
3039
3040     const auto data_const = op::Constant::create(
3041         element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3042     const auto indices_const =
3043         op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3044     const auto updates_const =
3045         op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3046     const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3047
3048     auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3049         data_const, indices_const, updates_const, axis_const);
3050     auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3051
3052     pass::Manager pass_manager;
3053     pass_manager.register_pass<pass::ConstantFolding>();
3054     pass_manager.run_passes(f);
3055
3056     ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3057     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3058
3059     auto result_node =
3060         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3061     ASSERT_TRUE(result_node);
3062     ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3063     std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3064     range_test_check(result_node->cast_vector<float>(), expected);
3065 }
3066
3067 TEST(constant_folding, constant_scatter_elements_update_negative_axis)
3068 {
3069     const Shape data_shape{3, 3};
3070     const Shape indices_shape{2, 3};
3071
3072     const auto data_const = op::Constant::create(
3073         element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3074     const auto indices_const =
3075         op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3076     const auto updates_const =
3077         op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3078     const auto axis_const = op::Constant::create(element::i64, Shape{}, {-1});
3079
3080     auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3081         data_const, indices_const, updates_const, axis_const);
3082     auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3083
3084     pass::Manager pass_manager;
3085     pass_manager.register_pass<pass::ConstantFolding>();
3086     pass_manager.run_passes(f);
3087
3088     ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3089     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3090
3091     auto result_node =
3092         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3093     ASSERT_TRUE(result_node);
3094     ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3095     std::vector<float> expected{1.1f, 1.0f, 1.2f, 2.0f, 2.2f, 2.1f, 0.0f, 0.0f, 0.0f};
3096     range_test_check(result_node->cast_vector<float>(), expected);
3097 }
3098
3099 TEST(constant_folding, constant_scatter_elements_update_1d_axis)
3100 {
3101     const Shape data_shape{3, 3};
3102     const Shape indices_shape{2, 3};
3103
3104     const auto data_const = op::Constant::create(
3105         element::f32, data_shape, std::vector<float>(shape_size(data_shape), 0.f));
3106     const auto indices_const =
3107         op::Constant::create(element::i32, indices_shape, {1, 0, 2, 0, 2, 1});
3108     const auto updates_const =
3109         op::Constant::create(element::f32, indices_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
3110     const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
3111
3112     auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3113         data_const, indices_const, updates_const, axis_const);
3114     auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3115
3116     pass::Manager pass_manager;
3117     pass_manager.register_pass<pass::ConstantFolding>();
3118     pass_manager.run_passes(f);
3119
3120     ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3121     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3122
3123     auto result_node =
3124         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3125     ASSERT_TRUE(result_node);
3126     ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3127     std::vector<float> expected{2.f, 1.1f, 0.0f, 1.f, 0.0f, 2.2f, 0.f, 2.1f, 1.2f};
3128     range_test_check(result_node->cast_vector<float>(), expected);
3129 }
3130
3131 TEST(constant_folding, constant_scatter_elements_update_3d_i16)
3132 {
3133     const Shape data_shape{3, 3, 3};
3134     const Shape indices_shape{2, 2, 3};
3135
3136     const auto data_const = op::Constant::create(
3137         element::i16, data_shape, std::vector<int16_t>(shape_size(data_shape), 0));
3138     const auto indices_const =
3139         op::Constant::create(element::i16, indices_shape, {1, 0, 2, 0, 2, 1, 2, 2, 2, 0, 1, 0});
3140     const auto updates_const =
3141         op::Constant::create(element::i16, indices_shape, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3142     const auto axis_const = op::Constant::create(element::i64, Shape{}, {1});
3143
3144     auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3145         data_const, indices_const, updates_const, axis_const);
3146     auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3147
3148     pass::Manager pass_manager;
3149     pass_manager.register_pass<pass::ConstantFolding>();
3150     pass_manager.run_passes(f);
3151
3152     ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3153     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3154
3155     auto result_node =
3156         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3157     ASSERT_TRUE(result_node);
3158     ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3159     std::vector<int16_t> expected{4, 2, 0, 1, 0, 6, 0, 5, 3, 10, 0, 12, 0, 11,
3160                                   0, 7, 8, 9, 0, 0, 0, 0, 0, 0,  0, 0,  0};
3161     range_test_check(result_node->cast_vector<int16_t>(), expected);
3162 }
3163
3164 TEST(constant_folding, constant_scatter_elements_update_one_elem)
3165 {
3166     const Shape data_shape{3, 3, 3};
3167     const Shape indices_shape{1, 1, 1};
3168     const auto input_data = std::vector<int32_t>(shape_size(data_shape), 0);
3169
3170     const auto data_const = op::Constant::create(element::i32, data_shape, input_data);
3171     const auto indices_const = op::Constant::create(element::i32, indices_shape, {1});
3172     const auto updates_const = op::Constant::create(element::i32, indices_shape, {2});
3173     const auto axis_const = op::Constant::create(element::i64, Shape{}, {0});
3174
3175     auto scatter_elem_updt = make_shared<op::v3::ScatterElementsUpdate>(
3176         data_const, indices_const, updates_const, axis_const);
3177     auto f = make_shared<Function>(scatter_elem_updt, ParameterVector{});
3178
3179     pass::Manager pass_manager;
3180     pass_manager.register_pass<pass::ConstantFolding>();
3181     pass_manager.run_passes(f);
3182
3183     ASSERT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(f), 0);
3184     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3185
3186     auto result_node =
3187         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3188     ASSERT_TRUE(result_node);
3189     ASSERT_EQ(data_shape, result_node->get_output_shape(0));
3190     std::vector<int32_t> expected{input_data};
3191     // we have updated coordinate (1, 0, 0)
3192     expected.at(9) = 2;
3193     range_test_check(result_node->cast_vector<int32_t>(), expected);
3194 }
3195
3196 void test_constant_folding_reshape_v1(Shape& shape_in,
3197                                       vector<float>& values_in,
3198                                       Shape shape_shape,
3199                                       vector<int32_t> values_shape,
3200                                       bool zero_flag = false)
3201 {
3202     auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
3203     auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
3204     auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, zero_flag);
3205     auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
3206
3207     pass::Manager pass_manager;
3208     pass_manager.register_pass<pass::ConstantFolding>();
3209     pass_manager.run_passes(f);
3210
3211     ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
3212     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3213
3214     auto new_const =
3215         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
3216     ASSERT_TRUE(new_const);
3217     auto values_out = new_const->get_vector<float>();
3218
3219     ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
3220 }
3221 TEST(constant_folding, constant_dyn_reshape_v1_2d)
3222 {
3223     Shape shape_in{2, 5};
3224     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3225
3226     test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 1, 10});
3227     test_constant_folding_reshape_v1(shape_in, values_in, {4}, {1, 1, 2, 5});
3228     test_constant_folding_reshape_v1(shape_in, values_in, {3}, {1, 2, 5});
3229     test_constant_folding_reshape_v1(shape_in, values_in, {3}, {5, 2, 1});
3230 }
3231
3232 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_negative_indices)
3233 {
3234     Shape shape_in{2, 2, 2, 2};
3235     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3236
3237     test_constant_folding_reshape_v1(shape_in, values_in, {3}, {4, -1, 2});
3238     test_constant_folding_reshape_v1(shape_in, values_in, {2}, {4, -1});
3239     test_constant_folding_reshape_v1(shape_in, values_in, {1}, {-1});
3240 }
3241
3242 TEST(constant_folding, constant_dyn_reshape_v1_pattern_with_zero_dims)
3243 {
3244     Shape shape_in{2, 2, 2, 2};
3245     vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3246
3247     test_constant_folding_reshape_v1(shape_in, values_in, {4}, {2, -1, 2, 0}, true);
3248     test_constant_folding_reshape_v1(shape_in, values_in, {4}, {4, 1, 0, 2}, true);
3249 }
3250
3251 TEST(constant_folding, disable_constant_folding)
3252 {
3253     auto input = make_shared<op::Parameter>(element::f32, Shape{1, 3});
3254     auto constant_shape = op::Constant::create(element::i64, Shape{1}, {3});
3255     auto dyn_reshape = make_shared<op::v1::Reshape>(input, constant_shape, true);
3256     auto& rt_info = dyn_reshape->get_rt_info();
3257     rt_info["DISABLED_CONSTANT_FOLDING"];
3258     auto f = make_shared<Function>(dyn_reshape, ParameterVector{input});
3259
3260     pass::Manager pass_manager;
3261     pass_manager.register_pass<pass::ConstantFolding>();
3262     pass_manager.run_passes(f);
3263
3264     ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
3265     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
3266 }