def _get_channel_axis(layout, op_name):
- if layout == "NCHW":
+ if layout in ["NCHW", "NCDHW"]:
return 1
if layout == "NHWC":
return 3
+ if layout == "NDHWC":
+ return 4
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name))
def _mx_conv(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
- if len(kernel_size) == 2:
+ if len(kernel_size) == 3:
+ return _mx_conv3d(inputs, attrs)
+ elif len(kernel_size) == 2:
return _mx_conv2d(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
- '1D or 2D kernels only are supported for operator Convolution')
+ '1D, 2D or 3D kernels only are supported for operator Convolution')
def _mx_conv1d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
return res
+def _get_mx_conv3d_attrs(attrs):
+ kernel_size = attrs.get_int_tuple("kernel")
+ data_layout = attrs.get_str("layout", "NCDHW")
+ if "kernel_layout" in attrs.attrs:
+ kernel_layout = attrs.get_str("kernel_layout")
+ else:
+ kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"
+ new_attrs = {}
+ new_attrs["channels"] = attrs.get_int("num_filter")
+ new_attrs["kernel_size"] = kernel_size
+ new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
+ new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
+ new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
+ new_attrs["groups"] = attrs.get_int("num_group", 1)
+ new_attrs["data_layout"] = data_layout
+ new_attrs["kernel_layout"] = kernel_layout
+ return new_attrs
+
+
+def _mx_conv3d(inputs, attrs):
+ kernel_size = attrs.get_int_tuple("kernel")
+ data_layout = attrs.get_str("layout", "NCDHW")
+ if len(kernel_size) != 3:
+ raise tvm.error.OpAttributeInvalid(
+ 'Only 3D kernels are supported for operator Convolution')
+
+ new_attrs = _get_mx_conv3d_attrs(attrs)
+ channel_axis = _get_channel_axis(data_layout, "conv3d")
+ use_bias = not attrs.get_bool("no_bias", False)
+ res = _op.nn.conv3d(inputs[0], inputs[1], **new_attrs)
+ if use_bias:
+ assert len(inputs) == 3
+ res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
+ return res
+
+
def _mx_conv_transpose(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
- if len(kernel_size) == 2:
+ if len(kernel_size) == 3:
+ return _mx_conv3d_transpose(inputs, attrs)
+ elif len(kernel_size) == 2:
return _mx_conv2d_transpose(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d_transpose(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
- '1D or 2D kernels only are supported for operator Convolution')
+ '1D, 2D or 3D kernels only are supported for operator Convolution')
def _mx_conv1d_transpose(inputs, attrs):
return res
+def _mx_conv3d_transpose(inputs, attrs):
+ if "target_shape" in attrs.attrs:
+ raise tvm.error.OpAttributeUnImplemented(
+ 'Attribute "target_shape" is not supported for operator Conv3D-transpose.')
+ kernel_size = attrs.get_int_tuple("kernel")
+ if len(kernel_size) != 3:
+ raise tvm.error.OpAttributeInvalid(
+ 'Non-3D kernels are not supported for operator Conv3D-transpose.')
+ data_layout = attrs.get_str("layout", "NCDHW")
+ channel_axis = _get_channel_axis(data_layout, "conv3d_transpose")
+
+ if "kernel_layout" in attrs.attrs:
+ kernel_layout = attrs.get_str("kernel_layout")
+ else:
+ kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"
+
+ new_attrs = {}
+ new_attrs["channels"] = attrs.get_int("num_filter")
+ new_attrs["kernel_size"] = kernel_size
+ new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
+ new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0, 0))
+ new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
+ new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
+ new_attrs["groups"] = attrs.get_int("num_group", 1)
+ new_attrs["data_layout"] = data_layout
+ new_attrs["kernel_layout"] = kernel_layout
+ use_bias = not attrs.get_bool("no_bias", True)
+ res = _op.nn.conv3d_transpose(inputs[0], inputs[1], **new_attrs)
+
+ if use_bias:
+ assert len(inputs) == 3
+ res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
+ return res
+
+
def _mx_pooling(inputs, attrs):
global_pool = attrs.get_bool("global_pool", False)
pool_type = attrs.get_str("pool_type")
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
is_depthwise=True)
+ verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
+ verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
+ verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2)
+ verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):