from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
-from ..util import equal_const_int, get_const_tuple, traverse_inline
+from ..util import get_const_tuple, traverse_inline
@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
-def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
+def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator.
Parameters
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
- batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
- _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
- stride_h, stride_w = strides
-
- # attach stride info to config, this is used in schedule space definition
- cfg.stride = strides
-
- # padding stage
- fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w))
- bpad_top = filter_h - 1 - fpad_top
- bpad_bottom = filter_h - 1 - fpad_bottom
- bpad_left = filter_w - 1 - fpad_left
- bpad_right = filter_w - 1 - fpad_right
-
- # padding stage
- FirstPad = nn.pad(Input,
- [0, 0, (bpad_top + stride_h - 1) // stride_h,
- (bpad_left + stride_w - 1) // stride_w],
- [0, 0, (bpad_bottom + stride_h - 1) // stride_h,
- (bpad_right + stride_w - 1) // stride_w], name='FirstPad')
-
- idxdiv = tvm.indexdiv
- idxmod = tvm.indexmod
- # remove extra padding introduced by dilatation
- border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
- border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
-
- # dilation stage
- data = FirstPad
- strides = [1, 1, stride_h, stride_w]
- n = len(data.shape)
-
- def _dilate(*indices):
- not_zero = []
- index_tuple = []
- for i in range(n):
- if not equal_const_int(strides[i], 1):
- index_tuple.append(idxdiv(indices[i], strides[i]))
- not_zero.append(idxmod(indices[i], strides[i]).equal(0))
- else:
- index_tuple.append(indices[i])
- if not_zero:
- not_zero = tvm.all(*not_zero)
- return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
- return data(*index_tuple)
-
- # convolution stage
- out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
- out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
- dc = tvm.reduce_axis((0, in_c), name='dc')
- dh = tvm.reduce_axis((0, filter_h), name='dh')
- dw = tvm.reduce_axis((0, filter_w), name='dw')
-
- Output = tvm.compute(
- (batch, out_c, out_h, out_w),
+ batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
+ _, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
+ stride_height, stride_width = stride
+ cfg.stride = stride
+ pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
+ padding, (kernel_height, kernel_width))
+
+ out_width = (inp_width - 1) * stride_width + \
+ kernel_width - pad_left - pad_right
+ pad_left = kernel_width - 1 - pad_left
+ pad_right = kernel_width - 1 - pad_right
+ dilated_width = stride_width * (inp_width - 1) + 1
+
+ out_height = (inp_height - 1) * stride_height + \
+ kernel_height - pad_top - pad_bottom
+ pad_top = kernel_height - 1 - pad_top
+ pad_bottom = kernel_height - 1 - pad_bottom
+ dilated_height = stride_height * (inp_height - 1) + 1
+
+ # compute pad
+ data = tvm.compute(
+ (batch, inp_channels,
+ pad_top + dilated_height + pad_bottom,
+ pad_left + dilated_width + pad_right),
+ lambda n, c, y, x: tvm.if_then_else(
+ tvm.all(x >= pad_left,
+ x < pad_left + dilated_width,
+ tvm.indexmod(x - pad_left, stride_width).equal(0),
+ y >= pad_top,
+ y < pad_top + dilated_height,
+ tvm.indexmod(y - pad_top, stride_height).equal(0)),
+ data[n, c,
+ tvm.indexdiv(y - pad_top, stride_height),
+ tvm.indexdiv(x - pad_left, stride_width)],
+ tvm.const(0., "float32")),
+ name='data_pad')
+
+ # compute transposed conv
+ dc = tvm.reduce_axis((0, inp_channels), name='dc')
+ dh = tvm.reduce_axis((0, kernel_height), name='dh')
+ dw = tvm.reduce_axis((0, kernel_width), name='dw')
+ data_out = tvm.compute(
+ (batch, out_channels, out_height, out_width),
lambda b, c, h, w: tvm.sum(
- _dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) *
- Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype),
+ data[b, dc, h + dh, w + dw].astype(out_dtype) *
+ kernel[dc,
+ c,
+ kernel_height - 1 - dh,
+ kernel_width - 1 - dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
- return Output
+ return data_out
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
['cuda', 'gpu'], 'direct')
else:
cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
# split F (output channel dimension)
- cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
+ if F > 1:
+ cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
# split Y (height dimension)
y_split_factor = 1
for candidate in range(5, 17):
cfg.define_knob("unroll_explicit", [0, 1])
if cfg.is_fallback:
- ko = int(kernel.shape[1])
- kh = int(kernel.shape[2])
- kw = int(kernel.shape[3])
- stride_h, stride_w = cfg.stride
- # Workaround to make CUDA compilation work. Issue #4470
- # TODO make _fallback_schedule work for all kernel/strides combinations
- # after issue #4470 is resolved
- do_fallback = True
- if ko == 1:
- do_fallback = False
- elif (kh, kw) == (1, 1):
- do_fallback = True
- elif (stride_h, stride_w) == (2, 2):
- do_fallback = False
- elif (kh, kw) == (stride_h, stride_w):
- do_fallback = False
-
- if do_fallback:
- N, F, Y, X = get_const_tuple(conv.shape)
- _fallback_schedule(N, F, Y, X)
+ N, F, Y, X = get_const_tuple(conv.shape)
+ _fallback_schedule(N, F, Y, X)
##### space definition end #####
from common import get_all_backend
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
- in_height = in_width = in_size
+ in_height, in_width = in_size
+ kernel_height, kernel_width = kernel
+ stride_height, stride_width = stride
+ pad_top, pad_left, pad_bottom, pad_right = padding
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
- W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
+ W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
- B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
+ B = topi.nn.conv2d_transpose_nchw(A, W,
+ [stride_height, stride_width],
+ [pad_top, pad_left, pad_bottom, pad_right],
+ A.dtype)
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
-
for device in get_all_backend():
check_device(device)
def test_conv2d_transpose_nchw():
- verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
- verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
- verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
- verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
- verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
-
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1))
+ verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
+ verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
+ verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
+ verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))
if __name__ == "__main__":
test_conv2d_transpose_nchw()