From 1ae44cf039a0cac6789ed112971e0b8898b2bef3 Mon Sep 17 00:00:00 2001 From: Alex Gladkov Date: Thu, 23 Jan 2020 19:17:45 -0800 Subject: [PATCH] Fix Tensorflow conv3d pad bug, add non-cubic data and kernel tests (#4772) --- python/tvm/relay/frontend/tensorflow.py | 4 ++-- topi/tests/python/test_topi_conv3d_ndhwc.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 408f88a..76bf058 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -498,7 +498,7 @@ def _conv3d(opname): pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]] + attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ @@ -509,7 +509,7 @@ def _conv3d(opname): attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW' use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) - channel_axis = 1 if attr['data_format'] == "NCDHW" else 3 + channel_axis = 1 if attr['data_format'] == "NCDHW" else 4 # Ignore the new attributes from TF2.0, for now. out = AttrCvt( diff --git a/topi/tests/python/test_topi_conv3d_ndhwc.py b/topi/tests/python/test_topi_conv3d_ndhwc.py index 66ccf08..242e054 100644 --- a/topi/tests/python/test_topi_conv3d_ndhwc.py +++ b/topi/tests/python/test_topi_conv3d_ndhwc.py @@ -25,10 +25,17 @@ from topi.util import get_const_tuple def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_depth = in_height = in_width = in_size + if isinstance(in_size, tuple): + in_depth, in_height, in_width = in_size + else: + in_depth = in_height = in_width = in_size + if isinstance(kernel, tuple): + kernel_depth, kernel_height, kernel_width = kernel + else: + kernel_depth = kernel_height = kernel_width = kernel A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A') - W = tvm.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W') + W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W') B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation) a_shape = get_const_tuple(A.shape) @@ -74,6 +81,12 @@ def test_conv3d_ndhwc(): # dilation = 2 verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2) + verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32, (1, 3, 3), (1, 2, 2), "SAME") + verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32, + (1, 6, 6), (1, 2, 2), (0, 2, 2)) + verify_conv3d_ndhwc(1, 4, (20, 256, 256), 8, + (1, 5, 5), (1, 2, 2), (0, 2, 2)) + if __name__ == "__main__": test_conv3d_ndhwc() -- 2.7.4