Use auto-tuner to improve conv2d_gemm performance (#6117)
authorGiuseppe Rossini <giuseppe.rossini@arm.com>
Wed, 26 Aug 2020 06:35:22 +0000 (07:35 +0100)
committerGitHub <noreply@github.com>
Wed, 26 Aug 2020 06:35:22 +0000 (14:35 +0800)
* Use auto-tuner to improve conv2d_gemm performance

The following tuning entities have been introduced:
- Unrolling and vectorizing input matrix transform
- Reordering gemm to exploit parallel threads
- Unrolling `gemm_quantized` intrinsic
- Interleaving `gemm_quantized` intrinsic

Change-Id: Icd3ab005663f78a80672e71ef368f6d0efa4a401

* Rebasing

Change-Id: Id27b6de705b16b93df8e885868961fa0321497be

* Fix python linting

Change-Id: I77d880424c3e7ce9de67c970ddb2cf2a92b52f79

* Fusing batch into inner dimensions before parallelizing

Change-Id: Ic58d1138ab96d58d12f5855f0e1044f10d9e6e9b

python/tvm/topi/arm_cpu/conv2d_gemm.py
python/tvm/topi/arm_cpu/conv2d_int8.py
python/tvm/topi/arm_cpu/tensor_intrin.py

index c8e1a5a..62f013a 100644 (file)
 import tvm
 from tvm import te
 from tvm.topi import nn
-from ..util import get_const_tuple
+from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
+from ..util import get_const_tuple, get_const_int
 from ..nn.util import get_pad_tuple
-from .tensor_intrin import gemv_quantized, gemv_quantized_impl
+from .tensor_intrin import gemm_quantized, gemm_quantized_impl
 
 def is_aarch64_arm():
     """ Checks whether we are compiling for an AArch64 target. """
@@ -38,15 +39,15 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
     executing GEMM and transforming the output back"""
     batches, IH, IW, IC = get_const_tuple(data.shape)
 
-    KH, KW = kernel_size
-    OC = output_channels
+    KH, KW = get_const_tuple(kernel_size)
+    OC = get_const_int(output_channels)
 
     K_AREA = KH * KW
 
     if isinstance(dilation, int):
         dilation_h = dilation_w = dilation
     else:
-        dilation_h, dilation_w = dilation
+        dilation_h, dilation_w = get_const_tuple(dilation)
 
     dilated_kernel_h = (KH - 1) * dilation_h + 1
     dilated_kernel_w = (KW - 1) * dilation_w + 1
@@ -126,6 +127,28 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
     out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
                      name='conv2d_gemm_output')
 
+
+    # Configuration space
+    x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
+    cfg.define_reorder('reorder_gemm',
+                       [x, y],
+                       policy='candidate',
+                       candidate=[[x, y],
+                                  [y, x]])
+
+    outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
+    cfg.define_annotate("A_interleaved_unroll_vec",
+                        [outer_loop, inner_loop],
+                        policy="try_unroll_vec")
+    cfg.define_knob('gemm_quantized_unroll', [True, False])
+    cfg.define_knob('gemm_quantized_interleave', [True, False])
+
+    # Fallback configuration
+    if cfg.is_fallback:
+        cfg['reorder_gemm'] = ReorderEntity([0, 1])
+        cfg['A_interleaved_unroll_vec'] = AnnotateEntity(["unroll", "vec"])
+        cfg['gemm_quantized_unroll'] = OtherOptionEntity(False)
+        cfg['gemm_quantized_interleave'] = OtherOptionEntity(True)
     return out
 
 # Schedules
@@ -150,15 +173,22 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
         n_outer, n_inner = s[data_im2col].split(n, 16)
         s[data_im2col].unroll(n_outer)
         s[data_im2col].vectorize(n_inner)
+        b_m_fused = s[data_im2col].fuse(b, m)
+        s[data_im2col].parallel(b_m_fused)
     else:
         s[data_im2col].compute_inline()
 
     # Computation(through tensorize)
     b, xo, yo, xi, yi = C_interleaved.op.axis
-    s[C_interleaved].reorder(xo, yo, yi, xi)
-    s[C_interleaved].parallel(xo)
-    s[A_interleaved].compute_at(s[C_interleaved], xo)
-    s[A_interleaved].vectorize(A_interleaved.op.axis[4])
+    outer_gemm, inner_gemm = cfg['reorder_gemm'].apply(s, C_interleaved, [xo, yo])
+    s[C_interleaved].reorder(yi, xi)
+    b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
+    s[C_interleaved].parallel(b_outer_gemm_fused)
+    s[A_interleaved].compute_at(s[C_interleaved], b_outer_gemm_fused)
+    _, _, _, outer_A_interleaved, inner_A_interleaved = A_interleaved.op.axis
+    cfg['A_interleaved_unroll_vec'].apply(s,
+                                          A_interleaved,
+                                          [outer_A_interleaved, inner_A_interleaved])
 
     in_type = A_interleaved.dtype
     out_type = C.dtype
@@ -166,10 +196,16 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
         K = A_interleaved_input.shape[2]
         _, M, N = C.shape
         assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"
-
-        gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type)
-        s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type))
-        s[C_interleaved].tensorize(yi, gem_v_dotprod)
+        unroll = cfg['gemm_quantized_unroll'].val
+        interleave = cfg['gemm_quantized_interleave'].val
+        gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
+        s[C_interleaved].pragma(b_outer_gemm_fused, "import_llvm", gemm_quantized_impl(M,
+                                                                                       N,
+                                                                                       K,
+                                                                                       unroll,
+                                                                                       interleave,
+                                                                                       in_type))
+        s[C_interleaved].tensorize(yi, gemm)
 
     # Output transform
     if out != final_out:
@@ -177,6 +213,4 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
         _, inner = s[out].split(c, 4)
         s[C].compute_at(s[out], inner)
         s[out].vectorize(inner)
-
-
     return s
index 89a37fa..9a6e8cc 100644 (file)
@@ -140,8 +140,10 @@ def schedule_conv2d_NHWC_quantized(cfg, outs):
     # Vectorize the output and then inline all the rest
     out = outs[0]
     n, h, w, c = out.op.axis
+    n_h_fused = s[out].fuse(n, h)
     outer, inner = s[out].split(c, 4)
     s[out].vectorize(inner)
+    s[out].parallel(n_h_fused)
 
     def _callback(op):
         """Traverse operators from computation graph"""
index 270bfbe..52e67ad 100644 (file)
@@ -21,7 +21,186 @@ import tvm
 from tvm import te
 from tvm.contrib import util, clang
 
-def gemv_quantized_impl(M, N, data_type='uint8'):
+def gemm_quantized_4_4_batched():
+    return """
+           // First half
+           // Higher part of a0 * {b0,b1,b2,b3}
+           "umull v8.8h, v0.8b, v4.8b\\n"
+           "umull v9.8h, v0.8b, v5.8b\\n"
+           "umull v10.8h, v0.8b, v6.8b\\n"
+           "umull v11.8h, v0.8b, v7.8b\\n"
+
+           // Higher part of a1 * {b0,b1,b2,b3}
+           "umull v12.8h, v1.8b, v4.8b\\n"
+           "umull v13.8h, v1.8b, v5.8b\\n"
+           "umull v14.8h, v1.8b, v6.8b\\n"
+           "umull v15.8h, v1.8b, v7.8b\\n"
+
+           // Accumulate
+           "uadalp v16.4s, v8.8h\\n"
+           "uadalp v17.4s, v9.8h\\n"
+           "uadalp v18.4s, v10.8h\\n"
+           "uadalp v19.4s, v11.8h\\n"
+           "uadalp v20.4s, v12.8h\\n"
+           "uadalp v21.4s, v13.8h\\n"
+           "uadalp v22.4s, v14.8h\\n"
+           "uadalp v23.4s, v15.8h\\n"
+
+           // Lower part of a0 * {b0,b1,b2,b3}
+           "umull2 v8.8h, v0.16b, v4.16b\\n"
+           "umull2 v9.8h, v0.16b, v5.16b\\n"
+           "umull2 v10.8h, v0.16b, v6.16b\\n"
+           "umull2 v11.8h, v0.16b, v7.16b\\n"
+
+           // Lower part of a1 * {b0,b1,b2,b3}
+           "umull2 v12.8h, v1.16b, v4.16b\\n"
+           "umull2 v13.8h, v1.16b, v5.16b\\n"
+           "umull2 v14.8h, v1.16b, v6.16b\\n"
+           "umull2 v15.8h, v1.16b, v7.16b\\n"
+
+            // Accumulate again
+           "uadalp v16.4s, v8.8h\\n"
+           "uadalp v17.4s, v9.8h\\n"
+           "uadalp v18.4s, v10.8h\\n"
+           "uadalp v19.4s, v11.8h\\n"
+           "uadalp v20.4s, v12.8h\\n"
+           "uadalp v21.4s, v13.8h\\n"
+           "uadalp v22.4s, v14.8h\\n"
+           "uadalp v23.4s, v15.8h\\n"
+
+           // Second half
+           // Lower part of a2 * {b0,b1,b2,b3}
+           "umull v8.8h, v2.8b, v4.8b\\n"
+           "umull v9.8h, v2.8b, v5.8b\\n"
+           "umull v10.8h, v2.8b, v6.8b\\n"
+           "umull v11.8h, v2.8b, v7.8b\\n"
+
+           // Lower part of a3 * {b0,b1,b2,b3}
+           "umull v12.8h, v3.8b, v4.8b\\n"
+           "umull v13.8h, v3.8b, v5.8b\\n"
+           "umull v14.8h, v3.8b, v6.8b\\n"
+           "umull v15.8h, v3.8b, v7.8b\\n"
+
+           // Accumulate
+           "uadalp v24.4s, v8.8h\\n"
+           "uadalp v25.4s, v9.8h\\n"
+           "uadalp v26.4s, v10.8h\\n"
+           "uadalp v27.4s, v11.8h\\n"
+           "uadalp v28.4s, v12.8h\\n"
+           "uadalp v29.4s, v13.8h\\n"
+           "uadalp v30.4s, v14.8h\\n"
+           "uadalp v31.4s, v15.8h\\n"
+
+           // Higher part of a2 * {b0,b1,b2,b3}
+           "umull2 v8.8h, v2.16b, v4.16b\\n"
+           "umull2 v9.8h, v2.16b, v5.16b\\n"
+           "umull2 v10.8h, v2.16b, v6.16b\\n"
+           "umull2 v11.8h, v2.16b, v7.16b\\n"
+
+           // Higher part of a3 * {b0,b1,b2,b3}
+           "umull2 v12.8h, v3.16b, v4.16b\\n"
+           "umull2 v13.8h, v3.16b, v5.16b\\n"
+           "umull2 v14.8h, v3.16b, v6.16b\\n"
+           "umull2 v15.8h, v3.16b, v7.16b\\n"
+
+           // Accumulate again
+           "uadalp v24.4s, v8.8h\\n"
+           "uadalp v25.4s, v9.8h\\n"
+           "uadalp v26.4s, v10.8h\\n"
+           "uadalp v27.4s, v11.8h\\n"
+           "uadalp v28.4s, v12.8h\\n"
+           "uadalp v29.4s, v13.8h\\n"
+           "uadalp v30.4s, v14.8h\\n"
+           "uadalp v31.4s, v15.8h\\n"
+    """
+
+def gemm_quantized_4_4_interleaved():
+    return """
+             // First half
+             // Higher part of a0 * {b0,b1,b2,b3} and accumulate
+             "umull v8.8h, v0.8b, v4.8b\\n"
+             "uadalp v16.4s, v8.8h\\n"
+             "umull v9.8h, v0.8b, v5.8b\\n"
+             "uadalp v17.4s, v9.8h\\n"
+             "umull v10.8h, v0.8b, v6.8b\\n"
+             "uadalp v18.4s, v10.8h\\n"
+             "umull v11.8h, v0.8b, v7.8b\\n"
+             "uadalp v19.4s, v11.8h\\n"
+
+             // Higher part of a1 * {b0,b1,b2,b3} and accumulate
+             "umull v12.8h, v1.8b, v4.8b\\n"
+             "uadalp v20.4s, v12.8h\\n"
+             "umull v13.8h, v1.8b, v5.8b\\n"
+             "uadalp v21.4s, v13.8h\\n"
+             "umull v14.8h, v1.8b, v6.8b\\n"
+             "uadalp v22.4s, v14.8h\\n"
+             "umull v15.8h, v1.8b, v7.8b\\n"
+             "uadalp v23.4s, v15.8h\\n"
+
+             // Lower part of a0 * {b0,b1,b2,b3} and accumulate
+             "umull2 v8.8h, v0.16b, v4.16b\\n"
+             "uadalp v16.4s, v8.8h\\n"
+             "umull2 v9.8h, v0.16b, v5.16b\\n"
+             "uadalp v17.4s, v9.8h\\n"
+             "umull2 v10.8h, v0.16b, v6.16b\\n"
+             "uadalp v18.4s, v10.8h\\n"
+             "umull2 v11.8h, v0.16b, v7.16b\\n"
+             "uadalp v19.4s, v11.8h\\n"
+
+             // Lower part of a1 * {b0,b1,b2,b3} and accumulate
+             "umull2 v12.8h, v1.16b, v4.16b\\n"
+             "uadalp v20.4s, v12.8h\\n"
+             "umull2 v13.8h, v1.16b, v5.16b\\n"
+             "uadalp v21.4s, v13.8h\\n"
+             "umull2 v14.8h, v1.16b, v6.16b\\n"
+             "uadalp v22.4s, v14.8h\\n"
+             "umull2 v15.8h, v1.16b, v7.16b\\n"
+             "uadalp v23.4s, v15.8h\\n"
+
+             // Second half
+             // Higher part of a2 * {b0,b1,b2,b3} and accumulate
+             "umull v8.8h, v2.8b, v4.8b\\n"
+             "uadalp v24.4s, v8.8h\\n"
+             "umull v9.8h, v2.8b, v5.8b\\n"
+             "uadalp v25.4s, v9.8h\\n"
+             "umull v10.8h, v2.8b, v6.8b\\n"
+             "uadalp v26.4s, v10.8h\\n"
+             "umull v11.8h, v2.8b, v7.8b\\n"
+             "uadalp v27.4s, v11.8h\\n"
+
+             // Higher part of a3 * {b0,b1,b2,b3} and accumulate
+             "umull v12.8h, v3.8b, v4.8b\\n"
+             "uadalp v28.4s, v12.8h\\n"
+             "umull v13.8h, v3.8b, v5.8b\\n"
+             "uadalp v29.4s, v13.8h\\n"
+             "umull v14.8h, v3.8b, v6.8b\\n"
+             "uadalp v30.4s, v14.8h\\n"
+             "umull v15.8h, v3.8b, v7.8b\\n"
+             "uadalp v31.4s, v15.8h\\n"
+
+             // Lower part of a2 * {b0,b1,b2,b3} and accumulate
+             "umull2 v8.8h, v2.16b, v4.16b\\n"
+             "uadalp v24.4s, v8.8h\\n"
+             "umull2 v9.8h, v2.16b, v5.16b\\n"
+             "uadalp v25.4s, v9.8h\\n"
+             "umull2 v10.8h, v2.16b, v6.16b\\n"
+             "uadalp v26.4s, v10.8h\\n"
+             "umull2 v11.8h, v2.16b, v7.16b\\n"
+             "uadalp v27.4s, v11.8h\\n"
+
+             // Lower part of a3 * {b0,b1,b2,b3} and accumulate
+             "umull2 v12.8h, v3.16b, v4.16b\\n"
+             "uadalp v28.4s, v12.8h\\n"
+             "umull2 v13.8h, v3.16b, v5.16b\\n"
+             "uadalp v29.4s, v13.8h\\n"
+             "umull2 v14.8h, v3.16b, v6.16b\\n"
+             "uadalp v30.4s, v14.8h\\n"
+             "umull2 v15.8h, v3.16b, v7.16b\\n"
+             "uadalp v31.4s, v15.8h\\n"
+    """
+
+
+def gemm_quantized_impl(M, N, K, unroll, interleave, data_type='uint8'):
     """ Assembly implementation of a blocked gemv. Given
     a block a of shape (4, k) and a block b' of shape (4, k)
     produces the output block c = a*b of shape (4,4) """
@@ -30,13 +209,21 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
     stepB = min(4, N)
     assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation'
 
-    cc_code = """
-          extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer,
-                                                    unsigned char *a_buffer,
-                                                    unsigned char *b_buffer,
-                                                    int K, int m, int n)
-              """.format(data_type, stepA, stepB)
+    signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(data_type,
+                                                                                 stepA,
+                                                                                 stepB)
+    if unroll:
+        signature += ("_" + str(K))
 
+    if interleave:
+        signature += ("_interleaved")
+
+    signature += """(int *c_buffer,
+                      unsigned char *a_buffer,
+                      unsigned char *b_buffer,
+                      int K, int m, int n)"""
+
+    cc_code = signature
     cc_code += """
     {
             unsigned char * a_ptr = a_buffer;
@@ -65,141 +252,58 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
             "1:"
     """
 
-    cc_code += ' "ldr q0, [%[a_ptr]]\\n" '
+    main_loop = ' "ldr q0, [%[a_ptr]]\\n" '
 
     if M > 1:
-        cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" '
+        main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" '
     else:
-        cc_code += ' "movi v1.4s, #0\\n" '
+        main_loop += ' "movi v1.4s, #0\\n" '
 
     if M > 2:
-        cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" '
+        main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" '
     else:
-        cc_code += ' "movi v2.4s, #0\\n" '
+        main_loop += ' "movi v2.4s, #0\\n" '
 
     if M > 3:
-        cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" '
+        main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" '
     else:
-        cc_code += ' "movi v3.4s, #0\\n" '
+        main_loop += ' "movi v3.4s, #0\\n" '
 
-    cc_code += ' "ldr q4, [%[b_ptr]]\\n" '
+    main_loop += ' "ldr q4, [%[b_ptr]]\\n" '
 
     if N > 1:
-        cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" '
+        main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" '
 
     if N > 2:
-        cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" '
+        main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" '
 
     if N > 3:
-        cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" '
+        main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" '
+
+    # Main computation can interleave multiply/accumulate instructions
+    # or schedule them in batches (first all multiplies then all accumulates)
+    if interleave:
+        main_loop += gemm_quantized_4_4_interleaved()
+    else:
+        main_loop += gemm_quantized_4_4_batched()
 
-    cc_code += """
-                // First half
-                // Higher part of a0 * {b0,b1,b2,b3}
-                "umull v8.8h, v0.8b, v4.8b\\n"
-                "umull v9.8h, v0.8b, v5.8b\\n"
-                "umull v10.8h, v0.8b, v6.8b\\n"
-                "umull v11.8h, v0.8b, v7.8b\\n"
-
-                // Higher part of a1 * {b0,b1,b2,b3}
-                "umull v12.8h, v1.8b, v4.8b\\n"
-                "umull v13.8h, v1.8b, v5.8b\\n"
-                "umull v14.8h, v1.8b, v6.8b\\n"
-                "umull v15.8h, v1.8b, v7.8b\\n"
-
-                // Accumulate
-                "uadalp v16.4s, v8.8h\\n"
-                "uadalp v17.4s, v9.8h\\n"
-                "uadalp v18.4s, v10.8h\\n"
-                "uadalp v19.4s, v11.8h\\n"
-                "uadalp v20.4s, v12.8h\\n"
-                "uadalp v21.4s, v13.8h\\n"
-                "uadalp v22.4s, v14.8h\\n"
-                "uadalp v23.4s, v15.8h\\n"
-
-                // Lower part of a0 * {b0,b1,b2,b3}
-                "umull2 v8.8h, v0.16b, v4.16b\\n"
-                "umull2 v9.8h, v0.16b, v5.16b\\n"
-                "umull2 v10.8h, v0.16b, v6.16b\\n"
-                "umull2 v11.8h, v0.16b, v7.16b\\n"
-
-                // Lower part of a1 * {b0,b1,b2,b3}
-                "umull2 v12.8h, v1.16b, v4.16b\\n"
-                "umull2 v13.8h, v1.16b, v5.16b\\n"
-                "umull2 v14.8h, v1.16b, v6.16b\\n"
-                "umull2 v15.8h, v1.16b, v7.16b\\n"
-
-                 // Accumulate again
-                "uadalp v16.4s, v8.8h\\n"
-                "uadalp v17.4s, v9.8h\\n"
-                "uadalp v18.4s, v10.8h\\n"
-                "uadalp v19.4s, v11.8h\\n"
-                "uadalp v20.4s, v12.8h\\n"
-                "uadalp v21.4s, v13.8h\\n"
-                "uadalp v22.4s, v14.8h\\n"
-                "uadalp v23.4s, v15.8h\\n"
-
-                // Second half
-
-                // Lower part of a2 * {b0,b1,b2,b3}
-                "umull v8.8h, v2.8b, v4.8b\\n"
-                "umull v9.8h, v2.8b, v5.8b\\n"
-                "umull v10.8h, v2.8b, v6.8b\\n"
-                "umull v11.8h, v2.8b, v7.8b\\n"
-
-                // Lower part of a3 * {b0,b1,b2,b3}
-                "umull v12.8h, v3.8b, v4.8b\\n"
-                "umull v13.8h, v3.8b, v5.8b\\n"
-                "umull v14.8h, v3.8b, v6.8b\\n"
-                "umull v15.8h, v3.8b, v7.8b\\n"
-
-                // Accumulate
-                "uadalp v24.4s, v8.8h\\n"
-                "uadalp v25.4s, v9.8h\\n"
-                "uadalp v26.4s, v10.8h\\n"
-                "uadalp v27.4s, v11.8h\\n"
-                "uadalp v28.4s, v12.8h\\n"
-                "uadalp v29.4s, v13.8h\\n"
-                "uadalp v30.4s, v14.8h\\n"
-                "uadalp v31.4s, v15.8h\\n"
-
-                // Higher part of a2 * {b0,b1,b2,b3}
-                "umull2 v8.8h, v2.16b, v4.16b\\n"
-                "umull2 v9.8h, v2.16b, v5.16b\\n"
-                "umull2 v10.8h, v2.16b, v6.16b\\n"
-                "umull2 v11.8h, v2.16b, v7.16b\\n"
-
-                // Higher part of a3 * {b0,b1,b2,b3}
-                "umull2 v12.8h, v3.16b, v4.16b\\n"
-                "umull2 v13.8h, v3.16b, v5.16b\\n"
-                "umull2 v14.8h, v3.16b, v6.16b\\n"
-                "umull2 v15.8h, v3.16b, v7.16b\\n"
-
-                // Accumulate again
-                "uadalp v24.4s, v8.8h\\n"
-                "uadalp v25.4s, v9.8h\\n"
-                "uadalp v26.4s, v10.8h\\n"
-                "uadalp v27.4s, v11.8h\\n"
-                "uadalp v28.4s, v12.8h\\n"
-                "uadalp v29.4s, v13.8h\\n"
-                "uadalp v30.4s, v14.8h\\n"
-                "uadalp v31.4s, v15.8h\\n"
-    """
     blockA = min(64, M * 16)
     blockB = min(64, N * 16)
-
-    cc_code += """
-                // Increment pointers and decrement k
-                "add %[a_ptr], %[a_ptr], #{0}\\n"
-                "add %[b_ptr], %[b_ptr], #{1}\\n"
-                "subs %w[k], %w[k], #1\\n"
-    """.format(blockA, blockB)
-
-    stepC = min(4, N)
-
+    main_loop += """// Increment pointers
+                    "add %[a_ptr], %[a_ptr], #{0}\\n"
+                    "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(blockA, blockB)
+
+    if unroll:
+        k = int(K//16)
+        for l in range(0, k):
+            cc_code += main_loop
+    else:
+        cc_code += main_loop
+        cc_code += """
+                    "subs %w[k], %w[k], #1\\n"
+                    "cbnz %w[k], 1b\\n"
+                   """
     cc_code += """
-                "cbnz %w[k], 1b\\n"
-
                 // Final additions
 
                 // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
@@ -237,6 +341,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
                 "str q16, [%[c_ptr]]\\n"
             """
 
+    stepC = min(4, N)
     if M > 1:
         cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)
 
@@ -272,7 +377,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
     return ll_code
 
 
-def gemv_quantized(M, N, K, in_type, out_type):
+def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     """
     Use integer ARM v8 instructions in order to produce a block c of 4x4 elements
     given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final
@@ -331,23 +436,17 @@ def gemv_quantized(M, N, K, in_type, out_type):
             cc = outs[0]
             stepA = min(4, M)
             stepB = min(4, N)
-
-            if in_type == 'int8':
-                ib.emit(tvm.tir.call_extern("int32",
-                                            "gemv_int8_int8_int32_{0}_{1}".format(stepA, stepB),
-                                            outs[0].access_ptr("w"),
-                                            a_buffer.access_ptr("r"),
-                                            b_buffer.access_ptr("r"),
-                                            K))
-            else:
-                ib.emit(tvm.tir.call_extern("int32",
-                                            "gemv_uint8_uint8_int32_{0}_{1}".format(stepA, stepB),
-                                            c_buffer.access_ptr("w"),
-                                            a_buffer.access_ptr("r"),
-                                            b_buffer.access_ptr("r"),
-                                            K,
-                                            C.shape[0],  # m, very useful for debug
-                                            C.shape[1]))  # n, very useful for debug
+            intrin_name = "gemm_quantized_{0}_{0}_int32_{1}_{2}".format(in_type, stepA, stepB)
+            if unroll:
+                intrin_name += ("_" + str(K))
+            if interleave:
+                intrin_name += "_interleaved"
+            ib.emit(tvm.tir.call_extern("int32",
+                                        intrin_name,
+                                        outs[0].access_ptr("w"),
+                                        a_buffer.access_ptr("r"),
+                                        b_buffer.access_ptr("r"),
+                                        K))
             return ib.get()
 
         # body, reset, update