Enable NGRAPH_DEPRECATED (#1617)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / dyn_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <numeric>
18
19 #include "dyn_elimination.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/range.hpp"
22 #include "ngraph/op/replace_slice.hpp"
23 #include "ngraph/op/reshape.hpp"
24 #include "ngraph/op/reverse.hpp"
25 #include "ngraph/op/slice.hpp"
26 #include "ngraph/op/transpose.hpp"
27 #include "ngraph/pattern/matcher.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/runtime/reference/range.hpp"
30 #include "ngraph/slice_plan.hpp"
31
32 using namespace std;
33 using namespace ngraph;
34
35 pass::DynElimination::DynElimination()
36     : GraphRewrite()
37 {
38     construct_transpose();
39     construct_range();
40 }
41
42 void pass::DynElimination::construct_transpose()
43 {
44     auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
45     auto perm_arg_label =
46         make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
47
48     auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
49
50     auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
51         auto pattern_map = m.get_pattern_map();
52
53         auto data_arg = pattern_map[data_arg_label];
54         auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
55
56         // TODO(amprocte): Can't handle the case where data shape is dynamic, because static
57         // Reshape requries the exact output shape to be declared. See if we can come up with a
58         // workaround.
59         if (data_arg->get_output_partial_shape(0).is_dynamic())
60         {
61             return false;
62         }
63
64         auto& data_shape = data_arg->get_output_shape(0);
65
66         NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
67         NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
68
69         if (perm_arg->get_output_element_type(0).is_dynamic() ||
70             perm_arg->get_output_partial_shape(0).is_dynamic())
71         {
72             return false;
73         }
74
75         auto perm = perm_arg->get_axis_vector_val();
76
77         auto output_shape = ngraph::apply_permutation(data_shape, perm);
78
79         auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
80
81         replace_node(m.get_match_root(), replacement);
82         return true;
83     };
84
85     auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
86     NGRAPH_SUPPRESS_DEPRECATED_START
87     add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
88     NGRAPH_SUPPRESS_DEPRECATED_END
89 }
90
91 template <typename T>
92 std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
93                                                      const Shape& shape,
94                                                      const std::shared_ptr<op::Constant>& start_arg,
95                                                      const std::shared_ptr<op::Constant>& step_arg)
96 {
97     std::vector<T> elements(shape_size(shape));
98     std::vector<T> start_vec = start_arg->get_vector<T>();
99     std::vector<T> step_vec = step_arg->get_vector<T>();
100
101     NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
102
103     runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
104
105     return make_shared<op::Constant>(et, shape, elements);
106 }
107
108 void pass::DynElimination::construct_range()
109 {
110     auto start_arg_label =
111         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
112     auto stop_arg_label =
113         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
114     auto step_arg_label =
115         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
116
117     auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
118
119     auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
120         auto pattern_map = m.get_pattern_map();
121
122         auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
123         auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
124         auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
125
126         NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
127                      step_arg->get_output_partial_shape(0).rank().compatible(0));
128
129         auto et = range_node->get_output_element_type(0);
130         auto shape = range_node->get_output_shape(0);
131
132         std::shared_ptr<op::Constant> replacement;
133
134 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
135 #pragma GCC diagnostic push
136 #pragma GCC diagnostic error "-Wswitch"
137 #pragma GCC diagnostic error "-Wswitch-enum"
138 #endif
139         switch (et)
140         {
141         case element::Type_t::bf16:
142             replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
143             break;
144         case element::Type_t::f16:
145             replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
146             break;
147         case element::Type_t::f32:
148             replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
149             break;
150         case element::Type_t::f64:
151             replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
152             break;
153         case element::Type_t::i8:
154             replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
155             break;
156         case element::Type_t::i16:
157             replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
158             break;
159         case element::Type_t::i32:
160             replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
161             break;
162         case element::Type_t::i64:
163             replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
164             break;
165         case element::Type_t::u8:
166             replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
167             break;
168         case element::Type_t::u16:
169             replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
170             break;
171         case element::Type_t::u32:
172             replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
173             break;
174         case element::Type_t::u64:
175             replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
176             break;
177         case element::Type_t::u1:
178         case element::Type_t::undefined:
179         case element::Type_t::dynamic:
180         case element::Type_t::boolean:
181             NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
182             break;
183         }
184 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
185 #pragma GCC diagnostic pop
186 #endif
187
188         replace_node(range_node, replacement);
189         return true;
190     };
191
192     auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
193     NGRAPH_SUPPRESS_DEPRECATED_START
194     add_matcher(range_matcher, range_callback, all_pass_property_off);
195     NGRAPH_SUPPRESS_DEPRECATED_END
196 }