Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / index_select.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "index_select_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
22
23 namespace cldnn
24 {
25         primitive_type_id index_select_type_id()
26         {
27                 static primitive_type_base<index_select> instance;
28                 return &instance;
29         }
30
31         layout index_select_inst::calc_output_layout(index_select_node const& node)
32         {
33         assert((bool)node.get_primitive()->output_data_type == false
34                && "Output data type forcing is not supported for "
35                   "index_select_node!");
36         auto desc = node.get_primitive();
37
38         auto input_layout = node.input().get_output_layout();
39         
40         
41         int32_t output_b = input_layout.size.batch[0];
42         int32_t output_f = input_layout.size.feature[0];
43         int32_t output_x = input_layout.size.spatial[0];
44         int32_t output_y = input_layout.size.spatial[1];
45
46         if (!node.get_reverse()) {
47             auto indices_layout = node.indices().get_output_layout();
48             auto indices_size = indices_layout.size.spatial[0];
49             auto axes = node.get_axes();
50             for (size_t i = 0; i < axes.size(); i++)
51             {
52                 switch (axes[i])
53                 {
54                 case index_select_axis_name::along_b:
55                     output_b = indices_size;
56                     break;
57                 case index_select_axis_name::along_f:
58                     output_f = indices_size;
59                     break;
60                 case index_select_axis_name::along_x:
61                     output_x = indices_size;
62                     break;
63                 case index_select_axis_name::along_y:
64                     output_y = indices_size;
65                     break;
66                 default:
67                     CLDNN_ERROR_MESSAGE(node.id(), "UNSUPPORTED AXIS");
68                     break;
69                 }
70             }
71         }
72         return layout{ input_layout.data_type, input_layout.format, { output_b, output_f, output_x, output_y } };
73         }
74
75         std::string index_select_inst::to_string(index_select_node const& node)
76         {
77                 auto desc = node.get_primitive();
78                 auto node_info = node.desc_to_json();
79                 std::stringstream primitive_description;
80
81         std::string axis_str = "";
82         for (size_t i = 0; i < desc->axis.size(); i++)
83         {
84             switch (desc->axis.at(i))
85             {
86             case index_select_axis_name::along_b:
87                 axis_str += "along_b, ";
88                 break;
89             case index_select_axis_name::along_f:
90                 axis_str += "along_f, ";
91                 break;
92             case index_select_axis_name::along_y:
93                 axis_str += "along_y, ";
94                 break;
95             case index_select_axis_name::along_x:
96                 axis_str += "along_x, ";
97                 break;
98             default:
99                 axis_str += "not supported axis, ";
100                 break;
101             }
102         }
103
104         json_composite index_select_info;
105         index_select_info.add("axes", axis_str);
106
107         node_info->add("index_select_info", index_select_info);
108                 node_info->dump(primitive_description);
109
110                 return primitive_description.str();
111         }
112
113     index_select_inst::typed_primitive_inst(network_impl& network, index_select_node const& node)
114                 : parent(network, node)
115         {
116         auto& input = node.input();
117         auto input_layout = input.get_output_layout();
118         auto const node_id = node.id();
119
120         CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", input_layout.format, "supported input format", format::bfyx, format::yxfb);
121         
122         if (!node.get_reverse())
123         {
124             auto& indices = node.indices();
125             auto indices_layout = indices.get_output_layout();
126
127             CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
128             CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
129             CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
130             CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
131             CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
132             CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);
133         }
134         }
135 }