[MXNET]conv3d and conv3d_transpose addedx (#5814)
authorSamuel <siju.samuel@huawei.com>
Mon, 15 Jun 2020 12:47:53 +0000 (18:17 +0530)
committerGitHub <noreply@github.com>
Mon, 15 Jun 2020 12:47:53 +0000 (21:47 +0900)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 847623b..1d8842d 100644 (file)
@@ -87,10 +87,12 @@ def _mx_fully_connected(inputs, attrs):
 
 
 def _get_channel_axis(layout, op_name):
-    if layout == "NCHW":
+    if layout in ["NCHW", "NCDHW"]:
         return 1
     if layout == "NHWC":
         return 3
+    if layout == "NDHWC":
+        return 4
     raise tvm.error.OpAttributeInvalid(
         'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name))
 
@@ -149,13 +151,15 @@ def _mx_zeros(inputs, attrs):
 
 def _mx_conv(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
-    if len(kernel_size) == 2:
+    if len(kernel_size) == 3:
+        return _mx_conv3d(inputs, attrs)
+    elif len(kernel_size) == 2:
         return _mx_conv2d(inputs, attrs)
     elif len(kernel_size) == 1:
         return _mx_conv1d(inputs, attrs)
     else:
         raise tvm.error.OpAttributeInvalid(
-            '1D or 2D kernels only are supported for operator Convolution')
+            '1D, 2D or 3D kernels only are supported for operator Convolution')
 
 def _mx_conv1d(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
@@ -226,15 +230,53 @@ def _mx_conv2d(inputs, attrs):
     return res
 
 
+def _get_mx_conv3d_attrs(attrs):
+    kernel_size = attrs.get_int_tuple("kernel")
+    data_layout = attrs.get_str("layout", "NCDHW")
+    if "kernel_layout" in attrs.attrs:
+        kernel_layout = attrs.get_str("kernel_layout")
+    else:
+        kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"
+    new_attrs = {}
+    new_attrs["channels"] = attrs.get_int("num_filter")
+    new_attrs["kernel_size"] = kernel_size
+    new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
+    new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
+    new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
+    new_attrs["groups"] = attrs.get_int("num_group", 1)
+    new_attrs["data_layout"] = data_layout
+    new_attrs["kernel_layout"] = kernel_layout
+    return new_attrs
+
+
+def _mx_conv3d(inputs, attrs):
+    kernel_size = attrs.get_int_tuple("kernel")
+    data_layout = attrs.get_str("layout", "NCDHW")
+    if len(kernel_size) != 3:
+        raise tvm.error.OpAttributeInvalid(
+            'Only 3D kernels are supported for operator Convolution')
+
+    new_attrs = _get_mx_conv3d_attrs(attrs)
+    channel_axis = _get_channel_axis(data_layout, "conv3d")
+    use_bias = not attrs.get_bool("no_bias", False)
+    res = _op.nn.conv3d(inputs[0], inputs[1], **new_attrs)
+    if use_bias:
+        assert len(inputs) == 3
+        res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
+    return res
+
+
 def _mx_conv_transpose(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
-    if len(kernel_size) == 2:
+    if len(kernel_size) == 3:
+        return _mx_conv3d_transpose(inputs, attrs)
+    elif len(kernel_size) == 2:
         return _mx_conv2d_transpose(inputs, attrs)
     elif len(kernel_size) == 1:
         return _mx_conv1d_transpose(inputs, attrs)
     else:
         raise tvm.error.OpAttributeInvalid(
-            '1D or 2D kernels only are supported for operator Convolution')
+            '1D, 2D or 3D kernels only are supported for operator Convolution')
 
 
 def _mx_conv1d_transpose(inputs, attrs):
@@ -300,6 +342,41 @@ def _mx_conv2d_transpose(inputs, attrs):
     return res
 
 
+def _mx_conv3d_transpose(inputs, attrs):
+    if "target_shape" in attrs.attrs:
+        raise tvm.error.OpAttributeUnImplemented(
+            'Attribute "target_shape" is not supported for operator Conv3D-transpose.')
+    kernel_size = attrs.get_int_tuple("kernel")
+    if len(kernel_size) != 3:
+        raise tvm.error.OpAttributeInvalid(
+            'Non-3D kernels are not supported for operator Conv3D-transpose.')
+    data_layout = attrs.get_str("layout", "NCDHW")
+    channel_axis = _get_channel_axis(data_layout, "conv3d_transpose")
+
+    if "kernel_layout" in attrs.attrs:
+        kernel_layout = attrs.get_str("kernel_layout")
+    else:
+        kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"
+
+    new_attrs = {}
+    new_attrs["channels"] = attrs.get_int("num_filter")
+    new_attrs["kernel_size"] = kernel_size
+    new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
+    new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0, 0))
+    new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
+    new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
+    new_attrs["groups"] = attrs.get_int("num_group", 1)
+    new_attrs["data_layout"] = data_layout
+    new_attrs["kernel_layout"] = kernel_layout
+    use_bias = not attrs.get_bool("no_bias", True)
+    res = _op.nn.conv3d_transpose(inputs[0], inputs[1], **new_attrs)
+
+    if use_bias:
+        assert len(inputs) == 3
+        res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
+    return res
+
+
 def _mx_pooling(inputs, attrs):
     global_pool = attrs.get_bool("global_pool", False)
     pool_type = attrs.get_str("pool_type")
index 463b50f..00c077f 100644 (file)
@@ -1033,6 +1033,10 @@ def test_forward_convolution():
     verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
            is_depthwise=True)
+    verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
+    verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
+    verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2)
+    verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
 
 def test_forward_deconvolution():
     def verify(data_shape, kernel_size, stride, pad, num_filter):