Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / index_select.hpp
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
16 #pragma once
17
18 #include "../C/index_select.h"
19 #include "primitive.hpp"
20
21
22 namespace cldnn
23 {
24 /// @brief Select index, which will be copied to the output..
25 ///
26 /// @details Applies index selecting along specified dimension. The indices, which will be copied are specifed by 
27 ///          by @c indices.
28 /// @n
29 /// @n Example:
30 /// @n      <tt>input_sizes  = (1, 2, 4, 2)</tt>
31 /// @n      <tt>input_values = (a, b, c, d)</tt>
32 /// @n      <tt>               (e, f, g, h)</tt>
33 /// @n      <tt>indices_sizes  = (1, 1, 6, 1)</tt>
34 /// @n      <tt>indices_values = {0, 0, 1, 1, 3, 3}</tt>                  
35 /// @n  For axis: along_x:
36 /// @n      <tt>output_sizes  = (1, 2, 6, 2)</tt>
37 /// @n      <tt>output_values = (a, a, b, b, d, d)</tt>
38 /// @n      <tt>                (e, e, f, f, h, h)</tt>
39 /// @n
40 /// @n The resulting output will have sizes equal to input_size with changed concrete tensor size to inidices x size.
41 /// @n
42 /// @n@b Requirements:
43 /// @n - @c input must be a valid primitive_id, which output's format is bfyx/yxfb;
44 /// @n - @c indices must be a valid primitive_id, which output's layout is: (bfyx/yxfb, i32, {1, 1, indicies_size, 1})
45 /// @n - @c axis - valid index_select_axis_name instance. 
46 /// @n Breaking any of this conditions will cause exeption throw.
47 struct index_select : public primitive_base<index_select, CLDNN_PRIMITIVE_DESC(index_select)>
48 {
49     CLDNN_DECLARE_PRIMITIVE(index_select)
50
51     /// @brief Constructs index_select primitive / layer.
52     ///
53     /// @param id                 An identifier of new primitive.
54     /// @param input              An identifier of primitive, which is an input for newly created
55     ///                           index_select primitive.
56     /// @param indicies           An identifer of primitive, which have indices in memory distributed along x. 
57     /// @param axis               Axis of index selecting.
58     /// @param output_padding     Optional padding for output from primitive.
59     index_select(
60         const primitive_id& id,
61         const primitive_id& input,
62         const primitive_id& indices,
63         index_select_axis_name axis = index_select_axis_name::along_b,
64         const padding& output_padding = padding()
65     )
66         : primitive_base(id, { input, indices }, output_padding)
67         , axis( { axis } )
68         , reverse(false)
69     {}
70
71     /// @brief Constructs index_select primitive / layer.
72     ///
73     /// @param id                 An identifier of new primitive.
74     /// @param input              An identifier of primitive, which is an input for newly created
75     ///                           index_select primitive.
76     /// @param axis               Axis of index selecting.
77     /// @param output_padding     Optional padding for output from primitive.
78     index_select(
79         const primitive_id& id,
80         const primitive_id& input,
81         index_select_axis_name axis = index_select_axis_name::along_b,
82         const padding& output_padding = padding()
83     )
84         : primitive_base(id, { input }, output_padding)
85         , axis( { axis } )
86         , reverse(true)
87     {}
88
89     /// @brief Constructs index_select primitive / layer.
90     ///
91     /// @param id                 An identifier of new primitive.
92     /// @param input              An identifier of primitive, which is an input for newly created
93     ///                           index_select primitive.
94     /// @param axis               Vector of axes of index selecting.
95     /// @param output_padding     Optional padding for output from primitive.
96     index_select(
97         const primitive_id& id,
98         const primitive_id& input,
99         const std::vector<index_select_axis_name>& axis = { index_select_axis_name::along_b },
100         const padding& output_padding = padding()
101     )
102         : primitive_base(id, { input }, output_padding)
103         , axis(axis)
104         , reverse(true)
105     {}
106
107     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{broadcast}
108     index_select(const dto* dto)
109         : primitive_base(dto)
110         , axis(dto->axis, dto->axis + dto->axis_num)
111         , reverse(dto->reverse)
112     {}
113
114     /// @brief A list of axes of index selecting
115     std::vector<index_select_axis_name> axis;
116     /// @brief Do index_select in reverse order on axis/axes.
117     bool reverse;
118
119 protected:
120     void update_dto(dto& dto) const override
121     {
122         dto.axis = axis.data();
123         dto.axis_num = (int)axis.size();
124         dto.reverse = reverse;
125     }
126 };
127 /// @}
128 /// @}
129 /// @}
130 }