schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
+schedule_concatenate = _reg.schedule_concatenate
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective)
-_reg.register_schedule("concatenate", schedule_injective)
+_reg.register_schedule("concatenate", schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
with target:
return topi.generic.schedule_injective(outputs)
+
+def schedule_concatenate(attrs, outputs, target):
+ """Generic schedule for concatinate."""
+ with target:
+ return topi.generic.schedule_concatenate(outputs)
+
+
__DEBUG_COUNTER__ = 0
def debug(expr, debug_func=None):
elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0])
return s
+
+@generic.schedule_concatenate.register(["arm_cpu"])
+def schedule_concatenate(outs):
+ """Schedule for concatenate op.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of reduce in the format
+ of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+ s = tvm.create_schedule([x.op for x in outs])
+ x = outs[0]
+ tvm.schedule.AutoInlineInjective(s)
+ if len(s[x].op.axis) >= 4:
+ fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
+ s[x].parallel(fused)
+ elif len(s[x].op.axis) >= 3:
+ fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
+ s[x].parallel(fused)
+ elif len(s[x].op.axis) >= 2:
+ s[x].parallel(s[x].op.axis[0])
+ return s
return
print("Running on target: %s" % device)
with tvm.target.create(device):
- s = topi.generic.schedule_injective(out_tensor)
+ s = topi.generic.schedule_concatenate(out_tensor)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
(12, 6, 7, 3),
(8, 6, 7, 3),
(2, 6, 7, 3)], 0)
+ verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
def test_stack():