lib: zstd: Improve decode performance
authorDongwoo Lee <dwoo08.lee@samsung.com>
Thu, 4 Jun 2020 04:54:44 +0000 (13:54 +0900)
committerHoegeun Kwon <hoegeun.kwon@samsung.com>
Thu, 3 Aug 2023 08:46:03 +0000 (17:46 +0900)
To speed up decode performance, optimizations are brought from zstd
github repository and ported as kernel-style.

Since the low-level algorithm is preferred in linux due to
compression/decompression performance (default level 3), the
optimization for low-level is chosed as follows:

  [1] lib: zstd: Speed up single segment zstd_fast by 5%
      (https://github.com/facebook/zstd/pull/1562/commits/95624b77e477752b3c380c22be7bcf67f06c9934)
  [2] perf improvements for zstd decode
      (https://github.com/facebook/zstd/pull/1668/commits/29d1e81bbdfc21085529623e7bc5abcb3e1627ae)
  [3] updated double_fast complementary insertion
      (https://github.com/facebook/zstd/pull/1681/commits/d1327738c277643f09c972a407083ad73c8ecf7b)
  [4] Improvements in zstd decode performance
      (https://github.com/facebook/zstd/pull/1756/commits/b83059958246dfcb5b91af9c187fad8c706869a0)
  [5] Optimize decompression and fix wildcopy overread
      (https://github.com/facebook/zstd/pull/1804/commits/efd37a64eaff5a0a26ae2566fdb45dc4a0c91673)
  [6] Improve ZSTD_highbit32's codegen
      (https://github.com/facebook/zstd/commit/a07da7b0db682c170a330a8c21585be3d68275fa)
  [7] Optimize decompression speed for gcc and clang (#1892)
      (https://github.com/facebook/zstd/commit/718f00ff6fe42db7e6ba09a7f7992b3e85283f77)
  [8] Fix performance regression on aarch64 with clang
      (https://github.com/facebook/zstd/pull/1973/commits/cb2abc3dbe010113d9e00ca3b612bf61983145a2)

Change-Id: Ia2cf120879a415988dbbc2fce59a994915c8c77c
Signed-off-by: Dongwoo Lee <dwoo08.lee@samsung.com>
lib/zstd/bitstream.h
lib/zstd/compress.c
lib/zstd/decompress.c
lib/zstd/huf_decompress.c
lib/zstd/zstd_internal.h
lib/zstd/zstd_opt.h

index 5d6343c..282471b 100644 (file)
@@ -145,7 +145,7 @@ ZSTD_STATIC size_t BIT_readBitsFast(BIT_DStream_t *bitD, unsigned nbBits);
 /*-**************************************************************
 *  Internal functions
 ****************************************************************/
-ZSTD_STATIC unsigned BIT_highbit32(register U32 val) { return 31 - __builtin_clz(val); }
+ZSTD_STATIC unsigned BIT_highbit32(register U32 val) { return __builtin_clz(val) ^ 31; }
 
 /*=====    Local Constants   =====*/
 static const unsigned BIT_mask[] = {0,       1,       3,       7,      0xF,      0x1F,     0x3F,     0x7F,      0xFF,
@@ -334,6 +334,24 @@ ZSTD_STATIC size_t BIT_readBitsFast(BIT_DStream_t *bitD, U32 nbBits)
        return value;
 }
 
+/*! BIT_reloadDStreamFast() :
+ *  Similar to BIT_reloadDStream(), but with two differences:
+ *  1. bitsConsumed <= sizeof(bitD->bitContainer)*8 must hold!
+ *  2. Returns BIT_DStream_overflow when bitD->ptr < bitD->limitPtr, at this
+ *     point you must use BIT_reloadDStream() to reload.
+ */
+ZSTD_STATIC BIT_DStream_status BIT_reloadDStreamFast(BIT_DStream_t *bitD)
+{
+       if (unlikely(bitD->ptr < bitD->start + sizeof(bitD->bitContainer)))
+               return BIT_DStream_overflow;
+
+       bitD->ptr -= bitD->bitsConsumed >> 3;
+       bitD->bitsConsumed &= 7;
+       bitD->bitContainer = ZSTD_readLEST(bitD->ptr);
+
+       return BIT_DStream_unfinished;
+}
+
 /*! BIT_reloadDStream() :
 *   Refill `bitD` from buffer previously set in BIT_initDStream() .
 *   This function is safe, it guarantees it will not read beyond src buffer.
@@ -345,10 +363,7 @@ ZSTD_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t *bitD)
                return BIT_DStream_overflow;
 
        if (bitD->ptr >= bitD->start + sizeof(bitD->bitContainer)) {
-               bitD->ptr -= bitD->bitsConsumed >> 3;
-               bitD->bitsConsumed &= 7;
-               bitD->bitContainer = ZSTD_readLEST(bitD->ptr);
-               return BIT_DStream_unfinished;
+               return BIT_reloadDStreamFast(bitD);
        }
        if (bitD->ptr == bitD->start) {
                if (bitD->bitsConsumed < sizeof(bitD->bitContainer) * 8)
index b080264..4016b5f 100644 (file)
@@ -853,15 +853,46 @@ ZSTD_STATIC size_t ZSTD_compressSequences(ZSTD_CCtx *zc, void *dst, size_t dstCa
        return cSize;
 }
 
+/*! ZSTD_safecopyLiterals() :
+ *  memcpy() function that won't read beyond more than WILDCOPY_OVERLENGTH bytes past ilimit_w.
+ *  Only called when the sequence ends past ilimit_w, so it only needs to be optimized for single
+ *  large copies.
+ */
+static void ZSTD_safecopyLiterals(BYTE *op, BYTE const *ip, BYTE const * const iend, BYTE const *ilimit_w)
+{
+       if (ip <= ilimit_w) {
+               ZSTD_wildcopy(op, ip, ilimit_w - ip, ZSTD_no_overlap);
+
+               op += ilimit_w - ip;
+               ip = ilimit_w;
+       }
+       while (ip < iend)
+               *op++ = *ip++;
+}
+
 /*! ZSTD_storeSeq() :
        Store a sequence (literal length, literals, offset code and match length code) into seqStore_t.
        `offsetCode` : distance to match, or 0 == repCode.
        `matchCode` : matchLength - MINMATCH
 */
-ZSTD_STATIC void ZSTD_storeSeq(seqStore_t *seqStorePtr, size_t litLength, const void *literals, U32 offsetCode, size_t matchCode)
+ZSTD_STATIC void ZSTD_storeSeq(seqStore_t *seqStorePtr, size_t litLength, const void *literals, const BYTE *litLimit, U32 offsetCode, size_t matchCode)
 {
+       BYTE const * const litLimit_w = litLimit - WILDCOPY_OVERLENGTH;
+       BYTE const * const litEnd = literals + litLength;
        /* copy Literals */
-       ZSTD_wildcopy(seqStorePtr->lit, literals, litLength);
+       /* We are guaranteed at least 8 bytes of literals space because of HASH_READ_SIZE and
+        * MINMATCH.
+        */
+       if (litEnd <= litLimit_w) {
+               /* Common case we can use wildcopy.
+                * First copy 16 bytes, because literals are likely short.
+                */
+               ZSTD_copy16(seqStorePtr->lit, literals);
+               if (litLength > 16)
+                       ZSTD_wildcopy(seqStorePtr->lit+16, literals+16, (ptrdiff_t)litLength-16, ZSTD_no_overlap);
+       } else {
+               ZSTD_safecopyLiterals(seqStorePtr->lit, literals, litEnd, litLimit_w);
+       }
        seqStorePtr->lit += litLength;
 
        /* literal Length */
@@ -1010,9 +1041,12 @@ void ZSTD_compressBlock_fast_generic(ZSTD_CCtx *cctx, const void *src, size_t sr
        U32 *const hashTable = cctx->hashTable;
        U32 const hBits = cctx->params.cParams.hashLog;
        seqStore_t *seqStorePtr = &(cctx->seqStore);
+       size_t const stepSize = 2;
        const BYTE *const base = cctx->base;
        const BYTE *const istart = (const BYTE *)src;
-       const BYTE *ip = istart;
+       /* We check ip0 (ip + 0) and ip1 (ip + 1) each loop */
+       const BYTE *ip0 = istart;
+       const BYTE *ip1;
        const BYTE *anchor = istart;
        const U32 lowestIndex = cctx->dictLimit;
        const BYTE *const lowest = base + lowestIndex;
@@ -1022,9 +1056,10 @@ void ZSTD_compressBlock_fast_generic(ZSTD_CCtx *cctx, const void *src, size_t sr
        U32 offsetSaved = 0;
 
        /* init */
-       ip += (ip == lowest);
+       ip0 += (ip0 == lowest);
+       ip1 = ip0 + 1;
        {
-               U32 const maxRep = (U32)(ip - lowest);
+               U32 const maxRep = (U32)(ip0 - lowest);
                if (offset_2 > maxRep)
                        offsetSaved = offset_2, offset_2 = 0;
                if (offset_1 > maxRep)
@@ -1032,58 +1067,88 @@ void ZSTD_compressBlock_fast_generic(ZSTD_CCtx *cctx, const void *src, size_t sr
        }
 
        /* Main Search Loop */
-       while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */
+       while (ip1 < ilimit) { /* < instead of <=, because check at (ip0+2) */
                size_t mLength;
-               size_t const h = ZSTD_hashPtr(ip, hBits, mls);
-               U32 const curr = (U32)(ip - base);
-               U32 const matchIndex = hashTable[h];
-               const BYTE *match = base + matchIndex;
-               hashTable[h] = curr; /* update hash table */
-
-               if ((offset_1 > 0) & (ZSTD_read32(ip + 1 - offset_1) == ZSTD_read32(ip + 1))) {
-                       mLength = ZSTD_count(ip + 1 + 4, ip + 1 + 4 - offset_1, iend) + 4;
-                       ip++;
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH);
-               } else {
-                       U32 offset;
-                       if ((matchIndex <= lowestIndex) || (ZSTD_read32(match) != ZSTD_read32(ip))) {
-                               ip += ((ip - anchor) >> g_searchStrength) + 1;
-                               continue;
-                       }
-                       mLength = ZSTD_count(ip + 4, match + 4, iend) + 4;
-                       offset = (U32)(ip - match);
-                       while (((ip > anchor) & (match > lowest)) && (ip[-1] == match[-1])) {
-                               ip--;
-                               match--;
-                               mLength++;
-                       } /* catch up */
-                       offset_2 = offset_1;
-                       offset_1 = offset;
+               BYTE const *ip2 = ip0 + 2;
+               size_t const h0 = ZSTD_hashPtr(ip0, hBits, mls);
+               U32 const val0 = ZSTD_read32(ip0);
+               size_t const h1 = ZSTD_hashPtr(ip1, hBits, mls);
+               U32 const val1 = ZSTD_read32(ip1);
+               U32 const current0 = (U32)(ip0 - base);
+               U32 const current1 = (U32)(ip1 - base);
+               U32 const matchIndex0 = hashTable[h0];
+               U32 const matchIndex1 = hashTable[h1];
+               BYTE const *repMatch = ip2 - offset_1;
+               const BYTE *match0 = base + matchIndex0;
+               const BYTE *match1 = base + matchIndex1;
+               U32 offcode;
+
+               hashTable[h0] = current0;
+               hashTable[h1] = current1;
+
+               if ((offset_1 > 0) & (ZSTD_read32(repMatch) == ZSTD_read32(ip2))) {
+                       mLength = ip2[-1] == repMatch[-1] ? 1 : 0;
+                       ip0 = ip2 - mLength;
+                       match0 = repMatch - mLength;
+                       offcode = 0;
+                       goto _match;
+               }
+               if ((matchIndex0 > lowestIndex) && ZSTD_read32(match0) == val0) {
+                       goto _offset;
+               }
+               if ((matchIndex1 > lowestIndex) && ZSTD_read32(match1) == val1) {
+                       ip0 = ip1;
+                       match0 = match1;
+                       goto _offset;
+               }
+               {
+                       size_t const step = ((ip0 - anchor) >> (g_searchStrength - 1)) + stepSize;
 
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
+                       ip0 += step;
+                       ip1 += step;
+                       continue;
                }
+_offset:       /* Requires: ip0, match0 */
+               /* Compute the offset code */
+               offset_2 = offset_1;
+               offset_1 = (U32)(ip0 - match0);
+               offcode = offset_1 + ZSTD_REP_MOVE;
+               mLength = 0;
+               /* Count the backwards match length */
+               while (((ip0 > anchor) & (match0 > lowest)) && (ip0[-1] == match0[-1])) {
+                       ip0--;
+                       match0--;
+                       mLength++;
+               } /* catch up */
+
+_match:                /* Requires: ip0, match0, offcode */
+               /* Count the forward length */
+               mLength += ZSTD_count(ip0 + mLength + 4, match0 + mLength + 4, iend) + 4;
+               ZSTD_storeSeq(seqStorePtr, ip0-anchor, anchor, iend, offcode, mLength - MINMATCH);
 
                /* match found */
-               ip += mLength;
-               anchor = ip;
+               ip0 += mLength;
+               anchor = ip0;
+               ip1 = ip0 + 1;
 
-               if (ip <= ilimit) {
+               if (ip0 <= ilimit) {
                        /* Fill Table */
-                       hashTable[ZSTD_hashPtr(base + curr + 2, hBits, mls)] = curr + 2; /* here because curr+2 could be > iend-8 */
-                       hashTable[ZSTD_hashPtr(ip - 2, hBits, mls)] = (U32)(ip - 2 - base);
+                       hashTable[ZSTD_hashPtr(base + current0 + 2, hBits, mls)] = current0 + 2; /* here because curr+2 could be > iend-8 */
+                       hashTable[ZSTD_hashPtr(ip0 - 2, hBits, mls)] = (U32)(ip0 - 2 - base);
                        /* check immediate repcode */
-                       while ((ip <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_2)))) {
+                       while ((ip0 <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip0) == ZSTD_read32(ip0 - offset_2)))) {
                                /* store sequence */
-                               size_t const rLength = ZSTD_count(ip + 4, ip + 4 - offset_2, iend) + 4;
+                               size_t const rLength = ZSTD_count(ip0 + 4, ip0 + 4 - offset_2, iend) + 4;
                                {
                                        U32 const tmpOff = offset_2;
                                        offset_2 = offset_1;
                                        offset_1 = tmpOff;
                                } /* swap offset_2 <=> offset_1 */
-                               hashTable[ZSTD_hashPtr(ip, hBits, mls)] = (U32)(ip - base);
-                               ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, rLength - MINMATCH);
-                               ip += rLength;
-                               anchor = ip;
+                               hashTable[ZSTD_hashPtr(ip0, hBits, mls)] = (U32)(ip0 - base);
+                               ip0 += rLength;
+                               ip1 = ip0 + 1;
+                               ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, rLength - MINMATCH);
+                               anchor = ip0;
                                continue; /* faster when present ... (?) */
                        }
                }
@@ -1150,7 +1215,7 @@ static void ZSTD_compressBlock_fast_extDict_generic(ZSTD_CCtx *ctx, const void *
                        const BYTE *repMatchEnd = repIndex < dictLimit ? dictEnd : iend;
                        mLength = ZSTD_count_2segments(ip + 1 + EQUAL_READ32, repMatch + EQUAL_READ32, iend, repMatchEnd, lowPrefixPtr) + EQUAL_READ32;
                        ip++;
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, 0, mLength - MINMATCH);
                } else {
                        if ((matchIndex < lowestIndex) || (ZSTD_read32(match) != ZSTD_read32(ip))) {
                                ip += ((ip - anchor) >> g_searchStrength) + 1;
@@ -1169,7 +1234,7 @@ static void ZSTD_compressBlock_fast_extDict_generic(ZSTD_CCtx *ctx, const void *
                                offset = curr - matchIndex;
                                offset_2 = offset_1;
                                offset_1 = offset;
-                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
+                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
                        }
                }
 
@@ -1194,7 +1259,7 @@ static void ZSTD_compressBlock_fast_extDict_generic(ZSTD_CCtx *ctx, const void *
                                        U32 tmpOffset = offset_2;
                                        offset_2 = offset_1;
                                        offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */
-                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, repLength2 - MINMATCH);
+                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, repLength2 - MINMATCH);
                                        hashTable[ZSTD_hashPtr(ip, hBits, mls)] = curr2;
                                        ip += repLength2;
                                        anchor = ip;
@@ -1294,7 +1359,7 @@ void ZSTD_compressBlock_doubleFast_generic(ZSTD_CCtx *cctx, const void *src, siz
                if ((offset_1 > 0) & (ZSTD_read32(ip + 1 - offset_1) == ZSTD_read32(ip + 1))) { /* note : by construction, offset_1 <= curr */
                        mLength = ZSTD_count(ip + 1 + 4, ip + 1 + 4 - offset_1, iend) + 4;
                        ip++;
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, 0, mLength - MINMATCH);
                } else {
                        U32 offset;
                        if ((matchIndexL > lowestIndex) && (ZSTD_read64(matchLong) == ZSTD_read64(ip))) {
@@ -1336,7 +1401,7 @@ void ZSTD_compressBlock_doubleFast_generic(ZSTD_CCtx *cctx, const void *src, siz
                        offset_2 = offset_1;
                        offset_1 = offset;
 
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
                }
 
                /* match found */
@@ -1345,10 +1410,14 @@ void ZSTD_compressBlock_doubleFast_generic(ZSTD_CCtx *cctx, const void *src, siz
 
                if (ip <= ilimit) {
                        /* Fill Table */
-                       hashLong[ZSTD_hashPtr(base + curr + 2, hBitsL, 8)] = hashSmall[ZSTD_hashPtr(base + curr + 2, hBitsS, mls)] =
-                           curr + 2; /* here because curr+2 could be > iend-8 */
-                       hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = hashSmall[ZSTD_hashPtr(ip - 2, hBitsS, mls)] = (U32)(ip - 2 - base);
+                       {
+                               U32 const insert_idx = curr + 2;
 
+                               hashLong[ZSTD_hashPtr(base + insert_idx, hBitsL, 8)] = insert_idx;
+                               hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = (U32)(ip - 2 - base);
+                               hashSmall[ZSTD_hashPtr(base + insert_idx, hBitsS, mls)] = insert_idx;
+                               hashSmall[ZSTD_hashPtr(ip - 1, hBitsS, mls)] = (U32)(ip - 1 - base);
+                       }
                        /* check immediate repcode */
                        while ((ip <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_2)))) {
                                /* store sequence */
@@ -1360,7 +1429,7 @@ void ZSTD_compressBlock_doubleFast_generic(ZSTD_CCtx *cctx, const void *src, siz
                                } /* swap offset_2 <=> offset_1 */
                                hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip - base);
                                hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip - base);
-                               ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, rLength - MINMATCH);
+                               ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, rLength - MINMATCH);
                                ip += rLength;
                                anchor = ip;
                                continue; /* faster when present ... (?) */
@@ -1437,7 +1506,7 @@ static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const
                        const BYTE *repMatchEnd = repIndex < dictLimit ? dictEnd : iend;
                        mLength = ZSTD_count_2segments(ip + 1 + 4, repMatch + 4, iend, repMatchEnd, lowPrefixPtr) + 4;
                        ip++;
-                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, 0, mLength - MINMATCH);
                } else {
                        if ((matchLongIndex > lowestIndex) && (ZSTD_read64(matchLong) == ZSTD_read64(ip))) {
                                const BYTE *matchEnd = matchLongIndex < dictLimit ? dictEnd : iend;
@@ -1452,7 +1521,7 @@ static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const
                                } /* catch up */
                                offset_2 = offset_1;
                                offset_1 = offset;
-                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
+                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
 
                        } else if ((matchIndex > lowestIndex) && (ZSTD_read32(match) == ZSTD_read32(ip))) {
                                size_t const h3 = ZSTD_hashPtr(ip + 1, hBitsL, 8);
@@ -1485,7 +1554,7 @@ static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const
                                }
                                offset_2 = offset_1;
                                offset_1 = offset;
-                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
+                               ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, iend, offset + ZSTD_REP_MOVE, mLength - MINMATCH);
 
                        } else {
                                ip += ((ip - anchor) >> g_searchStrength) + 1;
@@ -1499,10 +1568,14 @@ static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const
 
                if (ip <= ilimit) {
                        /* Fill Table */
-                       hashSmall[ZSTD_hashPtr(base + curr + 2, hBitsS, mls)] = curr + 2;
-                       hashLong[ZSTD_hashPtr(base + curr + 2, hBitsL, 8)] = curr + 2;
-                       hashSmall[ZSTD_hashPtr(ip - 2, hBitsS, mls)] = (U32)(ip - 2 - base);
-                       hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = (U32)(ip - 2 - base);
+                       {
+                               U32 const insert_idx = curr + 2;
+
+                               hashLong[ZSTD_hashPtr(base + insert_idx, hBitsL, 8)] = insert_idx;
+                               hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = (U32)(ip - 2 - base);
+                               hashSmall[ZSTD_hashPtr(base + insert_idx, hBitsS, mls)] = insert_idx;
+                               hashSmall[ZSTD_hashPtr(ip - 1, hBitsS, mls)] = (U32)(ip - 1 - base);
+                       }
                        /* check immediate repcode */
                        while (ip <= ilimit) {
                                U32 const curr2 = (U32)(ip - base);
@@ -1516,7 +1589,7 @@ static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const
                                        U32 tmpOffset = offset_2;
                                        offset_2 = offset_1;
                                        offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */
-                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, repLength2 - MINMATCH);
+                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, repLength2 - MINMATCH);
                                        hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = curr2;
                                        hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = curr2;
                                        ip += repLength2;
@@ -2016,7 +2089,7 @@ void ZSTD_compressBlock_lazy_generic(ZSTD_CCtx *ctx, const void *src, size_t src
 _storeSequence:
                {
                        size_t const litLength = start - anchor;
-                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, (U32)offset, matchLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, iend, (U32)offset, matchLength - MINMATCH);
                        anchor = ip = start + matchLength;
                }
 
@@ -2027,7 +2100,7 @@ _storeSequence:
                        offset = offset_2;
                        offset_2 = offset_1;
                        offset_1 = (U32)offset; /* swap repcodes */
-                       ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, matchLength - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, matchLength - MINMATCH);
                        ip += matchLength;
                        anchor = ip;
                        continue; /* faster when present ... (?) */
@@ -2210,7 +2283,7 @@ void ZSTD_compressBlock_lazy_extDict_generic(ZSTD_CCtx *ctx, const void *src, si
        /* store sequence */
        _storeSequence : {
                size_t const litLength = start - anchor;
-               ZSTD_storeSeq(seqStorePtr, litLength, anchor, (U32)offset, matchLength - MINMATCH);
+               ZSTD_storeSeq(seqStorePtr, litLength, anchor, iend, (U32)offset, matchLength - MINMATCH);
                anchor = ip = start + matchLength;
        }
 
@@ -2228,7 +2301,7 @@ void ZSTD_compressBlock_lazy_extDict_generic(ZSTD_CCtx *ctx, const void *src, si
                                        offset = offset_2;
                                        offset_2 = offset_1;
                                        offset_1 = (U32)offset; /* swap offset history */
-                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, matchLength - MINMATCH);
+                                       ZSTD_storeSeq(seqStorePtr, 0, anchor, iend, 0, matchLength - MINMATCH);
                                        ip += matchLength;
                                        anchor = ip;
                                        continue; /* faster when present ... (?) */
index 66cd487..70b5e42 100644 (file)
@@ -876,40 +876,113 @@ typedef struct {
        uPtrDiff gotoDict;
 } seqState_t;
 
+/*! ZSTD_overlapCopy8() :
+ *  Copies 8 bytes from ip to op and updates op and ip where ip <= op.
+ *  If the offset is < 8 then the offset is spread to at least 8 bytes.
+ *
+ *  Precondition: *ip <= *op
+ *  Postcondition: *op - *op >= 8
+ */
+static inline __attribute__((always_inline)) void ZSTD_overlapCopy8(BYTE **op, BYTE const **ip, size_t offset)
+{
+       if (offset < 8) {
+               /* close range match, overlap */
+               static const U32 dec32table[] = { 0, 1, 2, 1, 4, 4, 4, 4 };   /* added */
+               static const int dec64table[] = { 8, 8, 8, 7, 8, 9, 10, 11 };   /* subtracted */
+               int const sub2 = dec64table[offset];
+               (*op)[0] = (*ip)[0];
+               (*op)[1] = (*ip)[1];
+               (*op)[2] = (*ip)[2];
+               (*op)[3] = (*ip)[3];
+               *ip += dec32table[offset];
+               ZSTD_copy4(*op+4, *ip);
+               *ip -= sub2;
+       } else {
+               ZSTD_copy8(*op, *ip);
+       }
+       *ip += 8;
+       *op += 8;
+}
+
+/*! ZSTD_safecopy() :
+ *  Specialized version of memcpy() that is allowed to READ up to WILDCOPY_OVERLENGTH past the input buffer
+ *  and write up to 16 bytes past oend_w (op >= oend_w is allowed).
+ *  This function is only called in the uncommon case where the sequence is near the end of the block. It
+ *  should be fast for a single long sequence, but can be slow for several short sequences.
+ *
+ *  @param ovtype controls the overlap detection
+ *         - ZSTD_no_overlap: The source and destination are guaranteed to be at least WILDCOPY_VECLEN bytes apart.
+ *         - ZSTD_overlap_src_before_dst: The src and dst may overlap and may be any distance apart.
+ *           The src buffer must be before the dst buffer.
+ */
+static void ZSTD_safecopy(BYTE *op, BYTE *const oend_w, BYTE const *ip, ptrdiff_t length, enum ZSTD_overlap_e ovtype)
+{
+       ptrdiff_t const diff = op - ip;
+       BYTE *const oend = op + length;
+
+       if (length < 8) {
+               /* Handle short lengths. */
+               while (op < oend)
+                       *op++ = *ip++;
+               return;
+       }
+       if (ovtype == ZSTD_overlap_src_before_dst) {
+               /* Copy 8 bytes and ensure the offset >= 8 when there can be overlap. */
+               ZSTD_overlapCopy8(&op, &ip, diff);
+       }
+
+       if (oend <= oend_w) {
+               /* No risk of overwrite. */
+               ZSTD_wildcopy(op, ip, length, ovtype);
+               return;
+       }
+       if (op <= oend_w) {
+               /* Wildcopy until we get close to the end. */
+               ZSTD_wildcopy(op, ip, oend_w - op, ovtype);
+               ip += oend_w - op;
+               op = oend_w;
+       }
+       /* Handle the leftovers. */
+       while (op < oend)
+               *op++ = *ip++;
+}
+
+/* ZSTD_execSequenceEnd():
+ * This version handles cases that are near the end of the output buffer. It requires
+ * more careful checks to make sure there is no overflow. By separating out these hard
+ * and unlikely cases, we can speed up the common cases.
+ *
+ * NOTE: This function needs to be fast for a single long sequence, but doesn't need
+ * to be optimized for many small sequences, since those fall into ZSTD_execSequence().
+ */
 FORCE_NOINLINE
-size_t ZSTD_execSequenceLast7(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const base,
-                             const BYTE *const vBase, const BYTE *const dictEnd)
+size_t ZSTD_execSequenceEnd(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const prefixStart,
+                           const BYTE *const virtualStart, const BYTE *const dictEnd)
 {
        BYTE *const oLitEnd = op + sequence.litLength;
        size_t const sequenceLength = sequence.litLength + sequence.matchLength;
        BYTE *const oMatchEnd = op + sequenceLength; /* risk : address space overflow (32-bits) */
-       BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH;
        const BYTE *const iLitEnd = *litPtr + sequence.litLength;
        const BYTE *match = oLitEnd - sequence.offset;
+       BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH;
 
        /* check */
        if (oMatchEnd > oend)
                return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */
        if (iLitEnd > litLimit)
                return ERROR(corruption_detected); /* over-read beyond lit buffer */
-       if (oLitEnd <= oend_w)
-               return ERROR(GENERIC); /* Precondition */
 
        /* copy literals */
-       if (op < oend_w) {
-               ZSTD_wildcopy(op, *litPtr, oend_w - op);
-               *litPtr += oend_w - op;
-               op = oend_w;
-       }
-       while (op < oLitEnd)
-               *op++ = *(*litPtr)++;
+       ZSTD_safecopy(op, oend_w, *litPtr, sequence.litLength, ZSTD_no_overlap);
+       op = oLitEnd;
+       *litPtr = iLitEnd;
 
        /* copy Match */
-       if (sequence.offset > (size_t)(oLitEnd - base)) {
+       if (sequence.offset > (size_t)(oLitEnd - prefixStart)) {
                /* offset beyond prefix */
-               if (sequence.offset > (size_t)(oLitEnd - vBase))
+               if (sequence.offset > (size_t)(oLitEnd - virtualStart))
                        return ERROR(corruption_detected);
-               match = dictEnd - (base - match);
+               match = dictEnd - (prefixStart - match);
                if (match + sequence.matchLength <= dictEnd) {
                        memmove(oLitEnd, match, sequence.matchLength);
                        return sequenceLength;
@@ -920,15 +993,16 @@ size_t ZSTD_execSequenceLast7(BYTE *op, BYTE *const oend, seq_t sequence, const
                        memmove(oLitEnd, match, length1);
                        op = oLitEnd + length1;
                        sequence.matchLength -= length1;
-                       match = base;
+                       match = prefixStart;
                }
        }
-       while (op < oMatchEnd)
-               *op++ = *match++;
+       ZSTD_safecopy(op, oend_w, match, sequence.matchLength, ZSTD_overlap_src_before_dst);
        return sequenceLength;
 }
 
-static seq_t ZSTD_decodeSequence(seqState_t *seqState)
+enum ZSTD_prefetch_e { ZSTD_p_noPrefetch = 0, ZSTD_p_prefetch = 1 };
+
+static seq_t ZSTD_decodeSequence(seqState_t *seqState, int const longOffsets, const enum ZSTD_prefetch_e prefetch)
 {
        seq_t seq;
 
@@ -955,30 +1029,47 @@ static seq_t ZSTD_decodeSequence(seqState_t *seqState)
        /* sequence */
        {
                size_t offset;
-               if (!ofCode)
-                       offset = 0;
-               else {
-                       offset = OF_base[ofCode] + BIT_readBitsFast(&seqState->DStream, ofBits); /* <=  (ZSTD_WINDOWLOG_MAX-1) bits */
-                       if (ZSTD_32bits())
-                               BIT_reloadDStream(&seqState->DStream);
-               }
 
-               if (ofCode <= 1) {
-                       offset += (llCode == 0);
-                       if (offset) {
-                               size_t temp = (offset == 3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset];
-                               temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */
-                               if (offset != 1)
-                                       seqState->prevOffset[2] = seqState->prevOffset[1];
-                               seqState->prevOffset[1] = seqState->prevOffset[0];
-                               seqState->prevOffset[0] = offset = temp;
+               if (ofCode > 1) {
+                       if (longOffsets) {
+                               int const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN);
+                               offset = OF_base[ofCode] + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits);
+                               if (ZSTD_32bits() || extraBits)
+                                       BIT_reloadDStream(&seqState->DStream);
+                               if (extraBits)
+                                       offset += BIT_readBitsFast(&seqState->DStream, extraBits);
                        } else {
-                               offset = seqState->prevOffset[0];
+                               offset = OF_base[ofCode] + BIT_readBitsFast(&seqState->DStream, ofBits); /* <=  (ZSTD_WINDOWLOG_MAX-1) bits */
+                               if (ZSTD_32bits())
+                                       BIT_reloadDStream(&seqState->DStream);
                        }
-               } else {
+
                        seqState->prevOffset[2] = seqState->prevOffset[1];
                        seqState->prevOffset[1] = seqState->prevOffset[0];
                        seqState->prevOffset[0] = offset;
+
+               } else {
+                       U32 const ll0 = (llCode == 0);
+
+                       if (likely((ofCode == 0))) {
+                               if (likely(!ll0))
+                                       offset = seqState->prevOffset[0];
+                               else {
+                                       offset = seqState->prevOffset[1];
+                                       seqState->prevOffset[1] = seqState->prevOffset[0];
+                                       seqState->prevOffset[0] = offset;
+                               }
+                       } else {
+                               offset = OF_base[ofCode] + ll0 + BIT_readBitsFast(&seqState->DStream, 1);
+                               {
+                                       size_t temp = (offset == 3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset];
+                                       temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */
+                                       if (offset != 1)
+                                               seqState->prevOffset[2] = seqState->prevOffset[1];
+                                       seqState->prevOffset[1] = seqState->prevOffset[0];
+                                       seqState->prevOffset[0] = offset = temp;
+                               }
+                       }
                }
                seq.offset = offset;
        }
@@ -991,15 +1082,30 @@ static seq_t ZSTD_decodeSequence(seqState_t *seqState)
        if (ZSTD_32bits() || (totalBits > 64 - 7 - (LLFSELog + MLFSELog + OffFSELog)))
                BIT_reloadDStream(&seqState->DStream);
 
-       /* ANS state update */
+       seq.match = NULL;
+
+       if (prefetch == ZSTD_p_prefetch) {
+               size_t const pos = seqState->pos + seq.litLength;
+               const BYTE *const matchBase = (seq.offset > pos) ? seqState->gotoDict + seqState->base : seqState->base;
+               seq.match = matchBase + pos - seq.offset;  /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted.
+                                                           * No consequence though : no memory access will occur, offset is only used for prefetching */
+               seqState->pos = pos + seq.matchLength;
+       }
+
+       /* ANS state update
+        * gcc-9.0.0 does 2.5% worse with ZSTD_updateFseStateWithDInfo().
+        * clang-9.2.0 does 7% worse with ZSTD_updateFseState().
+        * Naturally it seems like ZSTD_updateFseStateWithDInfo() should be the
+        * better option, so it is the default for other compilers. But, if you
+        * measure that it is worse, please put up a pull request.
+        */
+
        FSE_updateState(&seqState->stateLL, &seqState->DStream); /* <=  9 bits */
        FSE_updateState(&seqState->stateML, &seqState->DStream); /* <=  9 bits */
        if (ZSTD_32bits())
                BIT_reloadDStream(&seqState->DStream);             /* <= 18 bits */
        FSE_updateState(&seqState->stateOffb, &seqState->DStream); /* <=  8 bits */
 
-       seq.match = NULL;
-
        return seq;
 }
 
@@ -1014,26 +1120,24 @@ size_t ZSTD_execSequence(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE
        const BYTE *const iLitEnd = *litPtr + sequence.litLength;
        const BYTE *match = oLitEnd - sequence.offset;
 
-       /* check */
-       if (oMatchEnd > oend)
-               return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */
-       if (iLitEnd > litLimit)
-               return ERROR(corruption_detected); /* over-read beyond lit buffer */
-       if (oLitEnd > oend_w)
-               return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, base, vBase, dictEnd);
-
-       /* copy Literals */
-       ZSTD_copy8(op, *litPtr);
-       if (sequence.litLength > 8)
-               ZSTD_wildcopy(op + 8, (*litPtr) + 8,
-                             sequence.litLength - 8); /* note : since oLitEnd <= oend-WILDCOPY_OVERLENGTH, no risk of overwrite beyond oend */
+       /* Errors and uncommon cases handled here. */
+       if (unlikely(iLitEnd > litLimit || oMatchEnd > oend_w))
+               return ZSTD_execSequenceEnd(op, oend, sequence, litPtr, litLimit, base, vBase, dictEnd);
+
+       /* Copy Literals:
+        * Split out litLength <= 16 since it is nearly always true. +1.6% on gcc-9.
+        * We likely don't need the full 32-byte wildcopy.
+        */
+       ZSTD_copy16(op, *litPtr);
+       if (unlikely(sequence.litLength > 16))
+               ZSTD_wildcopy(op + 16, (*litPtr) + 16, sequence.litLength - 16, ZSTD_no_overlap);
        op = oLitEnd;
-       *litPtr = iLitEnd; /* update for next sequence */
+       *litPtr = iLitEnd;   /* update for next sequence */
 
        /* copy Match */
        if (sequence.offset > (size_t)(oLitEnd - base)) {
                /* offset beyond prefix */
-               if (sequence.offset > (size_t)(oLitEnd - vBase))
+               if (unlikely(sequence.offset > (size_t)(oLitEnd - vBase)))
                        return ERROR(corruption_detected);
                match = dictEnd + (match - base);
                if (match + sequence.matchLength <= dictEnd) {
@@ -1047,45 +1151,27 @@ size_t ZSTD_execSequence(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE
                        op = oLitEnd + length1;
                        sequence.matchLength -= length1;
                        match = base;
-                       if (op > oend_w || sequence.matchLength < MINMATCH) {
-                               U32 i;
-                               for (i = 0; i < sequence.matchLength; ++i)
-                                       op[i] = match[i];
-                               return sequenceLength;
-                       }
                }
        }
-       /* Requirement: op <= oend_w && sequence.matchLength >= MINMATCH */
 
-       /* match within prefix */
-       if (sequence.offset < 8) {
-               /* close range match, overlap */
-               static const U32 dec32table[] = {0, 1, 2, 1, 4, 4, 4, 4};   /* added */
-               static const int dec64table[] = {8, 8, 8, 7, 8, 9, 10, 11}; /* subtracted */
-               int const sub2 = dec64table[sequence.offset];
-               op[0] = match[0];
-               op[1] = match[1];
-               op[2] = match[2];
-               op[3] = match[3];
-               match += dec32table[sequence.offset];
-               ZSTD_copy4(op + 4, match);
-               match -= sub2;
-       } else {
-               ZSTD_copy8(op, match);
+       /* Nearly all offsets are >= WILDCOPY_VECLEN bytes, which means we can use wildcopy
+        * without overlap checking.
+        */
+       if (likely(sequence.offset >= WILDCOPY_VECLEN)) {
+               /* We bet on a full wildcopy for matches, since we expect matches to be
+                * longer than literals (in general). In silesia, ~10% of matches are longer
+                * than 16 bytes.
+                */
+               ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength, ZSTD_no_overlap);
+               return sequenceLength;
        }
-       op += 8;
-       match += 8;
 
-       if (oMatchEnd > oend - (16 - MINMATCH)) {
-               if (op < oend_w) {
-                       ZSTD_wildcopy(op, match, oend_w - op);
-                       match += oend_w - op;
-                       op = oend_w;
-               }
-               while (op < oMatchEnd)
-                       *op++ = *match++;
-       } else {
-               ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength - 8); /* works even if matchLength < 8 */
+       /* Copy 8 bytes and spread the offset to be >= 8. */
+       ZSTD_overlapCopy8(&op, &match, sequence.offset);
+
+       /* If the match length is > 8 bytes, then continue with the wildcopy. */
+       if (sequence.matchLength > 8) {
+               ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength - 8, ZSTD_overlap_src_before_dst);
        }
        return sequenceLength;
 }
@@ -1094,7 +1180,7 @@ static size_t ZSTD_decompressSequences(ZSTD_DCtx *dctx, void *dst, size_t maxDst
 {
        const BYTE *ip = (const BYTE *)seqStart;
        const BYTE *const iend = ip + seqSize;
-       BYTE *const ostart = (BYTE * const)dst;
+       BYTE *const ostart = (BYTE *const)dst;
        BYTE *const oend = ostart + maxDstSize;
        BYTE *op = ostart;
        const BYTE *litPtr = dctx->litPtr;
@@ -1115,6 +1201,7 @@ static size_t ZSTD_decompressSequences(ZSTD_DCtx *dctx, void *dst, size_t maxDst
        /* Regen sequences */
        if (nbSeq) {
                seqState_t seqState;
+               size_t error = 0;
                dctx->fseEntropy = 1;
                {
                        U32 i;
@@ -1126,18 +1213,29 @@ static size_t ZSTD_decompressSequences(ZSTD_DCtx *dctx, void *dst, size_t maxDst
                FSE_initDState(&seqState.stateOffb, &seqState.DStream, dctx->OFTptr);
                FSE_initDState(&seqState.stateML, &seqState.DStream, dctx->MLTptr);
 
-               for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && nbSeq;) {
-                       nbSeq--;
-                       {
-                               seq_t const sequence = ZSTD_decodeSequence(&seqState);
-                               size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, base, vBase, dictEnd);
-                               if (ZSTD_isError(oneSeqSize))
-                                       return oneSeqSize;
+               for ( ; ; ) {
+                       seq_t const sequence = ZSTD_decodeSequence(&seqState, 0, ZSTD_p_noPrefetch);
+                       size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, base, vBase, dictEnd);
+
+                       BIT_reloadDStream(&(seqState.DStream));
+                       /* gcc and clang both don't like early returns in this loop.
+                        * gcc doesn't like early breaks either.
+                        * Instead save an error and report it at the end.
+                        * When there is an error, don't increment op, so we don't
+                        * overwrite.
+                        */
+                       if (unlikely(ZSTD_isError(oneSeqSize)))
+                               error = oneSeqSize;
+                       else
                                op += oneSeqSize;
-                       }
+
+                       if (unlikely(!--nbSeq))
+                               break;
                }
 
                /* check if reached exact end */
+               if (ZSTD_isError(error))
+                       return error;
                if (nbSeq)
                        return ERROR(corruption_detected);
                /* save reps for next block */
@@ -1160,196 +1258,20 @@ static size_t ZSTD_decompressSequences(ZSTD_DCtx *dctx, void *dst, size_t maxDst
        return op - ostart;
 }
 
-FORCE_INLINE seq_t ZSTD_decodeSequenceLong_generic(seqState_t *seqState, int const longOffsets)
-{
-       seq_t seq;
-
-       U32 const llCode = FSE_peekSymbol(&seqState->stateLL);
-       U32 const mlCode = FSE_peekSymbol(&seqState->stateML);
-       U32 const ofCode = FSE_peekSymbol(&seqState->stateOffb); /* <= maxOff, by table construction */
-
-       U32 const llBits = LL_bits[llCode];
-       U32 const mlBits = ML_bits[mlCode];
-       U32 const ofBits = ofCode;
-       U32 const totalBits = llBits + mlBits + ofBits;
-
-       static const U32 LL_base[MaxLL + 1] = {0,  1,  2,  3,  4,  5,  6,  7,  8,    9,     10,    11,    12,    13,     14,     15,     16,     18,
-                                              20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000};
-
-       static const U32 ML_base[MaxML + 1] = {3,  4,  5,  6,  7,  8,  9,  10,   11,    12,    13,    14,    15,     16,     17,     18,     19,     20,
-                                              21, 22, 23, 24, 25, 26, 27, 28,   29,    30,    31,    32,    33,     34,     35,     37,     39,     41,
-                                              43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003, 0x4003, 0x8003, 0x10003};
-
-       static const U32 OF_base[MaxOff + 1] = {0,       1,     1,      5,      0xD,      0x1D,      0x3D,      0x7D,      0xFD,     0x1FD,
-                                               0x3FD,   0x7FD,    0xFFD,    0x1FFD,   0x3FFD,   0x7FFD,    0xFFFD,    0x1FFFD,   0x3FFFD,  0x7FFFD,
-                                               0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD};
-
-       /* sequence */
-       {
-               size_t offset;
-               if (!ofCode)
-                       offset = 0;
-               else {
-                       if (longOffsets) {
-                               int const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN);
-                               offset = OF_base[ofCode] + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits);
-                               if (ZSTD_32bits() || extraBits)
-                                       BIT_reloadDStream(&seqState->DStream);
-                               if (extraBits)
-                                       offset += BIT_readBitsFast(&seqState->DStream, extraBits);
-                       } else {
-                               offset = OF_base[ofCode] + BIT_readBitsFast(&seqState->DStream, ofBits); /* <=  (ZSTD_WINDOWLOG_MAX-1) bits */
-                               if (ZSTD_32bits())
-                                       BIT_reloadDStream(&seqState->DStream);
-                       }
-               }
-
-               if (ofCode <= 1) {
-                       offset += (llCode == 0);
-                       if (offset) {
-                               size_t temp = (offset == 3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset];
-                               temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */
-                               if (offset != 1)
-                                       seqState->prevOffset[2] = seqState->prevOffset[1];
-                               seqState->prevOffset[1] = seqState->prevOffset[0];
-                               seqState->prevOffset[0] = offset = temp;
-                       } else {
-                               offset = seqState->prevOffset[0];
-                       }
-               } else {
-                       seqState->prevOffset[2] = seqState->prevOffset[1];
-                       seqState->prevOffset[1] = seqState->prevOffset[0];
-                       seqState->prevOffset[0] = offset;
-               }
-               seq.offset = offset;
-       }
-
-       seq.matchLength = ML_base[mlCode] + ((mlCode > 31) ? BIT_readBitsFast(&seqState->DStream, mlBits) : 0); /* <=  16 bits */
-       if (ZSTD_32bits() && (mlBits + llBits > 24))
-               BIT_reloadDStream(&seqState->DStream);
-
-       seq.litLength = LL_base[llCode] + ((llCode > 15) ? BIT_readBitsFast(&seqState->DStream, llBits) : 0); /* <=  16 bits */
-       if (ZSTD_32bits() || (totalBits > 64 - 7 - (LLFSELog + MLFSELog + OffFSELog)))
-               BIT_reloadDStream(&seqState->DStream);
-
-       {
-               size_t const pos = seqState->pos + seq.litLength;
-               seq.match = seqState->base + pos - seq.offset; /* single memory segment */
-               if (seq.offset > pos)
-                       seq.match += seqState->gotoDict; /* separate memory segment */
-               seqState->pos = pos + seq.matchLength;
-       }
-
-       /* ANS state update */
-       FSE_updateState(&seqState->stateLL, &seqState->DStream); /* <=  9 bits */
-       FSE_updateState(&seqState->stateML, &seqState->DStream); /* <=  9 bits */
-       if (ZSTD_32bits())
-               BIT_reloadDStream(&seqState->DStream);             /* <= 18 bits */
-       FSE_updateState(&seqState->stateOffb, &seqState->DStream); /* <=  8 bits */
-
-       return seq;
-}
-
-static seq_t ZSTD_decodeSequenceLong(seqState_t *seqState, unsigned const windowSize)
+static seq_t ZSTD_decodeSequenceLong(seqState_t *seqState, unsigned const windowSize, enum ZSTD_prefetch_e const prefetch)
 {
        if (ZSTD_highbit32(windowSize) > STREAM_ACCUMULATOR_MIN) {
-               return ZSTD_decodeSequenceLong_generic(seqState, 1);
+               return ZSTD_decodeSequence(seqState, 1, prefetch);
        } else {
-               return ZSTD_decodeSequenceLong_generic(seqState, 0);
+               return ZSTD_decodeSequence(seqState, 0, prefetch);
        }
 }
 
-FORCE_INLINE
-size_t ZSTD_execSequenceLong(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const base,
-                            const BYTE *const vBase, const BYTE *const dictEnd)
-{
-       BYTE *const oLitEnd = op + sequence.litLength;
-       size_t const sequenceLength = sequence.litLength + sequence.matchLength;
-       BYTE *const oMatchEnd = op + sequenceLength; /* risk : address space overflow (32-bits) */
-       BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH;
-       const BYTE *const iLitEnd = *litPtr + sequence.litLength;
-       const BYTE *match = sequence.match;
-
-       /* check */
-       if (oMatchEnd > oend)
-               return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */
-       if (iLitEnd > litLimit)
-               return ERROR(corruption_detected); /* over-read beyond lit buffer */
-       if (oLitEnd > oend_w)
-               return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, base, vBase, dictEnd);
-
-       /* copy Literals */
-       ZSTD_copy8(op, *litPtr);
-       if (sequence.litLength > 8)
-               ZSTD_wildcopy(op + 8, (*litPtr) + 8,
-                             sequence.litLength - 8); /* note : since oLitEnd <= oend-WILDCOPY_OVERLENGTH, no risk of overwrite beyond oend */
-       op = oLitEnd;
-       *litPtr = iLitEnd; /* update for next sequence */
-
-       /* copy Match */
-       if (sequence.offset > (size_t)(oLitEnd - base)) {
-               /* offset beyond prefix */
-               if (sequence.offset > (size_t)(oLitEnd - vBase))
-                       return ERROR(corruption_detected);
-               if (match + sequence.matchLength <= dictEnd) {
-                       memmove(oLitEnd, match, sequence.matchLength);
-                       return sequenceLength;
-               }
-               /* span extDict & currPrefixSegment */
-               {
-                       size_t const length1 = dictEnd - match;
-                       memmove(oLitEnd, match, length1);
-                       op = oLitEnd + length1;
-                       sequence.matchLength -= length1;
-                       match = base;
-                       if (op > oend_w || sequence.matchLength < MINMATCH) {
-                               U32 i;
-                               for (i = 0; i < sequence.matchLength; ++i)
-                                       op[i] = match[i];
-                               return sequenceLength;
-                       }
-               }
-       }
-       /* Requirement: op <= oend_w && sequence.matchLength >= MINMATCH */
-
-       /* match within prefix */
-       if (sequence.offset < 8) {
-               /* close range match, overlap */
-               static const U32 dec32table[] = {0, 1, 2, 1, 4, 4, 4, 4};   /* added */
-               static const int dec64table[] = {8, 8, 8, 7, 8, 9, 10, 11}; /* subtracted */
-               int const sub2 = dec64table[sequence.offset];
-               op[0] = match[0];
-               op[1] = match[1];
-               op[2] = match[2];
-               op[3] = match[3];
-               match += dec32table[sequence.offset];
-               ZSTD_copy4(op + 4, match);
-               match -= sub2;
-       } else {
-               ZSTD_copy8(op, match);
-       }
-       op += 8;
-       match += 8;
-
-       if (oMatchEnd > oend - (16 - MINMATCH)) {
-               if (op < oend_w) {
-                       ZSTD_wildcopy(op, match, oend_w - op);
-                       match += oend_w - op;
-                       op = oend_w;
-               }
-               while (op < oMatchEnd)
-                       *op++ = *match++;
-       } else {
-               ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength - 8); /* works even if matchLength < 8 */
-       }
-       return sequenceLength;
-}
-
 static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx *dctx, void *dst, size_t maxDstSize, const void *seqStart, size_t seqSize)
 {
        const BYTE *ip = (const BYTE *)seqStart;
        const BYTE *const iend = ip + seqSize;
-       BYTE *const ostart = (BYTE * const)dst;
+       BYTE *const ostart = (BYTE *const)dst;
        BYTE *const oend = ostart + maxDstSize;
        BYTE *op = ostart;
        const BYTE *litPtr = dctx->litPtr;
@@ -1394,16 +1316,18 @@ static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx *dctx, void *dst, size_t ma
 
                /* prepare in advance */
                for (seqNb = 0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && seqNb < seqAdvance; seqNb++) {
-                       sequences[seqNb] = ZSTD_decodeSequenceLong(&seqState, windowSize);
+                       sequences[seqNb] = ZSTD_decodeSequenceLong(&seqState, windowSize, ZSTD_p_prefetch);
+                       ZSTD_PREFETCH(sequences[seqNb].match);
+                       ZSTD_PREFETCH(sequences[seqNb].match + sequences[seqNb].matchLength - 1);
                }
                if (seqNb < seqAdvance)
                        return ERROR(corruption_detected);
 
                /* decode and decompress */
                for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && seqNb < nbSeq; seqNb++) {
-                       seq_t const sequence = ZSTD_decodeSequenceLong(&seqState, windowSize);
+                       seq_t const sequence = ZSTD_decodeSequenceLong(&seqState, windowSize, ZSTD_p_prefetch);
                        size_t const oneSeqSize =
-                           ZSTD_execSequenceLong(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd);
+                           ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd);
                        if (ZSTD_isError(oneSeqSize))
                                return oneSeqSize;
                        ZSTD_PREFETCH(sequence.match);
@@ -1416,7 +1340,7 @@ static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx *dctx, void *dst, size_t ma
                /* finish queue */
                seqNb -= seqAdvance;
                for (; seqNb < nbSeq; seqNb++) {
-                       size_t const oneSeqSize = ZSTD_execSequenceLong(op, oend, sequences[seqNb & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd);
+                       size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[seqNb & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd);
                        if (ZSTD_isError(oneSeqSize))
                                return oneSeqSize;
                        op += oneSeqSize;
@@ -1566,7 +1490,7 @@ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
 static size_t ZSTD_decompressFrame(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void **srcPtr, size_t *srcSizePtr)
 {
        const BYTE *ip = (const BYTE *)(*srcPtr);
-       BYTE *const ostart = (BYTE * const)dst;
+       BYTE *const ostart = (BYTE *const)dst;
        BYTE *const oend = ostart + dstCapacity;
        BYTE *op = ostart;
        size_t remainingSize = *srcSizePtr;
index 6526482..6753153 100644 (file)
@@ -254,6 +254,7 @@ static size_t HUF_decompress4X2_usingDTable_internal(void *dst, size_t dstSize,
                const BYTE *const istart = (const BYTE *)cSrc;
                BYTE *const ostart = (BYTE *)dst;
                BYTE *const oend = ostart + dstSize;
+               BYTE *const olimit = oend - 3;
                const void *const dtPtr = DTable + 1;
                const HUF_DEltX2 *const dt = (const HUF_DEltX2 *)dtPtr;
 
@@ -278,7 +279,7 @@ static size_t HUF_decompress4X2_usingDTable_internal(void *dst, size_t dstSize,
                BYTE *op2 = opStart2;
                BYTE *op3 = opStart3;
                BYTE *op4 = opStart4;
-               U32 endSignal;
+               U32 endSignal = 1;
                DTableDesc const dtd = HUF_getDTableDesc(DTable);
                U32 const dtLog = dtd.tableLog;
 
@@ -306,8 +307,7 @@ static size_t HUF_decompress4X2_usingDTable_internal(void *dst, size_t dstSize,
                }
 
                /* 16-32 symbols per loop (4-8 symbols per stream) */
-               endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
-               for (; (endSignal == BIT_DStream_unfinished) && (op4 < (oend - 7));) {
+               for ( ; (endSignal) & (op4 < olimit); ) {
                        HUF_DECODE_SYMBOLX2_2(op1, &bitD1);
                        HUF_DECODE_SYMBOLX2_2(op2, &bitD2);
                        HUF_DECODE_SYMBOLX2_2(op3, &bitD3);
@@ -324,7 +324,10 @@ static size_t HUF_decompress4X2_usingDTable_internal(void *dst, size_t dstSize,
                        HUF_DECODE_SYMBOLX2_0(op2, &bitD2);
                        HUF_DECODE_SYMBOLX2_0(op3, &bitD3);
                        HUF_DECODE_SYMBOLX2_0(op4, &bitD4);
-                       endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
+                       endSignal &= BIT_reloadDStreamFast(&bitD1) == BIT_DStream_unfinished;
+                       endSignal &= BIT_reloadDStreamFast(&bitD2) == BIT_DStream_unfinished;
+                       endSignal &= BIT_reloadDStreamFast(&bitD3) == BIT_DStream_unfinished;
+                       endSignal &= BIT_reloadDStreamFast(&bitD4) == BIT_DStream_unfinished;
                }
 
                /* check corruption */
@@ -713,6 +716,7 @@ static size_t HUF_decompress4X4_usingDTable_internal(void *dst, size_t dstSize,
                const BYTE *const istart = (const BYTE *)cSrc;
                BYTE *const ostart = (BYTE *)dst;
                BYTE *const oend = ostart + dstSize;
+               BYTE *const olimit = oend - (sizeof(size_t) - 1);
                const void *const dtPtr = DTable + 1;
                const HUF_DEltX4 *const dt = (const HUF_DEltX4 *)dtPtr;
 
@@ -737,7 +741,7 @@ static size_t HUF_decompress4X4_usingDTable_internal(void *dst, size_t dstSize,
                BYTE *op2 = opStart2;
                BYTE *op3 = opStart3;
                BYTE *op4 = opStart4;
-               U32 endSignal;
+               U32 endSignal = 1;
                DTableDesc const dtd = HUF_getDTableDesc(DTable);
                U32 const dtLog = dtd.tableLog;
 
@@ -765,8 +769,7 @@ static size_t HUF_decompress4X4_usingDTable_internal(void *dst, size_t dstSize,
                }
 
                /* 16-32 symbols per loop (4-8 symbols per stream) */
-               endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
-               for (; (endSignal == BIT_DStream_unfinished) & (op4 < (oend - (sizeof(bitD4.bitContainer) - 1)));) {
+               for ( ; (endSignal) & (op4 < olimit); ) {
                        HUF_DECODE_SYMBOLX4_2(op1, &bitD1);
                        HUF_DECODE_SYMBOLX4_2(op2, &bitD2);
                        HUF_DECODE_SYMBOLX4_2(op3, &bitD3);
@@ -783,8 +786,11 @@ static size_t HUF_decompress4X4_usingDTable_internal(void *dst, size_t dstSize,
                        HUF_DECODE_SYMBOLX4_0(op2, &bitD2);
                        HUF_DECODE_SYMBOLX4_0(op3, &bitD3);
                        HUF_DECODE_SYMBOLX4_0(op4, &bitD4);
-
-                       endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
+                       endSignal = likely(
+                                       (BIT_reloadDStreamFast(&bitD1) == BIT_DStream_unfinished)
+                                       & (BIT_reloadDStreamFast(&bitD2) == BIT_DStream_unfinished)
+                                       & (BIT_reloadDStreamFast(&bitD3) == BIT_DStream_unfinished)
+                                       & (BIT_reloadDStreamFast(&bitD4) == BIT_DStream_unfinished));
                }
 
                /* check corruption */
index dac7533..53950d2 100644 (file)
@@ -126,34 +126,77 @@ static const U32 OF_defaultNormLog = OF_DEFAULTNORMLOG;
 /*-*******************************************
 *  Shared functions to include for inlining
 *********************************************/
-ZSTD_STATIC void ZSTD_copy8(void *dst, const void *src) {
-       /*
-        * zstd relies heavily on gcc being able to analyze and inline this
-        * memcpy() call, since it is called in a tight loop. Preboot mode
-        * is compiled in freestanding mode, which stops gcc from analyzing
-        * memcpy(). Use __builtin_memcpy() to tell gcc to analyze this as a
-        * regular memcpy().
-        */
-       __builtin_memcpy(dst, src, 8);
+
+FORCE_INLINE void ZSTD_copy8(void *dst, const void *src)
+{
+       memcpy(dst, src, 8);
+}
+
+FORCE_INLINE void ZSTD_copy16(void *dst, const void *src)
+{
+       memcpy(dst, src, 16);
 }
+
+enum ZSTD_overlap_e {
+       ZSTD_no_overlap,
+       ZSTD_overlap_src_before_dst,
+       /*  ZSTD_overlap_dst_before_src, */
+};
+
+#define WILDCOPY_OVERLENGTH 32
+#define WILDCOPY_VECLEN 16
+
 /*! ZSTD_wildcopy() :
-*   custom version of memcpy(), can copy up to 7 bytes too many (8 bytes if length==0) */
-#define WILDCOPY_OVERLENGTH 8
-ZSTD_STATIC void ZSTD_wildcopy(void *dst, const void *src, ptrdiff_t length)
+ *  Custom version of memcpy(), can over read/write up to WILDCOPY_OVERLENGTH bytes (if length==0)
+ *  @param ovtype controls the overlap detection
+ *         - ZSTD_no_overlap: The source and destination are guaranteed to be at least WILDCOPY_VECLEN bytes apart.
+ *         - ZSTD_overlap_src_before_dst: The src and dst may overlap, but they MUST be at least 8 bytes apart.
+ *           The src buffer must be before the dst buffer.
+ */
+FORCE_INLINE void ZSTD_wildcopy(void *dst, const void *src, ptrdiff_t length, enum ZSTD_overlap_e const ovtype)
 {
-       const BYTE* ip = (const BYTE*)src;
-       BYTE* op = (BYTE*)dst;
-       BYTE* const oend = op + length;
-#if defined(GCC_VERSION) && GCC_VERSION >= 70000 && GCC_VERSION < 70200
-       /*
-        * Work around https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81388.
-        * Avoid the bad case where the loop only runs once by handling the
-        * special case separately. This doesn't trigger the bug because it
-        * doesn't involve pointer/integer overflow.
-        */
-       if (length <= 8)
-               return ZSTD_copy8(dst, src);
-#endif
+       ptrdiff_t diff = (BYTE *)dst - (const BYTE *)src;
+       const BYTE *ip = (const BYTE *)src;
+       BYTE *op = (BYTE *)dst;
+       BYTE *const oend = op + length;
+
+       if (ovtype == ZSTD_overlap_src_before_dst && diff < WILDCOPY_VECLEN) {
+               /* Handle short offset copies */
+               do {
+                       ZSTD_copy8(op, ip);
+                       op += 8;
+                       ip += 8;
+               } while (op < oend);
+       } else {
+               ZSTD_copy16(op, ip);
+               op += 16;
+               ip += 16;
+               ZSTD_copy16(op, ip);
+               op += 16;
+               ip += 16;
+               if (op >= oend)
+                       return;
+               do {
+                       ZSTD_copy16(op, ip);
+                       op += 16;
+                       ip += 16;
+                       ZSTD_copy16(op, ip);
+                       op += 16;
+                       ip += 16;
+               } while (op < oend);
+       }
+}
+
+/*! ZSTD_wildcopy8() :
+ *  The same as ZSTD_wildcopy(), but it can only overwrite 8 bytes, and works for
+ *  overlapping buffers that are at least 8 bytes apart.
+ */
+ZSTD_STATIC void ZSTD_wildcopy8(void *dst, const void *src, ptrdiff_t length)
+{
+       const BYTE *ip = (const BYTE *)src;
+       BYTE *op = (BYTE *)dst;
+       BYTE *const oend = (BYTE *)op + length;
+
        do {
                ZSTD_copy8(op, ip);
                op += 8;
@@ -253,7 +296,7 @@ void ZSTD_stackFree(void *opaque, void *address);
 
 /*======  common function  ======*/
 
-ZSTD_STATIC U32 ZSTD_highbit32(U32 val) { return 31 - __builtin_clz(val); }
+ZSTD_STATIC U32 ZSTD_highbit32(U32 val) { return __builtin_clz(val) ^ 31; }
 
 /* hidden functions */
 
index 55e1b4c..d731c42 100644 (file)
@@ -676,7 +676,7 @@ _storeSequence: /* cur, last_pos, best_mlen, best_off have to be set */
                        }
 
                        ZSTD_updatePrice(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH);
-                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, iend, offset, mlen - MINMATCH);
                        anchor = ip = ip + mlen;
                }
        } /* for (cur=0; cur < last_pos; ) */
@@ -991,7 +991,7 @@ _storeSequence: /* cur, last_pos, best_mlen, best_off have to be set */
                        }
 
                        ZSTD_updatePrice(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH);
-                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH);
+                       ZSTD_storeSeq(seqStorePtr, litLength, anchor, iend, offset, mlen - MINMATCH);
                        anchor = ip = ip + mlen;
                }
        } /* for (cur=0; cur < last_pos; ) */