Move downgrade passes to pass folder (#1675)
[platform/upstream/dldt.git] / ngraph / test / dyn_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/dyn_elimination.hpp"
22 #include "ngraph/pass/manager.hpp"
23 #include "pass/opset0_downgrade.hpp"
24 #include "util/all_close_f.hpp"
25 #include "util/test_tools.hpp"
26
27 using namespace ngraph;
28 using namespace std;
29
30 TEST(dyn_elimination, transpose)
31 {
32     Shape shape_in{2, 4, 6, 8};
33     auto param = make_shared<op::Parameter>(element::boolean, shape_in);
34
35     auto constant_perm =
36         make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
37
38     auto transpose = make_shared<op::Transpose>(param, constant_perm);
39
40     auto f = make_shared<Function>(transpose, ParameterVector{param});
41
42     pass::Manager pass_manager;
43     pass_manager.register_pass<pass::DynElimination>();
44     pass_manager.run_passes(f);
45
46     ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
47     ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
48
49     auto new_reshape =
50         as_type_ptr<op::Reshape>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
51     ASSERT_TRUE(new_reshape);
52
53     ASSERT_EQ(new_reshape->get_input_order(), (AxisVector{2, 3, 1, 0}));
54     ASSERT_EQ(new_reshape->get_output_shape(0), (Shape{6, 8, 4, 2}));
55     ASSERT_EQ(new_reshape->get_output_element_type(0), element::boolean);
56 }
57
58 // For now, we can't handle the case where the input has dynamic shapes,
59 // because the classic Reshape op demands a Shape. Probably won't be able to
60 // deal with this until/unless we make a "StaticTranspose". Just make sure
61 // we don't crash or mangle the graph.
62 TEST(dyn_elimination, transpose_dyn_shape)
63 {
64     PartialShape shape_in{2, 4, Dimension::dynamic(), 8};
65
66     auto param = make_shared<op::Parameter>(element::boolean, shape_in);
67
68     auto constant_perm =
69         make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
70
71     auto transpose = make_shared<op::Transpose>(param, constant_perm);
72
73     auto f = make_shared<Function>(transpose, ParameterVector{param});
74
75     pass::Manager pass_manager;
76     pass_manager.register_pass<pass::DynElimination>();
77     pass_manager.run_passes(f);
78
79     ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 1);
80     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
81
82     auto new_transpose =
83         as_type_ptr<op::Transpose>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
84     ASSERT_TRUE(new_transpose);
85
86     ASSERT_EQ(new_transpose->get_output_element_type(0), element::boolean);
87     ASSERT_TRUE(new_transpose->get_output_partial_shape(0).relaxes(
88         PartialShape{Dimension::dynamic(), 8, 4, 2}));
89 }
90
91 TEST(dyn_elimination, range)
92 {
93     auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
94     auto constant_stop = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{5});
95     auto constant_step = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{2});
96
97     auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
98
99     ASSERT_EQ(range->get_element_type(), element::i64);
100     ASSERT_EQ(range->get_shape(), (Shape{3}));
101
102     auto f = make_shared<Function>(range, ParameterVector{});
103
104     pass::Manager pass_manager;
105     pass_manager.register_pass<pass::DynElimination>();
106     pass_manager.run_passes(f);
107
108     ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
109     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
110
111     auto replacement =
112         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
113
114     ASSERT_NE(replacement, nullptr);
115     ASSERT_EQ(replacement->get_element_type(), element::i64);
116     ASSERT_EQ(replacement->get_shape(), (Shape{3}));
117
118     auto vals = replacement->get_vector<int64_t>();
119
120     ASSERT_EQ(vals, (vector<int64_t>{0, 2, 4}));
121 }
122
123 TEST(dyn_elimination, range_f64)
124 {
125     auto constant_start = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{-0.5});
126     auto constant_stop = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{2});
127     auto constant_step = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{0.25});
128
129     auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
130
131     ASSERT_EQ(range->get_element_type(), element::f64);
132     ASSERT_EQ(range->get_shape(), (Shape{10}));
133
134     auto f = make_shared<Function>(range, ParameterVector{});
135
136     pass::Manager pass_manager;
137     pass_manager.register_pass<pass::DynElimination>();
138     pass_manager.run_passes(f);
139
140     ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
141     ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
142
143     auto replacement =
144         as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
145
146     ASSERT_NE(replacement, nullptr);
147     ASSERT_EQ(replacement->get_element_type(), element::f64);
148     ASSERT_EQ(replacement->get_shape(), (Shape{10}));
149
150     auto vals = replacement->get_vector<double>();
151
152     ASSERT_TRUE(test::all_close_f(
153         vals, vector<double>{-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75}));
154 }