From 3aba2d99e14c2e2022c72f2ea37d4cf72fa8603a Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 4 Dec 2018 13:40:11 -0800 Subject: [PATCH] Add resnet test, convert more modules (#14437) Summary: This PR add resnet to test_jit and convert more nn modules, stacked on #14533 and #14715 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14437 Differential Revision: D13325871 Pulled By: wanchaol fbshipit-source-id: 6c94a988b36794a373af6541c0c262a07291f7b1 --- test/test_jit.py | 122 +++++++++++++++++++++++++++++++++++++++++++- torch/nn/modules/conv.py | 12 ++++- torch/nn/modules/pooling.py | 10 +--- 3 files changed, 132 insertions(+), 12 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 169e72c..f71ee37 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6305,7 +6305,7 @@ a") self.assertEqual(m_orig.forward(input), m_import.forward(input)) @skipIfNoTorchVision - def test_script_module_export_resnet18(self): + def test_script_module_trace_resnet18(self): x = torch.ones(1, 3, 224, 224) m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)) m_import = self.getExportImportCopy(m_orig) @@ -6323,6 +6323,126 @@ a") self.assertEqual(output_orig, output_import) self.assertEqual(grad_orig, grad_import) + @skipIfNoTorchVision + def test_script_module_script_resnet(self): + def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + class BasicBlock(torch.jit.ScriptModule): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + @torch.jit.script_method + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + class ResNet(torch.jit.ScriptModule): + __constants__ = ['layer1', 'layer2', 'layer3', 'layer4'] + + def __init__(self, block, layers, num_classes=1000): + super(ResNet, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + @torch.jit.script_method + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) + + resnet18_imported = self.getExportImportCopy(resnet18) + + input = torch.randn(1, 3, 224, 224, requires_grad=True) + output_orig = resnet18(input) + output_orig.sum().backward() + grad_orig = input.grad.clone() + input.grad.zero_() + + output_import = resnet18_imported(input) + output_import.sum().backward() + grad_import = input.grad.clone() + + self.assertEqual(output_orig, output_import) + self.assertEqual(grad_orig, grad_import) + def test_script_module_export_tensor_type(self): class M(torch.jit.ScriptModule): diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 2c03991..e9f89a1 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -6,10 +6,14 @@ from .. import functional as F from .. import init from .module import Module from .utils import _single, _pair, _triple +from ..._jit_internal import weak_module, weak_script_method +@weak_module class _ConvNd(Module): + __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias'] + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias): super(_ConvNd, self).__init__() @@ -62,6 +66,7 @@ class _ConvNd(Module): return s.format(**self.__dict__) +@weak_module class Conv1d(_ConvNd): r"""Applies a 1D convolution over an input signal composed of several input planes. @@ -176,11 +181,13 @@ class Conv1d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias) + @weak_script_method def forward(self, input): return F.conv1d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) +@weak_module class Conv2d(_ConvNd): r"""Applies a 2D convolution over an input signal composed of several input planes. @@ -297,7 +304,6 @@ class Conv2d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): kernel_size = _pair(kernel_size) @@ -308,11 +314,13 @@ class Conv2d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias) + @weak_script_method def forward(self, input): return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) +@weak_module class Conv3d(_ConvNd): r"""Applies a 3D convolution over an input signal composed of several input planes. @@ -424,7 +432,6 @@ class Conv3d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): kernel_size = _triple(kernel_size) @@ -435,6 +442,7 @@ class Conv3d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias) + @weak_script_method def forward(self, input): return F.conv3d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 3bde44b..942f8c1 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -433,6 +433,7 @@ class MaxUnpool3d(_MaxUnpoolNd): @weak_module class _AvgPoolNd(Module): + __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] def extra_repr(self): return 'kernel_size={}, stride={}, padding={}'.format( @@ -482,9 +483,6 @@ class AvgPool1d(_AvgPoolNd): >>> m(torch.tensor([[[1.,2,3,4,5,6,7]]])) tensor([[[ 2., 4., 6.]]]) """ - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', - 'count_include_pad'] - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True): super(AvgPool1d, self).__init__() @@ -552,9 +550,6 @@ class AvgPool2d(_AvgPoolNd): >>> input = torch.randn(20, 16, 50, 32) >>> output = m(input) """ - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', - 'count_include_pad'] - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True): super(AvgPool2d, self).__init__() @@ -628,9 +623,6 @@ class AvgPool3d(_AvgPoolNd): >>> input = torch.randn(20, 16, 50,44, 31) >>> output = m(input) """ - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', - 'count_include_pad'] - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True): super(AvgPool3d, self).__init__() -- 2.7.4