# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and weight_shape[1] == 1:
channel_multiplier = channels // groups
- new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
+ new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
weight = _op.transform.reshape(weight, new_weight_shape)
kernel_size = weight_shape[2:]
if isinstance(strides, _expr.Expr):
strides = _infer_shape(strides)
+ if len(kernel_size) == 1:
+ strides = (1, ) + strides
if isinstance(padding, _expr.Expr):
padding = _infer_shape(padding)
+ if len(kernel_size) == 1:
+ padding = (0, ) + padding
if isinstance(dilation, _expr.Expr):
dilation = _infer_shape(dilation)
+ if len(kernel_size) == 1:
+ dilation = (1, ) + dilation
if use_transpose:
if len(kernel_size) == 3:
data_layout = "NCHW"
kernel_layout = "OIHW"
+ if len(kernel_size) == 1:
+ data = _op.expand_dims(data, axis=2)
+ weight = _op.expand_dims(weight, axis=2)
conv_out = conv_op(data,
weight,
dilation=dilation,
groups=groups,
channels=channels,
- kernel_size=kernel_size,
+ kernel_size=[1] + kernel_size \
+ if len(kernel_size) == 1 \
+ else kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_layout="",
out_dtype="")
if use_bias:
- return _op.nn.bias_add(conv_out, bias)
+ res = _op.nn.bias_add(conv_out, bias)
else:
- return conv_out
+ res = conv_out
+ if len(kernel_size) == 1:
+ res = _op.squeeze(res, axis=[2])
+ return res
+
return _impl
def _softmax():
def test_forward_conv():
torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
+ conv1d_input_shape = [1, 3, 10]
+ conv2d_input_shape = [1, 3, 10, 10]
class Conv2D1(Module):
def __init__(self):
def forward(self, *args):
return self.softmax(self.conv(args[0]))
- input_data = torch.rand(input_shape).float()
- verify_model(Conv2D1().float().eval(), input_data=input_data)
- verify_model(Conv2D2().float().eval(), input_data=input_data)
+ class Conv1D1(Module):
+ def __init__(self):
+ super(Conv1D1, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7)
+ self.softmax = torch.nn.Softmax()
+
+ def forward(self, *args):
+ return self.softmax(self.conv(args[0]))
+
+ class Conv1D2(Module):
+ def __init__(self):
+ super(Conv1D2, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+ self.softmax = torch.nn.Softmax()
+
+ def forward(self, *args):
+ return self.softmax(self.conv(args[0]))
+
+ class Conv1D3(Module):
+ def __init__(self):
+ super(Conv1D3, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, groups=3, bias=False)
+ self.softmax = torch.nn.Softmax()
+
+ def forward(self, *args):
+ return self.softmax(self.conv(args[0]))
+
+ conv2d_input_data = torch.rand(conv2d_input_shape).float()
+ verify_model(Conv2D1().float().eval(), input_data=conv2d_input_data)
+ verify_model(Conv2D2().float().eval(), input_data=conv2d_input_data)
# depth wise conv with channel mult 2
- verify_model(Conv2D3().float().eval(), input_data=input_data)
+ verify_model(Conv2D3().float().eval(), input_data=conv2d_input_data)
# group conv
verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3),
stride=(1, 1), groups=2).eval(),
input_data=torch.randn((1, 8, 16, 16)))
+ conv1d_input_data = torch.rand(conv1d_input_shape).float()
+ verify_model(Conv1D1().float().eval(), input_data=conv1d_input_data)
+ verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data)
+ verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data)
def test_forward_conv_transpose():
torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
- input_data = torch.rand(input_shape).float()
- verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
- verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
+ conv2d_input_shape = [1, 3, 10, 10]
+ conv2d_input_data = torch.rand(conv2d_input_shape).float()
+ verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data)
+ verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data)
+
+ conv1d_input_shape = [1, 3, 10]
+ conv1d_input_data = torch.rand(conv1d_input_shape).float()
+ verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data)
+ verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)
def test_forward_threshold():