1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
34 using namespace ngraph;
38 template <typename OpV0, typename OpV1>
39 shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
41 const auto autob = node->get_autob();
42 auto replacement_node =
43 make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
44 replace_node(node, replacement_node);
45 return replacement_node;
48 // Default is that we didn nothing
49 shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
50 shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
52 return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55 shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
57 auto replacement_node = ngraph::builder::opset1::make_broadcast(
58 node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
59 replace_node(node, replacement_node.get_node_shared_ptr());
60 return replacement_node.get_node_shared_ptr();
63 shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
64 shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
66 auto strides = node->get_window_movement_strides();
67 auto dilations = node->get_window_dilation_strides();
68 auto pads_begin = node->get_padding_below();
69 auto pads_end = node->get_padding_above();
70 auto data_dilation_strides = node->get_data_dilation_strides();
71 auto auto_pad = node->get_pad_type();
73 bool is_dds_valid = all_of(data_dilation_strides.begin(),
74 data_dilation_strides.end(),
75 [](size_t value) { return value == 1; });
77 NGRAPH_CHECK(is_dds_valid,
78 "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
79 "other than `1`. Node: ",
82 auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
89 replace_node(node, replacement_node);
90 return replacement_node;
93 shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
95 auto data_batch_shape = node->get_data_batch_shape();
96 auto strides = node->get_window_movement_strides_forward();
97 auto dilations = node->get_window_dilation_strides_forward();
98 auto pads_begin = node->get_padding_below_forward();
99 auto pads_end = node->get_padding_above_forward();
100 auto data_dilation_strides = node->get_data_dilation_strides_forward();
102 bool is_dds_valid = all_of(data_dilation_strides.begin(),
103 data_dilation_strides.end(),
104 [](size_t value) { return value == 1; });
106 NGRAPH_CHECK(is_dds_valid,
107 "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
108 "with data dilation strides "
109 "other than `1`. Node: ",
112 auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
113 node->input_value(1), // data
114 node->input_value(0), // filters
115 op::Constant::create(
117 Shape{data_batch_shape.size() - 2},
118 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
123 replace_node(node, replacement_node);
124 return replacement_node;
127 shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
129 const auto autob = node->get_autob();
130 const bool pydiv = node->is_pythondiv();
131 auto replacement_node =
132 make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
133 replace_node(node, replacement_node);
134 return replacement_node;
137 shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
139 shared_ptr<Node> replacement_node =
140 builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
141 replace_node(node, replacement_node);
142 return replacement_node;
145 shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
147 return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
150 shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
152 int64_t axis = node->get_axis();
154 auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
155 auto replacement_node =
156 make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
157 replace_node(node, replacement_node);
158 return replacement_node;
161 shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
163 return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
166 shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
168 return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
171 shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
173 auto strides = node->get_window_movement_strides();
174 auto dilations = node->get_window_dilation_strides();
175 auto pads_begin = node->get_padding_below();
176 auto pads_end = node->get_padding_above();
177 auto data_dilation_strides = node->get_data_dilation_strides();
178 auto auto_pad = node->get_pad_type();
180 bool is_dds_valid = all_of(data_dilation_strides.begin(),
181 data_dilation_strides.end(),
182 [](size_t value) { return value == 1; });
184 NGRAPH_CHECK(is_dds_valid,
185 "Unable to convert GroupConvolution:0 to GroupConvolution:1"
186 "with data dilation strides other than `1`. Node: ",
189 shared_ptr<Node> replacement_node;
190 if (node->has_groups_in_filters())
192 replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
193 node->input_value(1),
202 NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
203 "Unable to convert GroupConvolution:0 to GroupConvolution:1"
204 "with dynamic filters shape. Node: ",
207 auto filters_shape = node->get_input_shape(1);
208 auto groups = node->get_groups();
209 filters_shape[0] /= groups;
210 filters_shape.insert(filters_shape.begin(), groups);
212 auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
214 replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
222 replace_node(node, replacement_node);
223 return replacement_node;
226 shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
228 const auto strides = node->get_window_movement_strides();
229 const auto dilations = node->get_window_dilation_strides();
230 const auto pads_begin = node->get_padding_below();
231 const auto pads_end = node->get_padding_above();
233 const auto data_batch_pshape = node->get_input_partial_shape(0);
234 const auto filters_pshape = node->get_input_partial_shape(1);
236 NGRAPH_CHECK(data_batch_pshape.is_static(),
237 "Unable to convert GroupConvolutionBackpropData:0 to "
238 "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
240 NGRAPH_CHECK(filters_pshape.is_static(),
241 "Unable to convert GroupConvolutionBackpropData:0 to "
242 "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
245 auto data_batch_shape = data_batch_pshape.to_shape();
246 // Remove N, C from output shape to preserve only spatial dimentions.
247 data_batch_shape.erase(std::begin(data_batch_shape),
248 std::next(std::begin(data_batch_shape), 2));
249 auto filters_shape = filters_pshape.to_shape();
250 auto groups = node->get_groups();
252 filters_shape[0] /= groups;
253 filters_shape.insert(filters_shape.begin(), groups);
254 auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
256 auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
257 node->input_value(2),
259 op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
264 replace_node(node, replacement_node);
265 return replacement_node;
268 shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
270 return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
273 shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
275 return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
278 shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
280 bool keep_dims = false;
281 auto replacement_node =
282 make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
283 replace_node(node, replacement_node);
284 return replacement_node;
287 shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
289 return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
292 shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
294 bool keep_dims = false;
295 auto replacement_node =
296 make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
297 replace_node(node, replacement_node);
298 return replacement_node;
301 shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
303 return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
306 shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
308 return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
311 shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
313 auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
314 replace_node(node, replacement_node);
315 return replacement_node;
318 shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
320 return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
323 shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
325 const auto indices = node->input_value(0).get_node_shared_ptr();
326 const auto one_hot_axis = node->get_one_hot_axis();
328 const auto output_pshape = node->get_output_partial_shape(0);
329 NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
330 "OneHot:v0 one hot axis dimension must be static ",
332 const auto depth = output_pshape[one_hot_axis].get_length();
333 const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
335 const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
336 const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
338 auto replacement_node =
339 make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
340 replace_node(node, replacement_node);
341 return replacement_node;
344 shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
346 return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
349 shared_ptr<Node> op_cast(shared_ptr<op::Pad> node)
351 auto padding_below = node->get_padding_below();
352 auto pads_begin_node =
353 make_shared<op::Constant>(element::i64, Shape{padding_below.size()}, padding_below);
354 auto padding_above = node->get_padding_above();
356 make_shared<op::Constant>(element::i64, Shape{padding_above.size()}, padding_above);
358 auto replacement_node = make_shared<op::v1::Pad>(node->input_value(0),
361 node->input_value(1),
362 node->get_pad_mode());
364 replace_node(node, replacement_node);
365 return replacement_node;
368 shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
370 return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
373 shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
375 bool keep_dims = false;
376 auto replacement_node =
377 make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
378 replace_node(node, replacement_node);
379 return replacement_node;
382 shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
384 // creates a Constant node from the v0::Reverse reversed_axes attribute
385 // and uses it as the second input of v1::Reverse
386 const auto reversed_axes = node->get_reversed_axes();
388 const auto reversed_axes_constant = op::Constant::create(
389 element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
391 const auto replacement_node = make_shared<op::v1::Reverse>(
392 node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
394 replace_node(node, replacement_node);
395 return replacement_node;
398 shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
400 auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
401 node->input_value(1),
402 node->input_value(2),
403 op::AutoBroadcastSpec());
404 replace_node(node, replacement_node);
405 return replacement_node;
408 shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
410 NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
411 "axes parameter is expected to be a static constant");
413 AxisSet axes = node->get_axes();
417 "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
420 auto replacement_node =
421 make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
422 replace_node(node, replacement_node);
423 return replacement_node;
426 shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
428 const auto data = node->input_value(0);
429 const auto begin = op::Constant::create(
430 element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
431 const auto end = op::Constant::create(
432 element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
433 const auto strides = op::Constant::create(
434 element::i64, Shape{node->get_strides().size()}, node->get_strides());
435 int64_t input_size = node->get_lower_bounds().size();
437 auto replacement_node = make_shared<op::v1::StridedSlice>(data,
441 vector<int64_t>(input_size, 0),
442 vector<int64_t>(input_size, 0));
444 replace_node(node, replacement_node);
445 return replacement_node;
448 shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
450 const auto& splits_vec = node->get_splits();
451 const auto first_elem = splits_vec.front();
453 const bool split_evenly =
454 std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
455 return split == first_elem;
458 std::shared_ptr<Node> replacement_node;
461 replacement_node = make_shared<op::v1::Split>(
462 node->input_value(0), node->input_value(1), splits_vec.front());
466 const auto split_lengths =
467 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
469 replacement_node = make_shared<op::v1::VariadicSplit>(
470 node->input_value(0), node->input_value(1), split_lengths);
473 replace_node(node, replacement_node);
474 return replacement_node;
477 shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
479 return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
482 shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
484 bool keep_dims = false;
485 auto replacement_node =
486 make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
487 replace_node(node, replacement_node);
488 return replacement_node;
491 shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
493 NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
494 "parameter k is expected to be a static constant");
495 NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
496 "parameter top_k_axis is expected to be a static constant");
498 const auto k = node->get_k();
499 const auto axis = node->get_top_k_axis();
502 switch (node->get_sort())
504 case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
505 case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
506 case op::TopK::SortType::NONE: sort = "none"; break;
510 if (node->get_compute_max())
519 const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
520 auto replacement_node =
521 make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
523 // indices output will be 0, values 1
524 vector<int64_t> output_order{1, 0};
525 replace_node(node, replacement_node, output_order);
526 return replacement_node;
529 shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
531 auto replacement_node = make_shared<op::v1::LogicalXor>(
532 node->input_value(0), node->input_value(1), node->get_autob());
533 replace_node(node, replacement_node);
534 return replacement_node;
537 using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
539 template <typename T>
540 bool op_cast_thunk(shared_ptr<Node> node)
542 auto upgraded_node = op_cast(as_type_ptr<T>(node));
545 if (ngraph::get_provenance_enabled())
547 const std::string provenance_tag =
548 "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
549 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
556 DispatchMap& get_dispatch_map()
558 static DispatchMap dispatch_map{
559 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
560 #include "opset0_tbl.hpp"
567 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
569 bool modified = false;
570 auto& dispatch_map = get_dispatch_map();
571 auto it = dispatch_map.find(node->get_type_info());
572 if (it != dispatch_map.end())
574 modified = it->second(node);