[TOPI] Tunable Template for Conv2D HWCN on CUDA (#4168)
authorCody Hao Yu <comaniac0422@gmail.com>
Thu, 24 Oct 2019 19:03:15 +0000 (12:03 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Thu, 24 Oct 2019 19:03:15 +0000 (12:03 -0700)
* support conv2d HWCN in AutoTVM and Relay

* fix lint

* fix comments and unit tests

python/tvm/autotvm/task/task.py
python/tvm/autotvm/task/topi_integration.py
python/tvm/relay/op/nn/_nn.py
src/pass/vectorize_loop.cc
topi/python/topi/cuda/conv2d_hwcn.py
topi/python/topi/generic/nn.py
topi/python/topi/nn/conv2d.py
topi/tests/python/test_topi_conv2d_hwcn.py
topi/tests/python/test_topi_conv2d_nchw.py

index e0db275..4f3cc90 100644 (file)
@@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None):
     elif x is None:
         workload = 0
     else:
-        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
+        raise RuntimeError('Do not support type "%s" in argument. Consider to use '
                            'primitive types only' % type(x))
     return (get_func_name(topi_compute_func), ) + workload  if topi_compute_func else workload
 
index 09f08ad..ac4683d 100644 (file)
@@ -176,9 +176,12 @@ class TaskExtractEnv:
             args = deserialize_args(args)
             A, W = args[:2]
             layout = args[-2]
-            assert layout == 'NCHW', "only support NCHW currently"
+            assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
             C = topi.nn.conv2d(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_nchw([C])
+            if layout == 'NCHW':
+                s = topi.generic.schedule_conv2d_nchw([C])
+            else:
+                s = topi.generic.schedule_conv2d_hwcn([C])
             return s, [A, W, C]
 
         @register("topi_nn_depthwise_conv2d_nchw")
index b857234..0043ffa 100644 (file)
@@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
     out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                  else out_dtype)
 
-    assert layout in ["NCHW", "NHWC", "NCHW4c"]
+    assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
     (dilation_h, dilation_w) = dilation
     if dilation_h < 1 or dilation_w < 1:
         raise ValueError("dilation should be positive value")
 
     def _get_out_depth():
         weight_shape = get_const_tuple(inputs[1].shape)
-        if kernel_layout == "HWOI":
+        if kernel_layout.startswith("HW"):
             return weight_shape[2] * weight_shape[3]
         return weight_shape[0] * weight_shape[1]
 
@@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target):
     with target:
         if groups == 1 and layout == "NCHW":
             return topi.generic.schedule_conv2d_nchw(outs)
-        if groups == 1 and layout == "NCHW4c":
+        elif groups == 1 and layout == "NCHW4c":
             return topi.generic.schedule_conv2d_nchw(outs)
-        if groups == 1 and layout == "NHWC":
+        elif groups == 1 and layout == "NHWC":
             return topi.generic.schedule_conv2d_nhwc(outs)
-        if groups != 1:
+        elif groups == 1 and layout == "HWCN":
+            return topi.generic.schedule_conv2d_hwcn(outs)
+        elif groups != 1:
             # collect in_channels to distinguish depthwise and group conv2d
             op = _find_conv2d_op(outs[0].op)
             assert op is not None
index 10db37d..1870330 100644 (file)
@@ -368,7 +368,6 @@ class Vectorizer : public IRMutator {
     CHECK(!op->extent.type().is_vector());
     Expr extent = Mutate(op->extent);
     if (extent.type().is_vector()) {
-      LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
       return Scalarize(s);
     }
     Stmt body = Mutate(op->body);
@@ -386,7 +385,6 @@ class Vectorizer : public IRMutator {
     CHECK(!op->condition.type().is_vector());
     Expr condition = this->Mutate(op->condition);
     if (condition.type().is_vector()) {
-      LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
       return Scalarize(s);
     }
     Stmt then_case = this->Mutate(op->then_case);
index 5d101b9..18a624a 100644 (file)
 # pylint: disable=invalid-name, too-many-locals, too-many-statements
 """Schedule for conv2d_hwcn with auto fusion"""
 import tvm
-from .. import tag
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
 
-def schedule_conv2d_hwcn(outs):
+from .. import generic, tag
+
+
+@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
+def schedule_conv2d_hwcn(cfg, outs):
     """Schedule for conv2d_hwcn and any element-wise operations.
 
     Parameters
@@ -51,36 +56,44 @@ def schedule_conv2d_hwcn(outs):
             sch[B].set_scope("local")
             BL = B
 
-        tile = 8
-        num_thread = 8
-        block_factor = tile * num_thread
-        step = 8
-        vthread = 2
+        hi, wi, fi, ni = sch[Out].op.axis
 
-        block_x = tvm.thread_axis("blockIdx.x")
-        block_y = tvm.thread_axis("blockIdx.y")
-        block_z = tvm.thread_axis("blockIdx.z")
-        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
-        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
-        thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
-        thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
+        # Create tuning space
+        n_thread_cand = [1, 2, 4, 8, 16, 32]
+        vthread_cand = [1, 2, 4, 8]
+
+        cfg.define_split(
+            'tile_fi',
+            fi,
+            num_outputs=4,
+            filter=lambda x:
+            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+        cfg.define_split(
+            'tile_ni',
+            ni,
+            num_outputs=4,
+            filter=lambda x:
+            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+
+        if cfg.is_fallback:
+            cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
+            cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])
+
+        # Scheduling
+        step = 8
 
-        hi, wi, fi, ni = sch[Out].op.axis
         bz = sch[Out].fuse(hi, wi)
-        by, fi = sch[Out].split(fi, factor=block_factor)
-        bx, ni = sch[Out].split(ni, factor=block_factor)
-        tyz, fi = sch[Out].split(fi, nparts=vthread)
-        txz, ni = sch[Out].split(ni, nparts=vthread)
-        ty, fi = sch[Out].split(fi, nparts=num_thread)
-        tx, ni = sch[Out].split(ni, nparts=num_thread)
+        by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
+        bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
         sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
-        sch[Out].bind(bz, block_z)
-        sch[Out].bind(by, block_y)
-        sch[Out].bind(bx, block_x)
-        sch[Out].bind(tyz, thread_yz)
-        sch[Out].bind(txz, thread_xz)
-        sch[Out].bind(ty, thread_y)
-        sch[Out].bind(tx, thread_x)
+
+        sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
+        sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
+        sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
+        sch[Out].bind(tyz, tvm.thread_axis('vthread'))
+        sch[Out].bind(txz, tvm.thread_axis('vthread'))
+        sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
+        sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))
 
         # Schedule BL local write
         sch[BL].compute_at(sch[Out], tx)
@@ -98,21 +111,21 @@ def schedule_conv2d_hwcn(outs):
         sch[WL].compute_at(sch[BL], rci)
         # Schedule for A's shared memory load
         yi, xi, ci, ni = sch[AA].op.axis
-        ty, ci = sch[AA].split(ci, nparts=num_thread)
-        tx, ni = sch[AA].split(ni, nparts=num_thread)
+        ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
+        tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
         _, ni = sch[AA].split(ni, factor=4)
         sch[AA].reorder(ty, tx, yi, xi, ci, ni)
-        sch[AA].bind(ty, thread_y)
-        sch[AA].bind(tx, thread_x)
+        sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
+        sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
         sch[AA].vectorize(ni)
         # Schedule for W's shared memory load
         yi, xi, ci, fi = sch[WW].op.axis
-        ty, ci = sch[WW].split(ci, nparts=num_thread)
-        tx, fi = sch[WW].split(fi, nparts=num_thread)
+        ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
+        tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
         _, fi = sch[WW].split(fi, factor=4)
         sch[WW].reorder(ty, tx, yi, xi, ci, fi)
-        sch[WW].bind(ty, thread_y)
-        sch[WW].bind(tx, thread_x)
+        sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
+        sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
         sch[WW].vectorize(fi)
 
     scheduled_ops = []
index c2cb2b2..4043cb7 100644 (file)
@@ -35,6 +35,24 @@ def _default_schedule(outs, auto_inline):
 
 
 @tvm.target.generic_func
+def schedule_conv2d_hwcn(outs):
+    """Schedule for conv2d_hwcn
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of conv2d_hwcn
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
+@tvm.target.generic_func
 def schedule_conv2d_nchw(outs):
     """Schedule for conv2d_nchw
 
index ffae4b2..130632f 100644 (file)
@@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
     # default declaration
     if layout == 'NCHW':
         return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
-    if layout == 'HWCN':
+    elif layout == 'HWCN':
         return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
-    if layout == 'NHWC':
+    elif layout == 'NHWC':
         return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
     raise ValueError("not support this layout {} yet".format(layout))
 
index 297df82..35423a6 100644 (file)
@@ -29,24 +29,25 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
 
     A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
     W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
-    C = topi.nn.relu(B)
-    s1 = topi.cuda.schedule_conv2d_hwcn([B])
-    s2 = topi.cuda.schedule_conv2d_hwcn([C])
+    B = tvm.placeholder((1, num_filter, 1), name='bias')
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
+    b_shape = get_const_tuple(B.shape)
     dtype = A.dtype
 
     @memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=b_shape).astype(dtype)
         dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
-        b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
-        c_np = np.maximum(b_np, 0)
-        return a_np, w_np, b_np, c_np
-    a_np, w_np, b_np, c_np = get_ref_data()
+        c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
+        c2_np = c1_np + b_np
+        c3_np = np.maximum(c2_np, 0)
+        return a_np, w_np, b_np, c1_np, c2_np, c3_np
+
+    a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -54,16 +55,32 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
+            t_bias = topi.add(t_conv, B)
+            t_relu = topi.nn.relu(t_bias)
+            s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
+            s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
+            s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
-        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        func1 = tvm.build(s1, [A, W, B], device)
-        func2 = tvm.build(s2, [A, W, C], device)
-        func1(a, w, b)
-        func2(a, w, c)
-        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+        b = tvm.nd.array(b_np, ctx)
+
+        conv_out = tvm.nd.array(
+            np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
+        bias_out = tvm.nd.array(
+            np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
+        relu_out = tvm.nd.array(
+            np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
+        func1 = tvm.build(s1, [A, W, t_conv], device)
+        func2 = tvm.build(s2, [A, W, B, t_bias], device)
+        func3 = tvm.build(s3, [A, W, B, t_relu], device)
+        func1(a, w, conv_out)
+        func2(a, w, b, bias_out)
+        func3(a, w, b, relu_out)
+        tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
+        tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
+        tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
 
     for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
         check_device(device)
index ca1cef2..d7c39a9 100644 (file)
@@ -48,7 +48,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
         c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
         if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
             c_np += b_np
         if add_relu:
             c_np = np.maximum(c_np, 0)