Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / one_hot.cpp
1 // Copyright (c) 2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "one_hot_inst.h"
17
18 #include "error_handler.h"
19 #include "json_object.h"
20 #include "primitive_type_base.h"
21
22
23 namespace cldnn
24 {
25     primitive_type_id one_hot_type_id()
26     {
27         static primitive_type_base<one_hot> instance;
28         return &instance;
29     }
30
31     layout one_hot_inst::calc_output_layout(one_hot_node const& node)
32     {
33         assert((bool)node.get_primitive()->output_data_type == false
34                && "Output data type forcing is not supported for one_hot_node!");
35         auto input_layout = node.input().get_output_layout();
36         auto desc = node.get_primitive();
37
38         if (desc->one_hot_axis > 3)
39         {
40             CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: one_hot_axis should be less or equal to 3.");
41         }
42
43         return{ input_layout.data_type, input_layout.format, desc->shape };
44     }
45
46     std::string one_hot_inst::to_string(one_hot_node const& node)
47     {
48         auto desc = node.get_primitive();
49         auto node_info = node.desc_to_json();
50         const auto& shape = desc->shape;
51         const auto& one_hot_axis = desc->one_hot_axis;
52         auto& input = node.input();
53
54         std::stringstream primitive_description;
55
56         json_composite one_hot_info;
57         one_hot_info.add("input id", input.id());
58         one_hot_info.add("output shape", shape.to_string());
59         one_hot_info.add("one-hot axis", one_hot_axis);
60
61         node_info->add("one_hot info", one_hot_info);
62         node_info->dump(primitive_description);
63
64         return primitive_description.str();
65     }
66
67     one_hot_inst::typed_primitive_inst(network_impl& network, one_hot_node const& node)
68         : parent(network, node)
69     {
70         auto input_layout = node.input().get_output_layout();
71
72         const auto& input_sizes = input_layout.size;
73         const auto& output_sizes = argument.shape;
74
75         std::vector<tensor::value_type> input_dims = { input_sizes.batch[0], input_sizes.feature[0],
76             input_sizes.spatial[1], input_sizes.spatial[0] };
77         std::vector<tensor::value_type> output_dims = { output_sizes.batch[0], output_sizes.feature[0],
78             output_sizes.spatial[1], output_sizes.spatial[0] };
79
80         const auto& one_hot_axis = node.get_primitive()->one_hot_axis;
81         if (input_dims[0] != 1)
82         {
83             CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: input batch size should be equal to 1.");
84         }
85
86         //bfyx format
87         for (int i = 3, j = 3; i > 0; --i, --j)
88         {
89             if (j == one_hot_axis)
90                 --j;
91             if (input_dims[i] != output_dims[j])
92             {
93                 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: shape does not fit input size.");
94             }
95         }
96     }
97 }