[Application] add object detection example using pytorch
[platform/core/ml/nntrainer.git] / Applications / YOLO / PyTorch / dataset.py
1 # SPDX-License-Identifier: Apache-2.0
2 # Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
3 #
4 # @file dataset.py
5 # @date 8 March 2023
6 # @brief Define dataset class for yolo
7 #
8 # @author Seungbaek Hong <sb92.hong@samsung.com>
9
10 import glob
11 import numpy as np
12 import torch
13 from torch.utils.data import Dataset
14 from torch.utils.data.dataloader import default_collate
15 from PIL import Image
16
17 ##
18 # @brief dataset class for yolo
19 # @note Need annotation text files corresponding to the name of the images.    
20 class YOLODataset(Dataset):
21     def __init__(self, img_dir, ann_dir):
22         super().__init__()
23         img_list = glob.glob(img_dir)
24         ann_list = glob.glob(ann_dir)
25         img_list.sort(), ann_list.sort()
26     
27         self.length = len(img_list)
28         self.input_images = []
29         self.bbox_gt = []
30         self.cls_gt = []
31
32         for i in range(len(img_list)):
33             img = np.array(Image.open(img_list[i])) / 255
34             label_bbox = []
35             label_cls = []
36             with open(ann_list[i], 'rt') as f:
37                 for line in f.readlines():
38                     line = [int(i) for i in line.split()]
39                     label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
40                     label_cls.append(line[0])
41                     
42             self.input_images.append(img)
43             self.bbox_gt.append(label_bbox)
44             self.cls_gt.append(label_cls)
45         
46         self.input_images = np.array(self.input_images)
47         self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
48         
49     def __len__(self):
50         return self.length
51     
52     def __getitem__(self, idx):
53         return self.input_images[idx], self.bbox_gt[idx], self.cls_gt[idx]
54     
55 ##
56 # @brief collate db function for yolo
57 def collate_db(batch):
58     """
59     @param batch list of batch, (img, bbox, cls)
60     @return collated list of batch, (img, bbox, cls)
61     """
62     items = list(zip(*batch))
63     items[0] = default_collate(items[0])
64     items[1] = list(items[1])
65     items[2] = list(items[2])
66     return items