[Arm] parallel batch axis (#3931)
authorYizhi Liu <liuyizhi@apache.org>
Wed, 11 Sep 2019 18:10:48 +0000 (02:10 +0800)
committerHaichen Shen <shenhaichen@gmail.com>
Wed, 11 Sep 2019 18:10:47 +0000 (11:10 -0700)
* support LLVM trunk

* guard with USE_LLVM in if condition for c++14

* GREATER_EQUAL -> GREATER

* [Arm] parallel batch axis

topi/python/topi/arm_cpu/conv2d.py

index 77b37ed..55d6bba 100644 (file)
@@ -280,13 +280,15 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
     s[conv].compute_at(s[last], ow)
 
     # mark parallel
-    s[last].parallel(co)
+    p = s[last].fuse(n, co)
+    s[last].parallel(p)
 
     if data_vec.op.name == 'data_vec_undilated':
-        _, h, _, _, _, _, _, _ = s[data_vec].op.axis
+        n, h, _, _, _, _, _, _ = s[data_vec].op.axis
     else:
-        _, h, _, _, _, _ = s[data_vec].op.axis
-    s[data_vec].parallel(h)
+        n, h, _, _, _, _ = s[data_vec].op.axis
+    p = s[data_vec].fuse(n, h)
+    s[data_vec].parallel(p)
 
     if kernel_vec.op.name == 'kernel_vec':
         co, _, _, _, _ = s[kernel_vec].op.axis
@@ -470,8 +472,9 @@ def _schedule_winograd(cfg, s, output, last):
     # output
     n, co, h, w = s[last].op.axis
     co, coi = cfg['tile_k'].apply(s, last, co)
-    s[M].compute_at(s[last], co)
-    s[last].parallel(co)
+    p = s[last].fuse(n, co)
+    s[M].compute_at(s[last], p)
+    s[last].parallel(p)
 
     MM = s.cache_read(M, 'global', [Y])
     m = get_const_int(V.shape[0]) + 1 - 3