import tvm
from tvm import te
from tvm import autotvm
-from tvm.autotvm.task.space import SplitEntity
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+ cfg["unroll_kw"] = OtherOptionEntity(False)
def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
"""Compute depthwise conv2d with NCHW layout."""
cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", out_channel, num_outputs=2)
cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+ cfg.define_knob("unroll_kw", [True, False])
# get workload and related schedule config
wkl = _get_workload(
def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
+ unroll_kw = cfg["unroll_kw"].val
+
# schedule pad
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag:
_, ic_chunk, oh, ow, ic_block = s[CC].op.axis
kh, kw = s[CC].op.reduce_axis
s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block)
+ if unroll_kw:
+ s[CC].unroll(kw)
s[CC].vectorize(ic_block)
s[CC].unroll(ow)