2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "arg_max_min_inst.h"
19 #include "primitive_type_base.h"
20 #include "sliding_window_utils.h"
21 #include "error_handler.h"
22 #include "json_object.h"
26 primitive_type_id arg_max_min_type_id()
28 static primitive_type_base<arg_max_min> instance;
32 layout arg_max_min_inst::calc_output_layout(arg_max_min_node const& node)
34 assert((bool)node.get_primitive()->output_data_type == false
35 && "Output data type forcing is not supported for "
37 auto desc = node.get_primitive();
39 auto input_layout = node.input().get_output_layout();
44 return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], input_layout.size.feature[0], (int32_t)desc->top_k, input_layout.size.spatial[1] }};
46 return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], input_layout.size.feature[0], input_layout.size.spatial[0], (int32_t)desc->top_k }};
47 case arg_max_min::feature:
48 return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], (int32_t)desc->top_k, input_layout.size.spatial[0], input_layout.size.spatial[1] }};
49 case arg_max_min::batch:
50 return layout{ data_types::f32, format::bfyx, tensor{ (int32_t)desc->top_k, input_layout.size.feature[0], input_layout.size.spatial[0], input_layout.size.spatial[1] }};
56 return layout{ data_types::f32, input_layout.format, tensor{ input_layout.size.batch[0], 1, (int32_t)desc->top_k, 1 } };
59 std::string arg_max_min_inst::to_string(arg_max_min_node const& node)
61 auto desc = node.get_primitive();
62 auto node_info = node.desc_to_json();
63 auto axis = desc->with_axis ? "true" : "false";
64 auto out_type = desc->output_type ? "max" : "min";
66 std::stringstream primitive_description;
68 json_composite conv_info;
69 conv_info.add("top_k", desc->top_k);
70 conv_info.add("with axis", axis);
72 conv_info.add("axis", desc->axis);
73 conv_info.add("output type", out_type);
74 node_info->add("arg_max_min info", conv_info);
75 node_info->dump(primitive_description);
77 return primitive_description.str();
80 arg_max_min_inst::typed_primitive_inst(network_impl& network, arg_max_min_node const& node)
81 : parent(network, node)