1 # SPDX-License-Identifier: Apache-2.0
2 # Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
6 # @brief Implement training for yolo
8 # @author Seungbaek Hong <sb92.hong@samsung.com>
13 import torch.optim as optim
14 import torch.nn.functional as F
15 from torch.utils.data import DataLoader
17 from yolo import YoloV2_light
18 from yolo_loss import YoloV2_LOSS
19 from dataset import YOLODataset, collate_db
30 img_dir = './custom_dataset/images/*'
31 ann_dir = './custom_dataset/annotations/*'
35 dataset = YOLODataset(img_dir, ann_dir)
36 loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=True, drop_last=True)
39 # set model, loss and optimizer
40 model = YoloV2_light(num_classes=5)
41 criterion = YoloV2_LOSS(num_classes=5)
42 optimizer = optim.Adam(model.parameters(), lr=1e-3)
43 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
48 for epoch in range(epochs):
50 for idx, (img, bbox, cls) in enumerate(loader):
53 hypothesis = model(img).permute((0, 2, 3, 1))
54 hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))
55 # split each prediction(bbox, iou, class prob)
56 bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
57 bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
58 bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
59 iou_pred = torch.sigmoid(hypothesis[..., 4:5])
60 score_pred = hypothesis[..., 5:].contiguous()
61 prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
63 loss = criterion(torch.FloatTensor(bbox_pred),
64 torch.FloatTensor(iou_pred),
65 torch.FloatTensor(prob_pred),
72 epoch_loss += loss.item()
74 if epoch_loss < best_loss:
75 best_loss = epoch_loss
76 torch.save(model.state_dict(), './best_model.pt')
78 print("{}epoch, loss: {:.4f}".format(epoch, epoch_loss / len(loader)))
81 # @brief bbox post process function for inference
82 def post_process_for_bbox(bbox_pred):
84 @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
85 @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
87 anchors = torch.FloatTensor(
96 width, height = outsize
98 # restore cell pos to x, y
99 for w in range(width):
100 for h in range(height):
101 bbox_pred[:, height*h + w, :, 0] += w
102 bbox_pred[:, height*h + w, :, 1] += h
103 bbox_pred[:, :, :, :2] /= 13
105 # apply anchors to w, h
106 anchor_w = anchors[:, 0].contiguous().view(-1, 1)
107 anchor_h = anchors[:, 1].contiguous().view(-1, 1)
108 bbox_pred[:, :, :, 2:3] *= anchor_w
109 bbox_pred[:, :, :, 3:4] *= anchor_h
113 # inference example using trained model
114 hypothesis = model(img).permute((0, 2, 3, 1))
115 hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))
118 bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
119 bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
120 bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
121 bbox_pred = post_process_for_bbox(bbox_pred)
122 iou_pred = torch.sigmoid(hypothesis[..., 4:5])
123 score_pred = hypothesis[..., 5:].contiguous()
124 prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
126 # result of inference (data range 0~1)
127 iou_mask = (iou_pred > 0.5)
128 print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)