From: Parichay Kapoor Date: Mon, 5 Jul 2021 07:07:49 +0000 (+0900) Subject: [batchnorm] Update to LayerV2 X-Git-Tag: accepted/tizen/unified/20210829.234903~164 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=07bdd3e50c9322b483f462742b5bbe22c3f795dd;p=platform%2Fcore%2Fml%2Fnntrainer.git [batchnorm] Update to LayerV2 Update batch norm to layer v2 style. Add corresponding common unittests and enable modelfile and models unittest. Signed-off-by: Parichay Kapoor --- diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index d631076..413af45 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -21,7 +21,6 @@ * */ -#include #include #include #include @@ -32,117 +31,127 @@ namespace nntrainer { +static constexpr size_t SINGLE_INOUT_IDX = 0; + enum BNParams { mu, var, gamma, beta }; /// @todo add multiple axis support -int BatchNormalizationLayer::initialize(Manager &manager) { - int status = ML_ERROR_NONE; - - if (getNumInputs() != 1) { +void BatchNormalizationLayer::finalize(InitLayerContext &context) { + if (context.getNumInputs() != 1) { throw std::invalid_argument( "Only one input is allowed for batch normalization layer"); } - output_dim[0] = input_dim[0]; + std::vector output_dims(1); + + /** set output dimensions */ + auto const &in_dim = context.getInputDimensions()[0]; + context.setOutputDimensions(context.getInputDimensions()); + TensorDim dim; /// @note this logic cannot tell channel is actually 1 or it is just not used. if (axis == -1) - axis = input_dim[0].channel() > 1 ? 1 : 3; + axis = in_dim.channel() > 1 ? 1 : 3; - dim.setTensorDim(axis, input_dim[0].getTensorDim(axis)); + dim.setTensorDim(axis, in_dim.getTensorDim(axis)); for (int i = 0; i < 4; ++i) { if (axis != i) axes_to_reduce.push_back(i); } - weights.clear(); - if (weights.empty()) { - weights.reserve(4); - weights.emplace_back(dim, initializers[BNParams::mu], - WeightRegularizer::NONE, 1.0f, false, false, - "BN::moving_mean"); - weights.emplace_back(dim, initializers[BNParams::var], - WeightRegularizer::NONE, 1.0f, false, false, - "BN::moving_variance"); - weights.emplace_back(dim, initializers[BNParams::gamma], - WeightRegularizer::NONE, 1.0f, true, false, - "BN::gamma"); - weights.emplace_back(dim, initializers[BNParams::beta], - WeightRegularizer::NONE, 1.0f, true, false, - "BN::beta"); - manager.trackWeights(weights); - } else { - weights[BNParams::mu].reset(dim, initializers[BNParams::mu], - WeightRegularizer::NONE, 1.0f, false); - weights[BNParams::var].reset(dim, initializers[BNParams::var], - WeightRegularizer::NONE, 1.0f, false); - weights[BNParams::gamma].reset(dim, initializers[BNParams::gamma], - WeightRegularizer::NONE, 1.0f, true); - weights[BNParams::beta].reset(dim, initializers[BNParams::beta], - WeightRegularizer::NONE, 1.0f, true); - } + weight_idx[BNParams::mu] = context.requestWeight( + dim, initializers[BNParams::mu], WeightRegularizer::NONE, 1.0f, + "BN::moving_mean", false); + weight_idx[BNParams::var] = context.requestWeight( + dim, initializers[BNParams::var], WeightRegularizer::NONE, 1.0f, + "BN::moving_variance", false); + weight_idx[BNParams::gamma] = + context.requestWeight(dim, initializers[BNParams::gamma], + WeightRegularizer::NONE, 1.0f, "BN::gamma", true); + weight_idx[BNParams::beta] = + context.requestWeight(dim, initializers[BNParams::beta], + WeightRegularizer::NONE, 1.0f, "BN::beta", true); +} + +void BatchNormalizationLayer::setProperty( + const std::vector &values) { + /// @todo: deprecate this in favor of loadProperties + for (unsigned int i = 0; i < values.size(); ++i) { + std::string key; + std::string value; + std::stringstream ss; + + if (getKeyValue(values[i], key, value) != ML_ERROR_NONE) { + throw std::invalid_argument("Error parsing the property: " + values[i]); + } + + if (value.empty()) { + ss << "value is empty: key: " << key << ", value: " << value; + throw std::invalid_argument(ss.str()); + } - return status; + /// @note this calls derived setProperty if available + setProperty(key, value); + } } -void BatchNormalizationLayer::setProperty(const PropertyType type, +void BatchNormalizationLayer::setProperty(const std::string &type_str, const std::string &value) { + using PropertyType = LayerV1::PropertyType; int status = ML_ERROR_NONE; + LayerV1::PropertyType type = + static_cast(parseLayerProperty(type_str)); + switch (type) { case PropertyType::epsilon: - if (!value.empty()) { - status = setFloat(epsilon, value); - throw_status(status); - } + status = setFloat(epsilon, value); + throw_status(status); break; case PropertyType::moving_mean_initializer: - if (!value.empty()) { - initializers[BNParams::mu] = - (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); - } + initializers[BNParams::mu] = + (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); break; case PropertyType::moving_variance_initializer: - if (!value.empty()) { - initializers[BNParams::var] = - (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); - } + initializers[BNParams::var] = + (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); break; case PropertyType::beta_initializer: - if (!value.empty()) { - initializers[BNParams::beta] = - (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); - } + initializers[BNParams::beta] = + (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); break; case PropertyType::gamma_initializer: - if (!value.empty()) { - initializers[BNParams::gamma] = - (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); - } + initializers[BNParams::gamma] = + (WeightInitializer)parseType(value, TOKEN_WEIGHT_INIT); break; case PropertyType::momentum: - if (!value.empty()) { - status = setFloat(momentum, value); - throw_status(status); - } + status = setFloat(momentum, value); + throw_status(status); break; default: - LayerV1::setProperty(type, value); - break; + std::string msg = + "[BatchNormalizationLayer] Unknown Layer Property Key for value " + + std::string(value); + throw exception::not_supported(msg); } } -void BatchNormalizationLayer::forwarding(bool training) { - Tensor &mu = weightAt(BNParams::mu).getVariableRef(); - Tensor &var = weightAt(BNParams::var).getVariableRef(); - Tensor &gamma = weightAt(BNParams::gamma).getVariableRef(); - Tensor &beta = weightAt(BNParams::beta).getVariableRef(); +void BatchNormalizationLayer::forwarding(RunLayerContext &context, + bool training) { + Tensor &mu = context.getWeight(weight_idx[BNParams::mu]); + Tensor &var = context.getWeight(weight_idx[BNParams::var]); + Tensor &gamma = context.getWeight(weight_idx[BNParams::gamma]); + Tensor &beta = context.getWeight(weight_idx[BNParams::beta]); - Tensor &input_ = net_input[0]->getVariableRef(); - Tensor &hidden_ = net_hidden[0]->getVariableRef(); + Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); if (training) { + /** + * @todo support average with preallocated tensors, + * and then register cmu as a temporary tensor + */ Tensor cmu = input_.average(axes_to_reduce); deviation = input_.subtract(cmu); @@ -166,14 +175,17 @@ void BatchNormalizationLayer::forwarding(bool training) { hidden_.add_i(beta); } -void BatchNormalizationLayer::calcDerivative() { +void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { - Tensor &gamma = weightAt(BNParams::gamma).getVariableRef(); - Tensor &deriv = net_hidden[0]->getGradientRef(); + Tensor &gamma = context.getWeight(weight_idx[BNParams::gamma]); + Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); + Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX); int N = 1; + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + const TensorDim &in_dim = input.getDim(); for (auto &axis : axes_to_reduce) { - N *= input_dim[0].getTensorDim(axis); + N *= in_dim.getTensorDim(axis); } Tensor dx_1 = gamma.multiply(invstd); @@ -182,16 +194,15 @@ void BatchNormalizationLayer::calcDerivative() { dx_2.subtract_i(deviation.divide(cvar).multiply( deviation.multiply(deriv).sum(axes_to_reduce))); - Tensor &dx = net_input[0]->getGradientRef(); dx = dx_2.multiply(dx_1, dx); dx.divide_i(N); } -void BatchNormalizationLayer::calcGradient() { +void BatchNormalizationLayer::calcGradient(RunLayerContext &context) { - Tensor &dgamma = weightAt(BNParams::gamma).getGradientRef(); - Tensor &dbeta = weightAt(BNParams::beta).getGradientRef(); - Tensor &deriv = net_hidden[0]->getGradientRef(); + Tensor &dgamma = context.getWeightGrad(weight_idx[BNParams::gamma]); + Tensor &dbeta = context.getWeightGrad(weight_idx[BNParams::beta]); + Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); dbeta = deriv.sum(axes_to_reduce); Tensor dev = deviation.multiply(invstd); @@ -199,12 +210,4 @@ void BatchNormalizationLayer::calcGradient() { dgamma = dev.sum(axes_to_reduce); } -void BatchNormalizationLayer::copy(std::shared_ptr l) { - LayerV1::copy(l); - - std::shared_ptr from = - std::static_pointer_cast(l); - this->cvar.copy(from->cvar); -} - } /* namespace nntrainer */ diff --git a/nntrainer/layers/bn_layer.h b/nntrainer/layers/bn_layer.h index 572c5c9..6c39934 100644 --- a/nntrainer/layers/bn_layer.h +++ b/nntrainer/layers/bn_layer.h @@ -28,8 +28,7 @@ #include #include -#include -#include +#include namespace nntrainer { @@ -37,21 +36,19 @@ namespace nntrainer { * @class BatchNormalizationLayer * @brief Batch Noramlization Layer */ -class BatchNormalizationLayer : public LayerV1 { +class BatchNormalizationLayer : public Layer { public: /** * @brief Constructor of Batch Noramlization Layer */ - template BatchNormalizationLayer( int axis = -1, float momentum = 0.99, float epsilon = 0.001, WeightInitializer moving_mean_initializer = WeightInitializer::WEIGHT_ZEROS, WeightInitializer moving_variance_initializer = WeightInitializer::WEIGHT_ZEROS, WeightInitializer gamma_initializer = WeightInitializer::WEIGHT_ONES, - WeightInitializer beta_initializer = WeightInitializer::WEIGHT_ONES, - Args... args) : - LayerV1(args...), + WeightInitializer beta_initializer = WeightInitializer::WEIGHT_ONES) : + Layer(), epsilon(epsilon), momentum(momentum), axis(axis), @@ -76,53 +73,57 @@ public: BatchNormalizationLayer &operator=(BatchNormalizationLayer &&rhs) = default; /** - * @copydoc Layer::forwarding(bool training) + * @copydoc Layer::finalize(InitLayerContext &context) */ - void forwarding(bool training = true) override; + void finalize(InitLayerContext &context) override; /** - * @copydoc Layer::calcDerivative() + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) */ - void calcDerivative() override; + void forwarding(RunLayerContext &context, bool training) override; /** - * @copydoc Layer::calcGradient() + * @copydoc Layer::calcDerivative(RunLayerContext &context) */ - void calcGradient() override; + void calcDerivative(RunLayerContext &context) override; /** - * @brief copy layer - * @param[in] l layer to copy + * @copydoc Layer::calcGradient(RunLayerContext &context) */ - void copy(std::shared_ptr l) override; + void calcGradient(RunLayerContext &context) override; /** - * @brief initialize layer - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) */ - int initialize(Manager &manager) override; + void exportTo(Exporter &exporter, + const ExportMethods &method) const override { + Layer::exportTo(exporter, method); + } /** * @copydoc Layer::getType() */ const std::string getType() const override { return BatchNormalizationLayer::type; - } + }; /** - * @copydoc Layer::supportInPlace() + * @copydoc Layer::supportBackwarding() */ - bool supportInPlace() const override { return true; } + bool supportBackwarding() const { return true; } - using LayerV1::setProperty; + using Layer::setProperty; /** * @copydoc Layer::setProperty(const PropertyType type, const std::string * &value) */ - void setProperty(const PropertyType type, - const std::string &value = "") override; + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::supportInPlace() + */ + bool supportInPlace() const override { return true; } inline static const std::string type = "batch_normalization"; @@ -139,6 +140,17 @@ private: std::vector axes_to_reduce; /**< target axes to reduce */ std::array initializers; /**< weight initializers */ + std::array weight_idx; /**< indices of the weights */ + + /** + * @brief setProperty by type and value separated + * @param[in] type property type to be passed + * @param[in] value value to be passed + * @exception exception::not_supported when property type is not valid for + * the particular layer + * @exception std::invalid_argument invalid argument + */ + void setProperty(const std::string &type, const std::string &value); }; } // namespace nntrainer diff --git a/nntrainer/layers/fc_layer.cpp b/nntrainer/layers/fc_layer.cpp index 8920c14..25ae338 100644 --- a/nntrainer/layers/fc_layer.cpp +++ b/nntrainer/layers/fc_layer.cpp @@ -22,10 +22,10 @@ */ #include -#include #include #include #include +#include #include #include @@ -67,7 +67,7 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) { void FullyConnectedLayer::exportTo(Exporter &exporter, const ExportMethods &method) const { - Layer::exportTo(exporter, method); + LayerImpl::exportTo(exporter, method); exporter.saveResult(fc_props, method, this); } diff --git a/nntrainer/layers/fc_layer.h b/nntrainer/layers/fc_layer.h index 2856eee..23839e5 100644 --- a/nntrainer/layers/fc_layer.h +++ b/nntrainer/layers/fc_layer.h @@ -96,8 +96,9 @@ public: inline static const std::string type = "fully_connected"; private: - std::tuple fc_props; - std::array weight_idx; + std::tuple + fc_props; /**< fc layer properties : unit - number of output neurons */ + std::array weight_idx; /**< indices of the weights */ }; } // namespace nntrainer diff --git a/nntrainer/layers/input_layer.cpp b/nntrainer/layers/input_layer.cpp index 0e7e577..b7823fa 100644 --- a/nntrainer/layers/input_layer.cpp +++ b/nntrainer/layers/input_layer.cpp @@ -68,7 +68,7 @@ void InputLayer::setProperty(const std::string &type_str, } break; default: std::string msg = - "[Layer] Unknown Layer Property Key for value " + std::string(value); + "[InputLayer] Unknown Layer Property Key for value " + std::string(value); throw exception::not_supported(msg); } } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 5aed2fc..acd2aa0 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -328,6 +328,16 @@ public: } /** + * @brief check if the weight has gradient + * + * @param idx Identifier of the weight + * @return true if weight has gradient, else false + */ + bool weightHasGradient(unsigned int idx) const { + return weights[idx]->getTrainable(); + } + + /** * @brief Get the Output tensor object * * @param idx Identifier of the output diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index f58b139..34f1197 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -466,8 +466,14 @@ public: */ Weight getWeightWrapper(unsigned int idx) { if (layerv1 == nullptr) { - return Weight(run_context.getWeight(idx), run_context.getWeightGrad(idx), - run_context.getWeightName(idx)); + if (run_context.weightHasGradient(idx)) { + return Weight(run_context.getWeight(idx), + run_context.getWeightGrad(idx), + run_context.getWeightName(idx)); + } else { + return Weight(run_context.getWeight(idx), Tensor(), + run_context.getWeightName(idx)); + } } else { return getLayer()->getWeightsRef()[idx]; } diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 84a2338..28e6576 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -19,6 +19,7 @@ * @author Jijoong Moon * @bug No known bugs except for NYI items * + * @todo deprecate new tensor allocation for out of place operations. */ #ifndef __TENSOR_H__ diff --git a/nntrainer/tensor/var_grad.h b/nntrainer/tensor/var_grad.h index c4fa320..c3e32e0 100644 --- a/nntrainer/tensor/var_grad.h +++ b/nntrainer/tensor/var_grad.h @@ -83,10 +83,13 @@ public: const std::string &n = "") : dim(v.getDim()), var(std::make_shared(v.getSharedDataTensor(dim, 0, false))), - grad(std::make_shared(g.getSharedDataTensor(dim, 0, false))), + grad(std::make_shared()), need_gradient(!g.uninitialized()), alloc_now(v.isAllocated()), - name(n) {} + name(n) { + if (trainable) + grad = std::make_shared(g.getSharedDataTensor(dim, 0, false)); + } /** * @brief Copy constructor for Var_Grad diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index ef13489..3d7a64e 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -5,6 +5,7 @@ test_target = [ 'unittest_layers_input.cpp', 'unittest_layers_loss.cpp', 'unittest_layers_fully_connected.cpp', + 'unittest_layers_batch_normalization.cpp', ] exe = executable( diff --git a/test/unittest/layers/unittest_layers_batch_normalization.cpp b/test/unittest/layers/unittest_layers_batch_normalization.cpp new file mode 100644 index 0000000..3f257c1 --- /dev/null +++ b/test/unittest/layers/unittest_layers_batch_normalization.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file unittest_layers_loss.cpp + * @date 15 June 2021 + * @brief Batch Normalization Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_bn = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::BatchNormalizationLayer::type, {}, {}, 0, false); + +INSTANTIATE_TEST_CASE_P(BatchNormalization, LayerSemantics, + ::testing::Values(semantic_bn)); diff --git a/test/unittest/unittest_nntrainer_modelfile.cpp b/test/unittest/unittest_nntrainer_modelfile.cpp index badbf40..5cb3b2b 100644 --- a/test/unittest/unittest_nntrainer_modelfile.cpp +++ b/test/unittest/unittest_nntrainer_modelfile.cpp @@ -315,8 +315,8 @@ INSTANTIATE_TEST_CASE_P( mkIniTc("basic5_p", {nw_base_cross, adam, input, out+"input_layers=inputlayer"}, SUCCESS), mkIniTc("basic6_p", {nw_base_cross, sgd, input, out+"input_layers=inputlayer"}, SUCCESS), mkIniTc("basic_act_p", {nw_base_cross, sgd, input + "-Activation", act_relu+"input_layers=inputlayer", out+"input_layers=activation_relu" }, SUCCESS), - // mkIniTc("basic_bn_p", {nw_base_cross, sgd, input + "-Activation", batch_normal+"input_layers=inputlayer", act_relu+"input_layers=bn", out+"input_layers=activation_relu" }, SUCCESS), - // mkIniTc("basic_bn2_p", {nw_base_cross, sgd, input + "-Activation", batch_normal + "Activation = relu"+"input_layers=inputlayer", out+"input_layers=bn" }, SUCCESS), + mkIniTc("basic_bn_p", {nw_base_cross, sgd, input + "-Activation", batch_normal+"input_layers=inputlayer", act_relu+"input_layers=bn", out+"input_layers=activation_relu" }, SUCCESS), + mkIniTc("basic_bn2_p", {nw_base_cross, sgd, input + "-Activation", batch_normal + "Activation = relu"+"input_layers=inputlayer", out+"input_layers=bn" }, SUCCESS), mkIniTc("basic_dataset_p", {nw_base_cross, adam, dataset, input, out+"input_layers=inputlayer"}, SUCCESS), mkIniTc("basic_dataset2_p", {nw_base_cross, sgd, input, out+"input_layers=inputlayer", dataset}, SUCCESS), mkIniTc("basic_dataset3_p", {dataset, nw_base_cross, sgd, input, out+"input_layers=inputlayer"}, SUCCESS), diff --git a/test/unittest/unittest_nntrainer_models.cpp b/test/unittest/unittest_nntrainer_models.cpp index 94a631b..481a1bc 100644 --- a/test/unittest/unittest_nntrainer_models.cpp +++ b/test/unittest/unittest_nntrainer_models.cpp @@ -1276,9 +1276,9 @@ INSTANTIATE_TEST_CASE_P( nntrainerModelAutoTests, nntrainerModelTest, ::testing::Values( mkModelTc(fc_sigmoid_mse, "3:1:1:10", 1), mkModelTc(fc_sigmoid_cross, "3:1:1:10", 1), - mkModelTc(fc_relu_mse, "3:1:1:2", 1) - // mkModelTc(fc_bn_sigmoid_cross, "3:1:1:10", 10), - // mkModelTc(fc_bn_sigmoid_mse, "3:1:1:10", 10), + mkModelTc(fc_relu_mse, "3:1:1:2", 1), + mkModelTc(fc_bn_sigmoid_cross, "3:1:1:10", 10), + mkModelTc(fc_bn_sigmoid_mse, "3:1:1:10", 10) // mkModelTc(mnist_conv_cross, "3:1:1:10", 10), // mkModelTc(mnist_conv_cross_one_input, "1:1:1:10", 10),