Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / one_hot.hpp
1 // Copyright (c) 2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
16 #pragma once
17
18 #include "../C/one_hot.h"
19 #include "primitive.hpp"
20
21
22 namespace cldnn
23 {
24     /// @addtogroup cpp_api C++ API
25     /// @{
26     /// @addtogroup cpp_topology Network Topology
27     /// @{
28     /// @addtogroup cpp_primitives Primitives
29     /// @{
30
31     /// @brief Creates a one-hot encoding of the input.
32     /// @details Creates a one-hot encoding of the input, putting the new one-hot axis in the position
33     /// @n       specified by the @p one_hot_axis input, using the @p shape tensor as size reference.
34     /// @n       The size of @p shape must be appropriate for adding a one-hot axis to input. For example,
35     /// @n      <tt>input_sizes = (1, in_f, in_y, in_x)</tt> 
36     /// @n expanded with 
37     /// @n      <tt>one_hot_axis = 2</tt> 
38     /// @n would insert the one-hot axis in the Y dimension, requiring
39     /// @n      <tt>shape = (in_f, in_y, one-hot_limit, in_x)</tt> 
40     /// @n The output values would then be determined by input as
41     /// @n      <tt>output[f, y, i, x] = (input[0, f, y, x] == i) ? 1 : 0;</tt>
42     /// @n Since determining whether the input is appropriate (that the one-hot axis
43     /// @n has enough space to fully encode all inputs) requires scanning the whole
44     /// @n input, the primitive doesn't check for that, instead producing all-zeros
45     /// @n output axes for inputs below 0 and greater than the limit set by
46     /// @n @p shape.
47     /// @n
48     /// @n@b Requirements
49     /// @n - @p one_hot_axis must be within (inclusive) range 0 - 3.
50     /// @n - @p shape must fit input sizes (see example above).
51     /// @n - input batch size must be equal to 1.
52     /// @n
53     /// @n Breaking any of this conditions will cause exception throw.
54     struct one_hot : public primitive_base<one_hot, CLDNN_PRIMITIVE_DESC(one_hot)>
55     {
56         CLDNN_DECLARE_PRIMITIVE(one_hot)
57
58             /// @brief Constructs one-hot primitive / layer.
59             ///
60             /// @param id              An identifier of new primitive.
61             /// @param input           An identifier of primitive which is an input for newly created
62             ///                        one-hot primitive.
63             /// @param shape           Size of the output primitive.
64             /// @param one_hot_axis    One-hot axis position (0-based, from left to right) in shape.
65             /// @param output_padding  Optional padding for output from primitive.
66             one_hot(
67                 const primitive_id& id,
68                 const primitive_id& input,
69                 const tensor& shape,
70                 const uint16_t& one_hot_axis,
71                 const padding& output_padding = padding()
72             )
73             : primitive_base(id, { input }, output_padding),
74             shape(shape),
75             one_hot_axis(one_hot_axis)
76         {
77         }
78
79         /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{one_hot}
80         one_hot(const dto* dto)
81             : primitive_base(dto),
82             shape(dto->shape),
83             one_hot_axis(dto->one_hot_axis)
84         {
85         }
86
87         /// @brief Output size reference.
88         tensor shape;
89         /// @brief One-hot axis position in output shape (0-based, from left to right).
90         uint16_t one_hot_axis;
91
92     protected:
93         void update_dto(dto& dto) const override
94         {
95             dto.shape = shape;
96             dto.one_hot_axis = one_hot_axis;
97
98         }
99     };
100     /// @}
101     /// @}
102     /// @}
103 }