1 | /* ****************************************************************** |
2 | * Common functions of New Generation Entropy library |
3 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
4 | * |
5 | * You can contact the author at : |
6 | * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy |
7 | * - Public forum : https://groups.google.com/forum/#!forum/lz4c |
8 | * |
9 | * This source code is licensed under both the BSD-style license (found in the |
10 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
11 | * in the COPYING file in the root directory of this source tree). |
12 | * You may select, at your option, one of the above-listed licenses. |
13 | ****************************************************************** */ |
14 | |
15 | /* ************************************* |
16 | * Dependencies |
17 | ***************************************/ |
18 | #include "mem.h" |
19 | #include "error_private.h" /* ERR_*, ERROR */ |
20 | #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ |
21 | #include "fse.h" |
22 | #include "huf.h" |
23 | #include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ |
24 | |
25 | |
26 | /*=== Version ===*/ |
27 | unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } |
28 | |
29 | |
30 | /*=== Error Management ===*/ |
31 | unsigned FSE_isError(size_t code) { return ERR_isError(code); } |
32 | const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } |
33 | |
34 | unsigned HUF_isError(size_t code) { return ERR_isError(code); } |
35 | const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } |
36 | |
37 | |
38 | /*-************************************************************** |
39 | * FSE NCount encoding-decoding |
40 | ****************************************************************/ |
41 | FORCE_INLINE_TEMPLATE |
42 | size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
43 | const void* , size_t hbSize) |
44 | { |
45 | const BYTE* const istart = (const BYTE*) headerBuffer; |
46 | const BYTE* const iend = istart + hbSize; |
47 | const BYTE* ip = istart; |
48 | int nbBits; |
49 | int remaining; |
50 | int threshold; |
51 | U32 bitStream; |
52 | int bitCount; |
53 | unsigned charnum = 0; |
54 | unsigned const maxSV1 = *maxSVPtr + 1; |
55 | int previous0 = 0; |
56 | |
57 | if (hbSize < 8) { |
58 | /* This function only works when hbSize >= 8 */ |
59 | char buffer[8] = {0}; |
60 | ZSTD_memcpy(buffer, headerBuffer, hbSize); |
61 | { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, |
62 | buffer, sizeof(buffer)); |
63 | if (FSE_isError(countSize)) return countSize; |
64 | if (countSize > hbSize) return ERROR(corruption_detected); |
65 | return countSize; |
66 | } } |
67 | assert(hbSize >= 8); |
68 | |
69 | /* init */ |
70 | ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ |
71 | bitStream = MEM_readLE32(ip); |
72 | nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ |
73 | if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); |
74 | bitStream >>= 4; |
75 | bitCount = 4; |
76 | *tableLogPtr = nbBits; |
77 | remaining = (1<<nbBits)+1; |
78 | threshold = 1<<nbBits; |
79 | nbBits++; |
80 | |
81 | for (;;) { |
82 | if (previous0) { |
83 | /* Count the number of repeats. Each time the |
84 | * 2-bit repeat code is 0b11 there is another |
85 | * repeat. |
86 | * Avoid UB by setting the high bit to 1. |
87 | */ |
88 | int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
89 | while (repeats >= 12) { |
90 | charnum += 3 * 12; |
91 | if (LIKELY(ip <= iend-7)) { |
92 | ip += 3; |
93 | } else { |
94 | bitCount -= (int)(8 * (iend - 7 - ip)); |
95 | bitCount &= 31; |
96 | ip = iend - 4; |
97 | } |
98 | bitStream = MEM_readLE32(ip) >> bitCount; |
99 | repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
100 | } |
101 | charnum += 3 * repeats; |
102 | bitStream >>= 2 * repeats; |
103 | bitCount += 2 * repeats; |
104 | |
105 | /* Add the final repeat which isn't 0b11. */ |
106 | assert((bitStream & 3) < 3); |
107 | charnum += bitStream & 3; |
108 | bitCount += 2; |
109 | |
110 | /* This is an error, but break and return an error |
111 | * at the end, because returning out of a loop makes |
112 | * it harder for the compiler to optimize. |
113 | */ |
114 | if (charnum >= maxSV1) break; |
115 | |
116 | /* We don't need to set the normalized count to 0 |
117 | * because we already memset the whole buffer to 0. |
118 | */ |
119 | |
120 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
121 | assert((bitCount >> 3) <= 3); /* For first condition to work */ |
122 | ip += bitCount>>3; |
123 | bitCount &= 7; |
124 | } else { |
125 | bitCount -= (int)(8 * (iend - 4 - ip)); |
126 | bitCount &= 31; |
127 | ip = iend - 4; |
128 | } |
129 | bitStream = MEM_readLE32(ip) >> bitCount; |
130 | } |
131 | { |
132 | int const max = (2*threshold-1) - remaining; |
133 | int count; |
134 | |
135 | if ((bitStream & (threshold-1)) < (U32)max) { |
136 | count = bitStream & (threshold-1); |
137 | bitCount += nbBits-1; |
138 | } else { |
139 | count = bitStream & (2*threshold-1); |
140 | if (count >= threshold) count -= max; |
141 | bitCount += nbBits; |
142 | } |
143 | |
144 | count--; /* extra accuracy */ |
145 | /* When it matters (small blocks), this is a |
146 | * predictable branch, because we don't use -1. |
147 | */ |
148 | if (count >= 0) { |
149 | remaining -= count; |
150 | } else { |
151 | assert(count == -1); |
152 | remaining += count; |
153 | } |
154 | normalizedCounter[charnum++] = (short)count; |
155 | previous0 = !count; |
156 | |
157 | assert(threshold > 1); |
158 | if (remaining < threshold) { |
159 | /* This branch can be folded into the |
160 | * threshold update condition because we |
161 | * know that threshold > 1. |
162 | */ |
163 | if (remaining <= 1) break; |
164 | nbBits = ZSTD_highbit32(remaining) + 1; |
165 | threshold = 1 << (nbBits - 1); |
166 | } |
167 | if (charnum >= maxSV1) break; |
168 | |
169 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
170 | ip += bitCount>>3; |
171 | bitCount &= 7; |
172 | } else { |
173 | bitCount -= (int)(8 * (iend - 4 - ip)); |
174 | bitCount &= 31; |
175 | ip = iend - 4; |
176 | } |
177 | bitStream = MEM_readLE32(ip) >> bitCount; |
178 | } } |
179 | if (remaining != 1) return ERROR(corruption_detected); |
180 | /* Only possible when there are too many zeros. */ |
181 | if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); |
182 | if (bitCount > 32) return ERROR(corruption_detected); |
183 | *maxSVPtr = charnum-1; |
184 | |
185 | ip += (bitCount+7)>>3; |
186 | return ip-istart; |
187 | } |
188 | |
189 | /* Avoids the FORCE_INLINE of the _body() function. */ |
190 | static size_t FSE_readNCount_body_default( |
191 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
192 | const void* , size_t hbSize) |
193 | { |
194 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
195 | } |
196 | |
197 | #if DYNAMIC_BMI2 |
198 | BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( |
199 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
200 | const void* , size_t hbSize) |
201 | { |
202 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
203 | } |
204 | #endif |
205 | |
206 | size_t FSE_readNCount_bmi2( |
207 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
208 | const void* , size_t hbSize, int bmi2) |
209 | { |
210 | #if DYNAMIC_BMI2 |
211 | if (bmi2) { |
212 | return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
213 | } |
214 | #endif |
215 | (void)bmi2; |
216 | return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
217 | } |
218 | |
219 | size_t FSE_readNCount( |
220 | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
221 | const void* , size_t hbSize) |
222 | { |
223 | return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); |
224 | } |
225 | |
226 | |
227 | /*! HUF_readStats() : |
228 | Read compact Huffman tree, saved by HUF_writeCTable(). |
229 | `huffWeight` is destination buffer. |
230 | `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. |
231 | @return : size read from `src` , or an error Code . |
232 | Note : Needed by HUF_readCTable() and HUF_readDTableX?() . |
233 | */ |
234 | size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
235 | U32* nbSymbolsPtr, U32* tableLogPtr, |
236 | const void* src, size_t srcSize) |
237 | { |
238 | U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; |
239 | return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); |
240 | } |
241 | |
242 | FORCE_INLINE_TEMPLATE size_t |
243 | HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
244 | U32* nbSymbolsPtr, U32* tableLogPtr, |
245 | const void* src, size_t srcSize, |
246 | void* workSpace, size_t wkspSize, |
247 | int bmi2) |
248 | { |
249 | U32 weightTotal; |
250 | const BYTE* ip = (const BYTE*) src; |
251 | size_t iSize; |
252 | size_t oSize; |
253 | |
254 | if (!srcSize) return ERROR(srcSize_wrong); |
255 | iSize = ip[0]; |
256 | /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ |
257 | |
258 | if (iSize >= 128) { /* special header */ |
259 | oSize = iSize - 127; |
260 | iSize = ((oSize+1)/2); |
261 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
262 | if (oSize >= hwSize) return ERROR(corruption_detected); |
263 | ip += 1; |
264 | { U32 n; |
265 | for (n=0; n<oSize; n+=2) { |
266 | huffWeight[n] = ip[n/2] >> 4; |
267 | huffWeight[n+1] = ip[n/2] & 15; |
268 | } } } |
269 | else { /* header compressed with FSE (normal case) */ |
270 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
271 | /* max (hwSize-1) values decoded, as last one is implied */ |
272 | oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); |
273 | if (FSE_isError(oSize)) return oSize; |
274 | } |
275 | |
276 | /* collect weight stats */ |
277 | ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); |
278 | weightTotal = 0; |
279 | { U32 n; for (n=0; n<oSize; n++) { |
280 | if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
281 | rankStats[huffWeight[n]]++; |
282 | weightTotal += (1 << huffWeight[n]) >> 1; |
283 | } } |
284 | if (weightTotal == 0) return ERROR(corruption_detected); |
285 | |
286 | /* get last non-null symbol weight (implied, total must be 2^n) */ |
287 | { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; |
288 | if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
289 | *tableLogPtr = tableLog; |
290 | /* determine last weight */ |
291 | { U32 const total = 1 << tableLog; |
292 | U32 const rest = total - weightTotal; |
293 | U32 const verif = 1 << ZSTD_highbit32(rest); |
294 | U32 const lastWeight = ZSTD_highbit32(rest) + 1; |
295 | if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ |
296 | huffWeight[oSize] = (BYTE)lastWeight; |
297 | rankStats[lastWeight]++; |
298 | } } |
299 | |
300 | /* check tree construction validity */ |
301 | if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ |
302 | |
303 | /* results */ |
304 | *nbSymbolsPtr = (U32)(oSize+1); |
305 | return iSize+1; |
306 | } |
307 | |
308 | /* Avoids the FORCE_INLINE of the _body() function. */ |
309 | static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
310 | U32* nbSymbolsPtr, U32* tableLogPtr, |
311 | const void* src, size_t srcSize, |
312 | void* workSpace, size_t wkspSize) |
313 | { |
314 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); |
315 | } |
316 | |
317 | #if DYNAMIC_BMI2 |
318 | static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
319 | U32* nbSymbolsPtr, U32* tableLogPtr, |
320 | const void* src, size_t srcSize, |
321 | void* workSpace, size_t wkspSize) |
322 | { |
323 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); |
324 | } |
325 | #endif |
326 | |
327 | size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
328 | U32* nbSymbolsPtr, U32* tableLogPtr, |
329 | const void* src, size_t srcSize, |
330 | void* workSpace, size_t wkspSize, |
331 | int flags) |
332 | { |
333 | #if DYNAMIC_BMI2 |
334 | if (flags & HUF_flags_bmi2) { |
335 | return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
336 | } |
337 | #endif |
338 | (void)flags; |
339 | return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
340 | } |
341 | |