2 * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
17 * @brief This is Batch Normalization Layer Class for Neural Network
18 * @see https://github.com/nnstreamer/nntrainer
19 * @author Jijoong Moon <jijoong.moon@samsung.com>
20 * @bug No known bugs except for NYI items
25 #include <lazy_tensor.h>
26 #include <nntrainer_error.h>
27 #include <nntrainer_log.h>
28 #include <parse_util.h>
29 #include <util_func.h>
33 static constexpr size_t SINGLE_INOUT_IDX = 0;
35 enum BNParams { mu, var, gamma, beta, deviation };
37 /// @todo add multiple axis support
38 void BatchNormalizationLayer::finalize(InitLayerContext &context) {
39 if (context.getNumInputs() != 1) {
40 throw std::invalid_argument(
41 "Only one input is allowed for batch normalization layer");
44 std::vector<TensorDim> output_dims(1);
46 /** set output dimensions */
47 auto const &in_dim = context.getInputDimensions()[0];
48 context.setOutputDimensions(context.getInputDimensions());
52 /// @note this logic cannot tell channel is actually 1 or it is just not used.
54 axis = in_dim.channel() > 1 ? 1 : 3;
56 dim.setTensorDim(axis, in_dim.getTensorDim(axis));
58 for (int i = 0; i < 4; ++i) {
60 axes_to_reduce.push_back(i);
63 wt_idx[BNParams::mu] = context.requestWeight(dim, initializers[BNParams::mu],
64 WeightRegularizer::NONE, 1.0f,
65 "BN::moving_mean", false);
66 wt_idx[BNParams::var] = context.requestWeight(
67 dim, initializers[BNParams::var], WeightRegularizer::NONE, 1.0f,
68 "BN::moving_variance", false);
69 wt_idx[BNParams::gamma] =
70 context.requestWeight(dim, initializers[BNParams::gamma],
71 WeightRegularizer::NONE, 1.0f, "BN::gamma", true);
72 wt_idx[BNParams::beta] =
73 context.requestWeight(dim, initializers[BNParams::beta],
74 WeightRegularizer::NONE, 1.0f, "BN::beta", true);
76 wt_idx[BNParams::deviation] =
77 context.requestTensor(in_dim, "BN::deviation", false, ITERATION_LIFESPAN);
80 void BatchNormalizationLayer::setProperty(
81 const std::vector<std::string> &values) {
82 /// @todo: deprecate this in favor of loadProperties
83 for (unsigned int i = 0; i < values.size(); ++i) {
88 if (getKeyValue(values[i], key, value) != ML_ERROR_NONE) {
89 throw std::invalid_argument("Error parsing the property: " + values[i]);
93 ss << "value is empty: key: " << key << ", value: " << value;
94 throw std::invalid_argument(ss.str());
97 /// @note this calls derived setProperty if available
98 setProperty(key, value);
102 void BatchNormalizationLayer::setProperty(const std::string &type_str,
103 const std::string &value) {
104 using PropertyType = nntrainer::Layer::PropertyType;
105 int status = ML_ERROR_NONE;
106 nntrainer::Layer::PropertyType type =
107 static_cast<nntrainer::Layer::PropertyType>(parseLayerProperty(type_str));
110 case PropertyType::epsilon:
111 status = setFloat(epsilon, value);
112 throw_status(status);
114 case PropertyType::moving_mean_initializer:
115 initializers[BNParams::mu] =
116 (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
118 case PropertyType::moving_variance_initializer:
119 initializers[BNParams::var] =
120 (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
122 case PropertyType::beta_initializer:
123 initializers[BNParams::beta] =
124 (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
126 case PropertyType::gamma_initializer:
127 initializers[BNParams::gamma] =
128 (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
130 case PropertyType::momentum:
131 status = setFloat(momentum, value);
132 throw_status(status);
136 "[BatchNormalizationLayer] Unknown Layer Property Key for value " +
138 throw exception::not_supported(msg);
142 void BatchNormalizationLayer::forwarding(RunLayerContext &context,
144 Tensor &mu = context.getWeight(wt_idx[BNParams::mu]);
145 Tensor &var = context.getWeight(wt_idx[BNParams::var]);
146 Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
147 Tensor &beta = context.getWeight(wt_idx[BNParams::beta]);
149 Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
150 Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
151 Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
155 * @todo support average with preallocated tensors,
156 * and then register cmu as a temporary tensor
158 Tensor cmu = input_.average(axes_to_reduce);
159 input_.subtract(cmu, deviation);
161 cvar = deviation.pow(2.0f).average(axes_to_reduce);
163 mu.multiply_i(momentum);
164 mu.add_i(cmu, 1 - momentum);
165 var.multiply_i(momentum);
166 var.add_i(cvar, 1 - momentum);
169 invstd = cvar.pow(-0.5f);
171 deviation = input_.subtract(mu);
172 invstd = var.add(epsilon);
176 hidden_ = deviation.multiply(invstd, hidden_);
177 hidden_.multiply_i(gamma);
181 void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
183 Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
184 Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
185 Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
186 Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
189 const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
190 const TensorDim &in_dim = input.getDim();
191 for (auto &axis : axes_to_reduce) {
192 N *= in_dim.getTensorDim(axis);
195 Tensor dx_1 = gamma.multiply(invstd);
196 Tensor dx_2 = deriv.multiply(N);
197 dx_2.subtract_i(deriv.sum(axes_to_reduce));
198 dx_2.subtract_i(deviation.divide(cvar).multiply(
199 deviation.multiply(deriv).sum(axes_to_reduce)));
201 dx = dx_2.multiply(dx_1, dx);
205 void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
207 Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
208 Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
209 Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
210 Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
212 dbeta = deriv.sum(axes_to_reduce);
213 Tensor dev = deviation.multiply(invstd);
214 dev.multiply_i(deriv);
215 dgamma = dev.sum(axes_to_reduce);
218 } /* namespace nntrainer */