From: Parichay Kapoor Date: Thu, 3 Dec 2020 08:38:28 +0000 (+0900) Subject: [optimizer] Move optimizer out of layer X-Git-Tag: accepted/tizen/unified/20201217.124219~13 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5afcb52dd90ef0c9fb203ab6c9a721822de7b52f;p=platform%2Fcore%2Fml%2Fnntrainer.git [optimizer] Move optimizer out of layer This patch moves optimizer out of layer. Now backwarding just calculates derivatives and gradient but does not applies the gradient. This gradient applying is done by the model. Layer still support applyGradient operation but requires optimizer as an argument. This decouples layers from optimizers and can operate independently. **Self evaluation:** 1. Build test: [x]Passed [ ]Failed [ ]Skipped 2. Run test: [x]Passed [ ]Failed [ ]Skipped Signed-off-by: Parichay Kapoor --- diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 12f8b36..dc4e6a0 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -535,19 +535,21 @@ void NetworkGraph::backwarding(sharedConstTensors output, int iteration) { i--) { LayerNode &layer_node = Sorted[i]; if (istrequal(layer_node.layer->getType(), nntrainer::LossLayer::type)) { - layer_node.layer->backwarding(iteration, output); + layer_node.layer->backwarding(output); } else { - layer_node.layer->backwarding(iteration); + layer_node.layer->backwarding(); } + optimizer->apply_gradients(layer_node.layer->getWeightsRef(), iteration); } /** The last trainable layer need not calculate the derivatives */ - // Oder is matter here. 1. calcGradient 2.Derivative & Gradient + // Order is matter here. 1. calcGradient 2.Derivative & Gradient Sorted[skip_non_trainable_layers + 1].layer->calcGradient(); #ifdef ENABLE_TEST Sorted[skip_non_trainable_layers + 1].layer->calcDerivative(); #endif - Sorted[skip_non_trainable_layers + 1].layer->applyGradient(iteration); + optimizer->apply_gradients( + Sorted[skip_non_trainable_layers + 1].layer->getWeightsRef(), iteration); } std::vector NetworkGraph::getInputDimension() { diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 604b58a..b40db74 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -197,6 +197,8 @@ public: */ void backwarding(sharedConstTensors input, int iteration); + void setOptimizer(std::shared_ptr opt) { optimizer = opt; } + /** * @brief getter of ordered graph * @retval ordered LayerNode list @@ -240,6 +242,8 @@ private: skip_non_trainable_layers; /**< denotes the number of non-trainable layers at the start of the graph */ + std::shared_ptr optimizer; + /** * @brief Calculate the number of non-trainable layers at the start */ diff --git a/nntrainer/layers/layer.cpp b/nntrainer/layers/layer.cpp index 15089f2..69254c2 100644 --- a/nntrainer/layers/layer.cpp +++ b/nntrainer/layers/layer.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -37,12 +37,6 @@ void Layer::setActivation(ActivationType acti) { activation_type = acti; } -int Layer::setOptimizer(std::shared_ptr opt) { - this->opt = opt; - this->opt->addOptimizerVariable(weights); - return ML_ERROR_NONE; -} - int Layer::checkValidation() { int status = ML_ERROR_NONE; @@ -82,8 +76,6 @@ void Layer::copy(std::shared_ptr l) { for (auto const &w : weights) weights.push_back(w.clone()); - // TODO: fix this #630 - this->opt = l->opt; this->input_dim = l->input_dim; this->output_dim = l->output_dim; this->input.copy(l->input); @@ -119,9 +111,10 @@ sharedConstTensors Layer::forwarding_with_val(sharedConstTensors input) { return out; } -sharedConstTensors Layer::backwarding_with_val(int iteration, - sharedConstTensors deriv, - sharedConstTensors in) { +sharedConstTensors +Layer::backwarding_with_val(int iteration, sharedConstTensors deriv, + sharedConstTensors in, + std::shared_ptr optimizer) { for (unsigned int i = 0; i < num_outputs; ++i) { net_hidden[i]->var = deriv[i]->clone(); @@ -133,11 +126,12 @@ sharedConstTensors Layer::backwarding_with_val(int iteration, // TODO Need to fix to use LossLayer::type instead of "loss". But cyclic // includes! if (istrequal(getType(), "loss")) { - backwarding(iteration, in); + backwarding(in); } else { - backwarding(iteration, deriv); + backwarding(deriv); } + applyGradient(iteration, optimizer); nntrainer::sharedConstTensors out; for (unsigned int i = 0; i < num_inputs; ++i) { @@ -151,22 +145,12 @@ void Layer::read(std::ifstream &file) { for (auto &weight : weights) { weight.getVariableRef().read(file); } - if (opt) - opt->read(file); } void Layer::save(std::ofstream &file) { for (auto &weight : weights) { weight.getVariableRef().save(file); } - if (opt) - opt->save(file); -} - -void Layer::applyGradient(unsigned int iteration) { - if (trainable && !weights.empty()) { - opt->apply_gradients(weights, iteration); - } } int Layer::setProperty(std::vector values) { diff --git a/nntrainer/layers/layer_internal.h b/nntrainer/layers/layer_internal.h index ebdbdb7..fd7bec3 100644 --- a/nntrainer/layers/layer_internal.h +++ b/nntrainer/layers/layer_internal.h @@ -134,22 +134,26 @@ public: * @brief Apply the gradient for the layer * @param[in] iteration Iteration value for the Optimizer */ - virtual void applyGradient(unsigned int iteration); + virtual void applyGradient(unsigned int iteration, + std::shared_ptr optimizer) { + if (optimizer) + optimizer->apply_gradients(weights, iteration); + } /** * @brief Back Propagate the derivative to the previous layer * @param[in] in List of Derivative Tensor from the next layer * @retval Derivative List of Tensor for the previous layer */ - virtual void backwarding(int iteration, sharedConstTensors in = {}) { + virtual void backwarding(sharedConstTensors in = {}) { calcGradient(in); calcDerivative(in); - applyGradient(iteration); } - virtual sharedConstTensors backwarding_with_val(int iteration, - sharedConstTensors deriv, - sharedConstTensors in = {}); + virtual sharedConstTensors + backwarding_with_val(int iteration, sharedConstTensors deriv, + sharedConstTensors in = {}, + std::shared_ptr optimizer = nullptr); /** * @brief read layer Weight & Bias data from file @@ -189,21 +193,6 @@ public: const std::string &value = ""); /** - * @brief Optimizer Setter - * @param[in] opt Optimizer - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int setOptimizer(std::shared_ptr opt); - - /** - * @brief Get the Optimizer object - * - * @return std::shared_ptr optimizer - */ - std::shared_ptr getOptimizer() { return opt; } - - /** * @brief Activation Type Getter * @retval Activation Type. */ @@ -334,6 +323,12 @@ public: std::vector getDerivatives(); + /** + * @brief Get reference to the weights + * @retval Reference of the list of weights in the layer + */ + std::vector &getWeightsRef() { return weights; } + #ifdef ENABLE_TEST void resizeNetInput(unsigned int size) { net_input.resize(size); } @@ -410,12 +405,6 @@ protected: std::vector output_dim; /** - * @brief Optimizer for this layer - */ - // TODO: fix with #630 - std::shared_ptr opt; - - /** * @brief Loss value added by this layer */ float loss; diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index d1bb54d..2a1204a 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -220,12 +220,7 @@ int NeuralNetwork::initialize() { status = l.initialize(manager); NN_RETURN_STATUS(); - if (istrequal(cur_type, BatchNormalizationLayer::type) || - istrequal(cur_type, Conv2DLayer::type) || - istrequal(cur_type, FullyConnectedLayer::type)) { - status = l.setOptimizer(opt); - NN_RETURN_STATUS(); - } + opt->addOptimizerVariable(l.getWeightsRef()); } for (unsigned int i = 0; i < model_graph.Sorted.back().layer->num_outputs; @@ -239,6 +234,8 @@ int NeuralNetwork::initialize() { manager.initialize(); + model_graph.setOptimizer(opt); + initialized = true; return status; } diff --git a/nntrainer/optimizers/adam.h b/nntrainer/optimizers/adam.h index 42abeea..e32e0fe 100644 --- a/nntrainer/optimizers/adam.h +++ b/nntrainer/optimizers/adam.h @@ -39,8 +39,7 @@ public: * @copydoc apply_gradient(Weight &weight, int tensor_idx, double updated_lr, * int iteration) */ - void apply_gradient(Weight &weight, double updated_lr, - int iteration); + void apply_gradient(Weight &weight, double updated_lr, int iteration); /** * @copydoc Optimizer::getType() @@ -81,7 +80,6 @@ public: static const std::string type; private: - double beta1; /** momentum for grad */ double beta2; /** momentum for grad**2 */ double epsilon; /** epsilon to protect overflow */ diff --git a/nntrainer/optimizers/optimizer.cpp b/nntrainer/optimizers/optimizer.cpp index 8cabef3..d624ed8 100644 --- a/nntrainer/optimizers/optimizer.cpp +++ b/nntrainer/optimizers/optimizer.cpp @@ -34,9 +34,7 @@ namespace nntrainer { -int Optimizer::initialize() { - return ML_ERROR_NONE; -} +int Optimizer::initialize() { return ML_ERROR_NONE; } double Optimizer::getLearningRate(int iteration) { double ll = learning_rate; @@ -51,6 +49,9 @@ double Optimizer::getLearningRate(int iteration) { void Optimizer::apply_gradients(std::vector &weight_list, int iteration) { + if (weight_list.empty()) + return; + double ll = getLearningRate(iteration); for (auto &weight : weight_list) { diff --git a/nntrainer/optimizers/optimizer_internal.h b/nntrainer/optimizers/optimizer_internal.h index a23e5d3..3193f47 100644 --- a/nntrainer/optimizers/optimizer_internal.h +++ b/nntrainer/optimizers/optimizer_internal.h @@ -140,20 +140,6 @@ public: */ virtual void checkValidation(); -protected: - /** - * @brief get Learning Rate for the given iteration - * @param[in] iteration Iteration for the learning rate - * @retval Learning rate - */ - virtual double getLearningRate(int iteration); - - float learning_rate; /** learning rate */ - float decay_rate; /** decay rate for learning rate */ - unsigned int decay_steps; /** decay steps for learning rate */ - bool continue_train; /** Continue training with previous tensors for adam */ - -private: /** * @brief initialize optimizer. * @retval #ML_ERROR_NONE Successful. @@ -169,6 +155,20 @@ private: */ virtual void addOptimizerVariable(std::vector ¶ms) {} +protected: + /** + * @brief get Learning Rate for the given iteration + * @param[in] iteration Iteration for the learning rate + * @retval Learning rate + */ + virtual double getLearningRate(int iteration); + + float learning_rate; /** learning rate */ + float decay_rate; /** decay rate for learning rate */ + unsigned int decay_steps; /** decay steps for learning rate */ + bool continue_train; /** Continue training with previous tensors for adam */ + +private: /** * @brief apply gradient to the given weight * @param[in] weight Weight and gradient set to be updated diff --git a/nntrainer/optimizers/sgd.cpp b/nntrainer/optimizers/sgd.cpp index f299ad4..c63f223 100644 --- a/nntrainer/optimizers/sgd.cpp +++ b/nntrainer/optimizers/sgd.cpp @@ -17,8 +17,7 @@ namespace nntrainer { const std::string SGD::type = "sgd"; -void SGD::apply_gradient(Weight &weight, double updated_lr, - int iteration) { +void SGD::apply_gradient(Weight &weight, double updated_lr, int iteration) { Tensor &x = weight.getVariableRef(); const Tensor &x_grad = weight.getGradientRef(); x.add_i(x_grad, -updated_lr); diff --git a/nntrainer/optimizers/sgd.h b/nntrainer/optimizers/sgd.h index da5c9bc..b3cc1d2 100644 --- a/nntrainer/optimizers/sgd.h +++ b/nntrainer/optimizers/sgd.h @@ -34,8 +34,7 @@ public: * @copydoc apply_gradient(Weight &weight, double updated_lr, * int iteration) */ - void apply_gradient(Weight &weight, double updated_lr, - int iteration); + void apply_gradient(Weight &weight, double updated_lr, int iteration); /** * @copydoc Optimizer::getType() diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index e196cd8..ea34858 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -174,14 +174,12 @@ public: * @param idx Index of the optimizer variable to get * @retval Reference of the optimizer variable */ - Tensor &getOptimizerVariableRef(unsigned int idx) { - return opt_vars[idx]; - } + Tensor &getOptimizerVariableRef(unsigned int idx) { return opt_vars[idx]; } private: WeightInitializer initializer; /**< initializer for this variable */ - std::vector opt_vars; /**< optimizer variables */ + std::vector opt_vars; /**< optimizer variables */ }; } // namespace nntrainer diff --git a/test/unittest/unittest_nntrainer_layers.cpp b/test/unittest/unittest_nntrainer_layers.cpp index 203a36e..25f7b71 100644 --- a/test/unittest/unittest_nntrainer_layers.cpp +++ b/test/unittest/unittest_nntrainer_layers.cpp @@ -159,14 +159,18 @@ protected: input_str.push_back((*i).str()); } - std::shared_ptr op; - EXPECT_NO_THROW(op = nntrainer::createOptimizer(type)); + EXPECT_NO_THROW(opt = nntrainer::createOptimizer(type)); - status = op->setProperty(input_str); + status = opt->setProperty(input_str); EXPECT_EQ(status, ML_ERROR_NONE); - status = layer.setOptimizer(op); + + status = opt->initialize(); EXPECT_EQ(status, ML_ERROR_NONE); + EXPECT_NO_THROW(opt->addOptimizerVariable(layer.getWeightsRef())); + // status = layer.setOptimizer(op); + // EXPECT_EQ(status, ML_ERROR_NONE); + return status; } @@ -207,6 +211,7 @@ protected: nntrainer::Tensor out; float local_tolerance = tolerance; nntrainer::Manager manager; + std::shared_ptr opt; }; class nntrainer_InputLayer @@ -604,7 +609,7 @@ protected: if (layers.size() && nntrainer::istrequal(layers.back()->getType(), nntrainer::LossLayer::type)) { if (with_loss) { - EXPECT_NO_THROW(layers.back()->backwarding(1, {label})); + EXPECT_NO_THROW(layers.back()->backwarding({label})); back_out = MAKE_SHARED_TENSOR(layers.back()->getDerivatives()[0]); } else { back_out = def_derivative; @@ -615,10 +620,11 @@ protected: } for (; idx >= 0; --idx) - EXPECT_NO_THROW(back_out = - layers[idx]->backwarding_with_val(1, {back_out})[0]); + EXPECT_NO_THROW(back_out = layers[idx]->backwarding_with_val( + 1, {back_out}, {}, opt)[0]); - EXPECT_NO_THROW(back_out = layer.backwarding_with_val(1, {back_out})[0]); + EXPECT_NO_THROW(back_out = + layer.backwarding_with_val(1, {back_out}, {}, opt)[0]); matchOutput(*back_out.get(), file_dx); loadUpdatedWeightsGradients(file_uw, file_g); @@ -681,7 +687,7 @@ TEST_F(nntrainer_FullyConnectedLayer_TFmatch, forwarding_backwarding_00_p) { nntrainer::Tensor result; EXPECT_NO_THROW(result = *layer.backwarding_with_val( - 1, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 1, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); matchOutput(result, "tc_fc_1_goldenFCGradientAdam.out"); @@ -967,8 +973,8 @@ TEST_F(nntrainer_BatchNormalizationLayer, forward_backward_training_01_p) { nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_fc_1_goldenBNLayerBackwardDxIn.out", backward_in); - nntrainer::Tensor backward_result = - *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(backward_in)})[0]; + nntrainer::Tensor backward_result = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(backward_in)}, {}, opt)[0]; matchOutput(backward_result, "tc_bn_fc_1_goldenBNLayerBackwardDx.out"); } @@ -1006,8 +1012,8 @@ TEST_F(nntrainer_BatchNormalizationLayer_Conv, forward_backward_training_01_p) { nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_conv_1_goldenBNLayerBackwardDxIn.out", backward_in); - nntrainer::Tensor backward_result = - *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(backward_in)})[0]; + nntrainer::Tensor backward_result = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(backward_in)}, {}, opt)[0]; matchOutput(backward_result, "tc_bn_conv_1_goldenBNLayerBackwardDx.out"); } @@ -1048,8 +1054,8 @@ TEST_F(nntrainer_BatchNormalizationLayer_Conv2, nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_conv_2_goldenBNLayerBackwardDxIn.out", backward_in); - nntrainer::Tensor backward_result = - *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(backward_in)})[0]; + nntrainer::Tensor backward_result = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(backward_in)}, {}, opt)[0]; matchOutput(backward_result, "tc_bn_conv_2_goldenBNLayerBackwardDx.out"); } @@ -1184,7 +1190,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_01_p) { } EXPECT_NO_THROW(result = *layer.backwarding_with_val( - 1, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 1, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); auto param_data = layer.getWeights(); const float *weight_grad = param_data[0].getGradient().getData(); @@ -1221,7 +1227,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_02_p) { derivatives.getData()[i] = 1.0; } EXPECT_NO_THROW(result = *layer.backwarding_with_val( - 1, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 1, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); auto param_data = layer.getWeights(); const float *weight_grad = param_data[0].getGradient().getData(); @@ -1240,7 +1246,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_02_p) { EXPECT_NO_THROW(out = *layer.forwarding_with_val({MAKE_SHARED_TENSOR(in)})[0]); EXPECT_NO_THROW(result = *layer.backwarding_with_val( - 0, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 0, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); } /// @fixme: the output value of this test is around +/- 1.0e+07 which can't @@ -1285,8 +1291,6 @@ TEST_F(nntrainer_Conv2DLayer, DISABLED_backwarding_03_p) { EXPECT_NO_THROW(op = nntrainer::createOptimizer(nntrainer::OptType::SGD)); status = op->setProperty({"learning_rate=1.0"}); EXPECT_EQ(status, ML_ERROR_NONE); - status = layer1.setOptimizer(op); - EXPECT_EQ(status, ML_ERROR_NONE); nntrainer::Conv2DLayer layer2; status = layer2.setProperty( @@ -1305,8 +1309,6 @@ TEST_F(nntrainer_Conv2DLayer, DISABLED_backwarding_03_p) { EXPECT_NO_THROW(op2 = nntrainer::createOptimizer(nntrainer::OptType::SGD)); status = op2->setProperty({"learning_rate=1.0"}); EXPECT_EQ(status, ML_ERROR_NONE); - status = layer2.setOptimizer(op2); - EXPECT_EQ(status, ML_ERROR_NONE); setOptimizer(nntrainer::OptType::SGD, "learning_rate=1.0"); @@ -1330,10 +1332,10 @@ TEST_F(nntrainer_Conv2DLayer, DISABLED_backwarding_03_p) { nntrainer::Tensor result2; EXPECT_NO_THROW(result2 = *layer2.backwarding_with_val( - 1, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 1, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); - EXPECT_NO_THROW( - result = *layer1.backwarding_with_val(1, {MAKE_SHARED_TENSOR(result2)})[0]); + EXPECT_NO_THROW(result = *layer1.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(result2)}, {}, opt)[0]); /** Compare second conv */ auto param_data = layer2.getWeights(); @@ -1377,7 +1379,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_04_p) { derivatives.getData()[i] = 1.0; } EXPECT_NO_THROW(result = *layer.backwarding_with_val( - 1, {MAKE_SHARED_TENSOR(derivatives)})[0]); + 1, {MAKE_SHARED_TENSOR(derivatives)}, {}, opt)[0]); auto param_data = layer.getWeights(); const float *weight_grad = param_data[0].getGradient().getData(); @@ -1521,8 +1523,8 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_01_p) { grad.getData()[i] = 1.0; } - EXPECT_NO_THROW( - in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(grad)})[0]); + EXPECT_NO_THROW(in = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(grad)}, {}, opt)[0]); matchOutput(in, "tc_pooling2d_1_goldenPooling2DmaxGrad.out"); } @@ -1543,7 +1545,7 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_02_p) { grad->getData()[i] = 1.0; } - EXPECT_NO_THROW(in = *layer.backwarding_with_val(1, {grad})[0]); + EXPECT_NO_THROW(in = *layer.backwarding_with_val(1, {grad}, {}, opt)[0]); matchOutput(in, "tc_pooling2d_1_goldenPooling2DaverageGrad.out"); } @@ -1565,8 +1567,8 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_03_p) { grad.getData()[i] = 1.0; } - EXPECT_NO_THROW( - in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(grad)})[0]); + EXPECT_NO_THROW(in = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(grad)}, {}, opt)[0]); matchOutput(in, "tc_pooling2d_1_goldenPooling2Dglobal_maxGrad.out"); } @@ -1587,8 +1589,8 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_04_p) { grad.getData()[i] = 1.0; } - EXPECT_NO_THROW( - in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(grad)})[0]); + EXPECT_NO_THROW(in = *layer.backwarding_with_val( + 1, {MAKE_SHARED_TENSOR(grad)}, {}, opt)[0]); matchOutput(in, "tc_pooling2d_1_goldenPooling2Dglobal_averageGrad.out"); } @@ -1647,7 +1649,7 @@ TEST_F(nntrainer_FlattenLayer, backwarding_01_p) { loadFile("tc_pooling2d_1_goldenPooling2Dmax.out", out); EXPECT_NO_THROW( - in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(out)})[0]); + in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(out)}, {}, opt)[0]); EXPECT_EQ(in.getDim(), nntrainer::TensorDim(1, 2, 4, 4)); matchOutput(in, "tc_pooling2d_1_goldenPooling2Dmax.out"); @@ -1666,7 +1668,7 @@ TEST_F(nntrainer_FlattenLayer, backwarding_02_p) { loadFile("tc_pooling2d_2_goldenPooling2Dmax.out", out); EXPECT_NO_THROW( - in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(out)})[0]); + in = *layer.backwarding_with_val(1, {MAKE_SHARED_TENSOR(out)}, {}, opt)[0]); EXPECT_EQ(in.getDim(), nntrainer::TensorDim(2, 2, 4, 4)); matchOutput(in, "tc_pooling2d_2_goldenPooling2Dmax.out"); @@ -1726,8 +1728,7 @@ TEST(nntrainer_LossLayer, backward_loss_unknown_n) { std::make_unique(); layer.setInputBuffer(0, in_buffer); layer.setOutputBuffer(0, out_buffer); - EXPECT_THROW(layer.backwarding(1, {MAKE_SHARED_TENSOR(a)}), - std::runtime_error); + EXPECT_THROW(layer.backwarding({MAKE_SHARED_TENSOR(a)}), std::runtime_error); } TEST(nntrainer_LossLayer, forward_loss_forward_entropy_n) { @@ -1760,8 +1761,7 @@ TEST(nntrainer_LossLayer, backward_loss_backward_entropy_n) { std::make_unique(); layer.setInputBuffer(0, in_buffer); layer.setOutputBuffer(0, out_buffer); - EXPECT_THROW(layer.backwarding(1, {MAKE_SHARED_TENSOR(a)}), - std::runtime_error); + EXPECT_THROW(layer.backwarding({MAKE_SHARED_TENSOR(a)}), std::runtime_error); } /**