[Application] update yolo v2 python for building pre-training model
authorSeungbaek Hong <sb92.hong@samsung.com>
Fri, 17 May 2024 08:38:52 +0000 (17:38 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:09:41 +0000 (08:09 +0900)
In order to train a large dataset, instead of loading the dataset into memory in advance, it was changed to a real-time loading method during training, and visualization code was added to check whether the training proceeded well.

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLOv2/PyTorch/dataset.py
Applications/YOLOv2/PyTorch/main.py

index 8f804eff49ea3a3a1c2d7fd4de1ae446db88470c..d939e0f8a96ec45686bc13d9e0fee72c3a51d9ba 100644 (file)
@@ -22,6 +22,7 @@ from PIL import Image
 class YOLODataset(Dataset):
     def __init__(self, img_dir, ann_dir):
         super().__init__()
+        self.img_dir = img_dir
         pattern = re.compile("\/(\d+)\.")
         img_list = glob.glob(img_dir + "*")
         ann_list = glob.glob(ann_dir + "*")
@@ -30,12 +31,11 @@ class YOLODataset(Dataset):
         ann_ids = list(map(lambda x: pattern.search(x).group(1), ann_list))
         ids_list = list(set(img_ids) & set(ann_ids))
 
-        self.input_images = []
+        self.ids_list = []
         self.bbox_gt = []
         self.cls_gt = []
 
         for ids in ids_list:
-            img = np.array(Image.open(img_dir + ids + ".jpg").resize((416, 416))) / 255
             label_bbox = []
             label_cls = []
             with open(ann_dir + ids + ".txt", "rt", encoding="utf-8") as f:
@@ -47,19 +47,27 @@ class YOLODataset(Dataset):
             if len(label_cls) == 0:
                 continue
 
-            self.input_images.append(img)
+            self.ids_list.append(ids)
             self.bbox_gt.append(label_bbox)
             self.cls_gt.append(label_cls)
 
-        self.length = len(self.input_images)
-        self.input_images = np.array(self.input_images)
-        self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
+        self.length = len(self.ids_list)
 
     def __len__(self):
         return self.length
 
     def __getitem__(self, idx):
-        return self.input_images[idx], self.bbox_gt[idx], self.cls_gt[idx]
+        img = (
+            torch.FloatTensor(
+                np.array(
+                    Image.open(self.img_dir + self.ids_list[idx] + ".jpg").resize(
+                        (416, 416)
+                    )
+                )
+            ).permute((2, 0, 1))
+            / 255
+        )
+        return img, self.bbox_gt[idx], self.cls_gt[idx]
 
 
 ##
index cd8d277945f8d7f8a1f8584b0fe105004012eb78..6e42fa1c6b25004d2d8c6d2ed44b4c70f7df6a15 100644 (file)
 import sys
 import os
 
-from torchconverter import save_bin
-import torch
+from PIL import Image, ImageDraw
+from matplotlib import pyplot as plt
 from torch import optim
 from torch.utils.data import DataLoader
+import torch
+import numpy as np
 
 from yolo import YoloV2
 from yolo_loss import YoloV2_LOSS
 from dataset import YOLODataset, collate_db
+from torchconverter import save_bin
 
 device = "cuda" if torch.cuda.is_available() else "cpu"
 
@@ -137,10 +140,9 @@ for epoch in range(epochs):
           valid loss: {epoch_valid_loss / len(valid_loader):.4f}"
     )
 
+
 ##
 # @brief bbox post process function for inference
-
-
 def post_process_for_bbox(bbox_p):
     """
     @param bbox_p shape(batch_size, cell_h x cell_w, num_anchors, 4)
@@ -175,8 +177,32 @@ def post_process_for_bbox(bbox_p):
     return bbox_p
 
 
+def visualize_bbox(img_pred, bbox_preds):
+    img_array = (img_pred.to("cpu") * 255).permute((1, 2, 0)).numpy().astype(np.uint8)
+    img = Image.fromarray(img_array)
+
+    for bbox_pred in bbox_preds:
+        bbox_pred = [int(x * 416) for x in bbox_pred]
+
+        if sum(bbox_pred) == 0:
+            continue
+
+        x_lefttop = bbox_pred[0]
+        y_lefttop = bbox_pred[1]
+        width = bbox_pred[2]
+        height = bbox_pred[3]
+
+        draw = ImageDraw.Draw(img)
+        draw.rectangle(
+            [(x_lefttop, y_lefttop), (x_lefttop + width, y_lefttop + height)]
+        )
+
+    plt.imshow(img)
+    plt.show()
+
+
 # inference example using trained model
-hypothesis = model(img).permute((0, 2, 3, 1))
+hypothesis = model(img.to(device)).permute((0, 2, 3, 1))
 hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5 + num_classes))
 
 # transform output
@@ -192,4 +218,5 @@ prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(
 
 # result of inference (data range 0~1)
 iou_mask = iou_pred > 0.5
-print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)
+bbox_pred = bbox_pred * iou_mask
+visualize_bbox(img, bbox_pred.reshape(-1, 4))