1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | // From Apache Impala (incubating) as of 2016-01-29 |
19 | |
20 | #include <cstdint> |
21 | #include <cstring> |
22 | #include <random> |
23 | #include <vector> |
24 | |
25 | #include <gtest/gtest.h> |
26 | |
27 | #include <boost/utility.hpp> // IWYU pragma: export |
28 | |
29 | #include "arrow/util/bit-stream-utils.h" |
30 | #include "arrow/util/bit-util.h" |
31 | #include "arrow/util/rle-encoding.h" |
32 | |
33 | using std::vector; |
34 | |
35 | namespace arrow { |
36 | namespace util { |
37 | |
38 | const int MAX_WIDTH = 32; |
39 | |
40 | TEST(BitArray, TestBool) { |
41 | const int len = 8; |
42 | uint8_t buffer[len]; |
43 | |
44 | BitUtil::BitWriter writer(buffer, len); |
45 | |
46 | // Write alternating 0's and 1's |
47 | for (int i = 0; i < 8; ++i) { |
48 | bool result = writer.PutValue(i % 2, 1); |
49 | EXPECT_TRUE(result); |
50 | } |
51 | writer.Flush(); |
52 | EXPECT_EQ((int)buffer[0], BOOST_BINARY(1 0 1 0 1 0 1 0)); |
53 | |
54 | // Write 00110011 |
55 | for (int i = 0; i < 8; ++i) { |
56 | bool result = false; |
57 | switch (i) { |
58 | case 0: |
59 | case 1: |
60 | case 4: |
61 | case 5: |
62 | result = writer.PutValue(false, 1); |
63 | break; |
64 | default: |
65 | result = writer.PutValue(true, 1); |
66 | break; |
67 | } |
68 | EXPECT_TRUE(result); |
69 | } |
70 | writer.Flush(); |
71 | |
72 | // Validate the exact bit value |
73 | EXPECT_EQ((int)buffer[0], BOOST_BINARY(1 0 1 0 1 0 1 0)); |
74 | EXPECT_EQ((int)buffer[1], BOOST_BINARY(1 1 0 0 1 1 0 0)); |
75 | |
76 | // Use the reader and validate |
77 | BitUtil::BitReader reader(buffer, len); |
78 | for (int i = 0; i < 8; ++i) { |
79 | bool val = false; |
80 | bool result = reader.GetValue(1, &val); |
81 | EXPECT_TRUE(result); |
82 | EXPECT_EQ(val, (i % 2) != 0); |
83 | } |
84 | |
85 | for (int i = 0; i < 8; ++i) { |
86 | bool val = false; |
87 | bool result = reader.GetValue(1, &val); |
88 | EXPECT_TRUE(result); |
89 | switch (i) { |
90 | case 0: |
91 | case 1: |
92 | case 4: |
93 | case 5: |
94 | EXPECT_EQ(val, false); |
95 | break; |
96 | default: |
97 | EXPECT_EQ(val, true); |
98 | break; |
99 | } |
100 | } |
101 | } |
102 | |
103 | // Writes 'num_vals' values with width 'bit_width' and reads them back. |
104 | void TestBitArrayValues(int bit_width, int num_vals) { |
105 | int len = static_cast<int>(BitUtil::BytesForBits(bit_width * num_vals)); |
106 | EXPECT_GT(len, 0); |
107 | const uint64_t mod = bit_width == 64 ? 1 : 1LL << bit_width; |
108 | |
109 | std::vector<uint8_t> buffer(len); |
110 | BitUtil::BitWriter writer(buffer.data(), len); |
111 | for (int i = 0; i < num_vals; ++i) { |
112 | bool result = writer.PutValue(i % mod, bit_width); |
113 | EXPECT_TRUE(result); |
114 | } |
115 | writer.Flush(); |
116 | EXPECT_EQ(writer.bytes_written(), len); |
117 | |
118 | BitUtil::BitReader reader(buffer.data(), len); |
119 | for (int i = 0; i < num_vals; ++i) { |
120 | int64_t val = 0; |
121 | bool result = reader.GetValue(bit_width, &val); |
122 | EXPECT_TRUE(result); |
123 | EXPECT_EQ(val, i % mod); |
124 | } |
125 | EXPECT_EQ(reader.bytes_left(), 0); |
126 | } |
127 | |
128 | TEST(BitArray, TestValues) { |
129 | for (int width = 1; width <= MAX_WIDTH; ++width) { |
130 | TestBitArrayValues(width, 1); |
131 | TestBitArrayValues(width, 2); |
132 | // Don't write too many values |
133 | TestBitArrayValues(width, (width < 12) ? (1 << width) : 4096); |
134 | TestBitArrayValues(width, 1024); |
135 | } |
136 | } |
137 | |
138 | // Test some mixed values |
139 | TEST(BitArray, TestMixed) { |
140 | const int len = 1024; |
141 | uint8_t buffer[len]; |
142 | bool parity = true; |
143 | |
144 | BitUtil::BitWriter writer(buffer, len); |
145 | for (int i = 0; i < len; ++i) { |
146 | bool result; |
147 | if (i % 2 == 0) { |
148 | result = writer.PutValue(parity, 1); |
149 | parity = !parity; |
150 | } else { |
151 | result = writer.PutValue(i, 10); |
152 | } |
153 | EXPECT_TRUE(result); |
154 | } |
155 | writer.Flush(); |
156 | |
157 | parity = true; |
158 | BitUtil::BitReader reader(buffer, len); |
159 | for (int i = 0; i < len; ++i) { |
160 | bool result; |
161 | if (i % 2 == 0) { |
162 | bool val; |
163 | result = reader.GetValue(1, &val); |
164 | EXPECT_EQ(val, parity); |
165 | parity = !parity; |
166 | } else { |
167 | int val; |
168 | result = reader.GetValue(10, &val); |
169 | EXPECT_EQ(val, i); |
170 | } |
171 | EXPECT_TRUE(result); |
172 | } |
173 | } |
174 | |
175 | // Validates encoding of values by encoding and decoding them. If |
176 | // expected_encoding != NULL, also validates that the encoded buffer is |
177 | // exactly 'expected_encoding'. |
178 | // if expected_len is not -1, it will validate the encoded size is correct. |
179 | void ValidateRle(const vector<int>& values, int bit_width, uint8_t* expected_encoding, |
180 | int expected_len) { |
181 | const int len = 64 * 1024; |
182 | uint8_t buffer[len]; |
183 | EXPECT_LE(expected_len, len); |
184 | |
185 | RleEncoder encoder(buffer, len, bit_width); |
186 | for (size_t i = 0; i < values.size(); ++i) { |
187 | bool result = encoder.Put(values[i]); |
188 | EXPECT_TRUE(result); |
189 | } |
190 | int encoded_len = encoder.Flush(); |
191 | |
192 | if (expected_len != -1) { |
193 | EXPECT_EQ(encoded_len, expected_len); |
194 | } |
195 | if (expected_encoding != NULL) { |
196 | EXPECT_EQ(memcmp(buffer, expected_encoding, encoded_len), 0); |
197 | } |
198 | |
199 | // Verify read |
200 | { |
201 | RleDecoder decoder(buffer, len, bit_width); |
202 | for (size_t i = 0; i < values.size(); ++i) { |
203 | uint64_t val; |
204 | bool result = decoder.Get(&val); |
205 | EXPECT_TRUE(result); |
206 | EXPECT_EQ(values[i], val); |
207 | } |
208 | } |
209 | |
210 | // Verify batch read |
211 | { |
212 | RleDecoder decoder(buffer, len, bit_width); |
213 | vector<int> values_read(values.size()); |
214 | ASSERT_EQ(values.size(), |
215 | decoder.GetBatch(values_read.data(), static_cast<int>(values.size()))); |
216 | EXPECT_EQ(values, values_read); |
217 | } |
218 | } |
219 | |
220 | // A version of ValidateRle that round-trips the values and returns false if |
221 | // the returned values are not all the same |
222 | bool CheckRoundTrip(const vector<int>& values, int bit_width) { |
223 | const int len = 64 * 1024; |
224 | uint8_t buffer[len]; |
225 | RleEncoder encoder(buffer, len, bit_width); |
226 | for (size_t i = 0; i < values.size(); ++i) { |
227 | bool result = encoder.Put(values[i]); |
228 | if (!result) { |
229 | return false; |
230 | } |
231 | } |
232 | int encoded_len = encoder.Flush(); |
233 | int out = 0; |
234 | |
235 | { |
236 | RleDecoder decoder(buffer, encoded_len, bit_width); |
237 | for (size_t i = 0; i < values.size(); ++i) { |
238 | EXPECT_TRUE(decoder.Get(&out)); |
239 | if (values[i] != out) { |
240 | return false; |
241 | } |
242 | } |
243 | } |
244 | |
245 | // Verify batch read |
246 | { |
247 | RleDecoder decoder(buffer, len, bit_width); |
248 | vector<int> values_read(values.size()); |
249 | if (static_cast<int>(values.size()) != |
250 | decoder.GetBatch(values_read.data(), static_cast<int>(values.size()))) { |
251 | return false; |
252 | } |
253 | if (values != values_read) { |
254 | return false; |
255 | } |
256 | } |
257 | |
258 | return true; |
259 | } |
260 | |
261 | TEST(Rle, SpecificSequences) { |
262 | const int len = 1024; |
263 | uint8_t expected_buffer[len]; |
264 | vector<int> values; |
265 | |
266 | // Test 50 0' followed by 50 1's |
267 | values.resize(100); |
268 | for (int i = 0; i < 50; ++i) { |
269 | values[i] = 0; |
270 | } |
271 | for (int i = 50; i < 100; ++i) { |
272 | values[i] = 1; |
273 | } |
274 | |
275 | // expected_buffer valid for bit width <= 1 byte |
276 | expected_buffer[0] = (50 << 1); |
277 | expected_buffer[1] = 0; |
278 | expected_buffer[2] = (50 << 1); |
279 | expected_buffer[3] = 1; |
280 | for (int width = 1; width <= 8; ++width) { |
281 | ValidateRle(values, width, expected_buffer, 4); |
282 | } |
283 | |
284 | for (int width = 9; width <= MAX_WIDTH; ++width) { |
285 | ValidateRle(values, width, NULL, |
286 | 2 * (1 + static_cast<int>(BitUtil::CeilDiv(width, 8)))); |
287 | } |
288 | |
289 | // Test 100 0's and 1's alternating |
290 | for (int i = 0; i < 100; ++i) { |
291 | values[i] = i % 2; |
292 | } |
293 | int num_groups = static_cast<int>(BitUtil::CeilDiv(100, 8)); |
294 | expected_buffer[0] = static_cast<uint8_t>((num_groups << 1) | 1); |
295 | for (int i = 1; i <= 100 / 8; ++i) { |
296 | expected_buffer[i] = BOOST_BINARY(1 0 1 0 1 0 1 0); |
297 | } |
298 | // Values for the last 4 0 and 1's. The upper 4 bits should be padded to 0. |
299 | expected_buffer[100 / 8 + 1] = BOOST_BINARY(0 0 0 0 1 0 1 0); |
300 | |
301 | // num_groups and expected_buffer only valid for bit width = 1 |
302 | ValidateRle(values, 1, expected_buffer, 1 + num_groups); |
303 | for (int width = 2; width <= MAX_WIDTH; ++width) { |
304 | int num_values = static_cast<int>(BitUtil::CeilDiv(100, 8)) * 8; |
305 | ValidateRle(values, width, NULL, |
306 | 1 + static_cast<int>(BitUtil::CeilDiv(width * num_values, 8))); |
307 | } |
308 | } |
309 | |
310 | // ValidateRle on 'num_vals' values with width 'bit_width'. If 'value' != -1, that value |
311 | // is used, otherwise alternating values are used. |
312 | void TestRleValues(int bit_width, int num_vals, int value = -1) { |
313 | const uint64_t mod = (bit_width == 64) ? 1 : 1LL << bit_width; |
314 | vector<int> values; |
315 | for (int v = 0; v < num_vals; ++v) { |
316 | values.push_back((value != -1) ? value : static_cast<int>(v % mod)); |
317 | } |
318 | ValidateRle(values, bit_width, NULL, -1); |
319 | } |
320 | |
321 | TEST(Rle, TestValues) { |
322 | for (int width = 1; width <= MAX_WIDTH; ++width) { |
323 | TestRleValues(width, 1); |
324 | TestRleValues(width, 1024); |
325 | TestRleValues(width, 1024, 0); |
326 | TestRleValues(width, 1024, 1); |
327 | } |
328 | } |
329 | |
330 | TEST(Rle, BitWidthZeroRepeated) { |
331 | uint8_t buffer[1]; |
332 | const int num_values = 15; |
333 | buffer[0] = num_values << 1; // repeated indicator byte |
334 | RleDecoder decoder(buffer, sizeof(buffer), 0); |
335 | uint8_t val; |
336 | for (int i = 0; i < num_values; ++i) { |
337 | bool result = decoder.Get(&val); |
338 | EXPECT_TRUE(result); |
339 | EXPECT_EQ(val, 0); // can only encode 0s with bit width 0 |
340 | } |
341 | EXPECT_FALSE(decoder.Get(&val)); |
342 | } |
343 | |
344 | TEST(Rle, BitWidthZeroLiteral) { |
345 | uint8_t buffer[1]; |
346 | const int num_groups = 4; |
347 | buffer[0] = num_groups << 1 | 1; // literal indicator byte |
348 | RleDecoder decoder = RleDecoder(buffer, sizeof(buffer), 0); |
349 | const int num_values = num_groups * 8; |
350 | uint8_t val; |
351 | for (int i = 0; i < num_values; ++i) { |
352 | bool result = decoder.Get(&val); |
353 | EXPECT_TRUE(result); |
354 | EXPECT_EQ(val, 0); // can only encode 0s with bit width 0 |
355 | } |
356 | EXPECT_FALSE(decoder.Get(&val)); |
357 | } |
358 | |
359 | // Test that writes out a repeated group and then a literal |
360 | // group but flush before finishing. |
361 | TEST(BitRle, Flush) { |
362 | vector<int> values; |
363 | for (int i = 0; i < 16; ++i) values.push_back(1); |
364 | values.push_back(0); |
365 | ValidateRle(values, 1, NULL, -1); |
366 | values.push_back(1); |
367 | ValidateRle(values, 1, NULL, -1); |
368 | values.push_back(1); |
369 | ValidateRle(values, 1, NULL, -1); |
370 | values.push_back(1); |
371 | ValidateRle(values, 1, NULL, -1); |
372 | } |
373 | |
374 | // Test some random sequences. |
375 | TEST(BitRle, Random) { |
376 | int niters = 50; |
377 | int ngroups = 1000; |
378 | int max_group_size = 16; |
379 | vector<int> values(ngroups + max_group_size); |
380 | |
381 | // prng setup |
382 | std::random_device rd; |
383 | std::uniform_int_distribution<int> dist(1, 20); |
384 | |
385 | for (int iter = 0; iter < niters; ++iter) { |
386 | // generate a seed with device entropy |
387 | uint32_t seed = rd(); |
388 | std::mt19937 gen(seed); |
389 | |
390 | bool parity = 0; |
391 | values.resize(0); |
392 | |
393 | for (int i = 0; i < ngroups; ++i) { |
394 | int group_size = dist(gen); |
395 | if (group_size > max_group_size) { |
396 | group_size = 1; |
397 | } |
398 | for (int i = 0; i < group_size; ++i) { |
399 | values.push_back(parity); |
400 | } |
401 | parity = !parity; |
402 | } |
403 | if (!CheckRoundTrip(values, BitUtil::NumRequiredBits(values.size()))) { |
404 | FAIL() << "failing seed: " << seed; |
405 | } |
406 | } |
407 | } |
408 | |
409 | // Test a sequence of 1 0's, 2 1's, 3 0's. etc |
410 | // e.g. 011000111100000 |
411 | TEST(BitRle, RepeatedPattern) { |
412 | vector<int> values; |
413 | const int min_run = 1; |
414 | const int max_run = 32; |
415 | |
416 | for (int i = min_run; i <= max_run; ++i) { |
417 | int v = i % 2; |
418 | for (int j = 0; j < i; ++j) { |
419 | values.push_back(v); |
420 | } |
421 | } |
422 | |
423 | // And go back down again |
424 | for (int i = max_run; i >= min_run; --i) { |
425 | int v = i % 2; |
426 | for (int j = 0; j < i; ++j) { |
427 | values.push_back(v); |
428 | } |
429 | } |
430 | |
431 | ValidateRle(values, 1, NULL, -1); |
432 | } |
433 | |
434 | TEST(BitRle, Overflow) { |
435 | for (int bit_width = 1; bit_width < 32; bit_width += 3) { |
436 | int len = RleEncoder::MinBufferSize(bit_width); |
437 | std::vector<uint8_t> buffer(len); |
438 | int num_added = 0; |
439 | bool parity = true; |
440 | |
441 | RleEncoder encoder(buffer.data(), len, bit_width); |
442 | // Insert alternating true/false until there is no space left |
443 | while (true) { |
444 | bool result = encoder.Put(parity); |
445 | parity = !parity; |
446 | if (!result) break; |
447 | ++num_added; |
448 | } |
449 | |
450 | int bytes_written = encoder.Flush(); |
451 | EXPECT_LE(bytes_written, len); |
452 | EXPECT_GT(num_added, 0); |
453 | |
454 | RleDecoder decoder(buffer.data(), bytes_written, bit_width); |
455 | parity = true; |
456 | uint32_t v; |
457 | for (int i = 0; i < num_added; ++i) { |
458 | bool result = decoder.Get(&v); |
459 | EXPECT_TRUE(result); |
460 | EXPECT_EQ(v != 0, parity); |
461 | parity = !parity; |
462 | } |
463 | // Make sure we get false when reading past end a couple times. |
464 | EXPECT_FALSE(decoder.Get(&v)); |
465 | EXPECT_FALSE(decoder.Get(&v)); |
466 | } |
467 | } |
468 | |
469 | } // namespace util |
470 | } // namespace arrow |
471 | |