1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "ngraph/builder/split.hpp"
18 #include "ngraph/op/slice.hpp"
19 #include "ngraph/opsets/opset1.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
23 using namespace ngraph;
27 inline size_t get_valid_array_index(size_t idx, size_t axis_size)
29 return std::min(idx, axis_size);
32 std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
33 const std::vector<size_t>& axes,
34 const std::vector<size_t>& starts,
35 const std::vector<size_t>& ends)
37 std::vector<size_t> upper_bounds{output.get_shape()};
38 std::vector<size_t> lower_bounds(upper_bounds.size());
39 for (size_t index{0}; index < axes.size(); ++index)
41 size_t axis{axes.at(index)};
42 lower_bounds.at(axis) =
43 get_valid_array_index(starts.at(index), output.get_shape().at(axis));
44 upper_bounds.at(axis) =
45 get_valid_array_index(ends.at(index), output.get_shape().at(axis));
47 return std::static_pointer_cast<op::Slice>(
48 std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
49 ->add_provenance_group_members_above({output}));
54 builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
56 size_t start_index{0};
58 for (const auto& length_part : length_parts)
60 size_t end_index{start_index + length_part};
61 outputs.push_back(make_ng_slice(value, {axis}, {start_index}, {end_index}));
62 start_index = end_index;
67 OutputVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
69 size_t axis_to_split{static_cast<size_t>(axis)};
72 axis_to_split = value.get_shape().size() + axis;
75 size_t length_axis_to_split{value.get_shape().at(axis_to_split)};
76 std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
77 return split(value, length_parts, axis_to_split);
80 OutputVector builder::opset1::split(const Output<Node>& value,
81 const std::vector<size_t>& split_lengths,
84 const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
85 const auto split_lengths_node =
86 ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
87 const auto variadic_split =
88 std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
90 return variadic_split->outputs();
93 OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
95 const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
96 const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
98 return split->outputs();