Add CUDA conv2d for NHWC layout (#4737)
authorAlex Gladkov <gladkova@lab126.com>
Mon, 20 Jan 2020 01:18:51 +0000 (17:18 -0800)
committerWuwei Lin <wuwei@apache.org>
Mon, 20 Jan 2020 01:18:51 +0000 (20:18 -0500)
topi/python/topi/cuda/conv2d.py
topi/tests/python/test_topi_conv2d_nhwc.py

index 0d4a73f..12ab77b 100644 (file)
@@ -123,6 +123,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
         return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
     if layout == 'HWCN':
         return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
+    if layout == 'NHWC':
+        return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
     raise ValueError("not support this layout {} yet".format(layout))
 
 
@@ -162,3 +164,37 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
 
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+
+@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc, ["cuda", "gpu"],
+                                ["direct"])
+def schedule_conv2d_nhwc_cuda(cfg, outs):
+    """TOPI schedule for CUDA conv2d_nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d.
+    """
+    target = tvm.target.current_target()
+    if 'cudnn' in target.libs:
+        return generic.schedule_extern(outs)
+
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'conv2d_nhwc':
+            schedule_direct_cuda(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
index 8c6e009..2a44d60 100644 (file)
@@ -29,7 +29,6 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
 
     A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
     W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    B = topi.nn.conv2d_nhwc(A, W, stride, padding, dilation)
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -50,6 +49,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
+            B = topi.nn.conv2d(A, W, (stride, stride), padding,
+                               (dilation, dilation), layout='NHWC', out_dtype=dtype)
             s = topi.generic.schedule_conv2d_nhwc([B])
         ctx = tvm.context(device, 0)
         a = tvm.nd.array(a_np, ctx)
@@ -59,7 +60,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
         func(a, w, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm']:
+    for device in ['llvm', 'cuda']:
         check_device(device)