2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/batch_norm.h"
20 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Batch normalization primitive.
32 /// @details Performs batch normalization as discribed in
33 /// "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" by Ioffe, Szegedy
34 /// @n See: http://arxiv.org/abs/1502.03167
37 /// @n global stats can be computed as:
38 /// @n out[i] = (in[i] - mean[b]) / sqrt(variance[b] + epsilon)
40 struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch_norm)>
42 CLDNN_DECLARE_PRIMITIVE(batch_norm)
44 /// @brief Constructs batch normalization primitive.
45 /// @param id This primitive id.
46 /// @param input Input primitive id.
47 /// @param mean Primitive id containing mean data.
48 /// @param variance Primitive id containing variance.
49 /// @param epsilon Epsilon.
51 const primitive_id& id,
52 const primitive_id& input,
53 const primitive_id& mean,
54 const primitive_id& variance,
56 const padding& output_padding = padding()
58 :primitive_base(id, {input}, output_padding)
66 /// @brief Constructs batch normalization primitive with mean and variance calculation (used for training).
67 /// @param id This primitive id.
68 /// @param input Input primitive id.
69 /// @param epsilon Epsilon.
70 /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. For inference leave empty.
72 const primitive_id& id,
73 const primitive_id& input,
75 const primitive_id& inv_variance = "",
76 const padding& output_padding = padding()
78 :primitive_base(id, { input }, output_padding)
81 , inv_variance(inv_variance)
86 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{batch_norm}
87 batch_norm(const dto* dto)
90 , variance(dto->variance)
91 , inv_variance(dto->inv_variance)
92 , epsilon(dto->epsilon)
96 /// @brief Primitive id containing mean data.
98 /// @brief Primitive id containing variance.
99 primitive_id variance;
100 /// @brief Primitive id containing inverted variance used in future gradient computing.
101 primitive_id inv_variance;
106 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
108 if (!mean.empty() && !variance.empty())
109 return{ mean, variance };
110 else if (!inv_variance.empty())
111 return{ inv_variance };
116 void update_dto(dto& dto) const override
118 dto.mean = mean.c_str();
119 dto.variance = variance.c_str();
120 dto.inv_variance = inv_variance.c_str();
121 dto.epsilon = epsilon;