1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "../C/one_hot.h"
19 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Creates a one-hot encoding of the input.
32 /// @details Creates a one-hot encoding of the input, putting the new one-hot axis in the position
33 /// @n specified by the @p one_hot_axis input, using the @p shape tensor as size reference.
34 /// @n The size of @p shape must be appropriate for adding a one-hot axis to input. For example,
35 /// @n <tt>input_sizes = (1, in_f, in_y, in_x)</tt>
37 /// @n <tt>one_hot_axis = 2</tt>
38 /// @n would insert the one-hot axis in the Y dimension, requiring
39 /// @n <tt>shape = (in_f, in_y, one-hot_limit, in_x)</tt>
40 /// @n The output values would then be determined by input as
41 /// @n <tt>output[f, y, i, x] = (input[0, f, y, x] == i) ? 1 : 0;</tt>
42 /// @n Since determining whether the input is appropriate (that the one-hot axis
43 /// @n has enough space to fully encode all inputs) requires scanning the whole
44 /// @n input, the primitive doesn't check for that, instead producing all-zeros
45 /// @n output axes for inputs below 0 and greater than the limit set by
49 /// @n - @p one_hot_axis must be within (inclusive) range 0 - 3.
50 /// @n - @p shape must fit input sizes (see example above).
51 /// @n - input batch size must be equal to 1.
53 /// @n Breaking any of this conditions will cause exception throw.
54 struct one_hot : public primitive_base<one_hot, CLDNN_PRIMITIVE_DESC(one_hot)>
56 CLDNN_DECLARE_PRIMITIVE(one_hot)
58 /// @brief Constructs one-hot primitive / layer.
60 /// @param id An identifier of new primitive.
61 /// @param input An identifier of primitive which is an input for newly created
62 /// one-hot primitive.
63 /// @param shape Size of the output primitive.
64 /// @param one_hot_axis One-hot axis position (0-based, from left to right) in shape.
65 /// @param output_padding Optional padding for output from primitive.
67 const primitive_id& id,
68 const primitive_id& input,
70 const uint16_t& one_hot_axis,
71 const padding& output_padding = padding()
73 : primitive_base(id, { input }, output_padding),
75 one_hot_axis(one_hot_axis)
79 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{one_hot}
80 one_hot(const dto* dto)
81 : primitive_base(dto),
83 one_hot_axis(dto->one_hot_axis)
87 /// @brief Output size reference.
89 /// @brief One-hot axis position in output shape (0-based, from left to right).
90 uint16_t one_hot_axis;
93 void update_dto(dto& dto) const override
96 dto.one_hot_axis = one_hot_axis;