Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / arg_max_min.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "arg_max_min_inst.h"
19 #include "primitive_type_base.h"
20 #include "sliding_window_utils.h"
21 #include "error_handler.h"
22 #include "json_object.h"
23
24 namespace cldnn
25 {
26         primitive_type_id arg_max_min_type_id()
27         {
28                 static primitive_type_base<arg_max_min> instance;
29                 return &instance;
30         }
31
32         layout arg_max_min_inst::calc_output_layout(arg_max_min_node const& node)
33         {
34         assert((bool)node.get_primitive()->output_data_type == false
35                && "Output data type forcing is not supported for "
36                   "arg_max_min_node!");
37         auto desc = node.get_primitive();
38
39                 auto input_layout = node.input().get_output_layout();
40
41                 if (desc->with_axis){
42                         switch (desc->axis){
43                                 case arg_max_min::x:
44                                         return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], input_layout.size.feature[0], (int32_t)desc->top_k, input_layout.size.spatial[1] }};
45                                 case arg_max_min::y:
46                                         return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], input_layout.size.feature[0], input_layout.size.spatial[0], (int32_t)desc->top_k }};
47                                 case arg_max_min::feature:
48                                         return layout{ data_types::f32, format::bfyx, tensor{ input_layout.size.batch[0], (int32_t)desc->top_k, input_layout.size.spatial[0], input_layout.size.spatial[1] }};
49                                 case arg_max_min::batch:
50                                         return layout{ data_types::f32, format::bfyx, tensor{ (int32_t)desc->top_k, input_layout.size.feature[0], input_layout.size.spatial[0], input_layout.size.spatial[1] }};
51                                 default:
52                                         break;
53                         }
54                 }
55
56                 return layout{ data_types::f32, input_layout.format, tensor{ input_layout.size.batch[0], 1, (int32_t)desc->top_k, 1 } };
57         }
58
59         std::string arg_max_min_inst::to_string(arg_max_min_node const& node)
60         {
61                 auto desc = node.get_primitive();
62                 auto node_info = node.desc_to_json();
63                 auto axis = desc->with_axis ? "true" : "false";
64                 auto out_type = desc->output_type ? "max" : "min";
65
66                 std::stringstream primitive_description;
67
68                 json_composite conv_info;
69                 conv_info.add("top_k", desc->top_k);
70                 conv_info.add("with axis", axis);
71                 if (desc->with_axis)
72                         conv_info.add("axis", desc->axis);
73                 conv_info.add("output type", out_type);
74                 node_info->add("arg_max_min info", conv_info);
75                 node_info->dump(primitive_description);
76
77                 return primitive_description.str();
78         }
79
80         arg_max_min_inst::typed_primitive_inst(network_impl& network, arg_max_min_node const& node)
81                 : parent(network, node)
82         {
83         }
84 }