[Application] add object detection example using pytorch
authorSeungbaek Hong <sb92.hong@samsung.com>
Fri, 3 Mar 2023 05:04:27 +0000 (14:04 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 15 Mar 2023 05:43:56 +0000 (14:43 +0900)
Add YOLO v2 object detection example using pytorch.

It was implemented in the same way as the YOLO v2 model,
but only the backbone model was made into a simpler cnn
model to make it easier to configure the initial version
of YOLO to be supported by NNTrainer.

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLO/PyTorch/dataset.py [new file with mode: 0644]
Applications/YOLO/PyTorch/main.py [new file with mode: 0644]
Applications/YOLO/PyTorch/yolo.py [new file with mode: 0644]
Applications/YOLO/PyTorch/yolo_loss.py [new file with mode: 0644]

diff --git a/Applications/YOLO/PyTorch/dataset.py b/Applications/YOLO/PyTorch/dataset.py
new file mode 100644 (file)
index 0000000..738c3e8
--- /dev/null
@@ -0,0 +1,66 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file dataset.py
+# @date 8 March 2023
+# @brief Define dataset class for yolo
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import glob
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+from PIL import Image
+
+##
+# @brief dataset class for yolo
+# @note Need annotation text files corresponding to the name of the images.    
+class YOLODataset(Dataset):
+    def __init__(self, img_dir, ann_dir):
+        super().__init__()
+        img_list = glob.glob(img_dir)
+        ann_list = glob.glob(ann_dir)
+        img_list.sort(), ann_list.sort()
+    
+        self.length = len(img_list)
+        self.input_images = []
+        self.bbox_gt = []
+        self.cls_gt = []
+
+        for i in range(len(img_list)):
+            img = np.array(Image.open(img_list[i])) / 255
+            label_bbox = []
+            label_cls = []
+            with open(ann_list[i], 'rt') as f:
+                for line in f.readlines():
+                    line = [int(i) for i in line.split()]
+                    label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
+                    label_cls.append(line[0])
+                    
+            self.input_images.append(img)
+            self.bbox_gt.append(label_bbox)
+            self.cls_gt.append(label_cls)
+        
+        self.input_images = np.array(self.input_images)
+        self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
+        
+    def __len__(self):
+        return self.length
+    
+    def __getitem__(self, idx):
+        return self.input_images[idx], self.bbox_gt[idx], self.cls_gt[idx]
+    
+##
+# @brief collate db function for yolo
+def collate_db(batch):
+    """
+    @param batch list of batch, (img, bbox, cls)
+    @return collated list of batch, (img, bbox, cls)
+    """
+    items = list(zip(*batch))
+    items[0] = default_collate(items[0])
+    items[1] = list(items[1])
+    items[2] = list(items[2])
+    return items
diff --git a/Applications/YOLO/PyTorch/main.py b/Applications/YOLO/PyTorch/main.py
new file mode 100644 (file)
index 0000000..5cd826c
--- /dev/null
@@ -0,0 +1,128 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file main.py
+# @date 8 March 2023
+# @brief Implement training for yolo
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from yolo import YoloV2_light
+from yolo_loss import YoloV2_LOSS
+from dataset import YOLODataset, collate_db
+
+
+# set config
+out_size = 13
+num_classes = 5
+num_anchors = 5
+
+epochs = 1000
+batch_size = 8
+
+img_dir = './custom_dataset/images/*'
+ann_dir = './custom_dataset/annotations/*'
+
+
+# load data
+dataset = YOLODataset(img_dir, ann_dir)
+loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=True, drop_last=True)
+
+
+# set model, loss and optimizer
+model = YoloV2_light(num_classes=5)
+criterion = YoloV2_LOSS(num_classes=5)
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
+
+
+# train model
+best_loss = 1e+10
+for epoch in range(epochs):
+    epoch_loss = 0
+    for idx, (img, bbox, cls) in enumerate(loader):
+        optimizer.zero_grad()
+        # model prediction
+        hypothesis = model(img).permute((0, 2, 3, 1))
+        hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))        
+        # split each prediction(bbox, iou, class prob)
+        bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
+        bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
+        bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)        
+        iou_pred = torch.sigmoid(hypothesis[..., 4:5])        
+        score_pred = hypothesis[..., 5:].contiguous()
+        prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
+        # calc loss
+        loss = criterion(torch.FloatTensor(bbox_pred),
+                         torch.FloatTensor(iou_pred),
+                         torch.FloatTensor(prob_pred),
+                         bbox,
+                         cls)
+        # back prop
+        loss.backward()
+        optimizer.step()  
+        scheduler.step()
+        epoch_loss += loss.item()
+        
+    if epoch_loss < best_loss:
+        best_loss = epoch_loss
+        torch.save(model.state_dict(), './best_model.pt')
+        
+    print("{}epoch, loss: {:.4f}".format(epoch, epoch_loss / len(loader)))
+
+##
+# @brief bbox post process function for inference
+def post_process_for_bbox(bbox_pred):    
+    """
+    @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)    
+    @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+    """
+    anchors = torch.FloatTensor(
+        [(1.3221, 1.73145),
+        (3.19275, 4.00944),
+        (5.05587, 8.09892),
+        (9.47112, 4.84053),
+        (11.2364, 10.0071)]
+    )
+
+    outsize = (13, 13)
+    width, height = outsize
+    
+    # restore cell pos to x, y    
+    for w in range(width):
+        for h in range(height):
+            bbox_pred[:, height*h + w, :, 0] += w
+            bbox_pred[:, height*h + w, :, 1] += h
+    bbox_pred[:, :, :, :2] /= 13
+    
+    # apply anchors to w, h
+    anchor_w = anchors[:, 0].contiguous().view(-1, 1)
+    anchor_h = anchors[:, 1].contiguous().view(-1, 1)        
+    bbox_pred[:, :, :, 2:3] *= anchor_w
+    bbox_pred[:, :, :, 3:4] *= anchor_h
+
+    return bbox_pred
+
+# inference example using trained model
+hypothesis = model(img).permute((0, 2, 3, 1))
+hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))        
+
+# transform output
+bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
+bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
+bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+bbox_pred = post_process_for_bbox(bbox_pred)
+iou_pred = torch.sigmoid(hypothesis[..., 4:5])
+score_pred = hypothesis[..., 5:].contiguous()
+prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
+
+# result of inference (data range 0~1)
+iou_mask = (iou_pred > 0.5)
+print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)
diff --git a/Applications/YOLO/PyTorch/yolo.py b/Applications/YOLO/PyTorch/yolo.py
new file mode 100644 (file)
index 0000000..28e0f4b
--- /dev/null
@@ -0,0 +1,58 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file yolo.py
+# @date 8 March 2023
+# @brief Define simple yolo model, but not original darknet.
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import torch
+import torch.nn as nn
+
+##
+# @brief define simple yolo model (not original darknet)
+class YoloV2_light(nn.Module): 
+    def __init__(self, 
+                 num_classes,
+                 anchors=\
+                 [(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053), (11.2364, 10.0071)]):
+        
+        super(YoloV2_light, self).__init__()              
+        self.num_classes = num_classes
+        self.anchors = anchors
+        self.stage1_conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32),
+                                          nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+        self.stage1_conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64),
+                                          nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+        self.stage1_conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
+                                          nn.LeakyReLU(0.1))
+        self.stage1_conv4 = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.BatchNorm2d(64),
+                                          nn.LeakyReLU(0.1))
+        self.stage1_conv5 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
+                                          nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+        self.stage1_conv6 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
+                                          nn.LeakyReLU(0.1))
+        self.stage1_conv7 = nn.Sequential(nn.Conv2d(256, 128, 1, 1, 0), nn.BatchNorm2d(128),
+                                          nn.LeakyReLU(0.1))
+        self.stage1_conv8 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
+                                          nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+        self.stage1_conv9 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512),
+                                          nn.LeakyReLU(0.1))
+        self.stage1_conv10 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256),
+                                           nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+        self.out_conv = nn.Conv2d(256, len(self.anchors) * (5 + num_classes), 1, 1, 0)
+
+    def forward(self, input):
+        output = self.stage1_conv1(input)
+        output = self.stage1_conv2(output)
+        output = self.stage1_conv3(output)
+        output = self.stage1_conv4(output)
+        output = self.stage1_conv5(output)
+        output = self.stage1_conv6(output)
+        output = self.stage1_conv7(output)
+        output = self.stage1_conv8(output)
+        output = self.stage1_conv9(output)
+        output = self.stage1_conv10(output)
+        output = self.out_conv(output)
+        return output
diff --git a/Applications/YOLO/PyTorch/yolo_loss.py b/Applications/YOLO/PyTorch/yolo_loss.py
new file mode 100644 (file)
index 0000000..c3754cc
--- /dev/null
@@ -0,0 +1,203 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file yolo_loss.py
+# @date 8 March 2023
+# @brief Define loss class for yolo
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import torch
+import torch.nn as nn
+import torch.functional as F
+import numpy as np
+
+##
+# @brief calculate iou between two boxes list
+def calculate_iou(bbox1, bbox2):    
+    """
+    @param bbox1 shape(numb_of_bbox, 4), it contains x, y, w, h
+    @param bbox2 shape(numb_of_bbox, 4), it contains x, y, w, h
+    @return result shape(numb_of_bbox, 1)
+    """
+    # bbox coordinates
+    b1x1, b1y1 = (bbox1[:, :2]).split(1, 1)
+    b1x2, b1y2 = (bbox1[:, :2] + (bbox1[:, 2:4])).split(1, 1)
+    b2x1, b2y1 = (bbox2[:, :2]).split(1, 1)
+    b2x2, b2y2 = (bbox2[:, :2] + (bbox2[:, 2:4])).split(1, 1)
+    
+    # box areas
+    areas1 = (b1x2 - b1x1) * (b1y2 - b1y1)
+    areas2 = (b2x2 - b2x1) * (b2y2 - b2y1)
+    
+    # intersections
+    min_x_of_max_x, max_x_of_min_x = torch.min(b1x2, b2x2), torch.max(b1x1, b2x1)
+    min_y_of_max_y, max_y_of_min_y = torch.min(b1y2, b2y2), torch.max(b1y1, b2y1)
+    intersection_width = (min_x_of_max_x - max_x_of_min_x).clamp(min=0)
+    intersection_height = (min_y_of_max_y - max_y_of_min_y).clamp(min=0)
+    intersections = intersection_width * intersection_height
+    
+    # unions        
+    unions = (areas1 + areas2) - intersections
+    
+    result = intersections / unions    
+    return result
+
+##
+# @brief find best iou and its index
+def find_best_ratio(anchors, bbox):    
+    """
+    @param anchors shape(numb_of_anchors, 2), it contains w, h
+    @param bbox shape(numb_of_bbox, 2), it contains w, h
+    @return best_match index of best match, shape(numb_of_bbox, 1)
+    """
+    b1 = np.divide(anchors[:, 0], anchors[:, 1])
+    b2 = np.divide(bbox[:, 0], bbox[:, 1])
+    similarities = np.abs(b1.reshape(-1, 1) - b2)
+    best_match = np.argmin(similarities, axis=0)
+    return best_match
+
+##
+# @brief loss class for yolo
+class YoloV2_LOSS(nn.Module):
+    """Yolo v2 loss"""
+    def __init__(self, num_classes, img_shape = (416, 416), outsize = (13, 13)):
+        super().__init__()
+        self.num_classes = num_classes
+        self.img_shape = img_shape
+        self.outsize = outsize
+        
+        self.anchors = torch.FloatTensor(
+            [(1.3221, 1.73145),
+            (3.19275, 4.00944),
+            (5.05587, 8.09892),
+            (9.47112, 4.84053),
+            (11.2364, 10.0071)]
+        )
+                
+        self.mse = nn.MSELoss()
+        self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None       
+        
+    def forward(self, bbox_pred, iou_pred, prob_pred, bbox_gt, cls_gt):        
+        """
+        @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+        @param iou_pred shape(batch_size, cell_h x cell_w, 1)
+        @param prob_pred shape(batch_size, cell_h x cell_w, num_anchors, num_classes)
+        @param bbox_gt shape(batch_size, num_bbox, 4), data range(0~1)
+        @param cls_gt shape(batch_size, num_bbox, 1)
+        @return loss shape(1,)
+        """
+        bbox_pred = self.apply_anchors_to_bbox(bbox_pred)
+        
+        bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask =\
+            self._build_target(bbox_pred, bbox_gt, cls_gt)
+        
+        self.bbox_loss = self.mse(bbox_pred * bbox_mask,
+                                        bbox_built * bbox_mask)
+        self.iou_loss = self.mse(iou_pred * iou_mask,
+                                       iou_built * iou_mask)
+        self.cls_loss = self.mse(prob_pred * cls_mask,
+                                       cls_built * cls_mask)
+        
+        return self.bbox_loss * 5 + self.iou_loss + self.cls_loss
+        
+    def apply_anchors_to_bbox(self, bbox_pred):        
+        """
+        @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+        @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)    
+        """
+        anchor_w = self.anchors[:, 0].contiguous().view(-1, 1)
+        anchor_h = self.anchors[:, 1].contiguous().view(-1, 1)
+        bbox_pred[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
+        bbox_pred[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)        
+        return bbox_pred
+    
+    def _build_target(self, bbox_pred, bbox_gt, cls_gt):
+        """
+        @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+        @param bbox_gt shape(batch_size, num_bbox, 4)
+        @param cls_gt shape(batch_size, num_bbox, 1)
+        @return tuple of (bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask)
+        """    
+        bbox_built, bbox_mask = [], []
+        iou_built, iou_mask = [], []
+        cls_built, cls_mask = [], []
+        
+        batch_size = bbox_pred.shape[0]
+                
+        for i in range(batch_size):
+            _bbox_built, _iou_built, _cls_built,\
+                _bbox_mask, _iou_mask, _cls_mask =\
+                    self._make_target_per_sample(
+                        torch.FloatTensor(bbox_pred[i]),
+                        torch.FloatTensor(np.array(bbox_gt[i])),
+                        torch.LongTensor(cls_gt[i])
+                    )
+            
+            bbox_built.append(_bbox_built.numpy())
+            bbox_mask.append(_bbox_mask.numpy())
+            iou_built.append(_iou_built.numpy())
+            iou_mask.append(_iou_mask.numpy())
+            cls_built.append(_cls_built.numpy())
+            cls_mask.append(_cls_mask.numpy())
+
+        bbox_built, bbox_mask, iou_built, iou_mask, cls_built, cls_mask =\
+            torch.FloatTensor(np.array(bbox_built)),\
+            torch.FloatTensor(np.array(bbox_mask)),\
+            torch.FloatTensor(np.array(iou_built)),\
+            torch.FloatTensor(np.array(iou_mask)),\
+            torch.FloatTensor(np.array(cls_built)),\
+            torch.FloatTensor(np.array(cls_mask))
+                    
+        return bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask
+        
+    def _make_target_per_sample(self, _bbox_pred, _bbox_gt, _cls_gt):
+        """
+        @param _bbox_pred shape(cell_h x cell_w, num_anchors, 4)
+        @param _bbox_gt shape(num_bbox, 4)
+        @param _cls_gt shape(num_bbox,)
+        @return tuple of (_bbox_built, _iou_built, _cls_built, _bbox_mask, _iou_mask, _cls_mask)
+        """
+        hw, num_anchors, _  = _bbox_pred.shape
+        
+        # set result template
+        _bbox_built = torch.zeros((hw, num_anchors, 4))
+        _bbox_mask = torch.zeros((hw, num_anchors, 1))
+        
+        _iou_built = torch.zeros((hw, num_anchors, 1))
+        _iou_mask = torch.ones((hw, num_anchors, 1)) * 0.5
+        
+        _cls_built = torch.zeros((hw, num_anchors, self.num_classes))
+        _cls_mask = torch.zeros((hw, num_anchors, 1))
+                        
+        # find best anchors
+        _bbox_gt_wh = _bbox_gt.clone()[:, 2:]        
+        best_anchors = find_best_ratio(self.anchors, _bbox_gt_wh)
+        
+        # normalize x, y pos based on cell coornindates
+        cx = _bbox_gt[:, 0] * self.outsize[0]
+        cy = _bbox_gt[:, 1] * self.outsize[1]
+        # calculate cell pos and normalize x, y
+        cell_idx = np.floor(cy) * self.outsize[0] + np.floor(cx)
+        cell_idx = np.array(cell_idx, dtype=np.int16)
+        cx -= np.floor(cx)
+        cy -= np.floor(cy)
+                
+        # set bbox of gt
+        _bbox_built[cell_idx, best_anchors, 0] = cx 
+        _bbox_built[cell_idx, best_anchors, 1] = cy
+        _bbox_built[cell_idx, best_anchors, 2] = torch.sqrt(_bbox_gt[:, 2]) 
+        _bbox_built[cell_idx, best_anchors, 3] = torch.sqrt(_bbox_gt[:, 3]) 
+        _bbox_mask[cell_idx, best_anchors, :] = 1
+        
+        # set cls of gt       
+        _cls_built[cell_idx, best_anchors, _cls_gt] = 1
+        _cls_mask[cell_idx, best_anchors, :] = 1
+        
+        # set confidence score of gt
+        _iou_built = calculate_iou(_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4)).detach()    
+        _iou_built = _iou_built.view(hw, num_anchors, 1)
+        _iou_mask[cell_idx, best_anchors, :] = 1
+        
+        return _bbox_built, _iou_built, _cls_built,\
+                _bbox_mask, _iou_mask, _cls_mask