1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
33 NGRAPH_SUPPRESS_DEPRECATED_START
36 using namespace ngraph;
38 namespace opset1_upgrade
40 template <typename OpV0, typename OpV1>
41 shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
43 const auto autob = node->get_autob();
44 auto replacement_node =
45 make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46 replace_node(node, replacement_node);
47 return replacement_node;
50 // Default is that we didn nothing
51 shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52 shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
54 return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
57 shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
59 auto strides = node->get_window_movement_strides();
60 auto dilations = node->get_window_dilation_strides();
61 auto pads_begin = node->get_padding_below();
62 auto pads_end = node->get_padding_above();
63 auto data_dilation_strides = node->get_data_dilation_strides();
64 auto auto_pad = node->get_pad_type();
66 bool is_dds_valid = all_of(data_dilation_strides.begin(),
67 data_dilation_strides.end(),
68 [](size_t value) { return value == 1; });
70 NGRAPH_CHECK(is_dds_valid,
71 "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
72 "other than `1`. Node: ",
75 auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
82 replace_node(node, replacement_node);
83 return replacement_node;
86 shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
88 auto data_batch_shape = node->get_data_batch_shape();
89 auto strides = node->get_window_movement_strides_forward();
90 auto dilations = node->get_window_dilation_strides_forward();
91 auto pads_begin = node->get_padding_below_forward();
92 auto pads_end = node->get_padding_above_forward();
93 auto data_dilation_strides = node->get_data_dilation_strides_forward();
95 bool is_dds_valid = all_of(data_dilation_strides.begin(),
96 data_dilation_strides.end(),
97 [](size_t value) { return value == 1; });
99 NGRAPH_CHECK(is_dds_valid,
100 "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
101 "with data dilation strides "
102 "other than `1`. Node: ",
105 auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
106 node->input_value(1), // data
107 node->input_value(0), // filters
108 op::Constant::create(
110 Shape{data_batch_shape.size() - 2},
111 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
116 replace_node(node, replacement_node);
117 return replacement_node;
120 shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
122 const auto autob = node->get_autob();
123 const bool pydiv = node->is_pythondiv();
124 auto replacement_node =
125 make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
126 replace_node(node, replacement_node);
127 return replacement_node;
130 shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
132 shared_ptr<Node> replacement_node =
133 builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
134 replace_node(node, replacement_node);
135 return replacement_node;
138 shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
140 return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
143 shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
145 int64_t axis = node->get_axis();
147 auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
148 auto replacement_node =
149 make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
150 replace_node(node, replacement_node);
151 return replacement_node;
154 shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
156 return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
159 shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
161 return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
164 shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
166 auto strides = node->get_window_movement_strides();
167 auto dilations = node->get_window_dilation_strides();
168 auto pads_begin = node->get_padding_below();
169 auto pads_end = node->get_padding_above();
170 auto data_dilation_strides = node->get_data_dilation_strides();
171 auto auto_pad = node->get_pad_type();
173 bool is_dds_valid = all_of(data_dilation_strides.begin(),
174 data_dilation_strides.end(),
175 [](size_t value) { return value == 1; });
177 NGRAPH_CHECK(is_dds_valid,
178 "Unable to convert GroupConvolution:0 to GroupConvolution:1"
179 "with data dilation strides other than `1`. Node: ",
182 shared_ptr<Node> replacement_node;
183 if (node->has_groups_in_filters())
185 replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
186 node->input_value(1),
195 NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
196 "Unable to convert GroupConvolution:0 to GroupConvolution:1"
197 "with dynamic filters shape. Node: ",
200 auto filters_shape = node->get_input_shape(1);
201 auto groups = node->get_groups();
202 filters_shape[0] /= groups;
203 filters_shape.insert(filters_shape.begin(), groups);
205 auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
207 replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
215 replace_node(node, replacement_node);
216 return replacement_node;
219 shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
221 const auto strides = node->get_window_movement_strides();
222 const auto dilations = node->get_window_dilation_strides();
223 const auto pads_begin = node->get_padding_below();
224 const auto pads_end = node->get_padding_above();
226 const auto data_batch_pshape = node->get_input_partial_shape(0);
227 const auto filters_pshape = node->get_input_partial_shape(1);
229 NGRAPH_CHECK(data_batch_pshape.is_static(),
230 "Unable to convert GroupConvolutionBackpropData:0 to "
231 "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
233 NGRAPH_CHECK(filters_pshape.is_static(),
234 "Unable to convert GroupConvolutionBackpropData:0 to "
235 "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
238 auto data_batch_shape = data_batch_pshape.to_shape();
239 // Remove N, C from output shape to preserve only spatial dimentions.
240 data_batch_shape.erase(std::begin(data_batch_shape),
241 std::next(std::begin(data_batch_shape), 2));
242 auto filters_shape = filters_pshape.to_shape();
243 auto groups = node->get_groups();
245 filters_shape[0] /= groups;
246 filters_shape.insert(filters_shape.begin(), groups);
247 auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
249 auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
250 node->input_value(2),
252 op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
257 replace_node(node, replacement_node);
258 return replacement_node;
261 shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
263 return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
266 shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
268 return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
271 shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
273 bool keep_dims = false;
274 auto replacement_node =
275 make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
276 replace_node(node, replacement_node);
277 return replacement_node;
280 shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
282 return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
285 shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
287 bool keep_dims = false;
288 auto replacement_node =
289 make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
290 replace_node(node, replacement_node);
291 return replacement_node;
294 shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
296 return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
299 shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
301 return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
304 shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
306 auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
307 replace_node(node, replacement_node);
308 return replacement_node;
311 shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
313 return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
316 shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
318 const auto indices = node->input_value(0).get_node_shared_ptr();
319 const auto one_hot_axis = node->get_one_hot_axis();
321 const auto output_pshape = node->get_output_partial_shape(0);
322 NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
323 "OneHot:v0 one hot axis dimension must be static ",
325 const auto depth = output_pshape[one_hot_axis].get_length();
326 const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
328 const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
329 const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
331 auto replacement_node =
332 make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
333 replace_node(node, replacement_node);
334 return replacement_node;
337 shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
339 return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
342 shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
344 return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
347 shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
349 bool keep_dims = false;
350 auto replacement_node =
351 make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
352 replace_node(node, replacement_node);
353 return replacement_node;
356 shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
358 // creates a Constant node from the v0::Reverse reversed_axes attribute
359 // and uses it as the second input of v1::Reverse
360 const auto reversed_axes = node->get_reversed_axes();
362 const auto reversed_axes_constant = op::Constant::create(
363 element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
365 const auto replacement_node = make_shared<op::v1::Reverse>(
366 node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
368 replace_node(node, replacement_node);
369 return replacement_node;
372 shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
374 auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
375 node->input_value(1),
376 node->input_value(2),
377 op::AutoBroadcastSpec());
378 replace_node(node, replacement_node);
379 return replacement_node;
382 shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
384 NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
385 "axes parameter is expected to be a static constant");
387 AxisSet axes = node->get_axes();
391 "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
394 auto replacement_node =
395 make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
396 replace_node(node, replacement_node);
397 return replacement_node;
400 shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
402 const auto data = node->input_value(0);
403 const auto begin = op::Constant::create(
404 element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
405 const auto end = op::Constant::create(
406 element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
407 const auto strides = op::Constant::create(
408 element::i64, Shape{node->get_strides().size()}, node->get_strides());
409 int64_t input_size = node->get_lower_bounds().size();
411 auto replacement_node = make_shared<op::v1::StridedSlice>(data,
415 vector<int64_t>(input_size, 0),
416 vector<int64_t>(input_size, 0));
418 replace_node(node, replacement_node);
419 return replacement_node;
422 shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
424 const auto& splits_vec = node->get_splits();
425 const auto first_elem = splits_vec.front();
427 const bool split_evenly =
428 std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
429 return split == first_elem;
432 std::shared_ptr<Node> replacement_node;
435 replacement_node = make_shared<op::v1::Split>(
436 node->input_value(0), node->input_value(1), splits_vec.front());
440 const auto split_lengths =
441 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
443 replacement_node = make_shared<op::v1::VariadicSplit>(
444 node->input_value(0), node->input_value(1), split_lengths);
447 replace_node(node, replacement_node);
448 return replacement_node;
451 shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
453 return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
456 shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
458 bool keep_dims = false;
459 auto replacement_node =
460 make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
461 replace_node(node, replacement_node);
462 return replacement_node;
465 shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
467 NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
468 "parameter k is expected to be a static constant");
469 NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
470 "parameter top_k_axis is expected to be a static constant");
472 const auto k = node->get_k();
473 const auto axis = node->get_top_k_axis();
476 switch (node->get_sort())
478 case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
479 case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
480 case op::TopK::SortType::NONE: sort = "none"; break;
484 if (node->get_compute_max())
493 const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
494 auto replacement_node =
495 make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
497 // indices output will be 0, values 1
498 vector<int64_t> output_order{1, 0};
499 replace_node(node, replacement_node, output_order);
500 return replacement_node;
503 shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
505 auto replacement_node = make_shared<op::v1::LogicalXor>(
506 node->input_value(0), node->input_value(1), node->get_autob());
507 replace_node(node, replacement_node);
508 return replacement_node;
511 using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
513 template <typename T>
514 bool op_cast_thunk(shared_ptr<Node> node)
516 auto upgraded_node = op_cast(as_type_ptr<T>(node));
519 if (ngraph::get_provenance_enabled())
521 const std::string provenance_tag =
522 "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
523 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
530 DispatchMap& get_dispatch_map()
532 NGRAPH_SUPPRESS_DEPRECATED_START
533 static DispatchMap dispatch_map{
534 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
535 #include "opset0_tbl.hpp"
539 NGRAPH_SUPPRESS_DEPRECATED_END
541 } // namespace opset1_upgrade
543 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
545 bool modified = false;
546 auto& dispatch_map = opset1_upgrade::get_dispatch_map();
547 auto it = dispatch_map.find(node->get_type_info());
548 if (it != dispatch_map.end())
550 modified = it->second(node);