[Arm64/Unix] Revise JIT_MemSet (dotnet/coreclr#11217)
authorSteve MacLean <sdmaclea@qti.qualcomm.com>
Wed, 3 May 2017 22:52:42 +0000 (18:52 -0400)
committerJan Kotas <jkotas@microsoft.com>
Wed, 3 May 2017 22:52:42 +0000 (15:52 -0700)
* [Arm64/Unix] Revise JIT_MemSet

Use DC ZVA
Use stp
Correctly handle short set lengths
Simplify code & pseudo code
Use uint*_t to make pseudo code more readable

Commit migrated from https://github.com/dotnet/coreclr/commit/cac4e2e82e63c3d892c935262cfb2b04090acae9

src/coreclr/src/vm/arm64/crthelpers.S

index 51774ae..c8b108c 100644 (file)
 //
 //void JIT_MemSet(void *dst, int val, SIZE_T count)
 //
-//    uintptr_t valEx = (unsigned char)val;
+//    uint64_t valEx = (unsigned char)val;
 //    valEx = valEx | valEx << 8;
 //    valEx = valEx | valEx << 16;
 //    valEx = valEx | valEx << 32;
 //
+//    size_t dc_zva_size = 4ULL << DCZID_EL0.BS;
+//
+//    uint64_t use_dc_zva = (val == 0) && !DCZID_EL0.p ? count / (2 * dc_zva_size) : 0; // ~Minimum size (assumes worst case alignment)
+//
 //    // If not aligned then make it 8-byte aligned   
-//    if(((uintptr_t)dst&0x7) != 0)
+//    if(((uint64_t)dst&0xf) != 0)
 //    {
-//        if(((uintptr_t)dst&0x3) == 0)
+//        // Calculate alignment we can do without exceeding count
+//        // Use math to avoid introducing more unpredictable branches
+//        // Due to inherent mod in lsr, ~7 is used instead of ~0 to handle count == 0
+//        // Note logic will fail is count >= (1 << 61).  But this exceeds max physical memory for arm64
+//        uint8_t align = (dst & 0x7) & (~uint64_t(7) >> (countLeadingZeros(count) mod 64))
+//
+//        if(align&0x1)
+//        {
+//            *(unit8_t*)dst = (unit8_t)valEx;
+//            dst = (unit8_t*)dst + 1;
+//            count-=1;
+//        }
+//
+//        if(align&0x2)
+//        {
+//            *(unit16_t*)dst = (unit16_t)valEx;
+//            dst = (unit16_t*)dst + 1;
+//            count-=2;
+//        }
+//
+//        if(align&0x4)
 //        {
-//            *(UINT*)dst = (UINT)valEx;
-//            dst = (UINT*)dst + 1;
+//            *(unit32_t*)dst = (unit32_t)valEx;
+//            dst = (unit32_t*)dst + 1;
 //            count-=4;
 //        }
-//        else if(((uintptr_t)dst&0x1) == 0)
+//    }
+//
+//    if(use_dc_zva)
+//    {
+//        // If not aligned then make it aligned to dc_zva_size
+//        if(dst&0x8)
 //        {
-//            while(count > 0 && ((uintptr_t)dst&0x7) != 0)
-//            {
-//                *(short*)dst = (short)valEx;
-//                dst = (short*)dst + 1;
-//                count-=2;
-//            }
+//            *(uint64_t*)dst = (uint64_t)valEx;
+//            dst = (uint64_t*)dst + 1;
+//            count-=8;
 //        }
-//        else
+//
+//        while(dst & (dc_zva_size - 1))
+//        {
+//            *(uint64_t*)dst = valEx;
+//            dst = (uint64_t*)dst + 1;
+//            *(uint64_t*)dst = valEx;
+//            dst = (uint64_t*)dst + 1;
+//            count-=16;
+//        }
+//
+//        count -= dc_zva_size;
+//
+//        while(count >= 0)
 //        {
-//            while(count > 0 && ((uintptr_t)dst&0x7) != 0)
-//            {
-//                *(char*)dst = (char)valEx;
-//                dst = (char*)dst + 1;
-//                count--;
-//            }
+//            dc_zva(dst);
+//            dst = (uint8_t*)dst + dc_zva_size;
+//            count-=dc_zva_size;
 //        }
+//
+//        count += dc_zva_size;
 //    }
 //
-//    while(count >= 8)
+//    count-=16;
+//
+//    while(count >= 0)
 //    {
-//        *(uintptr_t*)dst = valEx;
-//        dst = (uintptr_t*)dst + 1;
-//        count-=8;
+//        *(uint64_t*)dst = valEx;
+//        dst = (uint64_t*)dst + 1;
+//        *(uint64_t*)dst = valEx;
+//        dst = (uint64_t*)dst + 1;
+//        count-=16;
+//    }
+//
+//    if(count & 8)
+//    {
+//        *(uint64_t*)dst = valEx;
+//        dst = (uint64_t*)dst + 1;
 //    }
 //
 //    if(count & 4)
 //    {
-//        *(UINT*)dst = (UINT)valEx;
-//        dst = (UINT*)dst + 1;
+//        *(uint32_t*)dst = (uint32_t)valEx;
+//        dst = (uint32_t*)dst + 1;
 //    }
 //
 //    if(count & 2)
 //    {
-//        *(short*)dst = (short)valEx;
-//        dst = (short*)dst + 1;
+//        *(uint16_t*)dst = (uint16_t)valEx;
+//        dst = (uint16_t*)dst + 1;
 //    }
 //
 //    if(count & 1)
 //    {
-//        *(char*)dst = (char)valEx;
+//        *(uint8_t*)dst = (uint8_t)valEx;
 //    }
 //
 //
 // as C++ method.
 
 LEAF_ENTRY JIT_MemSet, _TEXT
-    uxtb        w8,w1
-    sxtw        x8,w8
-    orr         x8,x8,x8, lsl #8
-    orr         x8,x8,x8, lsl #0x10
-    orr         x9,x8,x8, lsl #0x20
-    and         x8,x0,#7
-    cbz         x8,LOCAL_LABEL(JIT_MemSet_0x7c)
-    and         x8,x0,#3
-    cbnz        x8,LOCAL_LABEL(JIT_MemSet_0x38)
-    str         w9,[x0]
-    add         x0,x0,#4
-    mov         x8,#-4
-    add         x2,x2,x8
-    b           LOCAL_LABEL(JIT_MemSet_0x7c)
-LOCAL_LABEL(JIT_MemSet_0x38):
-    cbz         x2,LOCAL_LABEL(JIT_MemSet_0x7c)
-    tbnz        x0,#0,LOCAL_LABEL(JIT_MemSet_0x60)
-LOCAL_LABEL(JIT_MemSet_0x40):
-    and         x8,x0,#7
-    cbz         x8,LOCAL_LABEL(JIT_MemSet_0x7c)
-    strh        w9,[x0]
-    add         x0,x0,#2
-    mov         x8,#-2
-    add         x2,x2,x8
-    cbnz        x2,LOCAL_LABEL(JIT_MemSet_0x40)
-    b           LOCAL_LABEL(JIT_MemSet_0x7c)
-LOCAL_LABEL(JIT_MemSet_0x60):
-    and         x8,x0,#7
-    cbz         x8,LOCAL_LABEL(JIT_MemSet_0x7c)
-    strb        w9,[x0]
-    add         x0,x0,#1
-    mov         x8,#-1
-    add         x2,x2,x8
-    cbnz        x2,LOCAL_LABEL(JIT_MemSet_0x60)
-LOCAL_LABEL(JIT_MemSet_0x7c):
-    cmp         x2,#8
-    blo         LOCAL_LABEL(JIT_MemSet_0xb8)
-    lsr         x8,x2,#3
-    mov         x11,x8
-    mov         x10,x0
-    add         x8,x10,x11, lsl #3
+    ands        w8, w1, #0xff
+    mrs         x3, DCZID_EL0                      // x3 = DCZID_EL0
+    mov         x6, #4
+    lsr         x11, x2, #3                        // x11 = count >> 3
+
+    orr         w8, w8, w8, lsl #8
+    and         x5, x3, #0xf                       // x5 = dczid_el0.bs
+    csel        x11, x11, xzr, eq                  // x11 = (val == 0) ? count >> 3 : 0
+    tst         x3, (1 << 4)
+
+    orr         w8, w8, w8, lsl #0x10
+    csel        x11, x11, xzr, eq                  // x11 = (val == 0) && !DCZID_EL0.p ? count >> 3 : 0
+    ands        x3, x0, #7                         // x3 = dst & 7
+    lsl         x9, x6, x5                         // x9 = size
+
+    orr         x8, x8, x8, lsl #0x20
+    lsr         x11, x11, x5                       // x11 = (val == 0) && !DCZID_EL0.p ? count >> (3 + DCZID_EL0.bs) : 0
+    sub         x10, x9, #1                        // x10 = mask
+
+    b.eq        LOCAL_LABEL(JIT_MemSet_0x80)
+
+    movn        x4, #7
+    clz         x5, x2
+    lsr         x4, x4, x5
+    and         x3, x3, x4
+
+    tbz         x3, #0, LOCAL_LABEL(JIT_MemSet_0x2c)
+    strb        w8, [x0], #1
+    sub         x2, x2, #1
+LOCAL_LABEL(JIT_MemSet_0x2c):
+    tbz         x3, #1, LOCAL_LABEL(JIT_MemSet_0x5c)
+    strh        w8, [x0], #2
+    sub         x2, x2, #2
+LOCAL_LABEL(JIT_MemSet_0x5c):
+    tbz         x3, #2, LOCAL_LABEL(JIT_MemSet_0x80)
+    str         w8, [x0], #4
+    sub         x2, x2, #4
+LOCAL_LABEL(JIT_MemSet_0x80):
+    cbz         x11, LOCAL_LABEL(JIT_MemSet_0x9c)
+    tbz         x0, #3, LOCAL_LABEL(JIT_MemSet_0x84)
+    str         x8, [x0], #8
+    sub         x2, x2, #8
+
+    b           LOCAL_LABEL(JIT_MemSet_0x85)
+LOCAL_LABEL(JIT_MemSet_0x84):
+    stp         x8, x8, [x0], #16
+    sub         x2, x2, #16
+LOCAL_LABEL(JIT_MemSet_0x85):
+    tst         x0, x10
+    b.ne        LOCAL_LABEL(JIT_MemSet_0x84)
+
+    b           LOCAL_LABEL(JIT_MemSet_0x8a)
+LOCAL_LABEL(JIT_MemSet_0x88):
+    dc          zva, x0
+    add         x0, x0, x9
+LOCAL_LABEL(JIT_MemSet_0x8a):
+    subs        x2, x2, x9
+    b.ge        LOCAL_LABEL(JIT_MemSet_0x88)
+
+LOCAL_LABEL(JIT_MemSet_0x8c):
+    add         x2, x2, x9
+
 LOCAL_LABEL(JIT_MemSet_0x9c):
-    cmp         x10,x8
-    beq         LOCAL_LABEL(JIT_MemSet_0xac)
-    str         x9,[x10],#8
-    b           LOCAL_LABEL(JIT_MemSet_0x9c)
-LOCAL_LABEL(JIT_MemSet_0xac):
-    mov         x8,#-8
-    madd        x2,x11,x8,x2
-    add         x0,x0,x11, lsl #3
-LOCAL_LABEL(JIT_MemSet_0xb8):
-    tbz         x2,#2,LOCAL_LABEL(JIT_MemSet_0xc4)
-    str         w9,[x0]
-    add         x0,x0,#4
-LOCAL_LABEL(JIT_MemSet_0xc4):
-    tbz         x2,#1,LOCAL_LABEL(JIT_MemSet_0xd0)
-    strh        w9,[x0]
-    add         x0,x0,#2
-LOCAL_LABEL(JIT_MemSet_0xd0):
-    tbz         x2,#0,LOCAL_LABEL(JIT_MemSet_0xd8)
-    strb        w9,[x0]
-LOCAL_LABEL(JIT_MemSet_0xd8):
+    b           LOCAL_LABEL(JIT_MemSet_0xa8)
+LOCAL_LABEL(JIT_MemSet_0xa0):
+    stp         x8, x8, [x0], #16
+LOCAL_LABEL(JIT_MemSet_0xa8):
+    subs        x2, x2, #16
+    b.ge        LOCAL_LABEL(JIT_MemSet_0xa0)
+
+LOCAL_LABEL(JIT_MemSet_0xb0):
+    tbz         x2, #3, LOCAL_LABEL(JIT_MemSet_0xb4)
+    str         x8, [x0], #8
+LOCAL_LABEL(JIT_MemSet_0xb4):
+    tbz         x2, #2, LOCAL_LABEL(JIT_MemSet_0xc8)
+    str         w8, [x0], #4
+LOCAL_LABEL(JIT_MemSet_0xc8):
+    tbz         x2, #1, LOCAL_LABEL(JIT_MemSet_0xdc)
+    strh        w8, [x0], #2
+LOCAL_LABEL(JIT_MemSet_0xdc):
+    tbz         x2, #0, LOCAL_LABEL(JIT_MemSet_0xe8)
+    strb        w8, [x0]
+LOCAL_LABEL(JIT_MemSet_0xe8):
     ret         lr
 LEAF_END_MARKED JIT_MemSet, _TEXT