Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / builder / split.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/builder/split.hpp"
18 #include "ngraph/op/slice.hpp"
19 #include "ngraph/opsets/opset1.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace ngraph;
24
25 namespace
26 {
27     inline size_t get_valid_array_index(size_t idx, size_t axis_size)
28     {
29         return std::min(idx, axis_size);
30     }
31
32     std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
33                                              const std::vector<size_t>& axes,
34                                              const std::vector<size_t>& starts,
35                                              const std::vector<size_t>& ends)
36     {
37         std::vector<size_t> upper_bounds{output.get_shape()};
38         std::vector<size_t> lower_bounds(upper_bounds.size());
39         for (size_t index{0}; index < axes.size(); ++index)
40         {
41             size_t axis{axes.at(index)};
42             lower_bounds.at(axis) =
43                 get_valid_array_index(starts.at(index), output.get_shape().at(axis));
44             upper_bounds.at(axis) =
45                 get_valid_array_index(ends.at(index), output.get_shape().at(axis));
46         }
47         return std::static_pointer_cast<op::Slice>(
48             std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
49                 ->add_provenance_group_members_above({output}));
50     }
51 }
52
53 OutputVector
54     builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
55 {
56     size_t start_index{0};
57     OutputVector outputs;
58     for (const auto& length_part : length_parts)
59     {
60         size_t end_index{start_index + length_part};
61         outputs.push_back(make_ng_slice(value, {axis}, {start_index}, {end_index}));
62         start_index = end_index;
63     }
64     return outputs;
65 }
66
67 OutputVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
68 {
69     size_t axis_to_split{static_cast<size_t>(axis)};
70     if (axis < 0)
71     {
72         axis_to_split = value.get_shape().size() + axis;
73     }
74
75     size_t length_axis_to_split{value.get_shape().at(axis_to_split)};
76     std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
77     return split(value, length_parts, axis_to_split);
78 }
79
80 OutputVector builder::opset1::split(const Output<Node>& value,
81                                     const std::vector<size_t>& split_lengths,
82                                     int64_t axis)
83 {
84     const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
85     const auto split_lengths_node =
86         ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
87     const auto variadic_split =
88         std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
89
90     return variadic_split->outputs();
91 }
92
93 OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
94 {
95     const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
96     const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
97
98     return split->outputs();
99 }