--- /dev/null
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file main.py
+# @date 8 March 2023
+# @brief Implement training for yolo
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from yolo import YoloV2_light
+from yolo_loss import YoloV2_LOSS
+from dataset import YOLODataset, collate_db
+
+
+# set config
+out_size = 13
+num_classes = 5
+num_anchors = 5
+
+epochs = 1000
+batch_size = 8
+
+img_dir = './custom_dataset/images/*'
+ann_dir = './custom_dataset/annotations/*'
+
+
+# load data
+dataset = YOLODataset(img_dir, ann_dir)
+loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=True, drop_last=True)
+
+
+# set model, loss and optimizer
+model = YoloV2_light(num_classes=5)
+criterion = YoloV2_LOSS(num_classes=5)
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
+
+
+# train model
+best_loss = 1e+10
+for epoch in range(epochs):
+ epoch_loss = 0
+ for idx, (img, bbox, cls) in enumerate(loader):
+ optimizer.zero_grad()
+ # model prediction
+ hypothesis = model(img).permute((0, 2, 3, 1))
+ hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))
+ # split each prediction(bbox, iou, class prob)
+ bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
+ bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
+ bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+ iou_pred = torch.sigmoid(hypothesis[..., 4:5])
+ score_pred = hypothesis[..., 5:].contiguous()
+ prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
+ # calc loss
+ loss = criterion(torch.FloatTensor(bbox_pred),
+ torch.FloatTensor(iou_pred),
+ torch.FloatTensor(prob_pred),
+ bbox,
+ cls)
+ # back prop
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ epoch_loss += loss.item()
+
+ if epoch_loss < best_loss:
+ best_loss = epoch_loss
+ torch.save(model.state_dict(), './best_model.pt')
+
+ print("{}epoch, loss: {:.4f}".format(epoch, epoch_loss / len(loader)))
+
+##
+# @brief bbox post process function for inference
+def post_process_for_bbox(bbox_pred):
+ """
+ @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ """
+ anchors = torch.FloatTensor(
+ [(1.3221, 1.73145),
+ (3.19275, 4.00944),
+ (5.05587, 8.09892),
+ (9.47112, 4.84053),
+ (11.2364, 10.0071)]
+ )
+
+ outsize = (13, 13)
+ width, height = outsize
+
+ # restore cell pos to x, y
+ for w in range(width):
+ for h in range(height):
+ bbox_pred[:, height*h + w, :, 0] += w
+ bbox_pred[:, height*h + w, :, 1] += h
+ bbox_pred[:, :, :, :2] /= 13
+
+ # apply anchors to w, h
+ anchor_w = anchors[:, 0].contiguous().view(-1, 1)
+ anchor_h = anchors[:, 1].contiguous().view(-1, 1)
+ bbox_pred[:, :, :, 2:3] *= anchor_w
+ bbox_pred[:, :, :, 3:4] *= anchor_h
+
+ return bbox_pred
+
+# inference example using trained model
+hypothesis = model(img).permute((0, 2, 3, 1))
+hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))
+
+# transform output
+bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
+bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
+bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+bbox_pred = post_process_for_bbox(bbox_pred)
+iou_pred = torch.sigmoid(hypothesis[..., 4:5])
+score_pred = hypothesis[..., 5:].contiguous()
+prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
+
+# result of inference (data range 0~1)
+iou_mask = (iou_pred > 0.5)
+print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)
--- /dev/null
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file yolo.py
+# @date 8 March 2023
+# @brief Define simple yolo model, but not original darknet.
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import torch
+import torch.nn as nn
+
+##
+# @brief define simple yolo model (not original darknet)
+class YoloV2_light(nn.Module):
+ def __init__(self,
+ num_classes,
+ anchors=\
+ [(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053), (11.2364, 10.0071)]):
+
+ super(YoloV2_light, self).__init__()
+ self.num_classes = num_classes
+ self.anchors = anchors
+ self.stage1_conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32),
+ nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+ self.stage1_conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64),
+ nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+ self.stage1_conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
+ nn.LeakyReLU(0.1))
+ self.stage1_conv4 = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.BatchNorm2d(64),
+ nn.LeakyReLU(0.1))
+ self.stage1_conv5 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
+ nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+ self.stage1_conv6 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
+ nn.LeakyReLU(0.1))
+ self.stage1_conv7 = nn.Sequential(nn.Conv2d(256, 128, 1, 1, 0), nn.BatchNorm2d(128),
+ nn.LeakyReLU(0.1))
+ self.stage1_conv8 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
+ nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+ self.stage1_conv9 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512),
+ nn.LeakyReLU(0.1))
+ self.stage1_conv10 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256),
+ nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
+ self.out_conv = nn.Conv2d(256, len(self.anchors) * (5 + num_classes), 1, 1, 0)
+
+ def forward(self, input):
+ output = self.stage1_conv1(input)
+ output = self.stage1_conv2(output)
+ output = self.stage1_conv3(output)
+ output = self.stage1_conv4(output)
+ output = self.stage1_conv5(output)
+ output = self.stage1_conv6(output)
+ output = self.stage1_conv7(output)
+ output = self.stage1_conv8(output)
+ output = self.stage1_conv9(output)
+ output = self.stage1_conv10(output)
+ output = self.out_conv(output)
+ return output
--- /dev/null
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+#
+# @file yolo_loss.py
+# @date 8 March 2023
+# @brief Define loss class for yolo
+#
+# @author Seungbaek Hong <sb92.hong@samsung.com>
+
+import torch
+import torch.nn as nn
+import torch.functional as F
+import numpy as np
+
+##
+# @brief calculate iou between two boxes list
+def calculate_iou(bbox1, bbox2):
+ """
+ @param bbox1 shape(numb_of_bbox, 4), it contains x, y, w, h
+ @param bbox2 shape(numb_of_bbox, 4), it contains x, y, w, h
+ @return result shape(numb_of_bbox, 1)
+ """
+ # bbox coordinates
+ b1x1, b1y1 = (bbox1[:, :2]).split(1, 1)
+ b1x2, b1y2 = (bbox1[:, :2] + (bbox1[:, 2:4])).split(1, 1)
+ b2x1, b2y1 = (bbox2[:, :2]).split(1, 1)
+ b2x2, b2y2 = (bbox2[:, :2] + (bbox2[:, 2:4])).split(1, 1)
+
+ # box areas
+ areas1 = (b1x2 - b1x1) * (b1y2 - b1y1)
+ areas2 = (b2x2 - b2x1) * (b2y2 - b2y1)
+
+ # intersections
+ min_x_of_max_x, max_x_of_min_x = torch.min(b1x2, b2x2), torch.max(b1x1, b2x1)
+ min_y_of_max_y, max_y_of_min_y = torch.min(b1y2, b2y2), torch.max(b1y1, b2y1)
+ intersection_width = (min_x_of_max_x - max_x_of_min_x).clamp(min=0)
+ intersection_height = (min_y_of_max_y - max_y_of_min_y).clamp(min=0)
+ intersections = intersection_width * intersection_height
+
+ # unions
+ unions = (areas1 + areas2) - intersections
+
+ result = intersections / unions
+ return result
+
+##
+# @brief find best iou and its index
+def find_best_ratio(anchors, bbox):
+ """
+ @param anchors shape(numb_of_anchors, 2), it contains w, h
+ @param bbox shape(numb_of_bbox, 2), it contains w, h
+ @return best_match index of best match, shape(numb_of_bbox, 1)
+ """
+ b1 = np.divide(anchors[:, 0], anchors[:, 1])
+ b2 = np.divide(bbox[:, 0], bbox[:, 1])
+ similarities = np.abs(b1.reshape(-1, 1) - b2)
+ best_match = np.argmin(similarities, axis=0)
+ return best_match
+
+##
+# @brief loss class for yolo
+class YoloV2_LOSS(nn.Module):
+ """Yolo v2 loss"""
+ def __init__(self, num_classes, img_shape = (416, 416), outsize = (13, 13)):
+ super().__init__()
+ self.num_classes = num_classes
+ self.img_shape = img_shape
+ self.outsize = outsize
+
+ self.anchors = torch.FloatTensor(
+ [(1.3221, 1.73145),
+ (3.19275, 4.00944),
+ (5.05587, 8.09892),
+ (9.47112, 4.84053),
+ (11.2364, 10.0071)]
+ )
+
+ self.mse = nn.MSELoss()
+ self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None
+
+ def forward(self, bbox_pred, iou_pred, prob_pred, bbox_gt, cls_gt):
+ """
+ @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ @param iou_pred shape(batch_size, cell_h x cell_w, 1)
+ @param prob_pred shape(batch_size, cell_h x cell_w, num_anchors, num_classes)
+ @param bbox_gt shape(batch_size, num_bbox, 4), data range(0~1)
+ @param cls_gt shape(batch_size, num_bbox, 1)
+ @return loss shape(1,)
+ """
+ bbox_pred = self.apply_anchors_to_bbox(bbox_pred)
+
+ bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask =\
+ self._build_target(bbox_pred, bbox_gt, cls_gt)
+
+ self.bbox_loss = self.mse(bbox_pred * bbox_mask,
+ bbox_built * bbox_mask)
+ self.iou_loss = self.mse(iou_pred * iou_mask,
+ iou_built * iou_mask)
+ self.cls_loss = self.mse(prob_pred * cls_mask,
+ cls_built * cls_mask)
+
+ return self.bbox_loss * 5 + self.iou_loss + self.cls_loss
+
+ def apply_anchors_to_bbox(self, bbox_pred):
+ """
+ @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ """
+ anchor_w = self.anchors[:, 0].contiguous().view(-1, 1)
+ anchor_h = self.anchors[:, 1].contiguous().view(-1, 1)
+ bbox_pred[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
+ bbox_pred[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)
+ return bbox_pred
+
+ def _build_target(self, bbox_pred, bbox_gt, cls_gt):
+ """
+ @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
+ @param bbox_gt shape(batch_size, num_bbox, 4)
+ @param cls_gt shape(batch_size, num_bbox, 1)
+ @return tuple of (bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask)
+ """
+ bbox_built, bbox_mask = [], []
+ iou_built, iou_mask = [], []
+ cls_built, cls_mask = [], []
+
+ batch_size = bbox_pred.shape[0]
+
+ for i in range(batch_size):
+ _bbox_built, _iou_built, _cls_built,\
+ _bbox_mask, _iou_mask, _cls_mask =\
+ self._make_target_per_sample(
+ torch.FloatTensor(bbox_pred[i]),
+ torch.FloatTensor(np.array(bbox_gt[i])),
+ torch.LongTensor(cls_gt[i])
+ )
+
+ bbox_built.append(_bbox_built.numpy())
+ bbox_mask.append(_bbox_mask.numpy())
+ iou_built.append(_iou_built.numpy())
+ iou_mask.append(_iou_mask.numpy())
+ cls_built.append(_cls_built.numpy())
+ cls_mask.append(_cls_mask.numpy())
+
+ bbox_built, bbox_mask, iou_built, iou_mask, cls_built, cls_mask =\
+ torch.FloatTensor(np.array(bbox_built)),\
+ torch.FloatTensor(np.array(bbox_mask)),\
+ torch.FloatTensor(np.array(iou_built)),\
+ torch.FloatTensor(np.array(iou_mask)),\
+ torch.FloatTensor(np.array(cls_built)),\
+ torch.FloatTensor(np.array(cls_mask))
+
+ return bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask
+
+ def _make_target_per_sample(self, _bbox_pred, _bbox_gt, _cls_gt):
+ """
+ @param _bbox_pred shape(cell_h x cell_w, num_anchors, 4)
+ @param _bbox_gt shape(num_bbox, 4)
+ @param _cls_gt shape(num_bbox,)
+ @return tuple of (_bbox_built, _iou_built, _cls_built, _bbox_mask, _iou_mask, _cls_mask)
+ """
+ hw, num_anchors, _ = _bbox_pred.shape
+
+ # set result template
+ _bbox_built = torch.zeros((hw, num_anchors, 4))
+ _bbox_mask = torch.zeros((hw, num_anchors, 1))
+
+ _iou_built = torch.zeros((hw, num_anchors, 1))
+ _iou_mask = torch.ones((hw, num_anchors, 1)) * 0.5
+
+ _cls_built = torch.zeros((hw, num_anchors, self.num_classes))
+ _cls_mask = torch.zeros((hw, num_anchors, 1))
+
+ # find best anchors
+ _bbox_gt_wh = _bbox_gt.clone()[:, 2:]
+ best_anchors = find_best_ratio(self.anchors, _bbox_gt_wh)
+
+ # normalize x, y pos based on cell coornindates
+ cx = _bbox_gt[:, 0] * self.outsize[0]
+ cy = _bbox_gt[:, 1] * self.outsize[1]
+ # calculate cell pos and normalize x, y
+ cell_idx = np.floor(cy) * self.outsize[0] + np.floor(cx)
+ cell_idx = np.array(cell_idx, dtype=np.int16)
+ cx -= np.floor(cx)
+ cy -= np.floor(cy)
+
+ # set bbox of gt
+ _bbox_built[cell_idx, best_anchors, 0] = cx
+ _bbox_built[cell_idx, best_anchors, 1] = cy
+ _bbox_built[cell_idx, best_anchors, 2] = torch.sqrt(_bbox_gt[:, 2])
+ _bbox_built[cell_idx, best_anchors, 3] = torch.sqrt(_bbox_gt[:, 3])
+ _bbox_mask[cell_idx, best_anchors, :] = 1
+
+ # set cls of gt
+ _cls_built[cell_idx, best_anchors, _cls_gt] = 1
+ _cls_mask[cell_idx, best_anchors, :] = 1
+
+ # set confidence score of gt
+ _iou_built = calculate_iou(_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4)).detach()
+ _iou_built = _iou_built.view(hw, num_anchors, 1)
+ _iou_mask[cell_idx, best_anchors, :] = 1
+
+ return _bbox_built, _iou_built, _cls_built,\
+ _bbox_mask, _iou_mask, _cls_mask