Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / type_prop / binary_elementwise.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace std;
24 using namespace ngraph;
25
26 //
27 // Tests for binary elementwise ops.
28 //
29 void test_binary(std::string /* node_type */,
30                  shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
31 {
32     // Check for bad arguments
33     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
34     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
35     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
36     auto tv0_4_2_param = make_shared<op::Parameter>(element::f32, Shape{4, 2});
37
38     auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
39                                                      const shared_ptr<Node>& y) {
40         try
41         {
42             auto node = f(x, y);
43             // Should have thrown, so fail if it didn't
44             FAIL() << "Incompatible view arguments not detected.";
45         }
46         catch (const NodeValidationFailure& error)
47         {
48             EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
49         }
50         catch (...)
51         {
52             FAIL() << "Deduced type check failed for unexpected reason";
53         }
54     };
55     test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
56
57     auto test_binary_bad_arguments_view_element_types = [&](const shared_ptr<Node>& x,
58                                                             const shared_ptr<Node>& y) {
59         try
60         {
61             auto node = f(x, y);
62             // Should have thrown, so fail if it didn't
63             FAIL() << "Incompatible view arguments not detected.";
64         }
65         catch (const NodeValidationFailure& error)
66         {
67             EXPECT_HAS_SUBSTRING(error.what(),
68                                  std::string("Argument element types are inconsistent"));
69         }
70         catch (...)
71         {
72             FAIL() << "Deduced type check failed for unexpected reason";
73         }
74     };
75
76     test_binary_bad_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
77
78     auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
79         auto node = f(x, y);
80         EXPECT_TRUE(node->has_same_type(node->input_values()[0].get_node_shared_ptr()));
81     };
82     test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
83 }
84
85 TEST(type_prop, add_bad_arguments)
86 {
87     test_binary("Add",
88                 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
89                     return make_shared<op::Add>(x, y);
90                 });
91 }
92
93 TEST(type_prop, divide_bad_arguments)
94 {
95     test_binary("Divide",
96                 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
97                     return make_shared<op::Divide>(x, y);
98                 });
99 }
100
101 TEST(type_prop, multiply_bad_arguments)
102 {
103     test_binary("Multiply",
104                 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
105                     return make_shared<op::Multiply>(x, y);
106                 });
107 }
108
109 TEST(type_prop, subtract_bad_arguments)
110 {
111     test_binary("Subtract",
112                 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
113                     return make_shared<op::Subtract>(x, y);
114                 });
115 }
116
117 //
118 // Tests for binary elementwise logical ops.
119 //
120 void test_binary_logical(std::string /* node_type */,
121                          shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
122 {
123     // Check for bad arguments
124     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
125     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
126     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
127     auto tv0_2_4_param_3 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
128     auto tv0_4_2_param = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
129
130     auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
131                                                      const shared_ptr<Node>& y) {
132         try
133         {
134             auto node = f(x, y);
135             // Should have thrown, so fail if it didn't
136             FAIL() << "Incompatible view arguments not detected.";
137         }
138         catch (const NodeValidationFailure& error)
139         {
140             EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
141         }
142         catch (...)
143         {
144             FAIL() << "Deduced type check failed for unexpected reason";
145         }
146     };
147     test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
148
149     auto test_binary_differ_arguments_view_element_types = [&](const shared_ptr<Node>& x,
150                                                                const shared_ptr<Node>& y) {
151         try
152         {
153             auto node = f(x, y);
154             // Should have thrown, so fail if it didn't
155             FAIL() << "Incompatible view arguments not detected.";
156         }
157         catch (const NodeValidationFailure& error)
158         {
159             EXPECT_HAS_SUBSTRING(error.what(),
160                                  std::string("Argument element types are inconsistent"));
161         }
162         catch (...)
163         {
164             FAIL() << "Deduced type check failed for unexpected reason";
165         }
166     };
167
168     auto test_binary_non_bool_arguments_view_element_types = [&](const shared_ptr<Node>& x,
169                                                                  const shared_ptr<Node>& y) {
170         try
171         {
172             auto node = f(x, y);
173             // Should have thrown, so fail if it didn't
174             FAIL() << "Incompatible view arguments not detected.";
175         }
176         catch (const ngraph_error& error)
177         {
178             EXPECT_HAS_SUBSTRING(error.what(), "must have boolean element type");
179         }
180         catch (...)
181         {
182             FAIL() << "Deduced type check failed for unexpected reason";
183         }
184
185     };
186
187     test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
188     test_binary_differ_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
189     test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
190
191     auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
192         auto node = f(x, y);
193         EXPECT_TRUE(node->has_same_type(node->input_values()[0].get_node_shared_ptr()));
194     };
195     test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
196 }
197
198 TEST(type_prop, or_bad_arguments)
199 {
200     test_binary_logical(
201         "Or", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
202             return make_shared<op::Or>(x, y);
203         });
204 }
205
206 TEST(type_prop, xor_bad_arguments)
207 {
208     test_binary_logical(
209         "Xor", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
210             return make_shared<op::Xor>(x, y);
211         });
212 }
213
214 template <typename T>
215 void test_binary_eltwise_numpy(const element::Type& et, const op::AutoBroadcastSpec& autob)
216 {
217     auto param1 = make_shared<op::Parameter>(et, Shape{1, 3, 6});
218     auto param2 = make_shared<op::Parameter>(et, Shape{3, 1});
219     auto param3 = make_shared<op::Parameter>(et, Shape{2, 3, 6});
220     auto param4 = make_shared<op::Parameter>(et, Shape{6});
221     EXPECT_EQ(make_shared<T>(param1, param2, autob)->get_shape(), (Shape{1, 3, 6}));
222     EXPECT_EQ(make_shared<T>(param1, param3, autob)->get_shape(), (Shape{2, 3, 6}));
223     EXPECT_EQ(make_shared<T>(param4, param3, autob)->get_shape(), (Shape{2, 3, 6}));
224
225     auto pp1 = make_shared<op::Parameter>(et, PartialShape{1, Dimension::dynamic(), 6});
226     auto pp2 = make_shared<op::Parameter>(et, PartialShape{3, 1});
227     EXPECT_EQ(make_shared<T>(pp1, pp2, autob)->get_shape(), (Shape{1, 3, 6}));
228 }
229
230 TEST(type_prop, eltwise_auto_bcast)
231 {
232     test_binary_eltwise_numpy<op::v1::Add>(element::f32, op::AutoBroadcastType::NUMPY);
233     test_binary_eltwise_numpy<op::Divide>(element::f32, op::AutoBroadcastType::NUMPY);
234     test_binary_eltwise_numpy<op::Equal>(element::f32, op::AutoBroadcastType::NUMPY);
235     test_binary_eltwise_numpy<op::Greater>(element::f32, op::AutoBroadcastType::NUMPY);
236     test_binary_eltwise_numpy<op::GreaterEq>(element::f32, op::AutoBroadcastType::NUMPY);
237     test_binary_eltwise_numpy<op::Less>(element::f32, op::AutoBroadcastType::NUMPY);
238     test_binary_eltwise_numpy<op::LessEq>(element::f32, op::AutoBroadcastType::NUMPY);
239     test_binary_eltwise_numpy<op::Maximum>(element::f32, op::AutoBroadcastType::NUMPY);
240     test_binary_eltwise_numpy<op::Minimum>(element::f32, op::AutoBroadcastType::NUMPY);
241     test_binary_eltwise_numpy<op::Multiply>(element::f32, op::AutoBroadcastType::NUMPY);
242     test_binary_eltwise_numpy<op::NotEqual>(element::f32, op::AutoBroadcastType::NUMPY);
243     test_binary_eltwise_numpy<op::Or>(element::boolean, op::AutoBroadcastType::NUMPY);
244     test_binary_eltwise_numpy<op::Power>(element::f32, op::AutoBroadcastType::NUMPY);
245     test_binary_eltwise_numpy<op::Subtract>(element::f32, op::AutoBroadcastType::NUMPY);
246     test_binary_eltwise_numpy<op::Xor>(element::boolean, op::AutoBroadcastType::NUMPY);
247 }
248
249 TEST(type_prop, comparison_good)
250 {
251     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
252     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
253     auto eq = make_shared<op::Equal>(tv0_2_4_param_0, tv0_2_4_param_1);
254     EXPECT_EQ(eq->get_element_type(), element::boolean);
255     EXPECT_EQ(eq->get_shape(), (Shape{2, 4}));
256 }
257
258 TEST(type_prop, binary_arithmetic_bad_argument_element_types)
259 {
260     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
261     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
262     try
263     {
264         auto bc = make_shared<op::Add>(tv0_2_4_param_0, tv0_2_4_param_1);
265         // Should have thrown, so fail if it didn't
266         FAIL() << "Did not detect incorrect element types for arithmetic operator";
267     }
268     catch (const NodeValidationFailure& error)
269     {
270         EXPECT_HAS_SUBSTRING(error.what(),
271                              std::string("Arguments cannot have boolean element type"));
272     }
273     catch (...)
274     {
275         FAIL() << "Deduced type check failed for unexpected reason";
276     }
277 }
278
279 TEST(type_prop, binary_elementwise_arithmetic_both_dynamic)
280 {
281     auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
282     auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
283     auto add = make_shared<op::Add>(a, b);
284
285     ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_dynamic());
286 }
287
288 TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_static)
289 {
290     auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
291     auto b = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
292     auto add = make_shared<op::Add>(a, b);
293
294     ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
295     ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
296 }
297
298 TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_dynamic)
299 {
300     auto a = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
301     auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
302     auto add = make_shared<op::Add>(a, b);
303
304     ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
305     ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
306 }
307
308 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_dynamic)
309 {
310     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
311     auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
312     auto add = make_shared<op::Add>(a, b);
313
314     ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
315     ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
316     ASSERT_TRUE(
317         add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
318 }
319
320 TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_rank_static_dynamic)
321 {
322     auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
323     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
324     auto add = make_shared<op::Add>(a, b);
325
326     ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
327     ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
328     ASSERT_TRUE(
329         add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
330 }
331
332 TEST(type_prop,
333      binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_static)
334 {
335     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
336     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
337     auto add = make_shared<op::Add>(a, b);
338
339     ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
340     ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
341 }
342
343 TEST(
344     type_prop,
345     binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_rank_static_dynamic)
346 {
347     auto a = make_shared<op::Parameter>(
348         element::f32, PartialShape{1, Dimension::dynamic(), Dimension::dynamic()});
349     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
350     auto add = make_shared<op::Add>(a, b);
351
352     ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
353     ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
354     ASSERT_TRUE(
355         add->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
356 }
357
358 TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_static_dynamic)
359 {
360     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
361     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
362     auto add = make_shared<op::Add>(a, b);
363
364     ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
365     ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
366 }
367
368 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_static)
369 {
370     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
371     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
372     auto add = make_shared<op::Add>(a, b);
373
374     ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
375     ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
376 }
377
378 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsistent)
379 {
380     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
381     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
382
383     try
384     {
385         auto add = make_shared<op::Add>(a, b);
386         FAIL() << "Inconsistent partial shapes not detected";
387     }
388     catch (const NodeValidationFailure& error)
389     {
390         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
391     }
392     catch (...)
393     {
394         FAIL() << "Deduced type check failed for unexpected reason";
395     }
396 }
397
398 TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsistent)
399 {
400     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
401     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
402
403     try
404     {
405         auto add = make_shared<op::Add>(a, b);
406         FAIL() << "Inconsistent partial shapes not detected";
407     }
408     catch (const NodeValidationFailure& error)
409     {
410         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
411     }
412     catch (...)
413     {
414         FAIL() << "Deduced type check failed for unexpected reason";
415     }
416 }
417
418 TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsistent)
419 {
420     auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3, 3});
421     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
422
423     try
424     {
425         auto add = make_shared<op::Add>(a, b);
426         FAIL() << "Inconsistent partial shapes not detected";
427     }
428     catch (const NodeValidationFailure& error)
429     {
430         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
431     }
432     catch (...)
433     {
434         FAIL() << "Deduced type check failed for unexpected reason";
435     }
436 }
437
438 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different_rank)
439 {
440     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
441     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
442
443     try
444     {
445         auto add = make_shared<op::Add>(a, b);
446         FAIL() << "Inconsistent partial shapes not detected";
447     }
448     catch (const NodeValidationFailure& error)
449     {
450         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
451     }
452     catch (...)
453     {
454         FAIL() << "Deduced type check failed for unexpected reason";
455     }
456 }
457
458 TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_different_rank)
459 {
460     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
461     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
462
463     try
464     {
465         auto add = make_shared<op::Add>(a, b);
466         FAIL() << "Inconsistent partial shapes not detected";
467     }
468     catch (const NodeValidationFailure& error)
469     {
470         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
471     }
472     catch (...)
473     {
474         FAIL() << "Deduced type check failed for unexpected reason";
475     }
476 }
477
478 TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different_rank)
479 {
480     auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3, 4});
481     auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
482
483     try
484     {
485         auto add = make_shared<op::Add>(a, b);
486         FAIL() << "Inconsistent partial shapes not detected";
487     }
488     catch (const NodeValidationFailure& error)
489     {
490         EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
491     }
492     catch (...)
493     {
494         FAIL() << "Deduced type check failed for unexpected reason";
495     }
496 }
497
498 TEST(type_prop, binary_elementwise_arithmetic_both_et_dynamic)
499 {
500     auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
501     auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
502     auto add = make_shared<op::Add>(a, b);
503
504     ASSERT_TRUE(add->get_output_element_type(0).is_dynamic());
505 }
506
507 TEST(type_prop, binary_elementwise_arithmetic_left_et_dynamic)
508 {
509     auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
510     auto b = make_shared<op::Parameter>(element::u32, Shape{1, 2, 3, 4});
511     auto add = make_shared<op::Add>(a, b);
512
513     ASSERT_EQ(add->get_output_element_type(0), element::u32);
514 }
515
516 TEST(type_prop, binary_elementwise_arithmetic_right_et_dynamic)
517 {
518     auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
519     auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
520     auto add = make_shared<op::Add>(a, b);
521
522     ASSERT_EQ(add->get_output_element_type(0), element::i64);
523 }
524
525 TEST(type_prop, logic_arith_compare_partial_et)
526 {
527     auto test_arith = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
528         auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
529         auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
530         return std::make_shared<op::Add>(param0, param1);
531     };
532
533     auto test_compare = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
534         auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
535         auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
536         return std::make_shared<op::Greater>(param0, param1);
537     };
538
539     auto test_not = [](element::Type et) -> std::shared_ptr<Node> {
540         auto param = std::make_shared<op::Parameter>(et, Shape{1, 2, 3});
541         return std::make_shared<op::Not>(param);
542     };
543
544     // Arith ops:
545     //
546     // int int -> int
547     // int boo -> !
548     // int dyn -> int
549     // boo int -> !
550     // boo boo -> !
551     // boo dyn -> !
552     // dyn int -> int
553     // dyn boo -> !
554     // dyn dyn -> dyn
555     ASSERT_EQ(test_arith(element::i32, element::i32)->get_element_type(), element::i32);
556     ASSERT_ANY_THROW({ test_arith(element::i32, element::boolean); });
557     ASSERT_EQ(test_arith(element::i32, element::dynamic)->get_element_type(), element::i32);
558     ASSERT_ANY_THROW({ test_arith(element::boolean, element::i32); });
559     ASSERT_ANY_THROW({ test_arith(element::boolean, element::boolean); });
560     ASSERT_ANY_THROW({ test_arith(element::boolean, element::dynamic); });
561     ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32);
562     ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); });
563     ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic);
564
565     // Comparison ops:
566     //
567     // int int -> boo
568     // int boo -> !
569     // int dyn -> boo
570     // boo int -> !
571     // boo boo -> boo
572     // boo dyn -> boo
573     // dyn int -> boo
574     // dyn boo -> boo
575     // dyn dyn -> boo
576     ASSERT_EQ(test_compare(element::i32, element::i32)->get_element_type(), element::boolean);
577     ASSERT_ANY_THROW({ test_compare(element::i32, element::boolean); });
578     ASSERT_EQ(test_compare(element::i32, element::dynamic)->get_element_type(), element::boolean);
579     ASSERT_ANY_THROW({ test_compare(element::boolean, element::i32); });
580     ASSERT_EQ(test_compare(element::boolean, element::boolean)->get_element_type(),
581               element::boolean);
582     ASSERT_EQ(test_compare(element::boolean, element::dynamic)->get_element_type(),
583               element::boolean);
584     ASSERT_EQ(test_compare(element::dynamic, element::i32)->get_element_type(), element::boolean);
585     ASSERT_EQ(test_compare(element::dynamic, element::boolean)->get_element_type(),
586               element::boolean);
587     ASSERT_EQ(test_compare(element::dynamic, element::dynamic)->get_element_type(),
588               element::boolean);
589
590     // Logical negation op:
591     //
592     // Current behavior:
593     // int -> int
594     // boo -> boo
595     // dyn -> dyn
596     //
597     // TODO(amprocte): I believe the behavior should actually be:
598     // int -> !
599     // boo -> boo
600     // dyn -> boo
601     ASSERT_EQ(test_not(element::i32)->get_element_type(), element::i32);
602     ASSERT_EQ(test_not(element::boolean)->get_element_type(), element::boolean);
603     ASSERT_EQ(test_not(element::dynamic)->get_element_type(), element::dynamic);
604 }