def _default_func(self, *args, **kwargs):
assert callable(self.fcompute) and callable(self.fschedule)
out = self.fcompute(*args, **kwargs)
- arg_bufs = [out] + self.get_inputs(out)
+ arg_bufs = [out] + self._get_inputs(out)
s = self.fschedule([out])
return s, arg_bufs
- def get_inputs(self, out):
+ @staticmethod
+ def _get_inputs(out):
inputs = []
queue = [out]
+ hash_set = set()
while queue:
t = queue.pop(0)
if isinstance(t.op, tensor.PlaceholderOp):
inputs.append(t)
else:
- queue.extend(t.op.input_tensors)
+ input_tensors = [t for t in t.op.input_tensors if t not in hash_set]
+ queue.extend(input_tensors)
+ hash_set.update(input_tensors)
return inputs
def _register_task_compute(name, func=None):
[(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d')
if autotvm.GLOBAL_SCOPE.in_tuning:
- VC = cfg['tile_k'].size[-1]
- kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC)
+ kvshape = (alpha, alpha, CO // bna, CI, bna)
U = tvm.te.placeholder(kvshape, kernel.dtype, name="U")
else:
# transform kernel