1 # SPDX-License-Identifier: Apache-2.0
2 # Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
6 # @brief Define simple yolo model, but not original darknet.
8 # @author Seungbaek Hong <sb92.hong@samsung.com>
14 # @brief define simple yolo model (not original darknet)
15 class YoloV2_light(nn.Module):
19 [(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053), (11.2364, 10.0071)]):
21 super(YoloV2_light, self).__init__()
22 self.num_classes = num_classes
23 self.anchors = anchors
24 self.stage1_conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32),
25 nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
26 self.stage1_conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64),
27 nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
28 self.stage1_conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
30 self.stage1_conv4 = nn.Sequential(nn.Conv2d(128, 64, 1, 1, 0), nn.BatchNorm2d(64),
32 self.stage1_conv5 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128),
33 nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
34 self.stage1_conv6 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
36 self.stage1_conv7 = nn.Sequential(nn.Conv2d(256, 128, 1, 1, 0), nn.BatchNorm2d(128),
38 self.stage1_conv8 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256),
39 nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
40 self.stage1_conv9 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512),
42 self.stage1_conv10 = nn.Sequential(nn.Conv2d(512, 256, 1, 1, 0), nn.BatchNorm2d(256),
43 nn.LeakyReLU(0.1), nn.MaxPool2d(2, 2))
44 self.out_conv = nn.Conv2d(256, len(self.anchors) * (5 + num_classes), 1, 1, 0)
46 def forward(self, input):
47 output = self.stage1_conv1(input)
48 output = self.stage1_conv2(output)
49 output = self.stage1_conv3(output)
50 output = self.stage1_conv4(output)
51 output = self.stage1_conv5(output)
52 output = self.stage1_conv6(output)
53 output = self.stage1_conv7(output)
54 output = self.stage1_conv8(output)
55 output = self.stage1_conv9(output)
56 output = self.stage1_conv10(output)
57 output = self.out_conv(output)