Support x86 dilation conv2d and improve multi-batch conv2d (#3308)
authorYao Wang <kevinthesunwy@gmail.com>
Mon, 10 Jun 2019 20:08:28 +0000 (13:08 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 10 Jun 2019 20:08:28 +0000 (13:08 -0700)
* Support x86 dilation conv2d and improve multi-batch conv2d

* Fix lint

topi/python/topi/x86/conv2d.py
topi/python/topi/x86/conv2d_avx_1x1.py
topi/python/topi/x86/conv2d_avx_common.py
topi/tests/python/test_topi_conv2d_NCHWc.py
tutorials/frontend/deploy_ssd_gluoncv.py

index 82d1caa..d51945e 100644 (file)
@@ -500,8 +500,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
     # we keep them for debug convenience when dumping autotvm workload
     HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
-    assert (dh, dw) == (1, 1), "Does not support dilation"
+    dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
+        else (dilation, dilation)
 
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
@@ -514,6 +514,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
     num_filter = oc_chunk * oc_bn
     groups = ic_chunk // ic_chunk_group
 
+    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
+
     if cfg.is_fallback:
         _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
                             tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
@@ -521,8 +524,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
                             strides, padding, out_dtype)
 
     # output shape
-    out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
-    out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1
+    out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
+    out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
     oshape = (n, oc_chunk, out_height, out_width, oc_bn)
 
     # DOPAD
@@ -548,8 +551,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
         ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
         ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
         return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                           tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
-                                            ic_f_inner * n_elems +  ic_s_inner]
+                           tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh*dilation_h,
+                                            ow*WSTR+kw*dilation_w,
+                                            ic_f_inner * n_elems + ic_s_inner]
                                    .astype(out_dtype) *
                                    kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                           oc_block, ic_s_inner].astype(out_dtype),
@@ -575,7 +579,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
 
     # else: fp implementation
     return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                       tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
+                       tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh*dilation_h,
+                                        ow*WSTR+kw*dilation_w,
                                         ic%ic_bn].astype(out_dtype) *
                                kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block],
                                axis=[ic, kh, kw]),
index 4994d45..256cea5 100644 (file)
@@ -134,7 +134,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
     A = data
     if isinstance(s[A].op, tvm.tensor.ComputeOp):
         batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(ic_chunk, ih)
+        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
         s[A].parallel(parallel_axis)
 
     C, O = conv_out, last
@@ -146,7 +146,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
     s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
     s[C].vectorize(oc_block)
 
-    parallel_axis = s[C].fuse(oc_chunk, oh_outer)
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
     s[CC].compute_at(s[C], parallel_axis)
     if C == O:
         s[C].parallel(parallel_axis)
@@ -172,7 +172,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
         ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
         s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
 
-        parallel_axis = s[O].fuse(oc_chunk, oh_outer)
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
         s[C].compute_at(s[O], parallel_axis)
         s[O].vectorize(oc_block)
         s[O].parallel(parallel_axis)
@@ -203,7 +203,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
     A = data
     if isinstance(s[A].op, tvm.tensor.ComputeOp):
         batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(ic_chunk, ih)
+        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
         s[A].parallel(parallel_axis)
 
     C, O = conv_out, last
@@ -215,7 +215,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
     s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
     s[C].vectorize(oc_block)
 
-    parallel_axis = s[C].fuse(oc_chunk, oh_outer)
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
     s[CC].compute_at(s[C], parallel_axis)
     if C == O:
         s[C].parallel(parallel_axis)
@@ -246,7 +246,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
         ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
         s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
 
-        parallel_axis = s[O].fuse(oc_chunk, oh_outer)
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
         s[C].compute_at(s[O], parallel_axis)
         s[O].vectorize(oc_block)
         s[O].parallel(parallel_axis)
index 3ab68d7..44867c9 100644 (file)
@@ -143,10 +143,10 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
     C, O = conv_out, last
     CC = s.cache_write(C, 'global')
 
-    _, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
     s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-    parallel_axis = s[C].fuse(oc_chunk, oh)
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh)
     s[C].vectorize(oc_block)
     if C == O:
         s[C].parallel(parallel_axis)
@@ -171,7 +171,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
         batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
         ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
         s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(oc_chunk, oh)
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
         s[C].compute_at(s[O], parallel_axis)
         s[O].vectorize(oc_block)
         s[O].parallel(parallel_axis)
@@ -214,10 +214,10 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
     C, O = conv_out, last
     CC = s.cache_write(C, 'global')
 
-    _, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
     s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-    parallel_axis = s[C].fuse(oc_chunk, oh)
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh)
     s[C].vectorize(oc_block)
     if C == O:
         s[C].parallel(parallel_axis)
@@ -251,7 +251,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
         batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
         ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
         s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(oc_chunk, oh)
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
         s[C].compute_at(s[O], parallel_axis)
         s[O].vectorize(oc_block)
         s[O].parallel(parallel_axis)
index 5aca0c0..26b4642 100644 (file)
@@ -49,7 +49,6 @@ def _transform_bias(bias, bn):
 
 def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
                         padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
-    assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
     print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
           (batch, in_channel, in_size, num_filter, kernel, stride, padding))
 
@@ -79,7 +78,8 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
         a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
         w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
         b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
-        c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
         if add_bias:
             c_np += b_np
         if add_relu:
@@ -149,8 +149,8 @@ def test_conv2d_NCHWc():
     verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True)
     verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
 
-    # disable dilation test since it is not supported by NCHW[x]c conv for now.
-    verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
+    # dilation
+    verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2)
 
     # batch size
     verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
index f536679..829957b 100644 (file)
@@ -47,9 +47,6 @@ from gluoncv import model_zoo, data, utils
 #
 #   To get best performance fo SSD on Intel graphics,
 #   change target argument to 'opencl -device=intel_graphics'
-#
-#   SSD with VGG as body network is not supported yet since
-#   x86 conv2d schedule doesn't support dilation.
 
 supported_model = [
     'ssd_512_resnet50_v1_voc',
@@ -57,6 +54,8 @@ supported_model = [
     'ssd_512_resnet101_v2_voc',
     'ssd_512_mobilenet1.0_voc',
     'ssd_512_mobilenet1.0_coco',
+    'ssd_300_vgg16_atrous_voc'
+    'ssd_512_vgg16_atrous_coco',
 ]
 
 model_name = supported_model[0]