class YOLODataset(Dataset):
def __init__(self, img_dir, ann_dir):
super().__init__()
+ self.img_dir = img_dir
pattern = re.compile("\/(\d+)\.")
img_list = glob.glob(img_dir + "*")
ann_list = glob.glob(ann_dir + "*")
ann_ids = list(map(lambda x: pattern.search(x).group(1), ann_list))
ids_list = list(set(img_ids) & set(ann_ids))
- self.input_images = []
+ self.ids_list = []
self.bbox_gt = []
self.cls_gt = []
for ids in ids_list:
- img = np.array(Image.open(img_dir + ids + ".jpg").resize((416, 416))) / 255
label_bbox = []
label_cls = []
with open(ann_dir + ids + ".txt", "rt", encoding="utf-8") as f:
if len(label_cls) == 0:
continue
- self.input_images.append(img)
+ self.ids_list.append(ids)
self.bbox_gt.append(label_bbox)
self.cls_gt.append(label_cls)
- self.length = len(self.input_images)
- self.input_images = np.array(self.input_images)
- self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
+ self.length = len(self.ids_list)
def __len__(self):
return self.length
def __getitem__(self, idx):
- return self.input_images[idx], self.bbox_gt[idx], self.cls_gt[idx]
+ img = (
+ torch.FloatTensor(
+ np.array(
+ Image.open(self.img_dir + self.ids_list[idx] + ".jpg").resize(
+ (416, 416)
+ )
+ )
+ ).permute((2, 0, 1))
+ / 255
+ )
+ return img, self.bbox_gt[idx], self.cls_gt[idx]
##
import sys
import os
-from torchconverter import save_bin
-import torch
+from PIL import Image, ImageDraw
+from matplotlib import pyplot as plt
from torch import optim
from torch.utils.data import DataLoader
+import torch
+import numpy as np
from yolo import YoloV2
from yolo_loss import YoloV2_LOSS
from dataset import YOLODataset, collate_db
+from torchconverter import save_bin
device = "cuda" if torch.cuda.is_available() else "cpu"
valid loss: {epoch_valid_loss / len(valid_loader):.4f}"
)
+
##
# @brief bbox post process function for inference
-
-
def post_process_for_bbox(bbox_p):
"""
@param bbox_p shape(batch_size, cell_h x cell_w, num_anchors, 4)
return bbox_p
+def visualize_bbox(img_pred, bbox_preds):
+ img_array = (img_pred.to("cpu") * 255).permute((1, 2, 0)).numpy().astype(np.uint8)
+ img = Image.fromarray(img_array)
+
+ for bbox_pred in bbox_preds:
+ bbox_pred = [int(x * 416) for x in bbox_pred]
+
+ if sum(bbox_pred) == 0:
+ continue
+
+ x_lefttop = bbox_pred[0]
+ y_lefttop = bbox_pred[1]
+ width = bbox_pred[2]
+ height = bbox_pred[3]
+
+ draw = ImageDraw.Draw(img)
+ draw.rectangle(
+ [(x_lefttop, y_lefttop), (x_lefttop + width, y_lefttop + height)]
+ )
+
+ plt.imshow(img)
+ plt.show()
+
+
# inference example using trained model
-hypothesis = model(img).permute((0, 2, 3, 1))
+hypothesis = model(img.to(device)).permute((0, 2, 3, 1))
hypothesis = hypothesis[0].reshape((1, out_size**2, num_anchors, 5 + num_classes))
# transform output
# result of inference (data range 0~1)
iou_mask = iou_pred > 0.5
-print(bbox_pred * iou_mask, iou_pred * iou_mask, prob_pred * iou_mask)
+bbox_pred = bbox_pred * iou_mask
+visualize_bbox(img, bbox_pred.reshape(-1, 4))