[ARM] Fix concat (#3061)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Fri, 17 May 2019 17:29:07 +0000 (10:29 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 17 May 2019 17:29:07 +0000 (10:29 -0700)
python/tvm/relay/op/_transform.py
python/tvm/relay/op/op.py
topi/python/topi/arm_cpu/injective.py
topi/tests/python/test_topi_transform.py

index 2eec6d0..95fb2ad 100644 (file)
@@ -23,6 +23,7 @@ from .op import schedule_injective, OpPattern
 
 schedule_injective = _reg.schedule_injective
 schedule_broadcast = _reg.schedule_injective
+schedule_concatenate = _reg.schedule_concatenate
 
 
 _reg.register_schedule("collapse_sum_like", _schedule_reduce)
@@ -46,7 +47,7 @@ _reg.register_schedule("take", schedule_injective)
 _reg.register_schedule("transpose", schedule_injective)
 _reg.register_schedule("where", schedule_broadcast)
 _reg.register_schedule("stack", schedule_injective)
-_reg.register_schedule("concatenate", schedule_injective)
+_reg.register_schedule("concatenate", schedule_concatenate)
 _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
 _reg.register_schedule("gather_nd", schedule_injective)
 
index 6ba2079..906bf25 100644 (file)
@@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target):
     with target:
         return topi.generic.schedule_injective(outputs)
 
+
+def schedule_concatenate(attrs, outputs, target):
+    """Generic schedule for concatinate."""
+    with target:
+        return topi.generic.schedule_concatenate(outputs)
+
+
 __DEBUG_COUNTER__ = 0
 
 def debug(expr, debug_func=None):
index 9afdc32..028558f 100644 (file)
@@ -51,3 +51,32 @@ def schedule_injective(outs):
     elif len(s[x].op.axis) >= 2:
         s[x].parallel(s[x].op.axis[0])
     return s
+
+@generic.schedule_concatenate.register(["arm_cpu"])
+def schedule_concatenate(outs):
+    """Schedule for concatenate op.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of reduce in the format
+          of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    x = outs[0]
+    tvm.schedule.AutoInlineInjective(s)
+    if len(s[x].op.axis) >= 4:
+        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
+        s[x].parallel(fused)
+    elif len(s[x].op.axis) >= 3:
+        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
+        s[x].parallel(fused)
+    elif len(s[x].op.axis) >= 2:
+        s[x].parallel(s[x].op.axis[0])
+    return s
index a078eac..d29fb64 100644 (file)
@@ -127,7 +127,7 @@ def verify_concatenate(shapes, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(out_tensor)
+            s = topi.generic.schedule_concatenate(out_tensor)
 
         foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
         data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
@@ -476,6 +476,7 @@ def test_concatenate():
                         (12, 6, 7, 3),
                         (8, 6, 7, 3),
                         (2, 6, 7, 3)], 0)
+    verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
 
 
 def test_stack():