[test] add test case when a specific layer is non-trainable
authorSeungbaek Hong <sb92.hong@samsung.com>
Thu, 29 Dec 2022 06:40:33 +0000 (15:40 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 1 Feb 2023 00:53:30 +0000 (09:53 +0900)
Add a test case when a specific layer is non-trainable.

- Add a test case when the output fc layer is set to non-trainable.

**Self evaluation:**
1. Build test: [x]Passed []Failed []Skipped
2. Run test: [x]Passed []Failed []Skipped

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelTests_v2.py
test/unittest/models/unittest_models.cpp

index f110be2..f0df357 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index e50ba6a..a56f437 100644 (file)
@@ -296,12 +296,9 @@ class NonTrainableFC(torch.nn.Module):
         self.fc3 = torch.nn.Linear(10, 2)
         self.loss = torch.nn.MSELoss()
         # determine which layer to set to non-trainable
-        if idx == 1:
-            for param in self.fc1.parameters():
-                param.requires_grad = False
-        elif idx == 2:
-            for param in self.fc2.parameters():
-                param.requires_grad = False
+        fc_layer_list = [self.fc1, self.fc2, self.fc3]
+        for param in fc_layer_list[idx-1].parameters():
+            param.requires_grad = False
 
     def forward(self, inputs, labels):
         out = torch.relu(self.fc1(inputs[0]))
@@ -529,6 +526,16 @@ if __name__ == "__main__":
         label_dims=[(3,2)],
         name="non_trainable_fc_idx2"
     )
+
+    non_trainable_fc_idx3 = NonTrainableFC(idx=3)
+    record_v2(
+        non_trainable_fc_idx3,
+        iteration=2,
+        input_dims=[(3,3)],
+        input_dtype=[float],
+        label_dims=[(3,2)],
+        name="non_trainable_fc_idx3"
+    )
     
     # Function to check the created golden test file
-    inspect_file("fc_relu_decay.nnmodelgolden")
+    inspect_file("non_trainable_fc_idx3.nnmodelgolden")
index 2205dd1..f3f939e 100644 (file)
@@ -60,8 +60,9 @@ getFuncToMakeNonTrainableFc(int idx) {
 
   std::string fc1_trainable = (idx == 1) ? "trainable=false" : "trainable=true";
   std::string fc2_trainable = (idx == 2) ? "trainable=false" : "trainable=true";
+  std::string fc3_trainable = (idx == 3) ? "trainable=false" : "trainable=true";
 
-  return [fc1_trainable, fc2_trainable]() {
+  return [fc1_trainable, fc2_trainable, fc3_trainable]() {
     std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
 
     nn->setProperty({"batch_size=3"});
@@ -69,14 +70,15 @@ getFuncToMakeNonTrainableFc(int idx) {
     auto outer_graph = makeGraph({
       {"input", {"name=in", "input_shape=1:1:3"}},
       {"fully_connected",
-       {"name=fc1", "input_layers=in", "unit=10", fc1_trainable}},
-      {"activation", {"name=act1", "input_layers=fc1", "activation=relu"}},
+       {"name=fc1", "input_layers=in", "unit=10", "activation=relu",
+        fc1_trainable}},
       {"fully_connected",
-       {"name=fc2", "input_layers=act1", "unit=10", fc2_trainable}},
-      {"activation", {"name=act2", "input_layers=fc2", "activation=relu"}},
-      {"fully_connected", {"name=fc3", "input_layers=act2", "unit=2"}},
-      {"activation", {"name=act3", "input_layers=fc3", "activation=sigmoid"}},
-      {"mse", {"name=loss", "input_layers=act3"}},
+       {"name=fc2", "input_layers=fc1", "unit=10", "activation=relu",
+        fc2_trainable}},
+      {"fully_connected",
+       {"name=fc3", "input_layers=fc2", "unit=2", "activation=sigmoid",
+        fc3_trainable}},
+      {"mse", {"name=loss", "input_layers=fc3"}},
     });
 
     for (auto &node : outer_graph) {
@@ -93,6 +95,7 @@ getFuncToMakeNonTrainableFc(int idx) {
 
 static auto makeNonTrainableFcIdx1 = getFuncToMakeNonTrainableFc(1);
 static auto makeNonTrainableFcIdx2 = getFuncToMakeNonTrainableFc(2);
+static auto makeNonTrainableFcIdx3 = getFuncToMakeNonTrainableFc(3);
 
 static std::unique_ptr<NeuralNetwork> makeMolAttention() {
   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
@@ -930,6 +933,8 @@ GTEST_PARAMETER_TEST(
                  ModelTestOption::ALL_V2),
     mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2",
                  ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3",
+                 ModelTestOption::ALL_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);