[Application] add object detection example using pytorch
[platform/core/ml/nntrainer.git] / Applications / YOLO / PyTorch / main.py
1 # SPDX-License-Identifier: Apache-2.0
2 # Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
3 #
4 # @file main.py
5 # @date 8 March 2023
6 # @brief Implement training for yolo
7 #
8 # @author Seungbaek Hong <sb92.hong@samsung.com>
9
10 import numpy as np
11 import torch
12 import torch.nn as nn
13 import torch.optim as optim
14 import torch.nn.functional as F
15 from torch.utils.data import DataLoader
16
17 from yolo import YoloV2_light
18 from yolo_loss import YoloV2_LOSS
19 from dataset import YOLODataset, collate_db
20
21
22 # set config
23 out_size = 13
24 num_classes = 5
25 num_anchors = 5
26
27 epochs = 1000
28 batch_size = 8
29
30 img_dir = './custom_dataset/images/*'
31 ann_dir = './custom_dataset/annotations/*'
32
33
34 # load data
35 dataset = YOLODataset(img_dir, ann_dir)
36 loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=True, drop_last=True)
37
38
39 # set model, loss and optimizer
40 model = YoloV2_light(num_classes=5)
41 criterion = YoloV2_LOSS(num_classes=5)
42 optimizer = optim.Adam(model.parameters(), lr=1e-3)
43 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
44
45
46 # train model
47 best_loss = 1e+10
48 for epoch in range(epochs):
49     epoch_loss = 0
50     for idx, (img, bbox, cls) in enumerate(loader):
51         optimizer.zero_grad()
52         # model prediction
53         hypothesis = model(img).permute((0, 2, 3, 1))
54         hypothesis = hypothesis.reshape((batch_size, out_size**2, num_anchors, 5+num_classes))        
55         # split each prediction(bbox, iou, class prob)
56         bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
57         bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
58         bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)        
59         iou_pred = torch.sigmoid(hypothesis[..., 4:5])        
60         score_pred = hypothesis[..., 5:].contiguous()
61         prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
62         # calc loss
63         loss = criterion(torch.FloatTensor(bbox_pred),
64                          torch.FloatTensor(iou_pred),
65                          torch.FloatTensor(prob_pred),
66                          bbox,
67                          cls)
68         # back prop
69         loss.backward()
70         optimizer.step()  
71         scheduler.step()
72         epoch_loss += loss.item()
73         
74     if epoch_loss < best_loss:
75         best_loss = epoch_loss
76         torch.save(model.state_dict(), './best_model.pt')
77         
78     print("{}epoch, loss: {:.4f}".format(epoch, epoch_loss / len(loader)))
79
80 ##
81 # @brief bbox post process function for inference
82 def post_process_for_bbox(bbox_pred):    
83     """
84     @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)    
85     @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
86     """
87     anchors = torch.FloatTensor(
88         [(1.3221, 1.73145),
89         (3.19275, 4.00944),
90         (5.05587, 8.09892),
91         (9.47112, 4.84053),
92         (11.2364, 10.0071)]
93     )
94
95     outsize = (13, 13)
96     width, height = outsize
97     
98     # restore cell pos to x, y    
99     for w in range(width):
100         for h in range(height):
101             bbox_pred[:, height*h + w, :, 0] += w
102             bbox_pred[:, height*h + w, :, 1] += h
103     bbox_pred[:, :, :, :2] /= 13
104     
105     # apply anchors to w, h
106     anchor_w = anchors[:, 0].contiguous().view(-1, 1)
107     anchor_h = anchors[:, 1].contiguous().view(-1, 1)        
108     bbox_pred[:, :, :, 2:3] *= anchor_w
109     bbox_pred[:, :, :, 3:4] *= anchor_h
110
111     return bbox_pred
112
113 # inference example using trained model
114 hypothesis = model(img).permute((0, 2, 3, 1))
115 hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5+num_classes))        
116
117 # transform output
118 bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
119 bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
120 bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
121 bbox_pred = post_process_for_bbox(bbox_pred)
122 iou_pred = torch.sigmoid(hypothesis[..., 4:5])
123 score_pred = hypothesis[..., 5:].contiguous()
124 prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
125
126 # result of inference (data range 0~1)
127 iou_mask = (iou_pred > 0.5)
128 print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)