sve zgemm kernel
authorBine Brank <binebrank@gmail.com>
Sun, 26 Dec 2021 07:44:05 +0000 (08:44 +0100)
committerBine Brank <binebrank@gmail.com>
Sun, 26 Dec 2021 07:44:05 +0000 (08:44 +0100)
kernel/arm64/zgemm_kernel_sve_v1x4.S

index 0fc966f8cce56d41a34e8498e5b1747fdae0f281..1201d6dac373efe4ba3592521b65ad3533d9e0de 100644 (file)
@@ -48,6 +48,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #define pCRow2         x14
 #define pCRow3         x15
 #define pA             x16
+#define lanes          x17
+
 #define alphaR         x19
 #define alphaI         x20
 
@@ -168,7 +170,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 .macro KERNELv1x4_I
        ld2d    {z0.d, z1.d}, p1/z, [pA]
-       ld2d    {z2.d, z3.d}, p1/z, [pA, lanes, lsl #4] // next one
+       ld2d    {z2.d, z3.d}, p1/z, [pA, #2, mul vl] // next one
        add     pA, pA, lanes, lsl #5    // pA += lanes*2*2*8
 
     ld1rd  z8.d, p0/z,  [pB]
@@ -561,17 +563,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
        prfm    PLDL1KEEP, [origPA]
 
        fmov    alphaR, d0
+       dup         alphaz_R, alphaR
        fmov    alphaI, d1
+       dup         alphaz_I, alphaI
 
        lsl     LDC, LDC, #4                    // ldc = ldc * 2 * 8
+    ptrue p0.d                  // create true predicate 
 
        mov     pB, origPB
 
+// Loop over N
        mov     counterJ, origN
        asr     counterJ, counterJ, #2          // J = J / 4
        cmp     counterJ, #0
        ble     .Lzgemm_kernel_L2_BEGIN
 
+/******************************************************************************/
 .Lzgemm_kernel_L4_BEGIN:
        mov     pCRow0, pC
        add     pCRow1, pCRow0, LDC
@@ -582,204 +589,112 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
        mov     pA, origPA                      // pA = start of A array
 
-.Lzgemm_kernel_L4_M4_BEGIN:
+.Lzgemm_kernel_L4_Mv1_BEGIN:
 
-       mov     counterI, origM
-       asr     counterI, counterI, #2          // counterI = counterI / 4
-       cmp     counterI, #0
-       ble     .Lzgemm_kernel_L4_M2_BEGIN
+/* Loop over M is done in an SVE fashion. This has the benefit of the last M%SVE_LEN iterations being done in a single sweep */
+    mov counterI, #0
+    whilelt p1.d, counterI, origM   
+    cntp lanes, p0, p1.d                        // lanes contain number of active SVE lanes in M dimension
 
        .align 5
-.Lzgemm_kernel_L4_M4_20:
+.Lzgemm_kernel_L4_Mv1_20:
 
        mov     pB, origPB
+    INITv1x4                     // fill with zeros
+
        asr     counterL , origK, #3
        cmp     counterL , #2
-       blt     .Lzgemm_kernel_L4_M4_32
+       blt     .Lzgemm_kernel_L4_Mv1_32
 
-       KERNEL4x4_I
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
+       KERNELv1x4_I
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
 
        subs    counterL, counterL, #2          // subtract 2
-       ble     .Lzgemm_kernel_L4_M4_22a
+       ble     .Lzgemm_kernel_L4_Mv1_22a
 
        .align 5
-.Lzgemm_kernel_L4_M4_22:
+.Lzgemm_kernel_L4_Mv1_22:
 
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
 
        subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L4_M4_22
+       bgt     .Lzgemm_kernel_L4_Mv1_22
 
        .align 5
-.Lzgemm_kernel_L4_M4_22a:
+.Lzgemm_kernel_L4_Mv1_22a:
 
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_E
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_E
 
-       b        .Lzgemm_kernel_L4_M4_44
+       b        .Lzgemm_kernel_L4_Mv1_44
 
        .align 5
-.Lzgemm_kernel_L4_M4_32:
+.Lzgemm_kernel_L4_Mv1_32:
 
        tst     counterL, #1
-       ble     .Lzgemm_kernel_L4_M4_40
+       ble     .Lzgemm_kernel_L4_Mv1_40
 
-       KERNEL4x4_I
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_M2
-       KERNEL4x4_M1
-       KERNEL4x4_E
+       KERNELv1x4_I
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_M2
+       KERNELv1x4_M1
+       KERNELv1x4_E
 
-       b       .Lzgemm_kernel_L4_M4_44
+       b       .Lzgemm_kernel_L4_Mv1_44
 
 
-.Lzgemm_kernel_L4_M4_40:
+.Lzgemm_kernel_L4_Mv1_40:
 
-       INIT4x4
+       INITv1x4
 
-.Lzgemm_kernel_L4_M4_44:
+.Lzgemm_kernel_L4_Mv1_44:
 
        ands    counterL , origK, #7
-       ble     .Lzgemm_kernel_L4_M4_100
+       ble     .Lzgemm_kernel_L4_Mv1_100
 
        .align 5
-.Lzgemm_kernel_L4_M4_46:
-       KERNEL4x4_SUB
+.Lzgemm_kernel_L4_Mv1_46:
+       KERNELv1x4_SUB
 
        subs    counterL, counterL, #1
-       bne     .Lzgemm_kernel_L4_M4_46
+       bne     .Lzgemm_kernel_L4_Mv1_46
 
-.Lzgemm_kernel_L4_M4_100:
+.Lzgemm_kernel_L4_Mv1_100:
        prfm    PLDL1KEEP, [pA]
        prfm    PLDL1KEEP, [pA, #64]
        prfm    PLDL1KEEP, [origPB]
 
-       SAVE4x4
-
-.Lzgemm_kernel_L4_M4_END:
-       subs    counterI, counterI, #1
-       bne     .Lzgemm_kernel_L4_M4_20
-
-.Lzgemm_kernel_L4_M2_BEGIN:
-
-       mov     counterI, origM
-       tst     counterI , #3
-       ble     .Lzgemm_kernel_L4_END
-
-       tst     counterI, #2                    // counterI = counterI / 2
-       ble     .Lzgemm_kernel_L4_M1_BEGIN
-
-.Lzgemm_kernel_L4_M2_20:
-
-       INIT2x4
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-       cmp     counterL , #0
-       ble     .Lzgemm_kernel_L4_M2_40
-
-.Lzgemm_kernel_L4_M2_22:
-
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-       KERNEL2x4_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L4_M2_22
-
-
-.Lzgemm_kernel_L4_M2_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L4_M2_100
-
-.Lzgemm_kernel_L4_M2_42:
-
-       KERNEL2x4_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L4_M2_42
-
-.Lzgemm_kernel_L4_M2_100:
-
-       SAVE2x4
-
-.Lzgemm_kernel_L4_M2_END:
-
-
-.Lzgemm_kernel_L4_M1_BEGIN:
-
-       tst     counterI, #1                    // counterI = counterI % 2
-       ble     .Lzgemm_kernel_L4_END
-
-.Lzgemm_kernel_L4_M1_20:
-
-       INIT1x4
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-       cmp     counterL , #0
-       ble     .Lzgemm_kernel_L4_M1_40
-
-.Lzgemm_kernel_L4_M1_22:
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-       KERNEL1x4_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L4_M1_22
-
-
-.Lzgemm_kernel_L4_M1_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L4_M1_100
-
-.Lzgemm_kernel_L4_M1_42:
+       SAVEv1x4
 
-       KERNEL1x4_SUB
+.Lzgemm_kernel_L4_Mv1_END:
 
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L4_M1_42
-
-.Lzgemm_kernel_L4_M1_100:
+    incd    counterI
+    whilelt p1.d, counterI, origM             //SVE instruction
+    cntp lanes, p0, p1.d                        // lanes contain number of active SVE lanes in M dimension
+    b.any   .Lzgemm_kernel_L4_Mv1_20   
 
-       SAVE1x4
 
 
 .Lzgemm_kernel_L4_END:
@@ -810,157 +725,61 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
-.Lzgemm_kernel_L2_M4_BEGIN:
+.Lzgemm_kernel_L2_Mv1_BEGIN:
+
+    mov counterI, #0
+    whilelt p1.d, counterI, origM               //SVE instruction
+    cntp lanes, p0, p1.d
 
-       mov     counterI, origM
-       asr     counterI, counterI, #2          // counterI = counterI / 4
-       cmp     counterI,#0
-       ble     .Lzgemm_kernel_L2_M2_BEGIN
 
-.Lzgemm_kernel_L2_M4_20:
+.Lzgemm_kernel_L2_Mv1_20:
 
-       INIT4x2
+       INITv1x2
 
        mov     pB, origPB
        asr     counterL , origK, #3            // counterL = counterL / 8
        cmp     counterL,#0
-       ble     .Lzgemm_kernel_L2_M4_40
+       ble     .Lzgemm_kernel_L2_Mv1_40
        .align 5
 
-.Lzgemm_kernel_L2_M4_22:
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
+.Lzgemm_kernel_L2_Mv1_22:
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
 
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
-       KERNEL4x2_SUB
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
+       KERNELv1x2_SUB
 
        subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M4_22
+       bgt     .Lzgemm_kernel_L2_Mv1_22
 
 
-.Lzgemm_kernel_L2_M4_40:
+.Lzgemm_kernel_L2_Mv1_40:
 
        ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L2_M4_100
-
-.Lzgemm_kernel_L2_M4_42:
-
-       KERNEL4x2_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M4_42
-
-.Lzgemm_kernel_L2_M4_100:
-
-       SAVE4x2
+       ble     .Lzgemm_kernel_L2_Mv1_100
 
-.Lzgemm_kernel_L2_M4_END:
+.Lzgemm_kernel_L2_Mv1_42:
 
-       subs    counterI, counterI, #1
-       bgt     .Lzgemm_kernel_L2_M4_20
-
-
-.Lzgemm_kernel_L2_M2_BEGIN:
-
-       mov     counterI, origM
-       tst     counterI , #3
-       ble     .Lzgemm_kernel_L2_END
-
-       tst     counterI, #2                    // counterI = counterI / 2
-       ble     .Lzgemm_kernel_L2_M1_BEGIN
-
-.Lzgemm_kernel_L2_M2_20:
-
-       INIT2x2
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-        cmp    counterL,#0
-       ble     .Lzgemm_kernel_L2_M2_40
-
-.Lzgemm_kernel_L2_M2_22:
-
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
-
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
-       KERNEL2x2_SUB
+       KERNELv1x2_SUB
 
        subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M2_22
+       bgt     .Lzgemm_kernel_L2_Mv1_42
 
+.Lzgemm_kernel_L2_Mv1_100:
 
-.Lzgemm_kernel_L2_M2_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L2_M2_100
-
-.Lzgemm_kernel_L2_M2_42:
-
-       KERNEL2x2_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M2_42
+       SAVEv1x2
 
-.Lzgemm_kernel_L2_M2_100:
+.Lzgemm_kernel_L2_Mv1_END:
 
-       SAVE2x2
 
-.Lzgemm_kernel_L2_M2_END:
-
-
-.Lzgemm_kernel_L2_M1_BEGIN:
-
-       tst     counterI, #1                    // counterI = counterI % 2
-       ble     .Lzgemm_kernel_L2_END
-
-.Lzgemm_kernel_L2_M1_20:
-
-       INIT1x2
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-        cmp     counterL, #0
-       ble     .Lzgemm_kernel_L2_M1_40
-
-.Lzgemm_kernel_L2_M1_22:
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-       KERNEL1x2_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M1_22
-
-
-.Lzgemm_kernel_L2_M1_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L2_M1_100
-
-.Lzgemm_kernel_L2_M1_42:
-
-       KERNEL1x2_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L2_M1_42
-
-.Lzgemm_kernel_L2_M1_100:
-
-       SAVE1x2
+    incd    counterI
+    whilelt p1.d, counterI, origM             //SVE instruction
+    cntp lanes, p0, p1.d
+    b.any   .Lzgemm_kernel_L2_Mv1_20   
 
 
 .Lzgemm_kernel_L2_END:
@@ -981,163 +800,64 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
        mov     pA, origPA                      // pA = A
 
+.Lzgemm_kernel_L1_Mv1_BEGIN:
 
+    mov counterI, #0
+    whilelt p1.d, counterI, origM               //SVE instruction
+    cntp lanes, p0, p1.d
 
-.Lzgemm_kernel_L1_M4_BEGIN:
 
-       mov     counterI, origM
-       asr     counterI, counterI, #2          // counterI = counterI / 4
-       cmp     counterI, #0
-       ble     .Lzgemm_kernel_L1_M2_BEGIN
+.Lzgemm_kernel_L1_Mv1_20:
 
-.Lzgemm_kernel_L1_M4_20:
-
-       INIT4x1
+       INITv1x1
 
        mov     pB, origPB
        asr     counterL , origK, #3            // counterL = counterL / 8
        cmp     counterL , #0
-       ble     .Lzgemm_kernel_L1_M4_40
+       ble     .Lzgemm_kernel_L1_Mv1_40
        .align 5
 
-.Lzgemm_kernel_L1_M4_22:
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
+.Lzgemm_kernel_L1_Mv1_22:
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
 
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
-       KERNEL4x1_SUB
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
+       KERNELv1x1_SUB
 
        subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M4_22
+       bgt     .Lzgemm_kernel_L1_Mv1_22
 
 
-.Lzgemm_kernel_L1_M4_40:
+.Lzgemm_kernel_L1_Mv1_40:
 
        ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L1_M4_100
+       ble     .Lzgemm_kernel_L1_Mv1_100
 
-.Lzgemm_kernel_L1_M4_42:
+.Lzgemm_kernel_L1_Mv1_42:
 
-       KERNEL4x1_SUB
+       KERNELv1x1_SUB
 
        subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M4_42
-
-.Lzgemm_kernel_L1_M4_100:
-
-       SAVE4x1
-
-.Lzgemm_kernel_L1_M4_END:
-
-       subs    counterI, counterI, #1
-       bgt     .Lzgemm_kernel_L1_M4_20
-
+       bgt     .Lzgemm_kernel_L1_Mv1_42
 
-.Lzgemm_kernel_L1_M2_BEGIN:
+.Lzgemm_kernel_L1_Mv1_100:
 
-       mov     counterI, origM
-       tst     counterI , #3
-       ble     .Lzgemm_kernel_L1_END
+       SAVEv1x1
 
-       tst     counterI, #2                    // counterI = counterI / 2
-       ble     .Lzgemm_kernel_L1_M1_BEGIN
-
-.Lzgemm_kernel_L1_M2_20:
-
-       INIT2x1
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-       cmp     counterL , #0
-       ble     .Lzgemm_kernel_L1_M2_40
-
-.Lzgemm_kernel_L1_M2_22:
-
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-       KERNEL2x1_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M2_22
-
-
-.Lzgemm_kernel_L1_M2_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L1_M2_100
-
-.Lzgemm_kernel_L1_M2_42:
-
-       KERNEL2x1_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M2_42
-
-.Lzgemm_kernel_L1_M2_100:
-
-       SAVE2x1
-
-.Lzgemm_kernel_L1_M2_END:
-
-
-.Lzgemm_kernel_L1_M1_BEGIN:
-
-       tst     counterI, #1                    // counterI = counterI % 2
-       ble     .Lzgemm_kernel_L1_END
-
-.Lzgemm_kernel_L1_M1_20:
-
-       INIT1x1
-
-       mov     pB, origPB
-       asr     counterL , origK, #3            // counterL = counterL / 8
-       cmp     counterL , #0
-       ble     .Lzgemm_kernel_L1_M1_40
-
-.Lzgemm_kernel_L1_M1_22:
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-       KERNEL1x1_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M1_22
-
-
-.Lzgemm_kernel_L1_M1_40:
-
-       ands    counterL , origK, #7            // counterL = counterL % 8
-       ble     .Lzgemm_kernel_L1_M1_100
-
-.Lzgemm_kernel_L1_M1_42:
-
-       KERNEL1x1_SUB
-
-       subs    counterL, counterL, #1
-       bgt     .Lzgemm_kernel_L1_M1_42
-
-.Lzgemm_kernel_L1_M1_100:
-
-       SAVE1x1
+.Lzgemm_kernel_L1_Mv1_END:
 
+    incd    counterI
+    whilelt p1.d, counterI, origM             //SVE instruction
+    cntp lanes, p0, p1.d
+    b.any   .Lzgemm_kernel_L1_Mv1_20   
 
 .Lzgemm_kernel_L1_END:
 
+/******************************************************************************/
 
 .Lzgemm_kernel_L999:
        mov     x0, #0                          // set return value