From: Seungbaek Hong Date: Thu, 29 Dec 2022 06:40:33 +0000 (+0900) Subject: [test] add test case when a specific layer is non-trainable X-Git-Tag: accepted/tizen/unified/20230425.130129~73 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a32be7c5c3af908cda2863ca721c9b6aa0bcd14f;p=platform%2Fcore%2Fml%2Fnntrainer.git [test] add test case when a specific layer is non-trainable Add a test case when a specific layer is non-trainable. - Add a test case when the output fc layer is set to non-trainable. **Self evaluation:** 1. Build test: [x]Passed []Failed []Skipped 2. Run test: [x]Passed []Failed []Skipped Signed-off-by: Seungbaek Hong --- diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index f110be2..f0df357 100644 Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index e50ba6a..a56f437 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -296,12 +296,9 @@ class NonTrainableFC(torch.nn.Module): self.fc3 = torch.nn.Linear(10, 2) self.loss = torch.nn.MSELoss() # determine which layer to set to non-trainable - if idx == 1: - for param in self.fc1.parameters(): - param.requires_grad = False - elif idx == 2: - for param in self.fc2.parameters(): - param.requires_grad = False + fc_layer_list = [self.fc1, self.fc2, self.fc3] + for param in fc_layer_list[idx-1].parameters(): + param.requires_grad = False def forward(self, inputs, labels): out = torch.relu(self.fc1(inputs[0])) @@ -529,6 +526,16 @@ if __name__ == "__main__": label_dims=[(3,2)], name="non_trainable_fc_idx2" ) + + non_trainable_fc_idx3 = NonTrainableFC(idx=3) + record_v2( + non_trainable_fc_idx3, + iteration=2, + input_dims=[(3,3)], + input_dtype=[float], + label_dims=[(3,2)], + name="non_trainable_fc_idx3" + ) # Function to check the created golden test file - inspect_file("fc_relu_decay.nnmodelgolden") + inspect_file("non_trainable_fc_idx3.nnmodelgolden") diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 2205dd1..f3f939e 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -60,8 +60,9 @@ getFuncToMakeNonTrainableFc(int idx) { std::string fc1_trainable = (idx == 1) ? "trainable=false" : "trainable=true"; std::string fc2_trainable = (idx == 2) ? "trainable=false" : "trainable=true"; + std::string fc3_trainable = (idx == 3) ? "trainable=false" : "trainable=true"; - return [fc1_trainable, fc2_trainable]() { + return [fc1_trainable, fc2_trainable, fc3_trainable]() { std::unique_ptr nn(new NeuralNetwork()); nn->setProperty({"batch_size=3"}); @@ -69,14 +70,15 @@ getFuncToMakeNonTrainableFc(int idx) { auto outer_graph = makeGraph({ {"input", {"name=in", "input_shape=1:1:3"}}, {"fully_connected", - {"name=fc1", "input_layers=in", "unit=10", fc1_trainable}}, - {"activation", {"name=act1", "input_layers=fc1", "activation=relu"}}, + {"name=fc1", "input_layers=in", "unit=10", "activation=relu", + fc1_trainable}}, {"fully_connected", - {"name=fc2", "input_layers=act1", "unit=10", fc2_trainable}}, - {"activation", {"name=act2", "input_layers=fc2", "activation=relu"}}, - {"fully_connected", {"name=fc3", "input_layers=act2", "unit=2"}}, - {"activation", {"name=act3", "input_layers=fc3", "activation=sigmoid"}}, - {"mse", {"name=loss", "input_layers=act3"}}, + {"name=fc2", "input_layers=fc1", "unit=10", "activation=relu", + fc2_trainable}}, + {"fully_connected", + {"name=fc3", "input_layers=fc2", "unit=2", "activation=sigmoid", + fc3_trainable}}, + {"mse", {"name=loss", "input_layers=fc3"}}, }); for (auto &node : outer_graph) { @@ -93,6 +95,7 @@ getFuncToMakeNonTrainableFc(int idx) { static auto makeNonTrainableFcIdx1 = getFuncToMakeNonTrainableFc(1); static auto makeNonTrainableFcIdx2 = getFuncToMakeNonTrainableFc(2); +static auto makeNonTrainableFcIdx3 = getFuncToMakeNonTrainableFc(3); static std::unique_ptr makeMolAttention() { std::unique_ptr nn(new NeuralNetwork()); @@ -930,6 +933,8 @@ GTEST_PARAMETER_TEST( ModelTestOption::ALL_V2), mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2", ModelTestOption::ALL_V2), + mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3", + ModelTestOption::ALL_V2), }), [](const testing::TestParamInfo &info) { return std::get<1>(info.param);