Remove old transformation passes (#1463)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / dyn_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <numeric>
18
19 #include "dyn_elimination.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/range.hpp"
22 #include "ngraph/op/replace_slice.hpp"
23 #include "ngraph/op/reshape.hpp"
24 #include "ngraph/op/reverse.hpp"
25 #include "ngraph/op/slice.hpp"
26 #include "ngraph/op/transpose.hpp"
27 #include "ngraph/pattern/matcher.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/runtime/reference/range.hpp"
30 #include "ngraph/slice_plan.hpp"
31
32 using namespace std;
33 using namespace ngraph;
34
35 pass::DynElimination::DynElimination()
36     : GraphRewrite()
37 {
38     construct_transpose();
39     construct_range();
40 }
41
42 void pass::DynElimination::construct_transpose()
43 {
44     auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
45     auto perm_arg_label =
46         make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
47
48     auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
49
50     auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
51         auto pattern_map = m.get_pattern_map();
52
53         auto data_arg = pattern_map[data_arg_label];
54         auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
55
56         // TODO(amprocte): Can't handle the case where data shape is dynamic, because static
57         // Reshape requries the exact output shape to be declared. See if we can come up with a
58         // workaround.
59         if (data_arg->get_output_partial_shape(0).is_dynamic())
60         {
61             return false;
62         }
63
64         auto& data_shape = data_arg->get_output_shape(0);
65
66         NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
67         NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
68
69         if (perm_arg->get_output_element_type(0).is_dynamic() ||
70             perm_arg->get_output_partial_shape(0).is_dynamic())
71         {
72             return false;
73         }
74
75         auto perm = perm_arg->get_axis_vector_val();
76
77         auto output_shape = ngraph::apply_permutation(data_shape, perm);
78
79         auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
80
81         replace_node(m.get_match_root(), replacement);
82         return true;
83     };
84
85     auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
86     add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
87 }
88
89 template <typename T>
90 std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
91                                                      const Shape& shape,
92                                                      const std::shared_ptr<op::Constant>& start_arg,
93                                                      const std::shared_ptr<op::Constant>& step_arg)
94 {
95     std::vector<T> elements(shape_size(shape));
96     std::vector<T> start_vec = start_arg->get_vector<T>();
97     std::vector<T> step_vec = step_arg->get_vector<T>();
98
99     NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
100
101     runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
102
103     return make_shared<op::Constant>(et, shape, elements);
104 }
105
106 void pass::DynElimination::construct_range()
107 {
108     auto start_arg_label =
109         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
110     auto stop_arg_label =
111         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
112     auto step_arg_label =
113         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
114
115     auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
116
117     auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
118         auto pattern_map = m.get_pattern_map();
119
120         auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
121         auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
122         auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
123
124         NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
125                      step_arg->get_output_partial_shape(0).rank().compatible(0));
126
127         auto et = range_node->get_output_element_type(0);
128         auto shape = range_node->get_output_shape(0);
129
130         std::shared_ptr<op::Constant> replacement;
131
132 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
133 #pragma GCC diagnostic push
134 #pragma GCC diagnostic error "-Wswitch"
135 #pragma GCC diagnostic error "-Wswitch-enum"
136 #endif
137         switch (et)
138         {
139         case element::Type_t::bf16:
140             replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
141             break;
142         case element::Type_t::f16:
143             replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
144             break;
145         case element::Type_t::f32:
146             replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
147             break;
148         case element::Type_t::f64:
149             replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
150             break;
151         case element::Type_t::i8:
152             replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
153             break;
154         case element::Type_t::i16:
155             replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
156             break;
157         case element::Type_t::i32:
158             replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
159             break;
160         case element::Type_t::i64:
161             replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
162             break;
163         case element::Type_t::u8:
164             replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
165             break;
166         case element::Type_t::u16:
167             replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
168             break;
169         case element::Type_t::u32:
170             replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
171             break;
172         case element::Type_t::u64:
173             replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
174             break;
175         case element::Type_t::u1:
176         case element::Type_t::undefined:
177         case element::Type_t::dynamic:
178         case element::Type_t::boolean:
179             NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
180             break;
181         }
182 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
183 #pragma GCC diagnostic pop
184 #endif
185
186         replace_node(range_node, replacement);
187         return true;
188     };
189
190     auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
191     add_matcher(range_matcher, range_callback, all_pass_property_off);
192 }