[Application] Add Resnet Pytorch example
authorDongHak Park <donghak.park@samsung.com>
Thu, 26 Jan 2023 07:17:51 +0000 (16:17 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 1 Feb 2023 00:49:57 +0000 (09:49 +0900)
Add Resnet PyTorch Example
- It support only training if user want to use this code, user need to update code with test, validation
- This Example's Dataset : CIFAR100
- This Example's network exactly same with NNtrainer Resnet18 example
- We conduct benchmark test base on this code

Signed-off-by: DongHak Park <donghak.park@samsung.com>
Applications/Resnet/PyTorch/main.py [new file with mode: 0644]

diff --git a/Applications/Resnet/PyTorch/main.py b/Applications/Resnet/PyTorch/main.py
new file mode 100644 (file)
index 0000000..f81bd7e
--- /dev/null
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (C) 2023 DongHak Park <donghak.park@samsung.com>
+#
+# @file   main.cpp
+# @date   26 Jan 2023
+# @see    https://github.com/nnstreamer/nntrainer
+# @author Donghak Park <donghak.park@samsung.com>
+# @bug   No known bugs except for NYI items
+# @brief  This is Resnet Example for PyTorch (only training)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.utils.data import DataLoader
+from torchvision import datasets
+from torchvision.transforms import ToTensor, transforms
+
+DEVICE = "cpu"
+print(f"Using {DEVICE} device")
+print(f"PyTorch version: {torch.__version__}")
+
+EPOCH = 1
+BATCH_SIZE = 128
+IMG_SIZE = 32
+OUTPUT_SIZE = 100
+print(
+    f"Epoch: {EPOCH}, Batch size: {BATCH_SIZE}, Image size: 3x{IMG_SIZE}x{IMG_SIZE}, Output size: 1x1x{OUTPUT_SIZE}"
+)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
+        )
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(
+                    in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(self.expansion * planes),
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class ResNet(nn.Module):
+    def __init__(self, block, num_blocks, num_classes=100):
+        super(ResNet, self).__init__()
+        self.in_planes = 64
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+        self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        out = F.avg_pool2d(out, 4)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out
+
+
+def ResNet18():
+    return ResNet(BasicBlock, [2, 2, 2, 2])
+
+
+def train(dataloader, model, loss_fn, optimizer):
+    size = len(dataloader.dataset)
+
+    for batch, (X, y) in enumerate(dataloader):
+        X, y = X.to(DEVICE), y.to(DEVICE)
+
+        # Compute prediction error
+        pred = model(X)
+        loss = loss_fn(pred, y)
+
+        # Backpropagation
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        if batch % 100 == 0:
+            loss, current = loss.item(), batch * len(X)
+            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
+
+
+if __name__ == "__main__":
+    model = ResNet18().to(DEVICE)
+
+    transform_train = transforms.Compose(
+        [
+            transforms.RandomCrop(32, padding=4),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+        ]
+    )
+
+    train_dataset = datasets.CIFAR100(
+        root="./data", train=True, download=True, transform=transform_train
+    )
+
+    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
+
+    loss_fn = nn.CrossEntropyLoss()
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+    for t in range(EPOCH):
+        print(f"\nEPOCH {t+1}\n-------------------------------")
+        train(trainloader, model, loss_fn, optimizer)
+
+    print("Training Done!")