1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "one_hot_inst.h"
18 #include "error_handler.h"
19 #include "json_object.h"
20 #include "primitive_type_base.h"
25 primitive_type_id one_hot_type_id()
27 static primitive_type_base<one_hot> instance;
31 layout one_hot_inst::calc_output_layout(one_hot_node const& node)
33 assert((bool)node.get_primitive()->output_data_type == false
34 && "Output data type forcing is not supported for one_hot_node!");
35 auto input_layout = node.input().get_output_layout();
36 auto desc = node.get_primitive();
38 if (desc->one_hot_axis > 3)
40 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: one_hot_axis should be less or equal to 3.");
43 return{ input_layout.data_type, input_layout.format, desc->shape };
46 std::string one_hot_inst::to_string(one_hot_node const& node)
48 auto desc = node.get_primitive();
49 auto node_info = node.desc_to_json();
50 const auto& shape = desc->shape;
51 const auto& one_hot_axis = desc->one_hot_axis;
52 auto& input = node.input();
54 std::stringstream primitive_description;
56 json_composite one_hot_info;
57 one_hot_info.add("input id", input.id());
58 one_hot_info.add("output shape", shape.to_string());
59 one_hot_info.add("one-hot axis", one_hot_axis);
61 node_info->add("one_hot info", one_hot_info);
62 node_info->dump(primitive_description);
64 return primitive_description.str();
67 one_hot_inst::typed_primitive_inst(network_impl& network, one_hot_node const& node)
68 : parent(network, node)
70 auto input_layout = node.input().get_output_layout();
72 const auto& input_sizes = input_layout.size;
73 const auto& output_sizes = argument.shape;
75 std::vector<tensor::value_type> input_dims = { input_sizes.batch[0], input_sizes.feature[0],
76 input_sizes.spatial[1], input_sizes.spatial[0] };
77 std::vector<tensor::value_type> output_dims = { output_sizes.batch[0], output_sizes.feature[0],
78 output_sizes.spatial[1], output_sizes.spatial[0] };
80 const auto& one_hot_axis = node.get_primitive()->one_hot_axis;
81 if (input_dims[0] != 1)
83 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: input batch size should be equal to 1.");
87 for (int i = 3, j = 3; i > 0; --i, --j)
89 if (j == one_hot_axis)
91 if (input_dims[i] != output_dims[j])
93 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: shape does not fit input size.");