[LayerV1] Delete for LayerV1
[platform/core/ml/nntrainer.git] / nntrainer / layers / bn_layer.cpp
1 /**
2  * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *   http://www.apache.org/licenses/LICENSE-2.0
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  *
15  * @file        bn_layer.cpp
16  * @date        14 May 2020
17  * @brief       This is Batch Normalization Layer Class for Neural Network
18  * @see         https://github.com/nnstreamer/nntrainer
19  * @author      Jijoong Moon <jijoong.moon@samsung.com>
20  * @bug         No known bugs except for NYI items
21  *
22  */
23
24 #include <bn_layer.h>
25 #include <lazy_tensor.h>
26 #include <nntrainer_error.h>
27 #include <nntrainer_log.h>
28 #include <parse_util.h>
29 #include <util_func.h>
30
31 namespace nntrainer {
32
33 static constexpr size_t SINGLE_INOUT_IDX = 0;
34
35 enum BNParams { mu, var, gamma, beta, deviation };
36
37 /// @todo add multiple axis support
38 void BatchNormalizationLayer::finalize(InitLayerContext &context) {
39   if (context.getNumInputs() != 1) {
40     throw std::invalid_argument(
41       "Only one input is allowed for batch normalization layer");
42   }
43
44   std::vector<TensorDim> output_dims(1);
45
46   /** set output dimensions */
47   auto const &in_dim = context.getInputDimensions()[0];
48   context.setOutputDimensions(context.getInputDimensions());
49
50   TensorDim dim;
51
52   /// @note this logic cannot tell channel is actually 1 or it is just not used.
53   if (axis == -1)
54     axis = in_dim.channel() > 1 ? 1 : 3;
55
56   dim.setTensorDim(axis, in_dim.getTensorDim(axis));
57
58   for (int i = 0; i < 4; ++i) {
59     if (axis != i)
60       axes_to_reduce.push_back(i);
61   }
62
63   wt_idx[BNParams::mu] = context.requestWeight(dim, initializers[BNParams::mu],
64                                                WeightRegularizer::NONE, 1.0f,
65                                                "BN::moving_mean", false);
66   wt_idx[BNParams::var] = context.requestWeight(
67     dim, initializers[BNParams::var], WeightRegularizer::NONE, 1.0f,
68     "BN::moving_variance", false);
69   wt_idx[BNParams::gamma] =
70     context.requestWeight(dim, initializers[BNParams::gamma],
71                           WeightRegularizer::NONE, 1.0f, "BN::gamma", true);
72   wt_idx[BNParams::beta] =
73     context.requestWeight(dim, initializers[BNParams::beta],
74                           WeightRegularizer::NONE, 1.0f, "BN::beta", true);
75
76   wt_idx[BNParams::deviation] =
77     context.requestTensor(in_dim, "BN::deviation", false, ITERATION_LIFESPAN);
78 }
79
80 void BatchNormalizationLayer::setProperty(
81   const std::vector<std::string> &values) {
82   /// @todo: deprecate this in favor of loadProperties
83   for (unsigned int i = 0; i < values.size(); ++i) {
84     std::string key;
85     std::string value;
86     std::stringstream ss;
87
88     if (getKeyValue(values[i], key, value) != ML_ERROR_NONE) {
89       throw std::invalid_argument("Error parsing the property: " + values[i]);
90     }
91
92     if (value.empty()) {
93       ss << "value is empty: key: " << key << ", value: " << value;
94       throw std::invalid_argument(ss.str());
95     }
96
97     /// @note this calls derived setProperty if available
98     setProperty(key, value);
99   }
100 }
101
102 void BatchNormalizationLayer::setProperty(const std::string &type_str,
103                                           const std::string &value) {
104   using PropertyType = nntrainer::Layer::PropertyType;
105   int status = ML_ERROR_NONE;
106   nntrainer::Layer::PropertyType type =
107     static_cast<nntrainer::Layer::PropertyType>(parseLayerProperty(type_str));
108
109   switch (type) {
110   case PropertyType::epsilon:
111     status = setFloat(epsilon, value);
112     throw_status(status);
113     break;
114   case PropertyType::moving_mean_initializer:
115     initializers[BNParams::mu] =
116       (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
117     break;
118   case PropertyType::moving_variance_initializer:
119     initializers[BNParams::var] =
120       (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
121     break;
122   case PropertyType::beta_initializer:
123     initializers[BNParams::beta] =
124       (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
125     break;
126   case PropertyType::gamma_initializer:
127     initializers[BNParams::gamma] =
128       (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT);
129     break;
130   case PropertyType::momentum:
131     status = setFloat(momentum, value);
132     throw_status(status);
133     break;
134   default:
135     std::string msg =
136       "[BatchNormalizationLayer] Unknown Layer Property Key for value " +
137       std::string(value);
138     throw exception::not_supported(msg);
139   }
140 }
141
142 void BatchNormalizationLayer::forwarding(RunLayerContext &context,
143                                          bool training) {
144   Tensor &mu = context.getWeight(wt_idx[BNParams::mu]);
145   Tensor &var = context.getWeight(wt_idx[BNParams::var]);
146   Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
147   Tensor &beta = context.getWeight(wt_idx[BNParams::beta]);
148
149   Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
150   Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
151   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
152
153   if (training) {
154     /**
155      * @todo support average with preallocated tensors,
156      * and then register cmu as a temporary tensor
157      */
158     Tensor cmu = input_.average(axes_to_reduce);
159     input_.subtract(cmu, deviation);
160
161     cvar = deviation.pow(2.0f).average(axes_to_reduce);
162
163     mu.multiply_i(momentum);
164     mu.add_i(cmu, 1 - momentum);
165     var.multiply_i(momentum);
166     var.add_i(cvar, 1 - momentum);
167
168     cvar.add_i(epsilon);
169     invstd = cvar.pow(-0.5f);
170   } else {
171     deviation = input_.subtract(mu);
172     invstd = var.add(epsilon);
173     invstd.pow_i(-0.5f);
174   }
175
176   hidden_ = deviation.multiply(invstd, hidden_);
177   hidden_.multiply_i(gamma);
178   hidden_.add_i(beta);
179 }
180
181 void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
182
183   Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
184   Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
185   Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
186   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
187
188   int N = 1;
189   const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
190   const TensorDim &in_dim = input.getDim();
191   for (auto &axis : axes_to_reduce) {
192     N *= in_dim.getTensorDim(axis);
193   }
194
195   Tensor dx_1 = gamma.multiply(invstd);
196   Tensor dx_2 = deriv.multiply(N);
197   dx_2.subtract_i(deriv.sum(axes_to_reduce));
198   dx_2.subtract_i(deviation.divide(cvar).multiply(
199     deviation.multiply(deriv).sum(axes_to_reduce)));
200
201   dx = dx_2.multiply(dx_1, dx);
202   dx.divide_i(N);
203 }
204
205 void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
206
207   Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
208   Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
209   Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
210   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
211
212   dbeta = deriv.sum(axes_to_reduce);
213   Tensor dev = deviation.multiply(invstd);
214   dev.multiply_i(deriv);
215   dgamma = dev.sum(axes_to_reduce);
216 }
217
218 } /* namespace nntrainer */