2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "arg_max_min_inst.h"
19 #include "primitive_type_base.h"
20 #include "sliding_window_utils.h"
21 #include "error_handler.h"
22 #include "json_object.h"
27 primitive_type_id arg_max_min::type_id() {
28 static primitive_type_base<arg_max_min> instance;
32 layout arg_max_min_inst::calc_output_layout(arg_max_min_node const& node) {
33 auto desc = node.get_primitive();
34 auto input_layout = node.input().get_output_layout();
35 auto output_data_type = desc->output_data_type ? *desc->output_data_type : input_layout.data_type;
36 auto size_check = [&](size_t tensor_size) {
38 // lowest integer not representable in floating point type = 2^(mantissa_bits + 1) + 1
39 // https://stackoverflow.com/questions/3793838/which-is-the-first-integer-that-an-ieee-754-float-is-incapable-of-representing-e
40 if (output_data_type == data_types::f32) {
41 max_size = (1 << std::numeric_limits<float>::digits);
42 } else if (output_data_type == data_types::f16) {
43 // mantissa_bits for fp16 = 10
45 } else if (output_data_type == data_types::u8) {
46 max_size = std::numeric_limits<uint8_t>::max();
48 max_size = std::numeric_limits<size_t>::max();
51 if (tensor_size > max_size) {
52 CLDNN_ERROR_GREATER_THAN(node.id(),
53 "Reduced tensor size",
55 "Maximum output data type value",
57 "Current output data type is unable to hold maximum index of a tensor.");
60 auto format = input_layout.format;
61 if (desc->with_axis) {
64 size_check(input_layout.size.spatial[0]);
65 if (format == cldnn::format::bfzyx)
66 return layout{output_data_type,
68 tensor{input_layout.size.batch[0],
69 input_layout.size.feature[0],
71 input_layout.size.spatial[1],
72 input_layout.size.spatial[2]}};
74 return layout{output_data_type,
76 tensor{input_layout.size.batch[0],
77 input_layout.size.feature[0],
79 input_layout.size.spatial[1]}};
81 size_check(input_layout.size.spatial[1]);
82 if (format == cldnn::format::bfzyx)
83 return layout{output_data_type,
85 tensor{input_layout.size.batch[0],
86 input_layout.size.feature[0],
87 input_layout.size.spatial[0],
89 input_layout.size.spatial[2]}};
91 return layout{output_data_type,
93 tensor{input_layout.size.batch[0],
94 input_layout.size.feature[0],
95 input_layout.size.spatial[0],
96 (int32_t)desc->top_k}};
97 case arg_max_min::feature:
98 size_check(input_layout.size.feature[0]);
99 if (format == cldnn::format::bfzyx)
100 return layout{output_data_type,
102 tensor{input_layout.size.batch[0],
103 (int32_t)desc->top_k,
104 input_layout.size.spatial[0],
105 input_layout.size.spatial[1],
106 input_layout.size.spatial[2]}};
108 return layout{output_data_type,
110 tensor{input_layout.size.batch[0],
111 (int32_t)desc->top_k,
112 input_layout.size.spatial[0],
113 input_layout.size.spatial[1]}};
114 case arg_max_min::batch:
115 size_check(input_layout.size.batch[0]);
116 if (format == cldnn::format::bfzyx)
117 return layout{output_data_type,
119 tensor{(int32_t)desc->top_k,
120 input_layout.size.feature[0],
121 input_layout.size.spatial[0],
122 input_layout.size.spatial[1],
123 input_layout.size.spatial[2]}};
125 return layout{output_data_type,
127 tensor{(int32_t)desc->top_k,
128 input_layout.size.feature[0],
129 input_layout.size.spatial[0],
130 input_layout.size.spatial[1]}};
132 size_check(input_layout.size.spatial[2]);
133 return layout{output_data_type,
135 tensor{input_layout.size.batch[0],
136 input_layout.size.feature[0],
137 input_layout.size.spatial[0],
138 input_layout.size.spatial[1],
139 (int32_t)desc->top_k}};
144 size_check(input_layout.size.feature[0] * input_layout.size.spatial[0] * input_layout.size.spatial[1]);
145 return layout{output_data_type,
147 tensor{input_layout.size.batch[0], 1, (int32_t)desc->top_k, 1}};
150 std::string arg_max_min_inst::to_string(arg_max_min_node const& node) {
151 auto desc = node.get_primitive();
152 auto node_info = node.desc_to_json();
153 auto axis = desc->with_axis ? "true" : "false";
154 auto out_type = desc->output_type ? "max" : "min";
156 std::stringstream primitive_description;
158 json_composite conv_info;
159 conv_info.add("top_k", desc->top_k);
160 conv_info.add("with axis", axis);
162 conv_info.add("axis", desc->axis);
163 conv_info.add("output type", out_type);
164 node_info->add("arg_max_min info", conv_info);
165 node_info->dump(primitive_description);
167 return primitive_description.str();
170 arg_max_min_inst::typed_primitive_inst(network_impl& network, arg_max_min_node const& node) : parent(network, node) {}