[TOPI][ARM] Improve injective schedule (#2801)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Mon, 18 Mar 2019 23:16:13 +0000 (16:16 -0700)
committerLianmin Zheng <mercy_zheng@sjtu.edu.cn>
Mon, 18 Mar 2019 23:16:13 +0000 (07:16 +0800)
topi/python/topi/arm_cpu/__init__.py
topi/python/topi/arm_cpu/injective.py [new file with mode: 0755]
topi/tests/python/test_topi_resize.py
topi/tests/python/test_topi_upsampling.py

index 8d78f67ac0b67a094cd3f720aa5c5bcc40cff9ee..3e888de55feca4ada8c408370940019a68d075fc 100644 (file)
@@ -4,3 +4,4 @@ from . import conv2d
 from . import depthwise_conv2d
 from . import conv2d_transpose
 from . import bitserial_conv2d
+from . import injective
diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py
new file mode 100755 (executable)
index 0000000..09ea86c
--- /dev/null
@@ -0,0 +1,37 @@
+# pylint: disable=invalid-name, unused-variable
+"""Schedule for pooling operators"""
+import tvm
+from .. import generic
+
+@generic.schedule_injective.register(["arm_cpu"])
+def schedule_injective(outs):
+    """ARM CPU schedule for injective op.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of injective in the format
+          of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    x = outs[0]
+    if list(s[x].op.axis):
+        # do not vectorize for broadcast
+        (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
+        s[x].vectorize(ii)
+    tvm.schedule.AutoInlineInjective(s)
+    if len(s[x].op.axis) >= 4:
+        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
+        s[x].parallel(fused)
+    elif len(s[x].op.axis) >= 3:
+        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
+        s[x].parallel(fused)
+    elif len(s[x].op.axis) >= 2:
+        s[x].parallel(s[x].op.axis[0])
+    return s
index 6926a3a2a73c673047bd0d769ca6cd99ce730b13..80966b15ddbe9a5e1d7002a0d624980b62605b42 100644 (file)
@@ -5,6 +5,8 @@ import topi
 import topi.testing
 import math
 
+from common import get_all_backend
+
 def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False):
 
     if layout == 'NCHW':
@@ -40,7 +42,7 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou
 
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
 
-    for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
+    for device in get_all_backend():
         check_device(device)
 
 def test_resize():
index 8b0ba519736a41374532f006ba6c985726e6c872..60f6e5655ffff055bf3b7c9a9c53d69bd19c09f4 100644 (file)
@@ -5,6 +5,8 @@ import topi
 import topi.testing
 import math
 
+from common import get_all_backend
+
 def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"):
 
 
@@ -45,7 +47,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
 
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-    for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
+    for device in get_all_backend():
         check_device(device)
 
 def test_upsampling():