[Application] Transfer learning example on Resnet-18
authorSeungbaek Hong <sb92.hong@samsung.com>
Mon, 19 Jun 2023 04:54:49 +0000 (13:54 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 23 Jun 2023 00:55:20 +0000 (09:55 +0900)
I added transfer learning option to resnet-18 example.

If this option is enabled, then load pre-trained weights
and freeze the weights of backbone(feature extractor).
(It just a simple transfer learning).

You can make pre-trained weights using save_bin function
from our pytorch resnet-18 example.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/Resnet/jni/main.cpp

index 0d43f85..10cfee8 100644 (file)
@@ -87,7 +87,8 @@ static std::string withKey(const std::string &key,
  */
 std::vector<LayerHandle> resnetBlock(const std::string &block_name,
                                      const std::string &input_name, int filters,
-                                     int kernel_size, bool downsample) {
+                                     int kernel_size, bool downsample,
+                                     bool pre_trained) {
   using ml::train::createLayer;
 
   auto scoped_name = [&block_name](const std::string &layer_name) {
@@ -97,17 +98,18 @@ std::vector<LayerHandle> resnetBlock(const std::string &block_name,
     return withKey("name", scoped_name(layer_name));
   };
 
-  auto create_conv = [&with_name, filters](const std::string &name,
-                                           int kernel_size, int stride,
-                                           const std::string &padding,
-                                           const std::string &input_layer) {
+  auto create_conv = [&with_name, filters,
+                      pre_trained](const std::string &name, int kernel_size,
+                                   int stride, const std::string &padding,
+                                   const std::string &input_layer) {
     std::vector<std::string> props{
       with_name(name),
       withKey("stride", {stride, stride}),
       withKey("filters", filters),
       withKey("kernel_size", {kernel_size, kernel_size}),
       withKey("padding", padding),
-      withKey("input_layers", input_layer)};
+      withKey("input_layers", input_layer),
+      withKey("trainable", pre_trained ? "true" : "false")};
 
     return createLayer("conv2d", props);
   };
@@ -118,7 +120,8 @@ std::vector<LayerHandle> resnetBlock(const std::string &block_name,
   LayerHandle a2 =
     createLayer("batch_normalization",
                 {with_name("a2"), withKey("activation", "relu"),
-                 withKey("momentum", "0.9"), withKey("epsilon", "0.00001")});
+                 withKey("momentum", "0.9"), withKey("epsilon", "0.00001"),
+                 withKey("trainable", pre_trained ? "true" : "false")});
   LayerHandle a3 = create_conv("a3", 3, 1, "same", scoped_name("a2"));
 
   /** skip path */
@@ -136,7 +139,9 @@ std::vector<LayerHandle> resnetBlock(const std::string &block_name,
   LayerHandle c2 =
     createLayer("batch_normalization",
                 {withKey("name", block_name), withKey("activation", "relu"),
-                 withKey("momentum", "0.9"), withKey("epsilon", "0.00001")});
+                 withKey("momentum", "0.9"), withKey("epsilon", "0.00001"),
+                 withKey("trainable", "false")});
+
   if (downsample) {
     return {b1, a1, a2, a3, c1, c2};
   } else {
@@ -149,7 +154,7 @@ std::vector<LayerHandle> resnetBlock(const std::string &block_name,
  *
  * @return vector of layers that contain full graph of resnet18
  */
-std::vector<LayerHandle> createResnet18Graph() {
+std::vector<LayerHandle> createResnet18Graph(bool pre_trained) {
   using ml::train::createLayer;
 
   std::vector<LayerHandle> layers;
@@ -157,32 +162,37 @@ std::vector<LayerHandle> createResnet18Graph() {
   layers.push_back(createLayer(
     "input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")}));
 
-  layers.push_back(
-    createLayer("conv2d", {
-                            withKey("name", "conv0"),
-                            withKey("filters", 64),
-                            withKey("kernel_size", {3, 3}),
-                            withKey("stride", {1, 1}),
-                            withKey("padding", "same"),
-                            withKey("bias_initializer", "zeros"),
-                            withKey("weight_initializer", "xavier_uniform"),
-                          }));
+  layers.push_back(createLayer(
+    "conv2d", {withKey("name", "conv0"), withKey("filters", 64),
+               withKey("kernel_size", {3, 3}), withKey("stride", {1, 1}),
+               withKey("padding", "same"), withKey("bias_initializer", "zeros"),
+               withKey("weight_initializer", "xavier_uniform"),
+               withKey("trainable", pre_trained ? "true" : "false")}));
 
   layers.push_back(createLayer(
     "batch_normalization",
     {withKey("name", "first_bn_relu"), withKey("activation", "relu"),
-     withKey("momentum", "0.9"), withKey("epsilon", "0.00001")}));
+     withKey("momentum", "0.9"), withKey("epsilon", "0.00001"),
+     withKey("trainable", pre_trained ? "true" : "false")}));
 
   std::vector<std::vector<LayerHandle>> blocks;
 
-  blocks.push_back(resnetBlock("conv1_0", "first_bn_relu", 64, 3, false));
-  blocks.push_back(resnetBlock("conv1_1", "conv1_0", 64, 3, false));
-  blocks.push_back(resnetBlock("conv2_0", "conv1_1", 128, 3, true));
-  blocks.push_back(resnetBlock("conv2_1", "conv2_0", 128, 3, false));
-  blocks.push_back(resnetBlock("conv3_0", "conv2_1", 256, 3, true));
-  blocks.push_back(resnetBlock("conv3_1", "conv3_0", 256, 3, false));
-  blocks.push_back(resnetBlock("conv4_0", "conv3_1", 512, 3, true));
-  blocks.push_back(resnetBlock("conv4_1", "conv4_0", 512, 3, false));
+  blocks.push_back(
+    resnetBlock("conv1_0", "first_bn_relu", 64, 3, false, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv1_1", "conv1_0", 64, 3, false, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv2_0", "conv1_1", 128, 3, true, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv2_1", "conv2_0", 128, 3, false, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv3_0", "conv2_1", 256, 3, true, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv3_1", "conv3_0", 256, 3, false, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv4_0", "conv3_1", 512, 3, true, pre_trained));
+  blocks.push_back(
+    resnetBlock("conv4_1", "conv4_0", 512, 3, false, pre_trained));
 
   for (auto &block : blocks) {
     layers.insert(layers.end(), block.begin(), block.end());
@@ -201,7 +211,7 @@ std::vector<LayerHandle> createResnet18Graph() {
 }
 
 /// @todo update createResnet18 to be more generic
-ModelHandle createResnet18() {
+ModelHandle createResnet18(bool pre_trained = false) {
 /// @todo support "LOSS : cross" for TF_Lite Exporter
 #if (defined(ENABLE_TFLITE_INTERPRETER) && !defined(ENABLE_TEST))
   ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET,
@@ -211,7 +221,7 @@ ModelHandle createResnet18() {
                                              {withKey("loss", "cross")});
 #endif
 
-  for (auto &layer : createResnet18Graph()) {
+  for (auto &layer : createResnet18Graph(pre_trained)) {
     model->addLayer(layer);
   }
 
@@ -243,7 +253,12 @@ TEST(Resnet_Training, verify_accuracy) {
 void createAndRun(unsigned int epochs, unsigned int batch_size,
                   UserDataType &train_user_data,
                   UserDataType &valid_user_data) {
-  ModelHandle model = createResnet18();
+  // set option for transfer learning
+  bool transfer_learning = false;
+  std::string pretrained_bin_path = "./pretrained_resnet18.bin";
+
+  // setup model
+  ModelHandle model = createResnet18(transfer_learning);
   model->setProperty({withKey("batch_size", batch_size),
                       withKey("epochs", epochs),
                       withKey("save_path", "resnet_full.bin")});
@@ -271,6 +286,8 @@ void createAndRun(unsigned int epochs, unsigned int batch_size,
   model->setDataset(ml::train::DatasetModeType::MODE_VALID,
                     std::move(dataset_valid));
 
+  if (transfer_learning)
+    model->load(pretrained_bin_path);
   model->train();
 
 #if defined(ENABLE_TEST)