Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / split.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/split.h"
20 #include "primitive.hpp"
21
22 namespace cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30
31 /// @brief Performs split operation on input.
32 /// @details splits the input data into n parts, for each user provides name and offsets.
33 /// @n User cannot use split primitive directly.
34 /// @n It is needed to refer to the output ids with the name "<split_prim_id>:<split_output_id>".
35 /// @n
36 /// @n\b Assumptions 
37 /// @n - offsets1 < offsets2 < offsets3 < ...
38 /// @n - size[n] = offsets[n+1] - offsets[n];
39 /// @n - last element: size[n] = split_input.size - offsets[n];
40 /// @n - no buffer overlapping, as the output size is calculated using offset and input size
41 /// @n - split primitive id cannot be used by any other primitive (user needs to use output_ids only)
42 /// @n Breaking any of this conditions will cause exeption throw.
43 /// @n
44 /// @n\b Example:
45 /// @n Splitting output to 2 parts by the features:
46 /// @n input_size = { 2, 4, 3, 5 };
47 /// @n split_id = "split";
48 /// @n output_ids_offsets[0] = { "out0", { 0,0,0,0 } };
49 /// @n output_ids_offsets[1] = { "out1", { 0,2,0,0 } };
50 /// @n After split there would be 2 primitives: "split:out0" and "split:out1" which contain 2 feature maps (lower and upper)
51 struct split : public primitive_base<split, CLDNN_PRIMITIVE_DESC(split)>
52 {
53     CLDNN_DECLARE_PRIMITIVE(split)
54
55     /// @brief Constructs split primitive.
56     /// @param id This primitive id.
57     /// @param input Input primitive id.
58     /// @param output_ids_offsets Pairs of output_ids and offsets
59     split(
60         const primitive_id& id,
61         const primitive_id& input,
62         const std::vector<std::pair<primitive_id, tensor> >& output_ids_offsets,
63         const padding& output_padding = padding()
64     )
65         :primitive_base(id, {input}, output_padding)
66         , output_ids(_output_ids.cpp_ids)
67         , output_offsets(extract_tensor_vector(output_ids_offsets))
68         , _output_ids(extract_primitive_vector(output_ids_offsets))
69         , _output_offsets(tensor_vector_to_cldnn_vector(output_offsets))
70     {
71     }
72
73     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{split}
74     split(const dto* dto)
75         :primitive_base(dto)
76         , output_ids(_output_ids.cpp_ids)
77         , output_offsets(tensor_arr_to_vector(dto->output_offsets))
78         , _output_ids(dto->output_ids)
79         , _output_offsets(tensor_arr_to_cldnn_vector(dto->output_offsets))
80     {
81     }
82
83     /// @brief List of output_ids.
84     fixed_size_vector_ref output_ids;
85     /// @brief Array of tensors with offsets.
86     std::vector<tensor> output_offsets;
87
88 protected:
89     primitive_id_arr _output_ids;
90     std::vector<cldnn_tensor> _output_offsets;
91
92     void update_dto(dto& dto) const override
93     {
94         dto.output_ids = _output_ids.ref();
95         dto.output_offsets = tensor_vector_to_arr(_output_offsets);
96     }
97
98     static std::vector<primitive_id> extract_primitive_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
99     {
100         std::vector<primitive_id> res;
101         for (auto &stor_pair : stor)
102             res.push_back(stor_pair.first);
103
104         return res;
105     }
106
107     static std::vector<tensor> extract_tensor_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
108     {
109         std::vector<tensor> res;
110         for (auto &stor_pair : stor)
111             res.push_back(stor_pair.second);
112
113         return res;
114     }
115 };
116 /// @}
117 /// @}
118 /// @}
119 }