1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/graph_util.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/util/attr_types.hpp"
27 #include "ngraph/op/util/op_types.hpp"
28 #include "ngraph/ops.hpp"
29 #include "ngraph/provenance.hpp"
30 #include "ngraph/slice_plan.hpp"
31 #include "ngraph/type.hpp"
32 #include "ngraph/validation_util.hpp"
33 #include "op/avg_pool.hpp"
34 #include "op/convolution.hpp"
35 #include "op/group_conv.hpp"
36 #include "pass/implicit_broadcast_elimination.hpp"
37 #include "pass/opset0_downgrade.hpp"
40 using namespace ngraph;
44 template <typename OpV0, typename OpV1>
45 shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
47 const auto input_arg0 = node->input_value(0);
48 const auto input_arg1 = node->input_value(1);
49 const auto autob = node->get_autob();
50 auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
51 replace_node(node, replacement_node);
52 return replacement_node;
55 template <typename OpV0, typename OpV1>
56 shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
58 auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
59 if (node->get_keep_dims())
61 string v1_op_name = string{node->get_type_name()} + ":v1";
62 string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
64 NGRAPH_CHECK(node->reduction_axes_constant(),
69 " if reduction axes are not constant (for keep_dims=true). Node: ",
71 auto output_pshape = replacement_node->get_output_partial_shape(0);
72 NGRAPH_CHECK(output_pshape.is_static(),
77 " if output shape is dynamic (for keep_dims=true). Node: ",
79 const auto output_shape = output_pshape.to_shape();
80 auto reshaped_output_shape = output_shape;
81 for (const auto& axis : node->get_reduction_axes())
83 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
85 auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
86 get_default_order(output_shape),
87 reshaped_output_shape);
88 return reshaped_product;
92 return replacement_node;
96 // Default is that we did nothing
97 shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
98 shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
100 return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
103 shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
105 auto const input_arg = node->input_value(0);
106 const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
107 const auto include_padding_in_avg_computation = !node->get_exclude_pad();
108 const auto pad_type = node->get_auto_pad();
109 const auto padding_below = node->get_pads_begin();
110 const auto padding_above = node->get_pads_end();
111 const auto window_movement_strides = node->get_strides();
112 const auto window_shape = node->get_kernel();
114 auto replacement_node = make_shared<op::v0::AvgPool>(input_arg,
116 window_movement_strides,
119 include_padding_in_avg_computation,
122 replace_node(node, replacement_node);
123 return replacement_node;
126 shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
128 auto arg = node->input_value(0);
129 auto arg_pshape = arg.get_partial_shape();
130 auto arg_rank = arg_pshape.rank();
131 auto target_shape_input = node->input_value(1);
133 shared_ptr<Node> replacement_node;
135 NGRAPH_CHECK(arg_pshape.is_static(),
136 "Unable to convert Broadcast:v1 to Broadcast:v0 "
137 "if argument shape is not static. Node: ",
139 const auto& arg_shape = arg_pshape.to_shape();
141 NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
142 auto target_shape = node->get_output_shape(0);
143 NGRAPH_CHECK(node->get_broadcast_axes().first);
145 // (Re)construct axes_mapping.
146 AxisSet broadcast_axes = node->get_broadcast_axes().second;
147 std::vector<size_t> axes_mapping{
148 ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
150 Output<Node> squeezed_arg = arg;
151 // Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
152 // the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
153 std::vector<size_t> empty_axes;
154 for (size_t a{0}; a < axes_mapping.size(); ++a)
156 if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
158 empty_axes.push_back(a);
161 // Check if arg_shape contains some more empty dimensions marked to broadcast.
162 // If axes_mapping size is less than arg_shape size, then some of arg dimensions may
163 // be equal to one and marked to broadcast.
164 if (axes_mapping.size() < arg_shape.size())
166 for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
168 if (arg_shape.at(a) == 1)
170 empty_axes.push_back(a);
174 if (!empty_axes.empty())
176 squeezed_arg = builder::squeeze(arg, empty_axes);
180 make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
181 replace_node(node, replacement_node);
182 return replacement_node;
185 shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
187 const auto data_arg = node->input_value(0);
188 const auto filters_arg = node->input_value(1);
189 const auto strides = node->get_strides();
190 const size_t num_spatial_dims = strides.size();
191 auto replacement_node = make_shared<op::v0::Convolution>(data_arg,
194 node->get_dilations(),
195 node->get_pads_begin(),
196 node->get_pads_end(),
197 Strides(num_spatial_dims, 1),
198 node->get_auto_pad());
199 replace_node(node, replacement_node);
200 return replacement_node;
203 shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
205 const auto data_arg = node->input_value(0);
206 const auto filters_arg = node->input_value(1);
208 auto data_pshape = data_arg.get_partial_shape();
209 auto filters_pshape = filters_arg.get_partial_shape();
211 NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
212 filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
213 "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
214 "if data shape N and filters shape C dimensions are not static. Node: ",
217 const size_t num_spatial_dims = data_pshape.rank().get_length() - 2;
219 const PartialShape output_pshape{node->get_output_partial_shape(0)};
220 NGRAPH_CHECK(output_pshape.is_static(),
221 "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
222 "if output shape is dynamic. Node: ",
224 Shape output_shape = output_pshape.to_shape();
226 auto replacement_node =
227 make_shared<op::v0::ConvolutionBackpropData>(output_shape,
231 node->get_dilations(),
232 node->get_pads_begin(),
233 node->get_pads_end(),
234 Strides(num_spatial_dims, 1));
235 replace_node(node, replacement_node);
236 return replacement_node;
239 shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
241 const auto input_arg0 = node->input_value(0);
242 const auto input_arg1 = node->input_value(1);
243 const auto autob = node->get_autob();
244 const bool pydiv = node->is_pythondiv();
245 auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
246 replace_node(node, replacement_node);
247 return replacement_node;
250 shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
252 shared_ptr<Node> replacement_node;
254 const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
255 const auto input_rank = node->get_input_partial_shape(0).rank();
256 if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
257 input_rank.is_static())
259 const auto output_shape = node->get_output_shape(0);
260 replacement_node = make_shared<op::Reshape>(
261 node->input_value(0), get_default_order(input_rank.get_length()), output_shape);
265 NGRAPH_CHECK(replacement_node, "Unable to convert Reshape:v1 with dynamic shape.");
268 replace_node(node, replacement_node);
269 return replacement_node;
272 shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
274 return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
277 shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
279 auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
281 NGRAPH_CHECK(axis_node,
282 "Unable to convert Gather:v1 to Gather:v0 if axis is not constant. Node: ",
286 axis_node->get_element_type() == element::i64,
287 "Unable to convert Gather:v1 to Gather:v0 with axis other type than int64. Node: ",
290 int64_t axis = axis_node->get_vector<int64_t>()[0];
292 auto replacement_node =
293 make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
294 replace_node(node, replacement_node);
295 return replacement_node;
298 shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
300 return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
303 shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
305 return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
308 shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
310 const auto data_arg = node->input_value(0);
311 const auto filters_arg = node->input_value(1);
312 const auto strides = node->get_strides();
313 const size_t num_spatial_dims = strides.size();
314 auto replacement_node = make_shared<op::v0::GroupConvolution>(data_arg,
317 node->get_dilations(),
318 node->get_pads_begin(),
319 node->get_pads_end(),
320 Strides(num_spatial_dims, 1),
321 node->get_auto_pad());
322 replace_node(node, replacement_node);
323 return replacement_node;
326 shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
328 const auto data_arg = node->input_value(0);
329 const auto filters_arg = node->input_value(1);
331 NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
332 "Unable to convert GroupConvolutionBackpropData:1 to "
333 "GroupConvolutionBackpropData:0 with dynamic data shape. Node: ",
336 NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
337 "Unable to convert GroupConvolutionBackpropData:1 to "
338 "GroupConvolutionBackpropData:0 with dynamic filters shape. Node: ",
341 auto filters_shape = filters_arg.get_shape();
342 const size_t groups = filters_shape.at(0);
344 const PartialShape output_pshape{node->get_output_partial_shape(0)};
345 NGRAPH_CHECK(output_pshape.is_static(),
346 "Unable to convert GroupConvolutionBackpropData:v1 to "
347 "GroupConvolutionBackpropData:v0 "
348 "if output_shape is dynamic. Node: ",
350 Shape output_shape = output_pshape.to_shape();
352 // Convert filters data layout from [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
353 // into [C x M/group x k1 x k2 x ... x kn]
354 filters_shape.erase(filters_shape.begin());
355 filters_shape[0] *= groups;
357 auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
359 auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
360 op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
364 node->get_dilations(),
365 node->get_pads_begin(),
366 node->get_pads_end(),
368 replace_node(node, replacement_node);
369 return replacement_node;
372 shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
374 return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
377 shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
379 return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
382 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
384 auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
385 replace_node(node, replacement_node);
386 return replacement_node;
389 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
391 return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
394 shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
396 return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
399 shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
401 return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
404 shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
406 return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
409 shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
411 return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
414 shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
416 return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
419 shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
421 const auto indices = node->input_value(0);
422 const auto depth = node->input_value(1).get_node();
423 auto on_value = node->input_value(2);
424 auto off_value = node->input_value(3);
425 const auto axis = node->get_axis();
427 NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
428 const auto output_pshape = node->get_output_partial_shape(0);
429 NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
430 const auto output_shape = output_pshape.to_shape();
432 auto one_hot = std::make_shared<ngraph::op::Convert>(
433 std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
434 on_value.get_element_type());
436 auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
437 on_value = broadcasted_values[1];
438 off_value = broadcasted_values[2];
440 auto replacement_node = one_hot * (on_value - off_value) + off_value;
442 replace_node(node, replacement_node);
443 return replacement_node;
446 shared_ptr<Node> op_cast(shared_ptr<op::v1::Pad> node)
448 const auto pad_arg = node->input_value(0);
449 Output<Node> pad_value;
450 if (node->get_input_size() == 4)
452 pad_value = node->input_value(3);
457 make_shared<op::Constant>(pad_arg.get_element_type(), Shape{}, vector<float>{0.f});
459 auto replacement_node = make_shared<op::v0::Pad>(
460 pad_arg, pad_value, node->get_pads_begin(), node->get_pads_end(), node->get_pad_mode());
462 replace_node(node, replacement_node);
463 return replacement_node;
466 shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
468 return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
471 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
473 auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
474 replace_node(node, replacement_node);
475 return replacement_node;
478 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
480 // ReduceMean = Sum / Count
481 auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
483 // Count = Sum(Constant(1, shape=data.shape))
484 const auto data = node->input_value(0);
485 const auto axes = node->input_value(1);
486 const auto const_node =
487 op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
488 std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
490 // Support keep_dims attribute
491 if (node->get_keep_dims())
493 // In order to keep the original dimensions we need to reshape the Count node
494 // before we use it in Divide with NUMPY broadcast
495 auto output_shape = count_node->get_shape();
496 auto reshaped_output_shape = output_shape;
497 for (const auto& axis : node->get_reduction_axes())
499 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
501 count_node = make_shared<op::Reshape>(
502 count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
505 const auto replacement_node =
506 std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
507 replace_node(node, replacement_node);
508 return replacement_node;
511 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
513 auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
514 replace_node(node, replacement_node);
515 return replacement_node;
518 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
520 auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
521 replace_node(node, replacement_node);
522 return replacement_node;
525 shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
527 auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
528 replace_node(node, replacement_node);
529 return replacement_node;
532 shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
534 auto axes_node = node->input_value(1).get_node_shared_ptr();
535 NGRAPH_CHECK(op::is_constant(axes_node),
536 "Unable to convert Reverse:v1 to Reverse:v0 "
537 "if reduction axes are not constant. Node: ",
539 const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
541 if (node->get_mode() == op::v1::Reverse::Mode::INDEX)
543 axes = axes_node_const->get_axis_vector_val();
547 auto axes_mask = axes_node_const->get_vector<bool>();
548 for (size_t i = 0; i < axes_mask.size(); ++i)
556 auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
558 replace_node(node, replacement_node);
559 return replacement_node;
562 shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
564 ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
565 auto replacement_node = make_shared<op::v0::Select>(
566 node->input_value(0), node->input_value(1), node->input_value(2));
567 replace_node(node, replacement_node);
568 return replacement_node;
571 shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
573 auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
575 for (auto i = 0; i < mask.size(); ++i)
585 const auto input_data = node->input_value(0);
586 const auto input_data_pshape = input_data.get_partial_shape();
588 NGRAPH_CHECK(input_data_pshape.is_static(),
589 "Unable to convert StridedSlice:v1 to Slice:v0 "
590 "if input rank is not static. Node: ",
593 const auto begin_const =
594 as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
595 const auto end_const =
596 as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
597 const auto strides = as_type_ptr<op::Constant>(node->input_value(3).get_node_shared_ptr());
599 NGRAPH_CHECK(begin_const && end_const && strides,
600 "Unable to convert StridedSlice:v1 to Slice:v0 "
601 "if begin, end or strides are not constant. Node: ",
604 SlicePlan p = make_slice_plan(input_data_pshape.to_shape(),
605 begin_const->get_vector<int64_t>(),
606 end_const->get_vector<int64_t>(),
607 strides->get_vector<int64_t>(),
608 convert_mask_to_axes(node->get_begin_mask()),
609 convert_mask_to_axes(node->get_end_mask()),
610 convert_mask_to_axes(node->get_new_axis_mask()),
611 convert_mask_to_axes(node->get_shrink_axis_mask()),
612 convert_mask_to_axes(node->get_ellipsis_mask()));
614 shared_ptr<Node> replacement_node =
615 make_shared<op::v0::Slice>(input_data,
616 Coordinate(p.begins.begin(), p.begins.end()),
617 Coordinate(p.ends.begin(), p.ends.end()),
618 Strides(p.strides.begin(), p.strides.end()));
620 if (p.reshape_in_shape != p.reshape_out_shape)
623 make_shared<op::Reshape>(replacement_node,
624 ngraph::get_default_order(p.reshape_in_shape),
625 p.reshape_out_shape);
628 if (!p.reverse_axes.empty())
630 replacement_node = make_shared<op::Reverse>(replacement_node, p.reverse_axes);
633 replace_node(node, replacement_node);
634 return replacement_node;
637 shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
639 const auto num_splits = node->get_num_splits();
641 auto replacement_node =
642 make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
644 replace_node(node, replacement_node);
645 return replacement_node;
648 shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
650 return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
653 shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
655 const auto axis = node->get_axis();
656 const auto sort_type = node->get_sort_type();
657 const auto index_elem_type = node->get_index_element_type();
660 switch (node->get_mode())
662 case op::v1::TopK::Mode::MAX: compute_max = true; break;
663 case op::v1::TopK::Mode::MIN: compute_max = false; break;
667 const auto arg_node = node->input_value(0);
668 const auto k_node = node->input_value(1);
670 auto replacement_node = make_shared<op::v0::TopK>(
671 arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
673 // values output will be 0, indices 1
674 vector<int64_t> output_order{1, 0};
675 replace_node(node, replacement_node, output_order);
676 return replacement_node;
679 shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
681 const auto data = node->input_value(0);
683 const auto data_pshape = data.get_partial_shape();
684 NGRAPH_CHECK(data_pshape.is_static(),
685 "Unable to convert Transpose:v1 to Reshape:v0 "
686 "if data shape is dynamic. Node: ",
688 const auto data_shape = data_pshape.to_shape();
690 const auto order_node = node->input_value(1).get_node_shared_ptr();
691 NGRAPH_CHECK(op::is_constant(order_node),
692 "Unable to convert Transpose:v1 to Reshape:v0 "
693 "if order node is not constant. Node: ",
695 const auto order_const = as_type_ptr<op::Constant>(order_node);
697 auto order = order_const->get_axis_vector_val();
698 Shape out_shape = data_shape;
701 order.resize(out_shape.size());
702 iota(begin(order), end(order), 0);
706 for (size_t i = 0; i < order.size(); ++i)
708 out_shape[i] = data_shape.at(order.at(i));
712 auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
713 replace_node(node, replacement_node);
714 return replacement_node;
717 shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
719 const auto split_lengths = node->input_value(2).get_node_shared_ptr();
721 NGRAPH_CHECK(op::is_constant(split_lengths),
722 "Unable to convert VariadicSplit:v1 to Split:v0 "
723 "if 'split_lengths' input is not constant. Node: ",
726 const auto splits = as_type_ptr<op::Constant>(split_lengths)->cast_vector<int64_t>();
727 const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
729 auto replacement_node =
730 make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
732 replace_node(node, replacement_node);
733 return replacement_node;
736 using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
738 template <typename T>
739 bool op_cast_thunk(shared_ptr<Node> node)
741 auto downgraded_node = op_cast(as_type_ptr<T>(node));
744 if (ngraph::get_provenance_enabled())
746 const std::string provenance_tag =
747 "<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
748 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
755 DispatchMap& get_dispatch_map()
757 static DispatchMap dispatch_map{
758 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
759 #include "ngraph/opsets/opset1_tbl.hpp"
766 bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
768 bool modified = false;
769 auto& dispatch_map = get_dispatch_map();
770 auto it = dispatch_map.find(node->get_type_info());
771 if (it != dispatch_map.end())
773 modified = it->second(node);