self.assertEqual(m_orig.forward(input), m_import.forward(input))
@skipIfNoTorchVision
- def test_script_module_export_resnet18(self):
+ def test_script_module_trace_resnet18(self):
x = torch.ones(1, 3, 224, 224)
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
m_import = self.getExportImportCopy(m_orig)
self.assertEqual(output_orig, output_import)
self.assertEqual(grad_orig, grad_import)
+ @skipIfNoTorchVision
+ def test_script_module_script_resnet(self):
+ def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+ def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+ class BasicBlock(torch.jit.ScriptModule):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ @torch.jit.script_method
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+ class ResNet(torch.jit.ScriptModule):
+ __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
+
+ def __init__(self, block, layers, num_classes=1000):
+ super(ResNet, self).__init__()
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ @torch.jit.script_method
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+ resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
+
+ resnet18_imported = self.getExportImportCopy(resnet18)
+
+ input = torch.randn(1, 3, 224, 224, requires_grad=True)
+ output_orig = resnet18(input)
+ output_orig.sum().backward()
+ grad_orig = input.grad.clone()
+ input.grad.zero_()
+
+ output_import = resnet18_imported(input)
+ output_import.sum().backward()
+ grad_import = input.grad.clone()
+
+ self.assertEqual(output_orig, output_import)
+ self.assertEqual(grad_orig, grad_import)
+
def test_script_module_export_tensor_type(self):
class M(torch.jit.ScriptModule):
from .. import init
from .module import Module
from .utils import _single, _pair, _triple
+from ..._jit_internal import weak_module, weak_script_method
+@weak_module
class _ConvNd(Module):
+ __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias']
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
return s.format(**self.__dict__)
+@weak_module
class Conv1d(_ConvNd):
r"""Applies a 1D convolution over an input signal composed of several input
planes.
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)
+ @weak_script_method
def forward(self, input):
return F.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
+@weak_module
class Conv2d(_ConvNd):
r"""Applies a 2D convolution over an input signal composed of several input
planes.
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
+ @weak_script_method
def forward(self, input):
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
+@weak_module
class Conv3d(_ConvNd):
r"""Applies a 3D convolution over an input signal composed of several input
planes.
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _triple(kernel_size)
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias)
+ @weak_script_method
def forward(self, input):
return F.conv3d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
@weak_module
class _AvgPoolNd(Module):
+ __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
def extra_repr(self):
return 'kernel_size={}, stride={}, padding={}'.format(
>>> m(torch.tensor([[[1.,2,3,4,5,6,7]]]))
tensor([[[ 2., 4., 6.]]])
"""
- __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
- 'count_include_pad']
-
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True):
super(AvgPool1d, self).__init__()
>>> input = torch.randn(20, 16, 50, 32)
>>> output = m(input)
"""
- __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
- 'count_include_pad']
-
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True):
super(AvgPool2d, self).__init__()
>>> input = torch.randn(20, 16, 50,44, 31)
>>> output = m(input)
"""
- __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
- 'count_include_pad']
-
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True):
super(AvgPool3d, self).__init__()