Remove obsoleted Min, Max operators (#2832)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 using namespace std;
36 using namespace ngraph;
37
38 namespace opset1_upgrade
39 {
40     template <typename OpV0, typename OpV1>
41     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
42     {
43         const auto autob = node->get_autob();
44         auto replacement_node =
45             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46         replace_node(node, replacement_node);
47         return replacement_node;
48     }
49
50     // Default is that we didn nothing
51     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
53     {
54         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55     }
56
57     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
58     {
59         auto strides = node->get_window_movement_strides();
60         auto dilations = node->get_window_dilation_strides();
61         auto pads_begin = node->get_padding_below();
62         auto pads_end = node->get_padding_above();
63         auto data_dilation_strides = node->get_data_dilation_strides();
64         auto auto_pad = node->get_pad_type();
65
66         bool is_dds_valid = all_of(data_dilation_strides.begin(),
67                                    data_dilation_strides.end(),
68                                    [](size_t value) { return value == 1; });
69
70         NGRAPH_CHECK(is_dds_valid,
71                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
72                      "other than `1`. Node: ",
73                      *node);
74
75         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
76                                                                  node->input_value(1),
77                                                                  strides,
78                                                                  pads_begin,
79                                                                  pads_end,
80                                                                  dilations,
81                                                                  auto_pad);
82         replace_node(node, replacement_node);
83         return replacement_node;
84     }
85
86     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
87     {
88         auto data_batch_shape = node->get_data_batch_shape();
89         auto strides = node->get_window_movement_strides_forward();
90         auto dilations = node->get_window_dilation_strides_forward();
91         auto pads_begin = node->get_padding_below_forward();
92         auto pads_end = node->get_padding_above_forward();
93         auto data_dilation_strides = node->get_data_dilation_strides_forward();
94
95         bool is_dds_valid = all_of(data_dilation_strides.begin(),
96                                    data_dilation_strides.end(),
97                                    [](size_t value) { return value == 1; });
98
99         NGRAPH_CHECK(is_dds_valid,
100                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
101                      "with data dilation strides "
102                      "other than `1`. Node: ",
103                      *node);
104
105         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
106             node->input_value(1), // data
107             node->input_value(0), // filters
108             op::Constant::create(
109                 element::i64,
110                 Shape{data_batch_shape.size() - 2},
111                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
112             strides,
113             pads_begin,
114             pads_end,
115             dilations);
116         replace_node(node, replacement_node);
117         return replacement_node;
118     }
119
120     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
121     {
122         const auto autob = node->get_autob();
123         const bool pydiv = node->is_pythondiv();
124         auto replacement_node =
125             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
126         replace_node(node, replacement_node);
127         return replacement_node;
128     }
129
130     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
131     {
132         shared_ptr<Node> replacement_node =
133             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
134         replace_node(node, replacement_node);
135         return replacement_node;
136     }
137
138     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
139     {
140         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
141     }
142
143     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
144     {
145         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
146     }
147
148     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
149     {
150         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
151     }
152
153     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
154     {
155         auto strides = node->get_window_movement_strides();
156         auto dilations = node->get_window_dilation_strides();
157         auto pads_begin = node->get_padding_below();
158         auto pads_end = node->get_padding_above();
159         auto data_dilation_strides = node->get_data_dilation_strides();
160         auto auto_pad = node->get_pad_type();
161
162         bool is_dds_valid = all_of(data_dilation_strides.begin(),
163                                    data_dilation_strides.end(),
164                                    [](size_t value) { return value == 1; });
165
166         NGRAPH_CHECK(is_dds_valid,
167                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
168                      "with data dilation strides other than `1`. Node: ",
169                      *node);
170
171         shared_ptr<Node> replacement_node;
172         if (node->has_groups_in_filters())
173         {
174             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
175                                                                      node->input_value(1),
176                                                                      strides,
177                                                                      pads_begin,
178                                                                      pads_end,
179                                                                      dilations,
180                                                                      auto_pad);
181         }
182         else
183         {
184             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
185                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
186                          "with dynamic filters shape. Node: ",
187                          *node);
188
189             auto filters_shape = node->get_input_shape(1);
190             auto groups = node->get_groups();
191             filters_shape[0] /= groups;
192             filters_shape.insert(filters_shape.begin(), groups);
193
194             auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
195
196             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
197                                                                      reshaped_filters,
198                                                                      strides,
199                                                                      pads_begin,
200                                                                      pads_end,
201                                                                      dilations,
202                                                                      auto_pad);
203         }
204         replace_node(node, replacement_node);
205         return replacement_node;
206     }
207
208     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
209     {
210         const auto strides = node->get_window_movement_strides();
211         const auto dilations = node->get_window_dilation_strides();
212         const auto pads_begin = node->get_padding_below();
213         const auto pads_end = node->get_padding_above();
214
215         const auto data_batch_pshape = node->get_input_partial_shape(0);
216         const auto filters_pshape = node->get_input_partial_shape(1);
217
218         NGRAPH_CHECK(data_batch_pshape.is_static(),
219                      "Unable to convert GroupConvolutionBackpropData:0 to "
220                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
221                      *node);
222         NGRAPH_CHECK(filters_pshape.is_static(),
223                      "Unable to convert GroupConvolutionBackpropData:0 to "
224                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
225                      *node);
226
227         auto data_batch_shape = data_batch_pshape.to_shape();
228         // Remove N, C from output shape to preserve only spatial dimentions.
229         data_batch_shape.erase(std::begin(data_batch_shape),
230                                std::next(std::begin(data_batch_shape), 2));
231         auto filters_shape = filters_pshape.to_shape();
232         auto groups = node->get_groups();
233
234         filters_shape[0] /= groups;
235         filters_shape.insert(filters_shape.begin(), groups);
236         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
237
238         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
239             node->input_value(2),
240             reshaped_filters,
241             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
242             strides,
243             pads_begin,
244             pads_end,
245             dilations);
246         replace_node(node, replacement_node);
247         return replacement_node;
248     }
249
250     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
251     {
252         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
253     }
254
255     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
256     {
257         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
258     }
259
260     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
261     {
262         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
263     }
264
265     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
266     {
267         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
268     }
269
270     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
271     {
272         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
273     }
274
275     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
276     {
277         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
278         replace_node(node, replacement_node);
279         return replacement_node;
280     }
281
282     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
283     {
284         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
285     }
286
287     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
288     {
289         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
290     }
291
292     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
293     {
294         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
295     }
296
297     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
298     {
299         // creates a Constant node from the v0::Reverse reversed_axes attribute
300         // and uses it as the second input of v1::Reverse
301         const auto reversed_axes = node->get_reversed_axes();
302
303         const auto reversed_axes_constant = op::Constant::create(
304             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
305
306         const auto replacement_node = make_shared<op::v1::Reverse>(
307             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
308
309         replace_node(node, replacement_node);
310         return replacement_node;
311     }
312
313     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
314     {
315         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
316                                                             node->input_value(1),
317                                                             node->input_value(2),
318                                                             op::AutoBroadcastSpec());
319         replace_node(node, replacement_node);
320         return replacement_node;
321     }
322
323     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
324     {
325         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
326                      "axes parameter is expected to be a static constant");
327
328         AxisSet axes = node->get_axes();
329
330         NGRAPH_CHECK(
331             axes.size() == 1,
332             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
333             *node);
334
335         auto replacement_node =
336             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
337         replace_node(node, replacement_node);
338         return replacement_node;
339     }
340
341     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
342     {
343         const auto data = node->input_value(0);
344         const auto begin = op::Constant::create(
345             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
346         const auto end = op::Constant::create(
347             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
348         const auto strides = op::Constant::create(
349             element::i64, Shape{node->get_strides().size()}, node->get_strides());
350         int64_t input_size = node->get_lower_bounds().size();
351
352         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
353                                                                   begin,
354                                                                   end,
355                                                                   strides,
356                                                                   vector<int64_t>(input_size, 0),
357                                                                   vector<int64_t>(input_size, 0));
358
359         replace_node(node, replacement_node);
360         return replacement_node;
361     }
362
363     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
364     {
365         const auto& splits_vec = node->get_splits();
366         const auto first_elem = splits_vec.front();
367
368         const bool split_evenly =
369             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
370                 return split == first_elem;
371             });
372
373         std::shared_ptr<Node> replacement_node;
374         if (split_evenly)
375         {
376             replacement_node = make_shared<op::v1::Split>(
377                 node->input_value(0), node->input_value(1), splits_vec.front());
378         }
379         else
380         {
381             const auto split_lengths =
382                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
383
384             replacement_node = make_shared<op::v1::VariadicSplit>(
385                 node->input_value(0), node->input_value(1), split_lengths);
386         }
387
388         replace_node(node, replacement_node);
389         return replacement_node;
390     }
391
392     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
393     {
394         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
395     }
396
397     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
398     {
399         bool keep_dims = false;
400         auto replacement_node =
401             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
402         replace_node(node, replacement_node);
403         return replacement_node;
404     }
405
406     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
407     {
408         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
409                      "parameter k is expected to be a static constant");
410         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
411                      "parameter top_k_axis is expected to be a static constant");
412
413         const auto k = node->get_k();
414         const auto axis = node->get_top_k_axis();
415
416         std::string sort;
417         switch (node->get_sort())
418         {
419         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
420         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
421         case op::TopK::SortType::NONE: sort = "none"; break;
422         }
423
424         std::string mode;
425         if (node->get_compute_max())
426         {
427             mode = "max";
428         }
429         else
430         {
431             mode = "min";
432         }
433
434         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
435         auto replacement_node =
436             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
437
438         // indices output will be 0, values 1
439         vector<int64_t> output_order{1, 0};
440         replace_node(node, replacement_node, output_order);
441         return replacement_node;
442     }
443
444     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
445     {
446         auto replacement_node = make_shared<op::v1::LogicalXor>(
447             node->input_value(0), node->input_value(1), node->get_autob());
448         replace_node(node, replacement_node);
449         return replacement_node;
450     }
451
452     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
453
454     template <typename T>
455     bool op_cast_thunk(shared_ptr<Node> node)
456     {
457         auto upgraded_node = op_cast(as_type_ptr<T>(node));
458         if (upgraded_node)
459         {
460             if (ngraph::get_provenance_enabled())
461             {
462                 const std::string provenance_tag =
463                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
464                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
465             }
466             return true;
467         }
468         return false;
469     }
470
471     DispatchMap& get_dispatch_map()
472     {
473         NGRAPH_SUPPRESS_DEPRECATED_START
474         static DispatchMap dispatch_map{
475 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
476 #include "opset0_tbl.hpp"
477 #undef NGRAPH_OP
478         };
479         return dispatch_map;
480         NGRAPH_SUPPRESS_DEPRECATED_END
481     }
482 } // namespace opset1_upgrade
483
484 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
485 {
486     bool modified = false;
487     auto& dispatch_map = opset1_upgrade::get_dispatch_map();
488     auto it = dispatch_map.find(node->get_type_info());
489     if (it != dispatch_map.end())
490     {
491         modified = it->second(node);
492     }
493     return modified;
494 }