From: Parichay Kapoor Date: Thu, 27 Jan 2022 06:36:46 +0000 (+0900) Subject: [test] unittest for weight decay X-Git-Tag: accepted/tizen/unified/20220323.062643~26 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=01b39fc208579e77b116cffe29e95714638b841b;p=platform%2Fcore%2Fml%2Fnntrainer.git [test] unittest for weight decay added unittest for weight decay along with fix. Signed-off-by: Parichay Kapoor --- diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index 78bbcb3..e696e75 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -200,9 +200,9 @@ void LSTMLayer::finalize(InitLayerContext &context) { // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ] // -> i, f, g, o const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}); - wt_idx[LSTMParams::weight_ih] = - context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "weight_ih", true); + wt_idx[LSTMParams::weight_ih] = context.requestWeight( + weight_ih_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "weight_ih", true); // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i, // f, g, o const TensorDim weight_hh_dim({unit, NUM_GATE * unit}); @@ -220,12 +220,12 @@ void LSTMLayer::finalize(InitLayerContext &context) { } else { // bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o const TensorDim bias_ih_dim({NUM_GATE * unit}); - wt_idx[LSTMParams::bias_ih] = - context.requestWeight(bias_ih_dim, bias_initializer, - WeightRegularizer::NONE, 1.0f, bias_decay, "bias_ih", true); + wt_idx[LSTMParams::bias_ih] = context.requestWeight( + bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, + bias_decay, "bias_ih", true); // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o wt_idx[LSTMParams::bias_hh] = context.requestWeight( - bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, + bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay, "bias_hh", true); } } @@ -257,14 +257,14 @@ void LSTMLayer::finalize(InitLayerContext &context) { const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit}); wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight( reverse_weight_ih_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, "reverse_weight_ih", true); + weight_regularizer_constant, weight_decay, "reverse_weight_ih", true); // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * // unit ] // -> i, f, g, o const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit}); wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight( reverse_weight_hh_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, "reverse_weight_hh", true); + weight_regularizer_constant, weight_decay, "reverse_weight_hh", true); if (!disable_bias) { if (integrate_bias) { // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias @@ -272,20 +272,20 @@ void LSTMLayer::finalize(InitLayerContext &context) { const TensorDim reverse_bias_h_dim({NUM_GATE * unit}); wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight( reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, - "reverse_bias_h", true); + bias_decay, "reverse_bias_h", true); } else { // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> // i, f, g, o const TensorDim reverse_bias_ih_dim({NUM_GATE * unit}); wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight( reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, - "reverse_bias_ih", true); + bias_decay, "reverse_bias_ih", true); // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> // i, f, g, o const TensorDim reverse_bias_hh_dim({NUM_GATE * unit}); wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight( reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, - "reverse_bias_hh", true); + bias_decay, "reverse_bias_hh", true); } } diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index a8b0525..2679611 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -295,9 +295,9 @@ private: std::vector opt_vars; /**< optimizer variables */ /** - * @brief Apply the gradient to the weight + * @brief Apply the weight decay to the weight */ - void applyWeightDecay() { var->add_i(*var.get(), -decay); } + void applyWeightDecay() { grad->add_i(*var.get(), decay); } }; } // namespace nntrainer diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index 4480c3f..40a7894 100644 Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index 1f050ec..9ebd599 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -72,6 +72,36 @@ class MolAttention(torch.nn.Module): return (output, kappa), loss +class FCRelu(torch.nn.Module): + def __init__(self, decay=False): + super().__init__() + self.fc = torch.nn.Linear(3, 10) + self.fc1 = torch.nn.Linear(10, 2) + self.loss = torch.nn.MSELoss() + self.decay = decay + + def forward(self, inputs, labels): + out = torch.relu(self.fc(inputs[0])) + out = torch.sigmoid(self.fc1(out)) + loss = self.loss(out, labels[0]) + return out, loss + + def getOptimizer(self): + if not self.decay: + return torch.optim.SGD(self.parameters(), lr=0.1) + else: + decay_params = [] + non_decay_params = [] + for name, params in self.named_parameters(): + if name == 'fc.weight' or name == 'fc1.bias': + decay_params.append(params) + else: + non_decay_params.append(params) + return torch.optim.SGD([ + {'params': non_decay_params}, + {'params': decay_params, 'weight_decay': 0.9}], lr=0.1) + + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -99,4 +129,15 @@ if __name__ == "__main__": name="mol_attention", ) - # inspect_file("mol_attention_masked.nnmodelgolden") + fc_relu_decay = FCRelu(decay=True) + record_v2( + fc_relu_decay, + iteration=2, + input_dims=[(3,3)], + input_dtype=[float], + label_dims=[(3,2)], + name="fc_relu_decay", + optimizer=fc_relu_decay.getOptimizer() + ) + + inspect_file("fc_relu_decay.nnmodelgolden") diff --git a/test/input_gen/recorder_v2.py b/test/input_gen/recorder_v2.py index 049a33d..5a58b9d 100644 --- a/test/input_gen/recorder_v2.py +++ b/test/input_gen/recorder_v2.py @@ -62,7 +62,7 @@ def _rand_like(shapes, scale=1, dtype=None): # @param label_dims dimensions to record including batch (list of tuple) # @param name golden name def record_v2(model, iteration, input_dims, label_dims, name, clip=False, - input_dtype=None, input_label_reader=None): + input_dtype=None, input_label_reader=None, optimizer=None): ## file format is as below # [ ...] # Each iteration contains @@ -74,7 +74,8 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False, if os.path.isfile(file_name): print("Warning: the file %s is being truncated and overwritten" % file_name) - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + if optimizer == None: + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) def record_iteration(write_fn): if input_label_reader != None: diff --git a/test/unittest/layers/unittest_layers_fully_connected.cpp b/test/unittest/layers/unittest_layers_fully_connected.cpp index afb9aa2..76ad965 100644 --- a/test/unittest/layers/unittest_layers_fully_connected.cpp +++ b/test/unittest/layers/unittest_layers_fully_connected.cpp @@ -31,7 +31,11 @@ auto fc_basic_single_batch = LayerGoldenTestParamType( nntrainer::createLayer, {"unit=4"}, "1:1:1:10", "fc_single_batch.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT); +auto fc_basic_no_decay = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10", + "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT); INSTANTIATE_TEST_CASE_P(FullyConnected, LayerGoldenTest, - ::testing::Values(fc_basic_plain, - fc_basic_single_batch)); + ::testing::Values(fc_basic_plain, fc_basic_single_batch, + fc_basic_no_decay)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 5512208..257c884 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -29,6 +29,7 @@ static std::string fc_base = "type = Fully_connected"; static std::string red_mean_base = "type = reduce_mean"; static IniSection sgd_base("optimizer", "Type = sgd"); static IniSection constant_loss("loss", "type = constant_derivative"); +static IniSection act_base("activation", "Type = Activation"); IniWrapper reduce_mean_last("reduce_mean_last", { @@ -40,6 +41,15 @@ IniWrapper reduce_mean_last("reduce_mean_last", constant_loss, }); +IniWrapper fc_relu_decay( + "fc_relu_decay", + {nn_base + "Loss=mse | batch_size = 3", sgd_base + "learning_rate = 0.1", + IniSection("input") + "type=input" + "input_shape = 1:1:3", + IniSection("dense") + fc_base + "unit = 10" + "weight_decay=0.9", + IniSection("act") + act_base + "Activation = relu", + IniSection("dense_1") + fc_base + "unit = 2" + "bias_decay=0.9", + IniSection("act_1") + act_base + "Activation = sigmoid"}); + static std::unique_ptr makeMolAttention() { std::unique_ptr nn(new NeuralNetwork()); nn->setProperty({"batch_size=3"}); @@ -96,6 +106,8 @@ INSTANTIATE_TEST_CASE_P( ModelTestOption::COMPARE_V2), mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", ModelTestOption::COMPARE_RUN_V2), + mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), }), [](const testing::TestParamInfo &info) { return std::get<1>(info.param);