Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / batch_norm.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/batch_norm.h"
20 #include "primitive.hpp"
21
22 namespace cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30
31 /// @brief Batch normalization primitive.
32 /// @details Performs batch normalization as discribed in
33 /// "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" by Ioffe, Szegedy
34 /// @n See: http://arxiv.org/abs/1502.03167
35 /// 
36 /// <b>Algorithm:</b>
37 /// @n global stats can be computed as:
38 /// @n out[i] = (in[i] - mean[b]) / sqrt(variance[b] + epsilon)
39
40 struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch_norm)>
41 {
42     CLDNN_DECLARE_PRIMITIVE(batch_norm)
43
44     /// @brief Constructs batch normalization primitive.
45     /// @param id This primitive id.
46     /// @param input Input primitive id.
47     /// @param mean Primitive id containing mean data.
48     /// @param variance Primitive id containing variance.
49     /// @param epsilon Epsilon.
50     batch_norm(
51         const primitive_id& id,
52         const primitive_id& input,
53         const primitive_id& mean,
54         const primitive_id& variance,
55         float epsilon,
56         const padding& output_padding = padding()
57     )
58         :primitive_base(id, {input}, output_padding)
59         , mean(mean)
60         , variance(variance)
61         , inv_variance("")
62         , epsilon(epsilon)
63     {
64     }
65
66     /// @brief Constructs batch normalization primitive with mean and variance calculation (used for training).
67     /// @param id This primitive id.
68     /// @param input Input primitive id.
69     /// @param epsilon Epsilon.
70     /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. For inference leave empty.
71     batch_norm(
72         const primitive_id& id,
73         const primitive_id& input,
74         float epsilon,
75         const primitive_id& inv_variance = "",
76         const padding& output_padding = padding()
77     )
78         :primitive_base(id, { input }, output_padding)
79         , mean("")
80         , variance("")
81         , inv_variance(inv_variance)
82         , epsilon(epsilon)
83     {
84     }
85
86     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{batch_norm}
87     batch_norm(const dto* dto)
88         :primitive_base(dto)
89         , mean(dto->mean)
90         , variance(dto->variance)
91         , inv_variance(dto->inv_variance)
92         , epsilon(dto->epsilon)
93     {
94     }
95
96     /// @brief Primitive id containing mean data.
97     primitive_id mean;
98     /// @brief Primitive id containing variance.
99     primitive_id variance;
100     /// @brief Primitive id containing inverted variance used in future gradient computing.
101     primitive_id inv_variance;
102     /// @brief Epsilon.
103     float epsilon;
104
105 protected:
106     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override 
107     { 
108         if (!mean.empty() && !variance.empty())
109             return{ mean, variance };
110         else if (!inv_variance.empty())
111             return{ inv_variance };
112         else
113             return{};
114     }
115
116     void update_dto(dto& dto) const override
117     {
118         dto.mean = mean.c_str();
119         dto.variance = variance.c_str();
120         dto.inv_variance = inv_variance.c_str();
121         dto.epsilon = epsilon;
122     }
123 };
124 /// @}
125 /// @}
126 /// @}
127 }