[AutoTVM][BugFix] Fix autotvm on the conv2d_nchw_winograd.mali operator (#6130)
authorYanming Wang <yanmingwang01@gmail.com>
Fri, 24 Jul 2020 23:00:09 +0000 (16:00 -0700)
committerGitHub <noreply@github.com>
Fri, 24 Jul 2020 23:00:09 +0000 (16:00 -0700)
* [AutoTVM] Fix conv2d_nchw_winograd.mali

* Fix pylint error

Co-authored-by: Yanming Wang <yanmwang@amazon.com>
python/tvm/autotvm/task/task.py
topi/python/topi/mali/conv2d.py

index b7cd6f2..3942599 100644 (file)
@@ -216,19 +216,23 @@ class TaskTemplate(object):
     def _default_func(self, *args, **kwargs):
         assert callable(self.fcompute) and callable(self.fschedule)
         out = self.fcompute(*args, **kwargs)
-        arg_bufs = [out] + self.get_inputs(out)
+        arg_bufs = [out] + self._get_inputs(out)
         s = self.fschedule([out])
         return s, arg_bufs
 
-    def get_inputs(self, out):
+    @staticmethod
+    def _get_inputs(out):
         inputs = []
         queue = [out]
+        hash_set = set()
         while queue:
             t = queue.pop(0)
             if isinstance(t.op, tensor.PlaceholderOp):
                 inputs.append(t)
             else:
-                queue.extend(t.op.input_tensors)
+                input_tensors = [t for t in t.op.input_tensors if t not in hash_set]
+                queue.extend(input_tensors)
+                hash_set.update(input_tensors)
         return inputs
 
 def _register_task_compute(name, func=None):
index ed19326..f2b26ee 100644 (file)
@@ -276,8 +276,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
             [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d')
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
-        VC = cfg['tile_k'].size[-1]
-        kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC)
+        kvshape = (alpha, alpha, CO // bna, CI, bna)
         U = tvm.te.placeholder(kvshape, kernel.dtype, name="U")
     else:
         # transform kernel