Add resnet test, convert more modules (#14437)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Tue, 4 Dec 2018 21:40:11 +0000 (13:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 21:42:41 +0000 (13:42 -0800)
Summary:
This PR add resnet to test_jit and convert more nn modules, stacked on #14533 and #14715
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14437

Differential Revision: D13325871

Pulled By: wanchaol

fbshipit-source-id: 6c94a988b36794a373af6541c0c262a07291f7b1

test/test_jit.py
torch/nn/modules/conv.py
torch/nn/modules/pooling.py

index 169e72c..f71ee37 100644 (file)
@@ -6305,7 +6305,7 @@ a")
         self.assertEqual(m_orig.forward(input), m_import.forward(input))
 
     @skipIfNoTorchVision
-    def test_script_module_export_resnet18(self):
+    def test_script_module_trace_resnet18(self):
         x = torch.ones(1, 3, 224, 224)
         m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
         m_import = self.getExportImportCopy(m_orig)
@@ -6323,6 +6323,126 @@ a")
         self.assertEqual(output_orig, output_import)
         self.assertEqual(grad_orig, grad_import)
 
+    @skipIfNoTorchVision
+    def test_script_module_script_resnet(self):
+        def conv1x1(in_planes, out_planes, stride=1):
+            """1x1 convolution"""
+            return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+        def conv3x3(in_planes, out_planes, stride=1):
+            """3x3 convolution with padding"""
+            return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                             padding=1, bias=False)
+
+        class BasicBlock(torch.jit.ScriptModule):
+            expansion = 1
+            __constants__ = ['downsample']
+
+            def __init__(self, inplanes, planes, stride=1, downsample=None):
+                super(BasicBlock, self).__init__()
+                self.conv1 = conv3x3(inplanes, planes, stride)
+                self.bn1 = nn.BatchNorm2d(planes)
+                self.relu = nn.ReLU(inplace=True)
+                self.conv2 = conv3x3(planes, planes)
+                self.bn2 = nn.BatchNorm2d(planes)
+                self.downsample = downsample
+                self.stride = stride
+
+            @torch.jit.script_method
+            def forward(self, x):
+                residual = x
+
+                out = self.conv1(x)
+                out = self.bn1(out)
+                out = self.relu(out)
+
+                out = self.conv2(out)
+                out = self.bn2(out)
+
+                if self.downsample is not None:
+                    residual = self.downsample(x)
+
+                out += residual
+                out = self.relu(out)
+
+                return out
+
+        class ResNet(torch.jit.ScriptModule):
+            __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
+
+            def __init__(self, block, layers, num_classes=1000):
+                super(ResNet, self).__init__()
+                self.inplanes = 64
+                self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                                       bias=False)
+                self.bn1 = nn.BatchNorm2d(64)
+                self.relu = nn.ReLU(inplace=True)
+                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+                self.layer1 = self._make_layer(block, 64, layers[0])
+                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+                self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+                self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+                for m in self.modules():
+                    if isinstance(m, nn.Conv2d):
+                        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                    elif isinstance(m, nn.BatchNorm2d):
+                        nn.init.constant_(m.weight, 1)
+                        nn.init.constant_(m.bias, 0)
+
+            def _make_layer(self, block, planes, blocks, stride=1):
+                downsample = None
+                if stride != 1 or self.inplanes != planes * block.expansion:
+                    downsample = nn.Sequential(
+                        conv1x1(self.inplanes, planes * block.expansion, stride),
+                        nn.BatchNorm2d(planes * block.expansion),
+                    )
+
+                layers = []
+                layers.append(block(self.inplanes, planes, stride, downsample))
+                self.inplanes = planes * block.expansion
+                for _ in range(1, blocks):
+                    layers.append(block(self.inplanes, planes))
+
+                return nn.Sequential(*layers)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                x = self.conv1(x)
+                x = self.bn1(x)
+                x = self.relu(x)
+                x = self.maxpool(x)
+
+                x = self.layer1(x)
+                x = self.layer2(x)
+                x = self.layer3(x)
+                x = self.layer4(x)
+
+                x = self.avgpool(x)
+                x = x.view(x.size(0), -1)
+                x = self.fc(x)
+
+                return x
+
+        resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
+
+        resnet18_imported = self.getExportImportCopy(resnet18)
+
+        input = torch.randn(1, 3, 224, 224, requires_grad=True)
+        output_orig = resnet18(input)
+        output_orig.sum().backward()
+        grad_orig = input.grad.clone()
+        input.grad.zero_()
+
+        output_import = resnet18_imported(input)
+        output_import.sum().backward()
+        grad_import = input.grad.clone()
+
+        self.assertEqual(output_orig, output_import)
+        self.assertEqual(grad_orig, grad_import)
+
     def test_script_module_export_tensor_type(self):
         class M(torch.jit.ScriptModule):
 
index 2c03991..e9f89a1 100644 (file)
@@ -6,10 +6,14 @@ from .. import functional as F
 from .. import init
 from .module import Module
 from .utils import _single, _pair, _triple
+from ..._jit_internal import weak_module, weak_script_method
 
 
+@weak_module
 class _ConvNd(Module):
 
+    __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias']
+
     def __init__(self, in_channels, out_channels, kernel_size, stride,
                  padding, dilation, transposed, output_padding, groups, bias):
         super(_ConvNd, self).__init__()
@@ -62,6 +66,7 @@ class _ConvNd(Module):
         return s.format(**self.__dict__)
 
 
+@weak_module
 class Conv1d(_ConvNd):
     r"""Applies a 1D convolution over an input signal composed of several input
     planes.
@@ -176,11 +181,13 @@ class Conv1d(_ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _single(0), groups, bias)
 
+    @weak_script_method
     def forward(self, input):
         return F.conv1d(input, self.weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
 
 
+@weak_module
 class Conv2d(_ConvNd):
     r"""Applies a 2D convolution over an input signal composed of several input
     planes.
@@ -297,7 +304,6 @@ class Conv2d(_ConvNd):
     .. _link:
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
     """
-
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True):
         kernel_size = _pair(kernel_size)
@@ -308,11 +314,13 @@ class Conv2d(_ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _pair(0), groups, bias)
 
+    @weak_script_method
     def forward(self, input):
         return F.conv2d(input, self.weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
 
 
+@weak_module
 class Conv3d(_ConvNd):
     r"""Applies a 3D convolution over an input signal composed of several input
     planes.
@@ -424,7 +432,6 @@ class Conv3d(_ConvNd):
     .. _link:
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
     """
-
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True):
         kernel_size = _triple(kernel_size)
@@ -435,6 +442,7 @@ class Conv3d(_ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _triple(0), groups, bias)
 
+    @weak_script_method
     def forward(self, input):
         return F.conv3d(input, self.weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
index 3bde44b..942f8c1 100644 (file)
@@ -433,6 +433,7 @@ class MaxUnpool3d(_MaxUnpoolNd):
 
 @weak_module
 class _AvgPoolNd(Module):
+    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
 
     def extra_repr(self):
         return 'kernel_size={}, stride={}, padding={}'.format(
@@ -482,9 +483,6 @@ class AvgPool1d(_AvgPoolNd):
         >>> m(torch.tensor([[[1.,2,3,4,5,6,7]]]))
         tensor([[[ 2.,  4.,  6.]]])
     """
-    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
-                     'count_include_pad']
-
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
         super(AvgPool1d, self).__init__()
@@ -552,9 +550,6 @@ class AvgPool2d(_AvgPoolNd):
         >>> input = torch.randn(20, 16, 50, 32)
         >>> output = m(input)
     """
-    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
-                     'count_include_pad']
-
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
         super(AvgPool2d, self).__init__()
@@ -628,9 +623,6 @@ class AvgPool3d(_AvgPoolNd):
         >>> input = torch.randn(20, 16, 50,44, 31)
         >>> output = m(input)
     """
-    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
-                     'count_include_pad']
-
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
         super(AvgPool3d, self).__init__()