[Relay][Legalize] Legalize conv2d_transpose for NHWC (#4399)
authorAlexander Pivovarov <pivovaa@amazon.com>
Sat, 23 Nov 2019 05:59:15 +0000 (21:59 -0800)
committerYao Wang <kevinthesunwy@gmail.com>
Sat, 23 Nov 2019 05:59:15 +0000 (21:59 -0800)
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/op_attrs.py
tests/python/relay/test_op_level2.py
topi/python/topi/nn/conv2d_transpose.py
topi/python/topi/testing/__init__.py
topi/python/topi/testing/conv2d_transpose_python.py [moved from topi/python/topi/testing/conv2d_transpose_nchw_python.py with 65% similarity]

index 54f13c6..cb2ecc9 100644 (file)
@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
         return topi.generic.schedule_conv2d_transpose_nchw(outs)
 
 
+@reg.register_legalize("nn.conv2d_transpose")
+def legalize_conv2d_transpose(attrs, inputs, types):
+    """Legalize conv2d_transpose op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current Transposed convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
+
 reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # bias_add
index 2de0257..35b2c05 100644 (file)
@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
 @register_relay_attr_node
 class BinaryDenseAttrs(Attrs):
     """Attributes used in bitserial dense operators"""
+
+
+@register_relay_attr_node
+class Conv2DTransposeAttrs(Attrs):
+    """Attributes used in Transposed Conv2D operators"""
index 08c5eb0..b54efaa 100644 (file)
@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
         (10, 15, 3, 3), "float32")
 
     # infer by shape of w, mixed precision
-    n, c, h, w = tvm.var("n"), 10, 10, 12
-    x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
+    n, h, w, c = tvm.var("n"), 10, 10, 12
+    x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
     w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
     y = relay.nn.conv2d_transpose(x, w,
                                   output_padding=(1, 1),
@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
         (n, 15, 15, 11), "float32")
 
 
-def test_conv2d_transpose_run():
+def test_conv2d_transpose_nchw_run():
     dshape = (1, 3, 18, 18)
     kshape = (3, 10, 3, 3)
     oshape = (1, 10, 37, 37)
@@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
         tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
 
+def test_conv2d_transpose_nhwc_run():
+    dshape_nhwc = (1, 18, 18, 3)
+    kshape_hwoi = (3, 3, 10, 3)
+    oshape_nhwc = (1, 37, 37, 10)
+    x = relay.var("x", shape=dshape_nhwc)
+    w = relay.var("w")
+    # kshape and kernel_layout should have swapped IO.
+    # kshape is HWOI and kernel_layout is HWIO
+    y = relay.nn.conv2d_transpose(x, w,
+                                  channels=10, kernel_size=(3, 3), strides=(2, 2),
+                                  padding=(1, 1), output_padding=(2, 2),
+                                  data_layout="NHWC", kernel_layout="HWIO")
+    func = relay.Function([x, w], y)
+    dtype = "float32"
+    data = np.random.uniform(size=dshape_nhwc).astype(dtype)
+    kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
+    # use true kshape layout here - HWOI
+    c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
+    d_np = np.zeros(shape=oshape_nhwc)
+    d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
+    ref_res = d_np
+
+    for target, ctx in ctx_list():
+        intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
+        op_res1 = intrp1.evaluate(func)(data, kernel)
+        tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
 
 def test_upsampling_infer_type():
     n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
@@ -819,7 +846,8 @@ if __name__ == "__main__":
     test_pad_infer_type()
     test_pad_run()
     test_conv2d_transpose_infer_type()
-    test_conv2d_transpose_run()
+    test_conv2d_transpose_nchw_run()
+    test_conv2d_transpose_nhwc_run()
     test_conv2d_run()
     test_conv2d_winograd()
     test_bitserial_conv2d_infer_type()
index 2f3e323..a240b68 100644 (file)
@@ -18,6 +18,7 @@
 """Transposed 2D convolution operators (sometimes called Deconvolution)."""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import relay
 from .dilate import dilate
 from .pad import pad
 from .util import get_pad_tuple
@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
             axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
 
     return Output
+
+
+@tvm.target.generic_func
+def conv2d_transpose_legalize(attrs, inputs, types):
+    """Legalizes Transposed 2D convolution op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current Transposed 2D convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    if attrs['data_layout'] == 'NHWC':
+        data, kernel = inputs
+        kernel_layout = attrs['kernel_layout']
+        # Convert Kernel layout to IOHW
+        # kernel_layout is different from input kernel layout - IO is swapped
+        if kernel_layout == 'HWIO':
+            # input kernel layout is swapped to HWOI
+            # output kernel layout will be IOHW
+            kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
+        elif kernel_layout == 'HWOI':
+            # input kernel layout is swapped to HWIO
+            # output kernel layout will be IOHW
+            kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
+        elif kernel_layout == 'IOHW':
+            # input kernel layout is swapped to OIHW
+            # output kernel layout will be IOHW
+            kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
+        elif kernel_layout == 'OIHW':
+            # input kernel layout is swapped to IOHW
+            # output kernel layout will be IOHW
+            pass
+        else:
+            # Skip legalize. Let relay.nn.conv2d_transpose to handle the case
+            return None
+
+        # Set new attrs for conv2d_transpose.
+        new_attrs = {k: attrs[k] for k in attrs.keys()}
+        new_attrs['data_layout'] = 'NCHW'
+        # layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
+        new_attrs['kernel_layout'] = 'OIHW'
+
+        # Convert data to NCHW.
+        data = relay.transpose(data, axes=(0, 3, 1, 2))
+        deconv = relay.nn.conv2d_transpose(data, kernel, **new_attrs)
+        # Convert back to original NHWC layout.
+        out = relay.transpose(deconv, axes=(0, 2, 3, 1))
+        return out
+
+    return None
index 6c4a0e3..6d73400 100644 (file)
@@ -24,7 +24,7 @@ from __future__ import absolute_import as _abs
 from .conv2d_hwcn_python import conv2d_hwcn_python
 from .conv2d_nchw_python import conv2d_nchw_python
 from .conv2d_nhwc_python import conv2d_nhwc_python
-from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
+from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
 from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
                     padded_a_np[n, c], w_np[c, f], mode='valid')
                 b_np[n, f] += out
     return b_np
+
+
+def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
+    """Transposed convolution operator in NHWC layout.
+
+    Parameters
+    ----------
+    a_nhwc : numpy.ndarray
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    weight : numpy.ndarray
+        4-D in formats HWIO, HWOI, OIHW or IOHW
+
+    weight_format : str
+        ['HWIO', 'HWOI', 'OIHW', 'IOHW']
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    assert a_nhwc.ndim == 4, "a_nhwc number of dimensions should be 4"
+    assert weight.ndim == 4, "weight number of dimensions should be 4"
+
+    a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2))
+
+    # conv2d_transpose_nchw_python needs kernel layout to be IOHW
+    if weight_format == 'HWIO':
+        w_iohw = np.transpose(weight, (2, 3, 0, 1))
+    elif weight_format == 'HWOI':
+        w_iohw = np.transpose(weight, (3, 2, 0, 1))
+    elif weight_format == 'OIHW':
+        w_iohw = np.transpose(weight, (1, 0, 2, 3))
+    elif weight_format == 'IOHW':
+        w_iohw = weight
+    else:
+        raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')
+
+    res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
+    res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
+    return res_nhwc