2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "index_select_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id index_select_type_id()
27 static primitive_type_base<index_select> instance;
31 layout index_select_inst::calc_output_layout(index_select_node const& node)
33 assert((bool)node.get_primitive()->output_data_type == false
34 && "Output data type forcing is not supported for "
35 "index_select_node!");
36 auto desc = node.get_primitive();
38 auto input_layout = node.input().get_output_layout();
41 int32_t output_b = input_layout.size.batch[0];
42 int32_t output_f = input_layout.size.feature[0];
43 int32_t output_x = input_layout.size.spatial[0];
44 int32_t output_y = input_layout.size.spatial[1];
46 if (!node.get_reverse()) {
47 auto indices_layout = node.indices().get_output_layout();
48 auto indices_size = indices_layout.size.spatial[0];
49 auto axes = node.get_axes();
50 for (size_t i = 0; i < axes.size(); i++)
54 case index_select_axis_name::along_b:
55 output_b = indices_size;
57 case index_select_axis_name::along_f:
58 output_f = indices_size;
60 case index_select_axis_name::along_x:
61 output_x = indices_size;
63 case index_select_axis_name::along_y:
64 output_y = indices_size;
67 CLDNN_ERROR_MESSAGE(node.id(), "UNSUPPORTED AXIS");
72 return layout{ input_layout.data_type, input_layout.format, { output_b, output_f, output_x, output_y } };
75 std::string index_select_inst::to_string(index_select_node const& node)
77 auto desc = node.get_primitive();
78 auto node_info = node.desc_to_json();
79 std::stringstream primitive_description;
81 std::string axis_str = "";
82 for (size_t i = 0; i < desc->axis.size(); i++)
84 switch (desc->axis.at(i))
86 case index_select_axis_name::along_b:
87 axis_str += "along_b, ";
89 case index_select_axis_name::along_f:
90 axis_str += "along_f, ";
92 case index_select_axis_name::along_y:
93 axis_str += "along_y, ";
95 case index_select_axis_name::along_x:
96 axis_str += "along_x, ";
99 axis_str += "not supported axis, ";
104 json_composite index_select_info;
105 index_select_info.add("axes", axis_str);
107 node_info->add("index_select_info", index_select_info);
108 node_info->dump(primitive_description);
110 return primitive_description.str();
113 index_select_inst::typed_primitive_inst(network_impl& network, index_select_node const& node)
114 : parent(network, node)
116 auto& input = node.input();
117 auto input_layout = input.get_output_layout();
118 auto const node_id = node.id();
120 CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", input_layout.format, "supported input format", format::bfyx, format::yxfb);
122 if (!node.get_reverse())
124 auto& indices = node.indices();
125 auto indices_layout = indices.get_output_layout();
127 CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
128 CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
129 CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
130 CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
131 CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
132 CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);