From: Yanming Wang Date: Fri, 24 Jul 2020 23:00:09 +0000 (-0700) Subject: [AutoTVM][BugFix] Fix autotvm on the conv2d_nchw_winograd.mali operator (#6130) X-Git-Tag: upstream/0.7.0~358 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=659f12bf0755ebca78ede0f931b3e01880d67c15;p=platform%2Fupstream%2Ftvm.git [AutoTVM][BugFix] Fix autotvm on the conv2d_nchw_winograd.mali operator (#6130) * [AutoTVM] Fix conv2d_nchw_winograd.mali * Fix pylint error Co-authored-by: Yanming Wang --- diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index b7cd6f2..3942599 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -216,19 +216,23 @@ class TaskTemplate(object): def _default_func(self, *args, **kwargs): assert callable(self.fcompute) and callable(self.fschedule) out = self.fcompute(*args, **kwargs) - arg_bufs = [out] + self.get_inputs(out) + arg_bufs = [out] + self._get_inputs(out) s = self.fschedule([out]) return s, arg_bufs - def get_inputs(self, out): + @staticmethod + def _get_inputs(out): inputs = [] queue = [out] + hash_set = set() while queue: t = queue.pop(0) if isinstance(t.op, tensor.PlaceholderOp): inputs.append(t) else: - queue.extend(t.op.input_tensors) + input_tensors = [t for t in t.op.input_tensors if t not in hash_set] + queue.extend(input_tensors) + hash_set.update(input_tensors) return inputs def _register_task_compute(name, func=None): diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ed19326..f2b26ee 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -276,8 +276,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d') if autotvm.GLOBAL_SCOPE.in_tuning: - VC = cfg['tile_k'].size[-1] - kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + kvshape = (alpha, alpha, CO // bna, CI, bna) U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: # transform kernel