[TOPI x86] Adding unroll_kw config option for depthwise conv2d. (#5197)
authorAnimesh Jain <anijain@umich.edu>
Fri, 3 Apr 2020 19:11:32 +0000 (12:11 -0700)
committerGitHub <noreply@github.com>
Fri, 3 Apr 2020 19:11:32 +0000 (12:11 -0700)
topi/python/topi/x86/depthwise_conv2d.py

index 5b43ced..240dee0 100644 (file)
@@ -20,7 +20,7 @@
 import tvm
 from tvm import te
 from tvm import autotvm
-from tvm.autotvm.task.space import SplitEntity
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from ..nn.pad import pad
 from ..util import get_const_tuple
 from ..nn.util import get_pad_tuple
@@ -67,6 +67,7 @@ def _fallback_schedule(cfg, wkl):
     cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
     cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
     cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+    cfg["unroll_kw"] = OtherOptionEntity(False)
 
 def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
     """Compute depthwise conv2d with NCHW layout."""
@@ -133,6 +134,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
     cfg.define_split("tile_ic", in_channel, num_outputs=2)
     cfg.define_split("tile_oc", out_channel, num_outputs=2)
     cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+    cfg.define_knob("unroll_kw", [True, False])
 
     # get workload and related schedule config
     wkl = _get_workload(
@@ -199,6 +201,8 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
 
 def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
     tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
+    unroll_kw = cfg["unroll_kw"].val
+
     # schedule pad
     if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
             and "pad" in data_vec.op.tag:
@@ -229,6 +233,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
     _, ic_chunk, oh, ow, ic_block = s[CC].op.axis
     kh, kw = s[CC].op.reduce_axis
     s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block)
+    if unroll_kw:
+        s[CC].unroll(kw)
     s[CC].vectorize(ic_block)
     s[CC].unroll(ow)