1 # SPDX-License-Identifier: Apache-2.0
2 # Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
6 # @brief Define dataset class for yolo
8 # @author Seungbaek Hong <sb92.hong@samsung.com>
13 from torch.utils.data import Dataset
14 from torch.utils.data.dataloader import default_collate
18 # @brief dataset class for yolo
19 # @note Need annotation text files corresponding to the name of the images.
20 class YOLODataset(Dataset):
21 def __init__(self, img_dir, ann_dir):
23 img_list = glob.glob(img_dir)
24 ann_list = glob.glob(ann_dir)
25 img_list.sort(), ann_list.sort()
27 self.length = len(img_list)
28 self.input_images = []
32 for i in range(len(img_list)):
33 img = np.array(Image.open(img_list[i])) / 255
36 with open(ann_list[i], 'rt') as f:
37 for line in f.readlines():
38 line = [int(i) for i in line.split()]
39 label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
40 label_cls.append(line[0])
42 self.input_images.append(img)
43 self.bbox_gt.append(label_bbox)
44 self.cls_gt.append(label_cls)
46 self.input_images = np.array(self.input_images)
47 self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
52 def __getitem__(self, idx):
53 return self.input_images[idx], self.bbox_gt[idx], self.cls_gt[idx]
56 # @brief collate db function for yolo
57 def collate_db(batch):
59 @param batch list of batch, (img, bbox, cls)
60 @return collated list of batch, (img, bbox, cls)
62 items = list(zip(*batch))
63 items[0] = default_collate(items[0])
64 items[1] = list(items[1])
65 items[2] = list(items[2])