Change the meaning of conv3d_transpose output_padding to match conv{1,2}d_transpose...
authorabergeron <abergeron@gmail.com>
Wed, 5 Aug 2020 05:15:21 +0000 (01:15 -0400)
committerGitHub <noreply@github.com>
Wed, 5 Aug 2020 05:15:21 +0000 (01:15 -0400)
* Change the meaning of output_padding to correspond to conv{1,2}d_transpose

* Fix long lines

* Fix the relay test

* Add missing doc.

* fix size ordering problem

python/tvm/relay/op/strategy/generic.py
python/tvm/topi/cuda/conv3d_transpose_ncdhw.py
python/tvm/topi/nn/conv3d_transpose.py
python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py
python/tvm/topi/x86/conv3d_transpose.py
tests/python/relay/test_op_level2.py
tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py

index bc54577..69c9bd7 100644 (file)
@@ -364,15 +364,12 @@ def wrap_compute_conv3d_transpose(topi_compute):
         """Compute definition of conv3d_transpose"""
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
+        output_padding = get_const_tuple(attrs.output_padding)
         out_dtype = attrs.out_dtype
         out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                      else out_dtype)
         out = topi_compute(
-            inputs[0], inputs[1], strides, padding, out_dtype)
-        output_padding = get_const_tuple(attrs.output_padding)
-        out = topi.nn.pad(out,
-                          [0, 0, 0, 0, 0],
-                          [0, 0, output_padding[0], output_padding[1], output_padding[2]])
+            inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
     return compute_conv3d_transpose
 
index bcad3e4..d6ca9bc 100644 (file)
@@ -26,7 +26,8 @@ from .conv3d_direct import schedule_direct_conv3d_cuda
 
 
 @autotvm.register_topi_compute("conv3d_transpose_ncdhw.cuda")
-def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
+def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype,
+                           output_padding):
     """Transposed 3D convolution ncdhw forward operator.
 
     Parameters
@@ -43,6 +44,8 @@ def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
         Padding size, or ['VALID', 'SAME']
     out_dtype: str
         The output type. This is used in mixed precision
+    output_padding : tuple of three ints
+        Used to disambiguate output shape
 
     Returns
     -------
@@ -52,24 +55,27 @@ def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
     batch, inp_channels, inp_depth, inp_height, inp_width = get_const_tuple(data.shape)
     _, out_channels, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape)
     stride_depth, stride_height, stride_width = stride
+    outpad_depth, outpad_height, outpad_width = output_padding
+    assert (outpad_height < stride_height and outpad_width < stride_width and
+            outpad_depth < stride_depth)
     cfg.stride = stride
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = nn.get_pad_tuple3d(
         padding, (kernel_depth, kernel_height, kernel_width))
 
     out_depth = (inp_depth - 1) * stride_depth + \
-        kernel_depth - pad_front - pad_back
+        kernel_depth - pad_front - pad_back + outpad_depth
     pad_front = kernel_depth - 1 - pad_front
     pad_back = kernel_depth - 1 - pad_back
     dilated_depth = stride_depth * (inp_depth - 1) + 1
 
     out_width = (inp_width - 1) * stride_width + \
-        kernel_width - pad_left - pad_right
+        kernel_width - pad_left - pad_right + outpad_width
     pad_left = kernel_width - 1 - pad_left
     pad_right = kernel_width - 1 - pad_right
     dilated_width = stride_width * (inp_width - 1) + 1
 
     out_height = (inp_height - 1) * stride_height + \
-        kernel_height - pad_top - pad_bottom
+        kernel_height - pad_top - pad_bottom + outpad_height
     pad_top = kernel_height - 1 - pad_top
     pad_bottom = kernel_height - 1 - pad_bottom
     dilated_height = stride_height * (inp_height - 1) + 1
index 29b9e53..cd57264 100644 (file)
@@ -25,7 +25,7 @@ from .util import get_pad_tuple3d
 from ..util import simplify
 
 
-def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype):
+def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype, output_padding):
     """Transposed 3D convolution ncdhw forward operator.
 
     Parameters
@@ -45,31 +45,37 @@ def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype):
     out_dtype : str
         The output data type. This is used for mixed precision.
 
+    output_padding : tuple of ints
+        Used to get the right output shape for gradients
+
     Returns
     -------
     Output : tvm.te.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    return declaration_conv3d_transpose_impl(Input, Filter, strides, padding, out_dtype)
+    return declaration_conv3d_transpose_impl(Input, Filter, strides, padding,
+                                             out_dtype, output_padding)
 
 
-def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype):
+def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding):
     """Preprocess data and kernel to make the compute pattern
        of conv3d_transpose the same as conv3d"""
     batch, in_c, in_d, in_h, in_w = data.shape
     _, out_c, filter_d, filter_h, filter_w = kernel.shape
     stride_d, stride_h, stride_w = strides
+    opad_d, opad_h, opad_w = output_padding
+    assert opad_d < stride_d and opad_h < stride_h and opad_w < stride_w
     # dilate data
     data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name='data_dilate')
     # pad data
     fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d(
         padding, (filter_d, filter_h, filter_w))
     bpad_front = filter_d - 1 - fpad_front
-    bpad_back = filter_d - 1 - fpad_back
+    bpad_back = filter_d - 1 - fpad_back + opad_d
     bpad_top = filter_h - 1 - fpad_top
-    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
+    bpad_right = filter_w - 1 - fpad_right + opad_w
     data_pad = pad(data_dilate, \
                    [0, 0, bpad_front, bpad_top, bpad_left], \
                    [0, 0, bpad_back, bpad_bottom, bpad_right], \
@@ -82,10 +88,10 @@ def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
     return data_pad, kernel_transform
 
 
-def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype):
+def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding):
     """Implementation of conv3d transpose"""
     data_pad, kernel_transform = \
-        conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
+        conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding)
     batch, in_c, in_d, in_h, in_w = data_pad.shape
     out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape
     stride_d, stride_h, stride_w = strides
index 8d03397..711f04b 100644 (file)
@@ -21,7 +21,7 @@ import tvm.topi.testing
 from tvm.topi.nn.util import get_pad_tuple3d
 
 
-def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
+def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
     """Transposed 3d convolution operator in NCDHW layout.
 
     Parameters
@@ -38,6 +38,9 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
     padding : int or str
         Padding size
 
+    output_padding : int or list/tuple of three ints
+        Used to disambiguate output shape.
+
     Returns
     -------
     b_np : np.ndarray
@@ -49,6 +52,11 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
         stride_d = stride_h = stride_w = stride
     else:
         stride_d, stride_h, stride_w = stride
+    if isinstance(output_padding, int):
+        opad_d = opad_h = opad_w = output_padding
+    else:
+        opad_d, opad_h, opad_w = output_padding
+    assert opad_d < stride_d and opad_h < stride_h and opad_w < stride_w
 
     # dilate stage
     dilated_a_np = tvm.topi.testing.dilate_python(a_np, [1, 1, stride_d, stride_h, stride_w])
@@ -58,11 +66,11 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
         padding, (filter_d, filter_h, filter_w))
 
     bpad_front = filter_d - 1 - fpad_front
-    bpad_back = filter_d - 1 - fpad_back
+    bpad_back = filter_d - 1 - fpad_back + opad_d
     bpad_top = filter_h - 1 - fpad_top
-    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
+    bpad_right = filter_w - 1 - fpad_right + opad_w
 
     padded_a_np = np.zeros((batch,
                             in_c,
@@ -70,7 +78,7 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
                             dilated_a_np.shape[3]+bpad_top+bpad_bottom,
                             dilated_a_np.shape[4]+bpad_left+bpad_right))
 
-    padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_back,
+    padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_front,
                 bpad_top:dilated_a_np.shape[3]+bpad_top,
                 bpad_left:dilated_a_np.shape[4]+bpad_left] = dilated_a_np
 
index ad035d3..698702a 100644 (file)
@@ -23,9 +23,10 @@ from ..util import traverse_inline
 from .. import nn
 from .conv3d import conv3d_ncdhw, schedule_conv3d_ncdhw
 
-def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype):
+def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding):
     data_pad, kernel_transform = \
-        nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
+        nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding,
+                                             out_dtype, output_padding)
 
     # reuse conv3d_ncdhw implementation
     return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1),
index b26d6e4..6258d8c 100644 (file)
@@ -663,8 +663,7 @@ def test_conv3d_transpose_ncdhw_run():
 
     data = np.random.uniform(size=dshape).astype(dtype)
     kernel = np.random.uniform(size=kshape).astype(dtype)
-
-    ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1)
+    ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0)
 
     for target, ctx in ctx_list():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
index 6ee386e..8e98120 100644 (file)
@@ -32,7 +32,7 @@ _conv3d_transpose_ncdhw_implement = {
     "gpu": (topi.cuda.conv3d_transpose_ncdhw, topi.cuda.schedule_conv3d_transpose_ncdhw),
 }
 
-def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
     in_depth, in_height, in_width = in_size
     kernel_depth, kernel_height, kernel_width = kernel
     stride_depth, stride_height, stride_width = stride
@@ -49,7 +49,7 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding)
+        b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding)
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -66,7 +66,7 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
             B = fcompute(A, W,
                          [stride_depth, stride_height, stride_width],
                          [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right],
-                         A.dtype)
+                         A.dtype, output_padding)
             C = topi.nn.relu(B)
             s1 = fschedule([B])
             s2 = fschedule([C])
@@ -86,15 +86,18 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
 
 
 def test_conv3d_transpose_ncdhw():
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1,  (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1,  (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (2, 2, 2))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (1, 0, 2))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
+    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1))
 
 if __name__ == "__main__":
     test_conv3d_transpose_ncdhw()