1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "../C/index_select.h"
19 #include "primitive.hpp"
24 /// @brief Select index, which will be copied to the output..
26 /// @details Applies index selecting along specified dimension. The indices, which will be copied are specifed by
30 /// @n <tt>input_sizes = (1, 2, 4, 2)</tt>
31 /// @n <tt>input_values = (a, b, c, d)</tt>
32 /// @n <tt> (e, f, g, h)</tt>
33 /// @n <tt>indices_sizes = (1, 1, 6, 1)</tt>
34 /// @n <tt>indices_values = {0, 0, 1, 1, 3, 3}</tt>
35 /// @n For axis: along_x:
36 /// @n <tt>output_sizes = (1, 2, 6, 2)</tt>
37 /// @n <tt>output_values = (a, a, b, b, d, d)</tt>
38 /// @n <tt> (e, e, f, f, h, h)</tt>
40 /// @n The resulting output will have sizes equal to input_size with changed concrete tensor size to inidices x size.
42 /// @n@b Requirements:
43 /// @n - @c input must be a valid primitive_id, which output's format is bfyx/yxfb;
44 /// @n - @c indices must be a valid primitive_id, which output's layout is: (bfyx/yxfb, i32, {1, 1, indicies_size, 1})
45 /// @n - @c axis - valid index_select_axis_name instance.
46 /// @n Breaking any of this conditions will cause exeption throw.
47 struct index_select : public primitive_base<index_select, CLDNN_PRIMITIVE_DESC(index_select)>
49 CLDNN_DECLARE_PRIMITIVE(index_select)
51 /// @brief Constructs index_select primitive / layer.
53 /// @param id An identifier of new primitive.
54 /// @param input An identifier of primitive, which is an input for newly created
55 /// index_select primitive.
56 /// @param indicies An identifer of primitive, which have indices in memory distributed along x.
57 /// @param axis Axis of index selecting.
58 /// @param output_padding Optional padding for output from primitive.
60 const primitive_id& id,
61 const primitive_id& input,
62 const primitive_id& indices,
63 index_select_axis_name axis = index_select_axis_name::along_b,
64 const padding& output_padding = padding()
66 : primitive_base(id, { input, indices }, output_padding)
71 /// @brief Constructs index_select primitive / layer.
73 /// @param id An identifier of new primitive.
74 /// @param input An identifier of primitive, which is an input for newly created
75 /// index_select primitive.
76 /// @param axis Axis of index selecting.
77 /// @param output_padding Optional padding for output from primitive.
79 const primitive_id& id,
80 const primitive_id& input,
81 index_select_axis_name axis = index_select_axis_name::along_b,
82 const padding& output_padding = padding()
84 : primitive_base(id, { input }, output_padding)
89 /// @brief Constructs index_select primitive / layer.
91 /// @param id An identifier of new primitive.
92 /// @param input An identifier of primitive, which is an input for newly created
93 /// index_select primitive.
94 /// @param axis Vector of axes of index selecting.
95 /// @param output_padding Optional padding for output from primitive.
97 const primitive_id& id,
98 const primitive_id& input,
99 const std::vector<index_select_axis_name>& axis = { index_select_axis_name::along_b },
100 const padding& output_padding = padding()
102 : primitive_base(id, { input }, output_padding)
107 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{broadcast}
108 index_select(const dto* dto)
109 : primitive_base(dto)
110 , axis(dto->axis, dto->axis + dto->axis_num)
111 , reverse(dto->reverse)
114 /// @brief A list of axes of index selecting
115 std::vector<index_select_axis_name> axis;
116 /// @brief Do index_select in reverse order on axis/axes.
120 void update_dto(dto& dto) const override
122 dto.axis = axis.data();
123 dto.axis_num = (int)axis.size();
124 dto.reverse = reverse;