1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/graph_util.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/util/attr_types.hpp"
27 #include "ngraph/op/util/op_types.hpp"
28 #include "ngraph/ops.hpp"
29 #include "ngraph/provenance.hpp"
30 #include "ngraph/slice_plan.hpp"
31 #include "ngraph/type.hpp"
32 #include "ngraph/validation_util.hpp"
33 #include "op/avg_pool.hpp"
34 #include "op/convolution.hpp"
35 #include "op/group_conv.hpp"
36 #include "pass/implicit_broadcast_elimination.hpp"
37 #include "pass/opset0_downgrade.hpp"
39 NGRAPH_SUPPRESS_DEPRECATED_START
42 using namespace ngraph;
44 namespace opset0_downgrade
46 template <typename OpV0, typename OpV1>
47 shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
49 const auto input_arg0 = node->input_value(0);
50 const auto input_arg1 = node->input_value(1);
51 const auto autob = node->get_autob();
52 auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
53 replace_node(node, replacement_node);
54 return replacement_node;
57 template <typename OpV0, typename OpV1>
58 shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
60 auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
61 if (node->get_keep_dims())
63 string v1_op_name = string{node->get_type_name()} + ":v1";
64 string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
66 NGRAPH_CHECK(node->reduction_axes_constant(),
71 " if reduction axes are not constant (for keep_dims=true). Node: ",
73 auto output_pshape = replacement_node->get_output_partial_shape(0);
74 NGRAPH_CHECK(output_pshape.is_static(),
79 " if output shape is dynamic (for keep_dims=true). Node: ",
81 const auto output_shape = output_pshape.to_shape();
82 auto reshaped_output_shape = output_shape;
83 for (const auto& axis : node->get_reduction_axes())
85 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
87 auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
88 get_default_order(output_shape),
89 reshaped_output_shape);
90 return reshaped_product;
94 return replacement_node;
98 // Default is that we did nothing
99 shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
100 shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
102 return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
105 shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
107 auto const input_arg = node->input_value(0);
108 const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
109 const auto include_padding_in_avg_computation = !node->get_exclude_pad();
110 const auto pad_type = node->get_auto_pad();
111 const auto padding_below = node->get_pads_begin();
112 const auto padding_above = node->get_pads_end();
113 const auto window_movement_strides = node->get_strides();
114 const auto window_shape = node->get_kernel();
116 auto replacement_node = make_shared<op::v0::AvgPool>(input_arg,
118 window_movement_strides,
121 include_padding_in_avg_computation,
124 replace_node(node, replacement_node);
125 return replacement_node;
128 shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
130 const auto data_arg = node->input_value(0);
131 const auto filters_arg = node->input_value(1);
132 const auto strides = node->get_strides();
133 const size_t num_spatial_dims = strides.size();
134 auto replacement_node = make_shared<op::v0::Convolution>(data_arg,
137 node->get_dilations(),
138 node->get_pads_begin(),
139 node->get_pads_end(),
140 Strides(num_spatial_dims, 1),
141 node->get_auto_pad());
142 replace_node(node, replacement_node);
143 return replacement_node;
146 shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
148 const auto data_arg = node->input_value(0);
149 const auto filters_arg = node->input_value(1);
151 auto data_pshape = data_arg.get_partial_shape();
152 auto filters_pshape = filters_arg.get_partial_shape();
154 NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
155 filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
156 "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
157 "if data shape N and filters shape C dimensions are not static. Node: ",
160 const size_t num_spatial_dims = data_pshape.rank().get_length() - 2;
162 const PartialShape output_pshape{node->get_output_partial_shape(0)};
163 NGRAPH_CHECK(output_pshape.is_static(),
164 "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
165 "if output shape is dynamic. Node: ",
167 Shape output_shape = output_pshape.to_shape();
169 auto replacement_node =
170 make_shared<op::v0::ConvolutionBackpropData>(output_shape,
174 node->get_dilations(),
175 node->get_pads_begin(),
176 node->get_pads_end(),
177 Strides(num_spatial_dims, 1));
178 replace_node(node, replacement_node);
179 return replacement_node;
182 shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
184 const auto input_arg0 = node->input_value(0);
185 const auto input_arg1 = node->input_value(1);
186 const auto autob = node->get_autob();
187 const bool pydiv = node->is_pythondiv();
188 auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
189 replace_node(node, replacement_node);
190 return replacement_node;
193 shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
195 shared_ptr<Node> replacement_node;
197 const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
198 const auto input_rank = node->get_input_partial_shape(0).rank();
199 if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
200 input_rank.is_static())
202 const auto output_shape = node->get_output_shape(0);
203 replacement_node = make_shared<op::Reshape>(
204 node->input_value(0), get_default_order(input_rank.get_length()), output_shape);
208 NGRAPH_CHECK(replacement_node, "Unable to convert Reshape:v1 with dynamic shape.");
211 replace_node(node, replacement_node);
212 return replacement_node;
215 shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
217 return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
220 shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
222 return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
225 shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
227 return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
230 shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
232 const auto data_arg = node->input_value(0);
233 const auto filters_arg = node->input_value(1);
234 const auto strides = node->get_strides();
235 const size_t num_spatial_dims = strides.size();
236 auto replacement_node = make_shared<op::v0::GroupConvolution>(data_arg,
239 node->get_dilations(),
240 node->get_pads_begin(),
241 node->get_pads_end(),
242 Strides(num_spatial_dims, 1),
243 node->get_auto_pad());
244 replace_node(node, replacement_node);
245 return replacement_node;
248 shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
250 const auto data_arg = node->input_value(0);
251 const auto filters_arg = node->input_value(1);
253 NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
254 "Unable to convert GroupConvolutionBackpropData:1 to "
255 "GroupConvolutionBackpropData:0 with dynamic data shape. Node: ",
258 NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
259 "Unable to convert GroupConvolutionBackpropData:1 to "
260 "GroupConvolutionBackpropData:0 with dynamic filters shape. Node: ",
263 auto filters_shape = filters_arg.get_shape();
264 const size_t groups = filters_shape.at(0);
266 const PartialShape output_pshape{node->get_output_partial_shape(0)};
267 NGRAPH_CHECK(output_pshape.is_static(),
268 "Unable to convert GroupConvolutionBackpropData:v1 to "
269 "GroupConvolutionBackpropData:v0 "
270 "if output_shape is dynamic. Node: ",
272 Shape output_shape = output_pshape.to_shape();
274 // Convert filters data layout from [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
275 // into [C x M/group x k1 x k2 x ... x kn]
276 filters_shape.erase(filters_shape.begin());
277 filters_shape[0] *= groups;
279 auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
281 auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
282 op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
286 node->get_dilations(),
287 node->get_pads_begin(),
288 node->get_pads_end(),
290 replace_node(node, replacement_node);
291 return replacement_node;
294 shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
296 return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
299 shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
301 return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
304 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
306 auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
307 replace_node(node, replacement_node);
308 return replacement_node;
311 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
313 return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
316 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
318 return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
321 shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
323 return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
326 shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
328 return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
331 shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
333 return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
336 shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
338 return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
341 shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
343 return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
346 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
348 auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
349 replace_node(node, replacement_node);
350 return replacement_node;
353 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
355 // ReduceMean = Sum / Count
356 auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
358 // Count = Sum(Constant(1, shape=data.shape))
359 const auto data = node->input_value(0);
360 const auto axes = node->input_value(1);
361 const auto const_node =
362 op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
363 std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
365 // Support keep_dims attribute
366 if (node->get_keep_dims())
368 // In order to keep the original dimensions we need to reshape the Count node
369 // before we use it in Divide with NUMPY broadcast
370 auto output_shape = count_node->get_shape();
371 auto reshaped_output_shape = output_shape;
372 for (const auto& axis : node->get_reduction_axes())
374 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
376 count_node = make_shared<op::Reshape>(
377 count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
380 const auto replacement_node =
381 std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
382 replace_node(node, replacement_node);
383 return replacement_node;
386 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
388 auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
389 replace_node(node, replacement_node);
390 return replacement_node;
393 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
395 auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
396 replace_node(node, replacement_node);
397 return replacement_node;
400 shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
402 auto axes_node = node->input_value(1).get_node_shared_ptr();
403 NGRAPH_CHECK(op::is_constant(axes_node),
404 "Unable to convert Reverse:v1 to Reverse:v0 "
405 "if reduction axes are not constant. Node: ",
407 const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
409 if (node->get_mode() == op::v1::Reverse::Mode::INDEX)
411 axes = axes_node_const->get_axis_vector_val();
415 auto axes_mask = axes_node_const->get_vector<bool>();
416 for (size_t i = 0; i < axes_mask.size(); ++i)
424 auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
426 replace_node(node, replacement_node);
427 return replacement_node;
430 shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
432 ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
433 auto replacement_node = make_shared<op::v0::Select>(
434 node->input_value(0), node->input_value(1), node->input_value(2));
435 replace_node(node, replacement_node);
436 return replacement_node;
439 shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
441 auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
443 for (auto i = 0; i < mask.size(); ++i)
453 const auto input_data = node->input_value(0);
454 const auto input_data_pshape = input_data.get_partial_shape();
456 NGRAPH_CHECK(input_data_pshape.is_static(),
457 "Unable to convert StridedSlice:v1 to Slice:v0 "
458 "if input rank is not static. Node: ",
461 const auto begin_const =
462 as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
463 const auto end_const =
464 as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
465 const auto strides = as_type_ptr<op::Constant>(node->input_value(3).get_node_shared_ptr());
467 NGRAPH_CHECK(begin_const && end_const && strides,
468 "Unable to convert StridedSlice:v1 to Slice:v0 "
469 "if begin, end or strides are not constant. Node: ",
472 SlicePlan p = make_slice_plan(input_data_pshape.to_shape(),
473 begin_const->get_vector<int64_t>(),
474 end_const->get_vector<int64_t>(),
475 strides->get_vector<int64_t>(),
476 convert_mask_to_axes(node->get_begin_mask()),
477 convert_mask_to_axes(node->get_end_mask()),
478 convert_mask_to_axes(node->get_new_axis_mask()),
479 convert_mask_to_axes(node->get_shrink_axis_mask()),
480 convert_mask_to_axes(node->get_ellipsis_mask()));
482 shared_ptr<Node> replacement_node =
483 make_shared<op::v0::Slice>(input_data,
484 Coordinate(p.begins.begin(), p.begins.end()),
485 Coordinate(p.ends.begin(), p.ends.end()),
486 Strides(p.strides.begin(), p.strides.end()));
488 if (p.reshape_in_shape != p.reshape_out_shape)
491 make_shared<op::Reshape>(replacement_node,
492 ngraph::get_default_order(p.reshape_in_shape),
493 p.reshape_out_shape);
496 if (!p.reverse_axes.empty())
498 replacement_node = make_shared<op::Reverse>(replacement_node, p.reverse_axes);
501 replace_node(node, replacement_node);
502 return replacement_node;
505 shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
507 const auto num_splits = node->get_num_splits();
509 auto replacement_node =
510 make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
512 replace_node(node, replacement_node);
513 return replacement_node;
516 shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
518 return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
521 shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
523 const auto axis = node->get_axis();
524 const auto sort_type = node->get_sort_type();
525 const auto index_elem_type = node->get_index_element_type();
528 switch (node->get_mode())
530 case op::v1::TopK::Mode::MAX: compute_max = true; break;
531 case op::v1::TopK::Mode::MIN: compute_max = false; break;
535 const auto arg_node = node->input_value(0);
536 const auto k_node = node->input_value(1);
538 auto replacement_node = make_shared<op::v0::TopK>(
539 arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
541 // values output will be 0, indices 1
542 vector<int64_t> output_order{1, 0};
543 replace_node(node, replacement_node, output_order);
544 return replacement_node;
547 shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
549 const auto data = node->input_value(0);
551 const auto data_pshape = data.get_partial_shape();
552 NGRAPH_CHECK(data_pshape.is_static(),
553 "Unable to convert Transpose:v1 to Reshape:v0 "
554 "if data shape is dynamic. Node: ",
556 const auto data_shape = data_pshape.to_shape();
558 const auto order_node = node->input_value(1).get_node_shared_ptr();
559 NGRAPH_CHECK(op::is_constant(order_node),
560 "Unable to convert Transpose:v1 to Reshape:v0 "
561 "if order node is not constant. Node: ",
563 const auto order_const = as_type_ptr<op::Constant>(order_node);
565 auto order = order_const->get_axis_vector_val();
566 Shape out_shape = data_shape;
569 order.resize(out_shape.size());
570 iota(begin(order), end(order), 0);
574 for (size_t i = 0; i < order.size(); ++i)
576 out_shape[i] = data_shape.at(order.at(i));
580 auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
581 replace_node(node, replacement_node);
582 return replacement_node;
585 shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
587 const auto split_lengths = node->input_value(2).get_node_shared_ptr();
589 NGRAPH_CHECK(op::is_constant(split_lengths),
590 "Unable to convert VariadicSplit:v1 to Split:v0 "
591 "if 'split_lengths' input is not constant. Node: ",
594 const auto splits = as_type_ptr<op::Constant>(split_lengths)->cast_vector<int64_t>();
595 const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
597 auto replacement_node =
598 make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
600 replace_node(node, replacement_node);
601 return replacement_node;
604 using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
606 template <typename T>
607 bool op_cast_thunk(shared_ptr<Node> node)
609 auto downgraded_node = op_cast(as_type_ptr<T>(node));
612 if (ngraph::get_provenance_enabled())
614 const std::string provenance_tag =
615 "<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
616 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
623 DispatchMap& get_dispatch_map()
625 static DispatchMap dispatch_map{
626 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
627 #include "ngraph/opsets/opset1_tbl.hpp"
632 } // namespace opset0_downgrade
634 bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
636 bool modified = false;
637 auto& dispatch_map = opset0_downgrade::get_dispatch_map();
638 auto it = dispatch_map.find(node->get_type_info());
639 if (it != dispatch_map.end())
641 modified = it->second(node);