[Application] Update yolo v2 model similar to original model
authorSeungbaek Hong <sb92.hong@samsung.com>
Thu, 30 Mar 2023 10:32:01 +0000 (19:32 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 May 2023 01:55:51 +0000 (10:55 +0900)
Yolo v2 model was updated similar to original yolo v2 model.

This model was intended to be implemented in accordance with
the original paper of Yolo v2 as much as possible,
but now average pooling is temporarily used instead of the
re-organization module.

If only the average pooling is replaced with the re-organization
module in the future, the rest is the same as the original paper
in Yolo v2.

Both the PyTorch version and the NNTrainer version updated the model
structure and verified that the same results could be obtained
by loading trained weights from PyTorch.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLO/PyTorch/main.py
Applications/YOLO/PyTorch/yolo.py
Applications/YOLO/jni/main.cpp

index 236ca6a..0c1b1be 100644 (file)
@@ -14,7 +14,7 @@ import torch.optim as optim
 import torch.nn.functional as F
 from torch.utils.data import DataLoader
 
-from yolo import YoloV2_light
+from yolo import YoloV2
 from yolo_loss import YoloV2_LOSS
 from dataset import YOLODataset, collate_db
 
@@ -34,7 +34,7 @@ from torchconverter import save_bin
 
 # set config
 out_size = 13
-num_classes = 5
+num_classes = 4
 num_anchors = 5
 
 epochs = 1000
@@ -51,14 +51,12 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=colla
 valid_dataset = YOLODataset(valid_img_dir, valid_ann_dir)
 valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=False, drop_last=False)
 
-
 # set model, loss and optimizer
-model = YoloV2_light(num_classes=5)
-criterion = YoloV2_LOSS(num_classes=5)
+model = YoloV2(num_classes=num_classes)
+criterion = YoloV2_LOSS(num_classes=num_classes)
 optimizer = optim.Adam(model.parameters(), lr=1e-3)
 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
 
-
 # train model
 best_loss = 1e+10
 for epoch in range(epochs):
@@ -69,12 +67,12 @@ for epoch in range(epochs):
         optimizer.zero_grad()
         # model prediction
         hypothesis = model(img).permute((0, 2, 3, 1))
-        hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))        
+        hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))
         # split each prediction(bbox, iou, class prob)
         bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
         bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
-        bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)        
-        iou_pred = torch.sigmoid(hypothesis[..., 4:5])        
+        bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+        iou_pred = torch.sigmoid(hypothesis[..., 4:5])
         score_pred = hypothesis[..., 5:].contiguous()
         prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
         # calc loss
@@ -120,9 +118,9 @@ for epoch in range(epochs):
 
 ##
 # @brief bbox post process function for inference
-def post_process_for_bbox(bbox_pred):    
+def post_process_for_bbox(bbox_pred):
     """
-    @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)    
+    @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
     @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
     """
     anchors = torch.FloatTensor(
@@ -136,7 +134,7 @@ def post_process_for_bbox(bbox_pred):
     outsize = (13, 13)
     width, height = outsize
     
-    # restore cell pos to x, y    
+    # restore cell pos to x, y
     for w in range(width):
         for h in range(height):
             bbox_pred[:, height*h + w, :, 0] += w
@@ -145,7 +143,7 @@ def post_process_for_bbox(bbox_pred):
     
     # apply anchors to w, h
     anchor_w = anchors[:, 0].contiguous().view(-1, 1)
-    anchor_h = anchors[:, 1].contiguous().view(-1, 1)        
+    anchor_h = anchors[:, 1].contiguous().view(-1, 1)
     bbox_pred[:, :, :, 2:3] *= anchor_w
     bbox_pred[:, :, :, 3:4] *= anchor_h
 
@@ -153,7 +151,7 @@ def post_process_for_bbox(bbox_pred):
 
 # inference example using trained model
 hypothesis = model(img).permute((0, 2, 3, 1))
-hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))        
+hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))
 
 # transform output
 bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
index f986f23..e31e772 100644 (file)
@@ -11,48 +11,95 @@ import torch
 import torch.nn as nn
 
 ##
-# @brief define simple yolo model (not original darknet)
-class YoloV2_light(nn.Module): 
-    def __init__(self, 
-                 num_classes,
-                 anchors=\
-                 [(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053), (11.2364, 10.0071)]):
+# @brief define yolo model (except for re-organization module)
+class YoloV2(nn.Module): 
+    def __init__(self, num_classes, num_anchors=5):
         
-        super(YoloV2_light, self).__init__()              
+        super(YoloV2, self).__init__()              
         self.num_classes = num_classes
-        self.anchors = anchors
-        self.stage1_conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32, eps=1e-3),
-                                          nn.LeakyReLU(), nn.MaxPool2d(2, 2))
-        self.stage1_conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64, eps=1e-3),
-                                          nn.LeakyReLU(), nn.MaxPool2d(2, 2))
-        self.stage1_conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128, eps=1e-3),
-                                          nn.LeakyReLU())
-        self.stage1_conv4 = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3),
-                                          nn.LeakyReLU())
-        self.stage1_conv5 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128, eps=1e-3),
-                                          nn.LeakyReLU(), nn.MaxPool2d(2, 2))
-        self.stage1_conv6 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256, eps=1e-3),
-                                          nn.LeakyReLU())
-        self.stage1_conv7 = nn.Sequential(nn.Conv2d(256, 128, 1, 1, 0), nn.BatchNorm2d(128, eps=1e-3),
-                                          nn.LeakyReLU())
-        self.stage1_conv8 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256, eps=1e-3),
-                                          nn.LeakyReLU(), nn.MaxPool2d(2, 2))
-        self.stage1_conv9 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512, eps=1e-3),
+        self.num_anchors = num_anchors
+        self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32, eps=1e-3),
+                                   nn.LeakyReLU(), nn.MaxPool2d(2, 2))
+        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64, eps=1e-3),
+                                   nn.LeakyReLU(), nn.MaxPool2d(2, 2))
+        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv4 = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv5 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128, eps=1e-3),
+                                   nn.LeakyReLU(), nn.MaxPool2d(2, 2))
+        self.conv6 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv7 = nn.Sequential(nn.Conv2d(256, 128, 1, 1, 0), nn.BatchNorm2d(128, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv8 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256, eps=1e-3),
+                                   nn.LeakyReLU(), nn.MaxPool2d(2, 2))
+        self.conv9 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv10 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256, eps=1e-3),
+                                    nn.LeakyReLU())
+        self.conv11 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512, eps=1e-3),
+                                    nn.LeakyReLU())
+        self.conv12 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256, eps=1e-3),
+                                    nn.LeakyReLU())
+        self.conv13 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512, eps=1e-3),
+                                           nn.LeakyReLU())
+
+        self.conv_b = nn.Sequential(nn.Conv2d(512, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3),
+                                    nn.LeakyReLU())
+        self.avgpool_b = nn.AvgPool2d(2, 2)
+
+        self.maxpool_a = nn.MaxPool2d(2, 2)
+        self.conv_a1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv_a2 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1, 0), nn.BatchNorm2d(512, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv_a3 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv_a4 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1, 0), nn.BatchNorm2d(512, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv_a5 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+                                   nn.LeakyReLU())        
+        self.conv_a6 = nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+                                   nn.LeakyReLU())
+        self.conv_a7 = nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+                                            nn.LeakyReLU())
+
+        self.conv_out1 = nn.Sequential(nn.Conv2d(1088, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
                                           nn.LeakyReLU())
-        self.stage1_conv10 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256, eps=1e-3),
-                                           nn.LeakyReLU(), nn.MaxPool2d(2, 2))
-        self.out_conv = nn.Conv2d(256, len(self.anchors) * (5 + num_classes), 1, 1, 0)
+
+        self.conv_out2 = nn.Conv2d(1024, self.num_anchors * (5 + num_classes), 1, 1, 0)
 
     def forward(self, input):
-        output = self.stage1_conv1(input)
-        output = self.stage1_conv2(output)
-        output = self.stage1_conv3(output)
-        output = self.stage1_conv4(output)
-        output = self.stage1_conv5(output)
-        output = self.stage1_conv6(output)
-        output = self.stage1_conv7(output)
-        output = self.stage1_conv8(output)
-        output = self.stage1_conv9(output)
-        output = self.stage1_conv10(output)
-        output = self.out_conv(output)
+        output = self.conv1(input)
+        output = self.conv2(output)
+        output = self.conv3(output)
+        output = self.conv4(output)
+        output = self.conv5(output)
+        output = self.conv6(output)
+        output = self.conv7(output)
+        output = self.conv8(output)
+        output = self.conv9(output)
+        output = self.conv10(output)
+        output = self.conv11(output)
+        output = self.conv12(output)
+        output = self.conv13(output)
+
+        residual = output
+
+        output_a = self.maxpool_a(output)
+        output_a = self.conv_a1(output_a)
+        output_a = self.conv_a2(output_a)
+        output_a = self.conv_a3(output_a)
+        output_a = self.conv_a4(output_a)
+        output_a = self.conv_a5(output_a)
+        output_a = self.conv_a6(output_a)
+        output_a = self.conv_a7(output_a)
+
+        output_b = self.conv_b(residual)
+        output_b = self.avgpool_b(output_b)
+
+        output = torch.cat((output_a, output_b), 1)
+        output = self.conv_out1(output)
+        output = self.conv_out2(output)
         return output
index 4ef8735..650cab3 100644 (file)
@@ -29,6 +29,8 @@ using LayerHandle = std::shared_ptr<ml::train::Layer>;
 using ModelHandle = std::unique_ptr<ml::train::Model>;
 using UserDataType = std::unique_ptr<nntrainer::util::DirDataLoader>;
 
+const int num_classes = 4;
+
 int trainData_cb(float **input, float **label, bool *last, void *user_data) {
   auto data = reinterpret_cast<nntrainer::util::DirDataLoader *>(user_data);
 
@@ -175,21 +177,51 @@ ModelHandle YOLO() {
   blocks.push_back(yoloBlock("conv7", "conv6", 128, 1, false));
   blocks.push_back(yoloBlock("conv8", "conv7", 256, 3, true));
   blocks.push_back(yoloBlock("conv9", "conv8", 512, 3, false));
-  blocks.push_back(yoloBlock("conv10", "conv9", 256, 1, true));
+  blocks.push_back(yoloBlock("conv10", "conv9", 256, 1, false));
+  blocks.push_back(yoloBlock("conv11", "conv10", 512, 3, false));
+  blocks.push_back(yoloBlock("conv12", "conv11", 256, 1, false));
+  blocks.push_back(yoloBlock("conv13", "conv12", 512, 3, false));
+
+  blocks.push_back({createLayer(
+    "pooling2d", {withKey("name", "conv_a_pool"), withKey("stride", {2, 2}),
+                  withKey("pooling", "max"), withKey("pool_size", {2, 2}),
+                  withKey("input_layers", "conv13")})});
+  blocks.push_back(yoloBlock("conv_a1", "conv_a_pool", 1024, 3, false));
+  blocks.push_back(yoloBlock("conv_a2", "conv_a1", 512, 1, false));
+  blocks.push_back(yoloBlock("conv_a3", "conv_a2", 1024, 3, false));
+  blocks.push_back(yoloBlock("conv_a4", "conv_a3", 512, 1, false));
+  blocks.push_back(yoloBlock("conv_a5", "conv_a4", 1024, 3, false));
+  blocks.push_back(yoloBlock("conv_a6", "conv_a5", 1024, 3, false));
+  blocks.push_back(yoloBlock("conv_a7", "conv_a6", 1024, 3, false));
+
+  blocks.push_back(yoloBlock("conv_b", "conv13", 64, 1, false));
+  // todo: conv_b_pool layer will be replaced with re-organization custom layer
+  blocks.push_back({createLayer(
+    "pooling2d", {withKey("name", "conv_b_pool"), withKey("stride", {2, 2}),
+                  withKey("pooling", "average"), withKey("pool_size", {2, 2}),
+                  withKey("input_layers", "conv_b")})});
+
+  blocks.push_back(
+    {createLayer("concat", {withKey("name", "concat"),
+                            withKey("input_layers", "conv_a7, conv_b_pool"),
+                            withKey("axis", 1)})});
+
+  blocks.push_back(yoloBlock("conv_out1", "concat", 1024, 3, false));
+
+  blocks.push_back(
+    {createLayer("conv2d", {
+                             withKey("name", "conv_out2"),
+                             withKey("filters", (5 + num_classes) * 5),
+                             withKey("kernel_size", {1, 1}),
+                             withKey("stride", {1, 1}),
+                             withKey("padding", "same"),
+                             withKey("input_layers", "conv_out1"),
+                           })});
 
   for (auto &block : blocks) {
     layers.insert(layers.end(), block.begin(), block.end());
   }
 
-  layers.push_back(createLayer("conv2d", {
-                                           withKey("name", "conv_out"),
-                                           withKey("filters", 50),
-                                           withKey("kernel_size", {1, 1}),
-                                           withKey("stride", {1, 1}),
-                                           withKey("padding", "same"),
-                                           withKey("input_layers", "conv10"),
-                                         }));
-
   for (auto layer : layers) {
     model->addLayer(layer);
   }