input_size=(),
desc='scalar',
),
+ dict(
+ fullname='Padding12_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding31_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (3, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding33_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (3, 3)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding1221_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (1, 2, 2, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding2322_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (2, 3, 2, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding3331_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'),
+ input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (3, 3, 3, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding122112_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding322112_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding332122_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+
+ dict(
+ module_name='Conv1d',
+ constructor_args=(3, 4, 2, 2, (1,), 1, 1, True, 'circular'),
+ input_size=(2, 3, 5,),
+ cudnn=True,
+ desc='stride1_pad1circular',
+ ),
+ dict(
+ module_name='Conv1d',
+ constructor_args=(3, 4, 2, 2, (2,), 1, 1, True, 'circular'),
+ input_size=(2, 3, 5,),
+ cudnn=True,
+ desc='stride1_pad2circular',
+ ),
+ dict(
+ module_name='Conv2d',
+ constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1, True, 'circular'),
+ input_size=(2, 3, 3, 3),
+ cudnn=True,
+ desc='pad2circular'
+ ),
+ dict(
+ module_name='Conv3d',
+ constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1, True, 'circular'),
+ input_size=(2, 3, 3, 3, 3),
+ cudnn=True,
+ desc='stride_pad1circular',
+ ),
]
output = output.to(dt)
return output
+
+def padding1d_circular(input, pad):
+ r""" input:
+ [[[0., 1., 2.],
+ [3., 4., 5.]]]
+ pad: (1, 2)
+ output:
+ [[[2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.]]]
+ """
+ return torch.cat([input[:, :, -pad[0]:], input,
+ input[:, :, 0:pad[1]]], dim=2)
+
+
+def padding2d_circular(input, pad):
+ r"""input:
+ [[[[0., 1., 2],
+ [3., 4., 5.]]]]
+ pad: (1, 2, 2, 1)
+ output:
+ [[[[2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.],
+ [2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.],
+ [2., 0., 1., 2., 0., 1.]]]]
+ """
+ input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2)
+ return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3)
+
+
+def padding3d_circular(input, pad):
+ r"""input:
+ [[[[[ 0., 1., 2.],
+ [ 3., 4., 5.]],
+ [[ 6., 7., 8.],
+ [ 9., 10., 11.]]]]]
+ pad: (1, 2, 2, 1, 1, 2)
+ output: [[[[[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]],
+
+ [[ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.]],
+
+ [[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]],
+
+ [[ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.]],
+
+ [[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]]]]]
+ """
+ input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2)
+ input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3)
+ return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
+
+
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
conv1d = _add_docstr(torch.conv1d, r"""
-conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
+conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
Applies a 1D convolution over an input signal composed of several input
planes.
bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None``
stride: the stride of the convolving kernel. Can be a single number or
a one-element tuple `(sW,)`. Default: 1
- padding: implicit zero paddings on both sides of the input. Can be a
+ padding: implicit paddings on both sides of the input. Can be a
single number or a one-element tuple `(padW,)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a one-element tuple `(dW,)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
the number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
Examples::
""")
conv2d = _add_docstr(torch.conv2d, r"""
-conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
+conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
Applies a 2D convolution over an input image composed of several input
planes.
bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
- padding: implicit zero paddings on both sides of the input. Can be a
+ padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
Examples::
""") # noqa: E501
conv3d = _add_docstr(torch.conv3d, r"""
-conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
+conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
Applies a 3D convolution over an input image composed of several input
planes.
bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sT, sH, sW)`. Default: 1
- padding: implicit zero paddings on both sides of the input. Can be a
+ padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padT, padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dT, dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
the number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
Examples::
input (Tensor): N-dimensional tensor
pad (tuple): m-elements tuple, where
:math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
- mode: ``'constant'``, ``'reflect'`` or ``'replicate'``.
+ mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
Default: ``'constant'``
value: fill value for ``'constant'`` padding. Default: ``0``
ret = torch._C._nn.reflection_pad1d(input, pad)
elif mode == 'replicate':
ret = torch._C._nn.replication_pad1d(input, pad)
+ elif mode == 'circular':
+ ret = pad_circular(input, pad)
else:
ret = input # TODO: remove this when jit raise supports control flow
raise NotImplementedError
ret = torch._C._nn.reflection_pad2d(input, pad)
elif mode == 'replicate':
ret = torch._C._nn.replication_pad2d(input, pad)
+ elif mode == 'circular':
+ ret = pad_circular(input, pad)
else:
ret = input # TODO: remove this when jit raise supports control flow
raise NotImplementedError
raise NotImplementedError
elif mode == 'replicate':
ret = torch._C._nn.replication_pad3d(input, pad)
+ elif mode == 'circular':
+ ret = pad_circular(input, pad)
else:
ret = input # TODO: remove this when jit raise supports control flow
raise NotImplementedError
else:
ret = input # TODO: remove this when jit raise supports control flow
raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
+
return ret
# distance
raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
ret = input # TODO: remove when jit supports exception control flow
return ret
+
+
+@weak_script
+def pad_circular(input, padding):
+ # type: (Tensor, List[int]) -> Tensor
+ """
+ Arguments
+ :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
+ :param padding: (tuple): m-elem tuple where m is the degree of convolution
+ Returns
+ :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
+ H + 2 * padding[1]], W + 2 * padding[2]))`
+ """
+
+ input = torch.cat([input, input[:, :, 0:padding[-1]]], dim=2)
+ input = torch.cat([input[:, :, -(padding[-1] + padding[-2]):-padding[-1]], input], dim=2)
+
+ if len(padding) > 2:
+ input = torch.cat([input, input[:, :, :, 0:padding[-3]]], dim=3)
+ input = torch.cat([input[:, :, :, -(padding[-3] + padding[-4]):-padding[-3]], input], dim=3)
+
+ if len(padding) > 4:
+ input = torch.cat([input, input[:, :, :, :, 0:padding[-5]]], dim=4)
+ input = torch.cat([input[:, :, :, :, -(padding[-5] + padding[-6]):-padding[-5]], input], dim=4)
+
+ return input
@weak_module
class _ConvNd(Module):
- __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias']
+ __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode']
def __init__(self, in_channels, out_channels, kernel_size, stride,
- padding, dilation, transposed, output_padding, groups, bias):
+ padding, dilation, transposed, output_padding,
+ groups, bias, padding_mode):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
+ self.padding_mode = padding_mode
if transposed:
self.weight = Parameter(torch.Tensor(
in_channels, out_channels // groups, *kernel_size))
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
dilation (int or tuple, optional): Spacing between kernel
elements. Default: 1
groups (int, optional): Number of blocked connections from input
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True):
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- False, _single(0), groups, bias)
+ False, _single(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride,
+ _single(0), self.dilation, self.groups)
return F.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True):
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- False, _pair(0), groups, bias)
+ False, _pair(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
+ (self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride,
+ _pair(0), self.dilation, self.groups)
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True):
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super(Conv3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- False, _triple(0), groups, bias)
+ False, _triple(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
+ (self.padding[1] + 1) // 2, self.padding[1] // 2,
+ (self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv3d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride, _triple(0),
+ self.dilation, self.groups)
return F.conv3d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
@weak_module
class _ConvTransposeMixin(object):
__constants__ = ['stride', 'padding', 'kernel_size', 'dim_size',
- 'output_padding', 'groups', 'dilation', 'transposed', 'bias']
+ 'output_padding', 'groups', 'dilation', 'transposed',
+ 'bias', 'padding_mode']
@weak_script_method
def forward(self, input, output_size=None):
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, output_padding=0, groups=1, bias=True, dilation=1):
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
output_padding = _single(output_padding)
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- True, output_padding, groups, bias)
+ True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
+
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, output_padding=0, groups=1, bias=True, dilation=1):
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
output_padding = _pair(output_padding)
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- True, output_padding, groups, bias)
+ True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
+
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, output_padding=0, groups=1, bias=True, dilation=1):
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
output_padding = _triple(output_padding)
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
- True, output_padding, groups, bias)
+ True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
+
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+
return F.conv_transpose3d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)