Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / arg_max_min.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "arg_max_min_inst.h"
19 #include "primitive_type_base.h"
20 #include "sliding_window_utils.h"
21 #include "error_handler.h"
22 #include "json_object.h"
23 #include <string>
24 #include <limits>
25
26 namespace cldnn {
27 primitive_type_id arg_max_min::type_id() {
28     static primitive_type_base<arg_max_min> instance;
29     return &instance;
30 }
31
32 layout arg_max_min_inst::calc_output_layout(arg_max_min_node const& node) {
33     auto desc = node.get_primitive();
34     auto input_layout = node.input().get_output_layout();
35     auto output_data_type = desc->output_data_type ? *desc->output_data_type : input_layout.data_type;
36     auto size_check = [&](size_t tensor_size) {
37         size_t max_size;
38         // lowest integer not representable in floating point type = 2^(mantissa_bits + 1) + 1
39         // https://stackoverflow.com/questions/3793838/which-is-the-first-integer-that-an-ieee-754-float-is-incapable-of-representing-e
40         if (output_data_type == data_types::f32) {
41             max_size = (1 << std::numeric_limits<float>::digits);
42         } else if (output_data_type == data_types::f16) {
43             // mantissa_bits for fp16 = 10
44             max_size = (1 << 11);
45         } else if (output_data_type == data_types::u8) {
46             max_size = std::numeric_limits<uint8_t>::max();
47         } else {
48             max_size = std::numeric_limits<size_t>::max();
49         }
50
51         if (tensor_size > max_size) {
52             CLDNN_ERROR_GREATER_THAN(node.id(),
53                                      "Reduced tensor size",
54                                      tensor_size,
55                                      "Maximum output data type value",
56                                      max_size,
57                                      "Current output data type is unable to hold maximum index of a tensor.");
58         }
59     };
60     auto format = input_layout.format;
61     if (desc->with_axis) {
62         switch (desc->axis) {
63             case arg_max_min::x:
64                 size_check(input_layout.size.spatial[0]);
65                 if (format == cldnn::format::bfzyx)
66                     return layout{output_data_type,
67                                   format::bfzyx,
68                                   tensor{input_layout.size.batch[0],
69                                          input_layout.size.feature[0],
70                                          (int32_t)desc->top_k,
71                                          input_layout.size.spatial[1],
72                                          input_layout.size.spatial[2]}};
73                 else
74                     return layout{output_data_type,
75                                   format,
76                                   tensor{input_layout.size.batch[0],
77                                          input_layout.size.feature[0],
78                                          (int32_t)desc->top_k,
79                                          input_layout.size.spatial[1]}};
80             case arg_max_min::y:
81                 size_check(input_layout.size.spatial[1]);
82                 if (format == cldnn::format::bfzyx)
83                     return layout{output_data_type,
84                                   format::bfzyx,
85                                   tensor{input_layout.size.batch[0],
86                                          input_layout.size.feature[0],
87                                          input_layout.size.spatial[0],
88                                          (int32_t)desc->top_k,
89                                          input_layout.size.spatial[2]}};
90                 else
91                     return layout{output_data_type,
92                                   format,
93                                   tensor{input_layout.size.batch[0],
94                                          input_layout.size.feature[0],
95                                          input_layout.size.spatial[0],
96                                          (int32_t)desc->top_k}};
97             case arg_max_min::feature:
98                 size_check(input_layout.size.feature[0]);
99                 if (format == cldnn::format::bfzyx)
100                     return layout{output_data_type,
101                                   format::bfzyx,
102                                   tensor{input_layout.size.batch[0],
103                                          (int32_t)desc->top_k,
104                                          input_layout.size.spatial[0],
105                                          input_layout.size.spatial[1],
106                                          input_layout.size.spatial[2]}};
107                 else
108                     return layout{output_data_type,
109                                   format,
110                                   tensor{input_layout.size.batch[0],
111                                          (int32_t)desc->top_k,
112                                          input_layout.size.spatial[0],
113                                          input_layout.size.spatial[1]}};
114             case arg_max_min::batch:
115                 size_check(input_layout.size.batch[0]);
116                 if (format == cldnn::format::bfzyx)
117                     return layout{output_data_type,
118                                   format::bfzyx,
119                                   tensor{(int32_t)desc->top_k,
120                                          input_layout.size.feature[0],
121                                          input_layout.size.spatial[0],
122                                          input_layout.size.spatial[1],
123                                          input_layout.size.spatial[2]}};
124                 else
125                     return layout{output_data_type,
126                                   format,
127                                   tensor{(int32_t)desc->top_k,
128                                          input_layout.size.feature[0],
129                                          input_layout.size.spatial[0],
130                                          input_layout.size.spatial[1]}};
131             case arg_max_min::z:
132                 size_check(input_layout.size.spatial[2]);
133                 return layout{output_data_type,
134                               format::bfzyx,
135                               tensor{input_layout.size.batch[0],
136                                      input_layout.size.feature[0],
137                                      input_layout.size.spatial[0],
138                                      input_layout.size.spatial[1],
139                                      (int32_t)desc->top_k}};
140             default:
141                 break;
142         }
143     }
144     size_check(input_layout.size.feature[0] * input_layout.size.spatial[0] * input_layout.size.spatial[1]);
145     return layout{output_data_type,
146                   input_layout.format,
147                   tensor{input_layout.size.batch[0], 1, (int32_t)desc->top_k, 1}};
148 }
149
150 std::string arg_max_min_inst::to_string(arg_max_min_node const& node) {
151     auto desc = node.get_primitive();
152     auto node_info = node.desc_to_json();
153     auto axis = desc->with_axis ? "true" : "false";
154     auto out_type = desc->output_type ? "max" : "min";
155
156     std::stringstream primitive_description;
157
158     json_composite conv_info;
159     conv_info.add("top_k", desc->top_k);
160     conv_info.add("with axis", axis);
161     if (desc->with_axis)
162         conv_info.add("axis", desc->axis);
163     conv_info.add("output type", out_type);
164     node_info->add("arg_max_min info", conv_info);
165     node_info->dump(primitive_description);
166
167     return primitive_description.str();
168 }
169
170 arg_max_min_inst::typed_primitive_inst(network_impl& network, arg_max_min_node const& node) : parent(network, node) {}
171 }  // namespace cldnn