Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / dyn_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <numeric>
18
19 #include "dyn_elimination.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/range.hpp"
22 #include "ngraph/op/replace_slice.hpp"
23 #include "ngraph/op/reshape.hpp"
24 #include "ngraph/op/slice.hpp"
25 #include "ngraph/op/transpose.hpp"
26 #include "ngraph/pattern/matcher.hpp"
27 #include "ngraph/pattern/op/label.hpp"
28 #include "ngraph/runtime/reference/range.hpp"
29 #include "ngraph/slice_plan.hpp"
30
31 NGRAPH_SUPPRESS_DEPRECATED_START
32
33 using namespace std;
34 using namespace ngraph;
35
36 pass::DynElimination::DynElimination()
37     : GraphRewrite()
38 {
39     construct_transpose();
40     construct_range();
41 }
42
43 void pass::DynElimination::construct_transpose()
44 {
45     auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
46     auto perm_arg_label =
47         make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
48
49     auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
50
51     auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
52         auto pattern_map = m.get_pattern_map();
53
54         auto data_arg = pattern_map[data_arg_label];
55         auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
56
57         // TODO(amprocte): Can't handle the case where data shape is dynamic, because static
58         // Reshape requries the exact output shape to be declared. See if we can come up with a
59         // workaround.
60         if (data_arg->get_output_partial_shape(0).is_dynamic())
61         {
62             return false;
63         }
64
65         auto& data_shape = data_arg->get_output_shape(0);
66
67         NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
68         NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
69
70         if (perm_arg->get_output_element_type(0).is_dynamic() ||
71             perm_arg->get_output_partial_shape(0).is_dynamic())
72         {
73             return false;
74         }
75
76         auto perm = perm_arg->get_axis_vector_val();
77
78         auto output_shape = ngraph::apply_permutation(data_shape, perm);
79
80         auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
81
82         replace_node(m.get_match_root(), replacement);
83         return true;
84     };
85
86     auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
87     NGRAPH_SUPPRESS_DEPRECATED_START
88     add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
89     NGRAPH_SUPPRESS_DEPRECATED_END
90 }
91
92 template <typename T>
93 std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
94                                                      const Shape& shape,
95                                                      const std::shared_ptr<op::Constant>& start_arg,
96                                                      const std::shared_ptr<op::Constant>& step_arg)
97 {
98     std::vector<T> elements(shape_size(shape));
99     std::vector<T> start_vec = start_arg->get_vector<T>();
100     std::vector<T> step_vec = step_arg->get_vector<T>();
101
102     NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
103
104     runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
105
106     return make_shared<op::Constant>(et, shape, elements);
107 }
108
109 void pass::DynElimination::construct_range()
110 {
111     auto start_arg_label =
112         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
113     auto stop_arg_label =
114         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
115     auto step_arg_label =
116         make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
117
118     auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
119
120     auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
121         auto pattern_map = m.get_pattern_map();
122
123         auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
124         auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
125         auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
126
127         NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
128                      step_arg->get_output_partial_shape(0).rank().compatible(0));
129
130         auto et = range_node->get_output_element_type(0);
131         auto shape = range_node->get_output_shape(0);
132
133         std::shared_ptr<op::Constant> replacement;
134
135 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
136 #pragma GCC diagnostic push
137 #pragma GCC diagnostic error "-Wswitch"
138 #pragma GCC diagnostic error "-Wswitch-enum"
139 #endif
140         switch (et)
141         {
142         case element::Type_t::bf16:
143             replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
144             break;
145         case element::Type_t::f16:
146             replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
147             break;
148         case element::Type_t::f32:
149             replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
150             break;
151         case element::Type_t::f64:
152             replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
153             break;
154         case element::Type_t::i8:
155             replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
156             break;
157         case element::Type_t::i16:
158             replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
159             break;
160         case element::Type_t::i32:
161             replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
162             break;
163         case element::Type_t::i64:
164             replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
165             break;
166         case element::Type_t::u8:
167             replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
168             break;
169         case element::Type_t::u16:
170             replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
171             break;
172         case element::Type_t::u32:
173             replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
174             break;
175         case element::Type_t::u64:
176             replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
177             break;
178         case element::Type_t::u1:
179         case element::Type_t::undefined:
180         case element::Type_t::dynamic:
181         case element::Type_t::boolean:
182             NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
183             break;
184         }
185 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
186 #pragma GCC diagnostic pop
187 #endif
188
189         replace_node(range_node, replacement);
190         return true;
191     };
192
193     auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
194     NGRAPH_SUPPRESS_DEPRECATED_START
195     add_matcher(range_matcher, range_callback, all_pass_property_off);
196     NGRAPH_SUPPRESS_DEPRECATED_END
197 }