1 | // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors |
2 | // Licensed under the MIT License: |
3 | // |
4 | // Permission is hereby granted, free of charge, to any person obtaining a copy |
5 | // of this software and associated documentation files (the "Software"), to deal |
6 | // in the Software without restriction, including without limitation the rights |
7 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
8 | // copies of the Software, and to permit persons to whom the Software is |
9 | // furnished to do so, subject to the following conditions: |
10 | // |
11 | // The above copyright notice and this permission notice shall be included in |
12 | // all copies or substantial portions of the Software. |
13 | // |
14 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
15 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
16 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
17 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
18 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
19 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
20 | // THE SOFTWARE. |
21 | |
22 | #include "serialize-packed.h" |
23 | #include <kj/debug.h> |
24 | #include "layout.h" |
25 | #include <vector> |
26 | |
27 | namespace capnp { |
28 | |
29 | namespace _ { // private |
30 | |
31 | PackedInputStream::PackedInputStream(kj::BufferedInputStream& inner): inner(inner) {} |
32 | PackedInputStream::~PackedInputStream() noexcept(false) {} |
33 | |
34 | size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) { |
35 | if (maxBytes == 0) { |
36 | return 0; |
37 | } |
38 | |
39 | KJ_DREQUIRE(minBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned." ); |
40 | KJ_DREQUIRE(maxBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned." ); |
41 | |
42 | uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(dst); |
43 | uint8_t* const outEnd = reinterpret_cast<uint8_t*>(dst) + maxBytes; |
44 | uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes; |
45 | |
46 | kj::ArrayPtr<const byte> buffer = inner.tryGetReadBuffer(); |
47 | if (buffer.size() == 0) { |
48 | return 0; |
49 | } |
50 | const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin()); |
51 | |
52 | #define REFRESH_BUFFER() \ |
53 | inner.skip(buffer.size()); \ |
54 | buffer = inner.getReadBuffer(); \ |
55 | KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \ |
56 | return out - reinterpret_cast<uint8_t*>(dst); \ |
57 | } \ |
58 | in = reinterpret_cast<const uint8_t*>(buffer.begin()) |
59 | |
60 | #define BUFFER_END (reinterpret_cast<const uint8_t*>(buffer.end())) |
61 | #define BUFFER_REMAINING ((size_t)(BUFFER_END - in)) |
62 | |
63 | for (;;) { |
64 | uint8_t tag; |
65 | |
66 | KJ_DASSERT((out - reinterpret_cast<uint8_t*>(dst)) % sizeof(word) == 0, |
67 | "Output pointer should always be aligned here." ); |
68 | |
69 | if (BUFFER_REMAINING < 10) { |
70 | if (out >= outMin) { |
71 | // We read at least the minimum amount, so go ahead and return. |
72 | inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); |
73 | return out - reinterpret_cast<uint8_t*>(dst); |
74 | } |
75 | |
76 | if (BUFFER_REMAINING == 0) { |
77 | REFRESH_BUFFER(); |
78 | continue; |
79 | } |
80 | |
81 | // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds |
82 | // check on each byte. |
83 | |
84 | tag = *in++; |
85 | |
86 | for (uint i = 0; i < 8; i++) { |
87 | if (tag & (1u << i)) { |
88 | if (BUFFER_REMAINING == 0) { |
89 | REFRESH_BUFFER(); |
90 | } |
91 | *out++ = *in++; |
92 | } else { |
93 | *out++ = 0; |
94 | } |
95 | } |
96 | |
97 | if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) { |
98 | REFRESH_BUFFER(); |
99 | } |
100 | } else { |
101 | tag = *in++; |
102 | |
103 | #define HANDLE_BYTE(n) \ |
104 | { \ |
105 | bool isNonzero = (tag & (1u << n)) != 0; \ |
106 | *out++ = *in & (-(int8_t)isNonzero); \ |
107 | in += isNonzero; \ |
108 | } |
109 | |
110 | HANDLE_BYTE(0); |
111 | HANDLE_BYTE(1); |
112 | HANDLE_BYTE(2); |
113 | HANDLE_BYTE(3); |
114 | HANDLE_BYTE(4); |
115 | HANDLE_BYTE(5); |
116 | HANDLE_BYTE(6); |
117 | HANDLE_BYTE(7); |
118 | #undef HANDLE_BYTE |
119 | } |
120 | |
121 | if (tag == 0) { |
122 | KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here." ); |
123 | |
124 | uint runLength = *in++ * sizeof(word); |
125 | |
126 | KJ_REQUIRE(runLength <= outEnd - out, |
127 | "Packed input did not end cleanly on a segment boundary." ) { |
128 | return out - reinterpret_cast<uint8_t*>(dst); |
129 | } |
130 | memset(out, 0, runLength); |
131 | out += runLength; |
132 | |
133 | } else if (tag == 0xffu) { |
134 | KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here." ); |
135 | |
136 | uint runLength = *in++ * sizeof(word); |
137 | |
138 | KJ_REQUIRE(runLength <= outEnd - out, |
139 | "Packed input did not end cleanly on a segment boundary." ) { |
140 | return out - reinterpret_cast<uint8_t*>(dst); |
141 | } |
142 | |
143 | uint inRemaining = BUFFER_REMAINING; |
144 | if (inRemaining >= runLength) { |
145 | // Fast path. |
146 | memcpy(out, in, runLength); |
147 | out += runLength; |
148 | in += runLength; |
149 | } else { |
150 | // Copy over the first buffer, then do one big read for the rest. |
151 | memcpy(out, in, inRemaining); |
152 | out += inRemaining; |
153 | runLength -= inRemaining; |
154 | |
155 | inner.skip(buffer.size()); |
156 | inner.read(out, runLength); |
157 | out += runLength; |
158 | |
159 | if (out == outEnd) { |
160 | return maxBytes; |
161 | } else { |
162 | buffer = inner.getReadBuffer(); |
163 | in = reinterpret_cast<const uint8_t*>(buffer.begin()); |
164 | |
165 | // Skip the bounds check below since we just did the same check above. |
166 | continue; |
167 | } |
168 | } |
169 | } |
170 | |
171 | if (out == outEnd) { |
172 | inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); |
173 | return maxBytes; |
174 | } |
175 | } |
176 | |
177 | KJ_FAIL_ASSERT("Can't get here." ); |
178 | return 0; // GCC knows KJ_FAIL_ASSERT doesn't return, but Eclipse CDT still warns... |
179 | |
180 | #undef REFRESH_BUFFER |
181 | } |
182 | |
183 | void PackedInputStream::skip(size_t bytes) { |
184 | // We can't just read into buffers because buffers must end on block boundaries. |
185 | |
186 | if (bytes == 0) { |
187 | return; |
188 | } |
189 | |
190 | KJ_DREQUIRE(bytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned." ); |
191 | |
192 | kj::ArrayPtr<const byte> buffer = inner.getReadBuffer(); |
193 | const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin()); |
194 | |
195 | #define REFRESH_BUFFER() \ |
196 | inner.skip(buffer.size()); \ |
197 | buffer = inner.getReadBuffer(); \ |
198 | KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \ |
199 | in = reinterpret_cast<const uint8_t*>(buffer.begin()) |
200 | |
201 | for (;;) { |
202 | uint8_t tag; |
203 | |
204 | if (BUFFER_REMAINING < 10) { |
205 | if (BUFFER_REMAINING == 0) { |
206 | REFRESH_BUFFER(); |
207 | continue; |
208 | } |
209 | |
210 | // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds |
211 | // check on each byte. |
212 | |
213 | tag = *in++; |
214 | |
215 | for (uint i = 0; i < 8; i++) { |
216 | if (tag & (1u << i)) { |
217 | if (BUFFER_REMAINING == 0) { |
218 | REFRESH_BUFFER(); |
219 | } |
220 | in++; |
221 | } |
222 | } |
223 | bytes -= 8; |
224 | |
225 | if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) { |
226 | REFRESH_BUFFER(); |
227 | } |
228 | } else { |
229 | tag = *in++; |
230 | |
231 | #define HANDLE_BYTE(n) \ |
232 | in += (tag & (1u << n)) != 0 |
233 | |
234 | HANDLE_BYTE(0); |
235 | HANDLE_BYTE(1); |
236 | HANDLE_BYTE(2); |
237 | HANDLE_BYTE(3); |
238 | HANDLE_BYTE(4); |
239 | HANDLE_BYTE(5); |
240 | HANDLE_BYTE(6); |
241 | HANDLE_BYTE(7); |
242 | #undef HANDLE_BYTE |
243 | |
244 | bytes -= 8; |
245 | } |
246 | |
247 | if (tag == 0) { |
248 | KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here." ); |
249 | |
250 | uint runLength = *in++ * sizeof(word); |
251 | |
252 | KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary." ) { |
253 | return; |
254 | } |
255 | |
256 | bytes -= runLength; |
257 | |
258 | } else if (tag == 0xffu) { |
259 | KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here." ); |
260 | |
261 | uint runLength = *in++ * sizeof(word); |
262 | |
263 | KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary." ) { |
264 | return; |
265 | } |
266 | |
267 | bytes -= runLength; |
268 | |
269 | uint inRemaining = BUFFER_REMAINING; |
270 | if (inRemaining > runLength) { |
271 | // Fast path. |
272 | in += runLength; |
273 | } else { |
274 | // Forward skip to the underlying stream. |
275 | runLength -= inRemaining; |
276 | inner.skip(buffer.size() + runLength); |
277 | |
278 | if (bytes == 0) { |
279 | return; |
280 | } else { |
281 | buffer = inner.getReadBuffer(); |
282 | in = reinterpret_cast<const uint8_t*>(buffer.begin()); |
283 | |
284 | // Skip the bounds check below since we just did the same check above. |
285 | continue; |
286 | } |
287 | } |
288 | } |
289 | |
290 | if (bytes == 0) { |
291 | inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); |
292 | return; |
293 | } |
294 | } |
295 | |
296 | KJ_FAIL_ASSERT("Can't get here." ); |
297 | } |
298 | |
299 | // ------------------------------------------------------------------- |
300 | |
301 | PackedOutputStream::PackedOutputStream(kj::BufferedOutputStream& inner) |
302 | : inner(inner) {} |
303 | PackedOutputStream::~PackedOutputStream() noexcept(false) {} |
304 | |
305 | void PackedOutputStream::write(const void* src, size_t size) { |
306 | kj::ArrayPtr<byte> buffer = inner.getWriteBuffer(); |
307 | byte slowBuffer[20]; |
308 | |
309 | uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(buffer.begin()); |
310 | |
311 | const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(src); |
312 | const uint8_t* const inEnd = reinterpret_cast<const uint8_t*>(src) + size; |
313 | |
314 | while (in < inEnd) { |
315 | if (reinterpret_cast<uint8_t*>(buffer.end()) - out < 10) { |
316 | // Oops, we're out of space. We need at least 10 bytes for the fast path, since we don't |
317 | // bounds-check on every byte. |
318 | |
319 | // Write what we have so far. |
320 | inner.write(buffer.begin(), out - reinterpret_cast<uint8_t*>(buffer.begin())); |
321 | |
322 | // Use a slow buffer into which we'll encode 10 to 20 bytes. This should get us past the |
323 | // output stream's buffer boundary. |
324 | buffer = kj::arrayPtr(slowBuffer, sizeof(slowBuffer)); |
325 | out = reinterpret_cast<uint8_t*>(buffer.begin()); |
326 | } |
327 | |
328 | uint8_t* tagPos = out++; |
329 | |
330 | #define HANDLE_BYTE(n) \ |
331 | uint8_t bit##n = *in != 0; \ |
332 | *out = *in; \ |
333 | out += bit##n; /* out only advances if the byte was non-zero */ \ |
334 | ++in |
335 | |
336 | HANDLE_BYTE(0); |
337 | HANDLE_BYTE(1); |
338 | HANDLE_BYTE(2); |
339 | HANDLE_BYTE(3); |
340 | HANDLE_BYTE(4); |
341 | HANDLE_BYTE(5); |
342 | HANDLE_BYTE(6); |
343 | HANDLE_BYTE(7); |
344 | #undef HANDLE_BYTE |
345 | |
346 | uint8_t tag = (bit0 << 0) | (bit1 << 1) | (bit2 << 2) | (bit3 << 3) |
347 | | (bit4 << 4) | (bit5 << 5) | (bit6 << 6) | (bit7 << 7); |
348 | *tagPos = tag; |
349 | |
350 | if (tag == 0) { |
351 | // An all-zero word is followed by a count of consecutive zero words (not including the |
352 | // first one). |
353 | |
354 | // We can check a whole word at a time. (Here is where we use the assumption that |
355 | // `src` is word-aligned.) |
356 | const uint64_t* inWord = reinterpret_cast<const uint64_t*>(in); |
357 | |
358 | // The count must fit it 1 byte, so limit to 255 words. |
359 | const uint64_t* limit = reinterpret_cast<const uint64_t*>(inEnd); |
360 | if (limit - inWord > 255) { |
361 | limit = inWord + 255; |
362 | } |
363 | |
364 | while (inWord < limit && *inWord == 0) { |
365 | ++inWord; |
366 | } |
367 | |
368 | // Write the count. |
369 | *out++ = inWord - reinterpret_cast<const uint64_t*>(in); |
370 | |
371 | // Advance input. |
372 | in = reinterpret_cast<const uint8_t*>(inWord); |
373 | |
374 | } else if (tag == 0xffu) { |
375 | // An all-nonzero word is followed by a count of consecutive uncompressed words, followed |
376 | // by the uncompressed words themselves. |
377 | |
378 | // Count the number of consecutive words in the input which have no more than a single |
379 | // zero-byte. We look for at least two zeros because that's the point where our compression |
380 | // scheme becomes a net win. |
381 | // TODO(perf): Maybe look for three zeros? Compressing a two-zero word is a loss if the |
382 | // following word has no zeros. |
383 | const uint8_t* runStart = in; |
384 | |
385 | const uint8_t* limit = inEnd; |
386 | if ((size_t)(limit - in) > 255 * sizeof(word)) { |
387 | limit = in + 255 * sizeof(word); |
388 | } |
389 | |
390 | while (in < limit) { |
391 | // Check eight input bytes for zeros. |
392 | uint c = *in++ == 0; |
393 | c += *in++ == 0; |
394 | c += *in++ == 0; |
395 | c += *in++ == 0; |
396 | c += *in++ == 0; |
397 | c += *in++ == 0; |
398 | c += *in++ == 0; |
399 | c += *in++ == 0; |
400 | |
401 | if (c >= 2) { |
402 | // Un-read the word with multiple zeros, since we'll want to compress that one. |
403 | in -= 8; |
404 | break; |
405 | } |
406 | } |
407 | |
408 | // Write the count. |
409 | uint count = in - runStart; |
410 | *out++ = count / sizeof(word); |
411 | |
412 | if (count <= reinterpret_cast<uint8_t*>(buffer.end()) - out) { |
413 | // There's enough space to memcpy. |
414 | memcpy(out, runStart, count); |
415 | out += count; |
416 | } else { |
417 | // Input overruns the output buffer. We'll give it to the output stream in one chunk |
418 | // and let it decide what to do. |
419 | inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin()); |
420 | inner.write(runStart, in - runStart); |
421 | buffer = inner.getWriteBuffer(); |
422 | out = reinterpret_cast<uint8_t*>(buffer.begin()); |
423 | } |
424 | } |
425 | } |
426 | |
427 | // Write whatever is left. |
428 | inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin()); |
429 | } |
430 | |
431 | } // namespace _ (private) |
432 | |
433 | // ======================================================================================= |
434 | |
435 | PackedMessageReader::PackedMessageReader( |
436 | kj::BufferedInputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) |
437 | : PackedInputStream(inputStream), |
438 | InputStreamMessageReader(static_cast<PackedInputStream&>(*this), options, scratchSpace) {} |
439 | |
440 | PackedMessageReader::~PackedMessageReader() noexcept(false) {} |
441 | |
442 | PackedFdMessageReader::PackedFdMessageReader( |
443 | int fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) |
444 | : FdInputStream(fd), |
445 | BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)), |
446 | PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this), |
447 | options, scratchSpace) {} |
448 | |
449 | PackedFdMessageReader::PackedFdMessageReader( |
450 | kj::AutoCloseFd fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) |
451 | : FdInputStream(kj::mv(fd)), |
452 | BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)), |
453 | PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this), |
454 | options, scratchSpace) {} |
455 | |
456 | PackedFdMessageReader::~PackedFdMessageReader() noexcept(false) {} |
457 | |
458 | void writePackedMessage(kj::BufferedOutputStream& output, |
459 | kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { |
460 | _::PackedOutputStream packedOutput(output); |
461 | writeMessage(packedOutput, segments); |
462 | } |
463 | |
464 | void writePackedMessage(kj::OutputStream& output, |
465 | kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { |
466 | KJ_IF_MAYBE(bufferedOutputPtr, kj::dynamicDowncastIfAvailable<kj::BufferedOutputStream>(output)) { |
467 | writePackedMessage(*bufferedOutputPtr, segments); |
468 | } else { |
469 | byte buffer[8192]; |
470 | kj::BufferedOutputStreamWrapper bufferedOutput(output, kj::arrayPtr(buffer, sizeof(buffer))); |
471 | writePackedMessage(bufferedOutput, segments); |
472 | } |
473 | } |
474 | |
475 | void writePackedMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { |
476 | kj::FdOutputStream output(fd); |
477 | writePackedMessage(output, segments); |
478 | } |
479 | |
480 | size_t computeUnpackedSizeInWords(kj::ArrayPtr<const byte> packedBytes) { |
481 | const byte* ptr = packedBytes.begin(); |
482 | const byte* end = packedBytes.end(); |
483 | |
484 | size_t total = 0; |
485 | while (ptr < end) { |
486 | uint tag = *ptr; |
487 | size_t count = kj::popCount(tag); |
488 | total += 1; |
489 | KJ_REQUIRE(end - ptr >= count, "invalid packed data" ); |
490 | ptr += count + 1; |
491 | |
492 | if (tag == 0) { |
493 | KJ_REQUIRE(ptr < end, "invalid packed data" ); |
494 | total += *ptr++; |
495 | } else if (tag == 0xff) { |
496 | KJ_REQUIRE(ptr < end, "invalid packed data" ); |
497 | size_t words = *ptr++; |
498 | total += words; |
499 | size_t bytes = words * sizeof(word); |
500 | KJ_REQUIRE(end - ptr >= bytes, "invalid packed data" ); |
501 | ptr += bytes; |
502 | } |
503 | } |
504 | |
505 | return total; |
506 | } |
507 | |
508 | } // namespace capnp |
509 | |