x86: Fix overflow bug with wmemchr-sse2 and wmemchr-avx2 [BZ #27974]
authorNoah Goldstein <goldstein.w.n@gmail.com>
Wed, 9 Jun 2021 20:25:32 +0000 (16:25 -0400)
committerNoah Goldstein <goldstein.w.n@gmail.com>
Wed, 23 Jun 2021 18:13:03 +0000 (14:13 -0400)
This commit fixes the bug mentioned in the previous commit.

The previous implementations of wmemchr in these files relied
on n * sizeof(wchar_t) which was not guranteed by the standard.

The new overflow tests added in the previous commit now
pass (As well as all the other tests).

Signed-off-by: Noah Goldstein <goldstein.w.n@gmail.com>
Reviewed-by: H.J. Lu <hjl.tools@gmail.com>
sysdeps/x86_64/memchr.S
sysdeps/x86_64/multiarch/memchr-avx2.S

index beff270..3ddc465 100644 (file)
 #ifdef USE_AS_WMEMCHR
 # define MEMCHR                wmemchr
 # define PCMPEQ                pcmpeqd
+# define CHAR_PER_VEC  4
 #else
 # define MEMCHR                memchr
 # define PCMPEQ                pcmpeqb
+# define CHAR_PER_VEC  16
 #endif
 
 /* fast SSE2 version with using pmaxub and 64 byte loop */
@@ -33,15 +35,14 @@ ENTRY(MEMCHR)
        movd    %esi, %xmm1
        mov     %edi, %ecx
 
+#ifdef __ILP32__
+       /* Clear the upper 32 bits.  */
+       movl    %edx, %edx
+#endif
 #ifdef USE_AS_WMEMCHR
        test    %RDX_LP, %RDX_LP
        jz      L(return_null)
-       shl     $2, %RDX_LP
 #else
-# ifdef __ILP32__
-       /* Clear the upper 32 bits.  */
-       movl    %edx, %edx
-# endif
        punpcklbw %xmm1, %xmm1
        test    %RDX_LP, %RDX_LP
        jz      L(return_null)
@@ -60,13 +61,16 @@ ENTRY(MEMCHR)
        test    %eax, %eax
 
        jnz     L(matches_1)
-       sub     $16, %rdx
+       sub     $CHAR_PER_VEC, %rdx
        jbe     L(return_null)
        add     $16, %rdi
        and     $15, %ecx
        and     $-16, %rdi
+#ifdef USE_AS_WMEMCHR
+       shr     $2, %ecx
+#endif
        add     %rcx, %rdx
-       sub     $64, %rdx
+       sub     $(CHAR_PER_VEC * 4), %rdx
        jbe     L(exit_loop)
        jmp     L(loop_prolog)
 
@@ -77,16 +81,21 @@ L(crosscache):
        movdqa  (%rdi), %xmm0
 
        PCMPEQ  %xmm1, %xmm0
-/* Check if there is a match.  */
+       /* Check if there is a match.  */
        pmovmskb %xmm0, %eax
-/* Remove the leading bytes.  */
+       /* Remove the leading bytes.  */
        sar     %cl, %eax
        test    %eax, %eax
        je      L(unaligned_no_match)
-/* Check which byte is a match.  */
+       /* Check which byte is a match.  */
        bsf     %eax, %eax
-
+#ifdef USE_AS_WMEMCHR
+       mov     %eax, %esi
+       shr     $2, %esi
+       sub     %rsi, %rdx
+#else
        sub     %rax, %rdx
+#endif
        jbe     L(return_null)
        add     %rdi, %rax
        add     %rcx, %rax
@@ -94,15 +103,18 @@ L(crosscache):
 
        .p2align 4
 L(unaligned_no_match):
-        /* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
+       /* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
           "rdx - (16 - rcx)" instead of "(rdx + rcx) - 16" to void
           possible addition overflow.  */
        neg     %rcx
        add     $16, %rcx
+#ifdef USE_AS_WMEMCHR
+       shr     $2, %ecx
+#endif
        sub     %rcx, %rdx
        jbe     L(return_null)
        add     $16, %rdi
-       sub     $64, %rdx
+       sub     $(CHAR_PER_VEC * 4), %rdx
        jbe     L(exit_loop)
 
        .p2align 4
@@ -135,7 +147,7 @@ L(loop_prolog):
        test    $0x3f, %rdi
        jz      L(align64_loop)
 
-       sub     $64, %rdx
+       sub     $(CHAR_PER_VEC * 4), %rdx
        jbe     L(exit_loop)
 
        movdqa  (%rdi), %xmm0
@@ -167,11 +179,14 @@ L(loop_prolog):
        mov     %rdi, %rcx
        and     $-64, %rdi
        and     $63, %ecx
+#ifdef USE_AS_WMEMCHR
+       shr     $2, %ecx
+#endif
        add     %rcx, %rdx
 
        .p2align 4
 L(align64_loop):
-       sub     $64, %rdx
+       sub     $(CHAR_PER_VEC * 4), %rdx
        jbe     L(exit_loop)
        movdqa  (%rdi), %xmm0
        movdqa  16(%rdi), %xmm2
@@ -218,7 +233,7 @@ L(align64_loop):
 
        .p2align 4
 L(exit_loop):
-       add     $32, %edx
+       add     $(CHAR_PER_VEC * 2), %edx
        jle     L(exit_loop_32)
 
        movdqa  (%rdi), %xmm0
@@ -238,7 +253,7 @@ L(exit_loop):
        pmovmskb %xmm3, %eax
        test    %eax, %eax
        jnz     L(matches32_1)
-       sub     $16, %edx
+       sub     $CHAR_PER_VEC, %edx
        jle     L(return_null)
 
        PCMPEQ  48(%rdi), %xmm1
@@ -250,13 +265,13 @@ L(exit_loop):
 
        .p2align 4
 L(exit_loop_32):
-       add     $32, %edx
+       add     $(CHAR_PER_VEC * 2), %edx
        movdqa  (%rdi), %xmm0
        PCMPEQ  %xmm1, %xmm0
        pmovmskb %xmm0, %eax
        test    %eax, %eax
        jnz     L(matches_1)
-       sub     $16, %edx
+       sub     $CHAR_PER_VEC, %edx
        jbe     L(return_null)
 
        PCMPEQ  16(%rdi), %xmm1
@@ -293,7 +308,13 @@ L(matches32):
        .p2align 4
 L(matches_1):
        bsf     %eax, %eax
+#ifdef USE_AS_WMEMCHR
+       mov     %eax, %esi
+       shr     $2, %esi
+       sub     %rsi, %rdx
+#else
        sub     %rax, %rdx
+#endif
        jbe     L(return_null)
        add     %rdi, %rax
        ret
@@ -301,7 +322,13 @@ L(matches_1):
        .p2align 4
 L(matches16_1):
        bsf     %eax, %eax
+#ifdef USE_AS_WMEMCHR
+       mov     %eax, %esi
+       shr     $2, %esi
+       sub     %rsi, %rdx
+#else
        sub     %rax, %rdx
+#endif
        jbe     L(return_null)
        lea     16(%rdi, %rax), %rax
        ret
@@ -309,7 +336,13 @@ L(matches16_1):
        .p2align 4
 L(matches32_1):
        bsf     %eax, %eax
+#ifdef USE_AS_WMEMCHR
+       mov     %eax, %esi
+       shr     $2, %esi
+       sub     %rsi, %rdx
+#else
        sub     %rax, %rdx
+#endif
        jbe     L(return_null)
        lea     32(%rdi, %rax), %rax
        ret
@@ -317,7 +350,13 @@ L(matches32_1):
        .p2align 4
 L(matches48_1):
        bsf     %eax, %eax
+#ifdef USE_AS_WMEMCHR
+       mov     %eax, %esi
+       shr     $2, %esi
+       sub     %rsi, %rdx
+#else
        sub     %rax, %rdx
+#endif
        jbe     L(return_null)
        lea     48(%rdi, %rax), %rax
        ret
index 0d8758e..afdb956 100644 (file)
 
 # define VEC_SIZE 32
 # define PAGE_SIZE 4096
+# define CHAR_PER_VEC  (VEC_SIZE / CHAR_SIZE)
 
        .section SECTION(.text),"ax",@progbits
 ENTRY (MEMCHR)
 # ifndef USE_AS_RAWMEMCHR
        /* Check for zero length.  */
-       test    %RDX_LP, %RDX_LP
-       jz      L(null)
-# endif
-# ifdef USE_AS_WMEMCHR
-       shl     $2, %RDX_LP
-# else
 #  ifdef __ILP32__
-       /* Clear the upper 32 bits.  */
-       movl    %edx, %edx
+       /* Clear upper bits.  */
+       and     %RDX_LP, %RDX_LP
+#  else
+       test    %RDX_LP, %RDX_LP
 #  endif
+       jz      L(null)
 # endif
        /* Broadcast CHAR to YMMMATCH.  */
        vmovd   %esi, %xmm0
@@ -84,7 +82,7 @@ ENTRY (MEMCHR)
        vpmovmskb %ymm1, %eax
 # ifndef USE_AS_RAWMEMCHR
        /* If length < CHAR_PER_VEC handle special.  */
-       cmpq    $VEC_SIZE, %rdx
+       cmpq    $CHAR_PER_VEC, %rdx
        jbe     L(first_vec_x0)
 # endif
        testl   %eax, %eax
@@ -98,6 +96,10 @@ ENTRY (MEMCHR)
 L(first_vec_x0):
        /* Check if first match was before length.  */
        tzcntl  %eax, %eax
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Multiply length by 4 to get byte count.  */
+       sall    $2, %edx
+#  endif
        xorl    %ecx, %ecx
        cmpl    %eax, %edx
        leaq    (%rdi, %rax), %rax
@@ -110,12 +112,12 @@ L(null):
 # endif
        .p2align 4
 L(cross_page_boundary):
-       /* Save pointer before aligning as its original value is necessary
-          for computer return address if byte is found or adjusting length
-          if it is not and this is memchr.  */
+       /* Save pointer before aligning as its original value is
+          necessary for computer return address if byte is found or
+          adjusting length if it is not and this is memchr.  */
        movq    %rdi, %rcx
-       /* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr and
-          rdi for rawmemchr.  */
+       /* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr
+          and rdi for rawmemchr.  */
        orq     $(VEC_SIZE - 1), %ALGN_PTR_REG
        VPCMPEQ -(VEC_SIZE - 1)(%ALGN_PTR_REG), %ymm0, %ymm1
        vpmovmskb %ymm1, %eax
@@ -124,6 +126,10 @@ L(cross_page_boundary):
           match).  */
        leaq    1(%ALGN_PTR_REG), %rsi
        subq    %RRAW_PTR_REG, %rsi
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Divide bytes by 4 to get wchar_t count.  */
+       shrl    $2, %esi
+#  endif
 # endif
        /* Remove the leading bytes.  */
        sarxl   %ERAW_PTR_REG, %eax, %eax
@@ -181,6 +187,10 @@ L(cross_page_continue):
        orq     $(VEC_SIZE - 1), %rdi
        /* esi is for adjusting length to see if near the end.  */
        leal    (VEC_SIZE * 4 + 1)(%rdi, %rcx), %esi
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Divide bytes by 4 to get the wchar_t count.  */
+       sarl    $2, %esi
+#  endif
 # else
        orq     $(VEC_SIZE - 1), %rdi
 L(cross_page_continue):
@@ -213,7 +223,7 @@ L(cross_page_continue):
 
 # ifndef USE_AS_RAWMEMCHR
        /* Check if at last VEC_SIZE * 4 length.  */
-       subq    $(VEC_SIZE * 4), %rdx
+       subq    $(CHAR_PER_VEC * 4), %rdx
        jbe     L(last_4x_vec_or_less_cmpeq)
        /* Align data to VEC_SIZE * 4 - 1 for the loop and readjust
           length.  */
@@ -221,6 +231,10 @@ L(cross_page_continue):
        movl    %edi, %ecx
        orq     $(VEC_SIZE * 4 - 1), %rdi
        andl    $(VEC_SIZE * 4 - 1), %ecx
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Divide bytes by 4 to get the wchar_t count.  */
+       sarl    $2, %ecx
+#  endif
        addq    %rcx, %rdx
 # else
        /* Align data to VEC_SIZE * 4 - 1 for loop.  */
@@ -250,15 +264,19 @@ L(loop_4x_vec):
 
        subq    $-(VEC_SIZE * 4), %rdi
 
-       subq    $(VEC_SIZE * 4), %rdx
+       subq    $(CHAR_PER_VEC * 4), %rdx
        ja      L(loop_4x_vec)
 
-       /* Fall through into less than 4 remaining vectors of length case.
-        */
+       /* Fall through into less than 4 remaining vectors of length
+          case.  */
        VPCMPEQ (VEC_SIZE * 0 + 1)(%rdi), %ymm0, %ymm1
        vpmovmskb %ymm1, %eax
        .p2align 4
 L(last_4x_vec_or_less):
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Multiply length by 4 to get byte count.  */
+       sall    $2, %edx
+#  endif
        /* Check if first VEC contained match.  */
        testl   %eax, %eax
        jnz     L(first_vec_x1_check)
@@ -355,6 +373,10 @@ L(last_vec_x2_return):
 L(last_4x_vec_or_less_cmpeq):
        VPCMPEQ (VEC_SIZE * 4 + 1)(%rdi), %ymm0, %ymm1
        vpmovmskb %ymm1, %eax
+#  ifdef USE_AS_WMEMCHR
+       /* NB: Multiply length by 4 to get byte count.  */
+       sall    $2, %edx
+#  endif
        subq    $-(VEC_SIZE * 4), %rdi
        /* Check first VEC regardless.  */
        testl   %eax, %eax