2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/split.h"
20 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Performs split operation on input.
32 /// @details splits the input data into n parts, for each user provides name and offsets.
33 /// @n User cannot use split primitive directly.
34 /// @n It is needed to refer to the output ids with the name "<split_prim_id>:<split_output_id>".
37 /// @n - offsets1 < offsets2 < offsets3 < ...
38 /// @n - size[n] = offsets[n+1] - offsets[n];
39 /// @n - last element: size[n] = split_input.size - offsets[n];
40 /// @n - no buffer overlapping, as the output size is calculated using offset and input size
41 /// @n - split primitive id cannot be used by any other primitive (user needs to use output_ids only)
42 /// @n Breaking any of this conditions will cause exeption throw.
45 /// @n Splitting output to 2 parts by the features:
46 /// @n input_size = { 2, 4, 3, 5 };
47 /// @n split_id = "split";
48 /// @n output_ids_offsets[0] = { "out0", { 0,0,0,0 } };
49 /// @n output_ids_offsets[1] = { "out1", { 0,2,0,0 } };
50 /// @n After split there would be 2 primitives: "split:out0" and "split:out1" which contain 2 feature maps (lower and upper)
51 struct split : public primitive_base<split, CLDNN_PRIMITIVE_DESC(split)>
53 CLDNN_DECLARE_PRIMITIVE(split)
55 /// @brief Constructs split primitive.
56 /// @param id This primitive id.
57 /// @param input Input primitive id.
58 /// @param output_ids_offsets Pairs of output_ids and offsets
60 const primitive_id& id,
61 const primitive_id& input,
62 const std::vector<std::pair<primitive_id, tensor> >& output_ids_offsets,
63 const padding& output_padding = padding()
65 :primitive_base(id, {input}, output_padding)
66 , output_ids(_output_ids.cpp_ids)
67 , output_offsets(extract_tensor_vector(output_ids_offsets))
68 , _output_ids(extract_primitive_vector(output_ids_offsets))
69 , _output_offsets(tensor_vector_to_cldnn_vector(output_offsets))
73 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{split}
76 , output_ids(_output_ids.cpp_ids)
77 , output_offsets(tensor_arr_to_vector(dto->output_offsets))
78 , _output_ids(dto->output_ids)
79 , _output_offsets(tensor_arr_to_cldnn_vector(dto->output_offsets))
83 /// @brief List of output_ids.
84 fixed_size_vector_ref output_ids;
85 /// @brief Array of tensors with offsets.
86 std::vector<tensor> output_offsets;
89 primitive_id_arr _output_ids;
90 std::vector<cldnn_tensor> _output_offsets;
92 void update_dto(dto& dto) const override
94 dto.output_ids = _output_ids.ref();
95 dto.output_offsets = tensor_vector_to_arr(_output_offsets);
98 static std::vector<primitive_id> extract_primitive_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
100 std::vector<primitive_id> res;
101 for (auto &stor_pair : stor)
102 res.push_back(stor_pair.first);
107 static std::vector<tensor> extract_tensor_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
109 std::vector<tensor> res;
110 for (auto &stor_pair : stor)
111 res.push_back(stor_pair.second);