1// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2// Licensed under the MIT License:
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20// THE SOFTWARE.
21
22#include "serialize-packed.h"
23#include <kj/debug.h>
24#include "layout.h"
25#include <vector>
26
27namespace capnp {
28
29namespace _ { // private
30
31PackedInputStream::PackedInputStream(kj::BufferedInputStream& inner): inner(inner) {}
32PackedInputStream::~PackedInputStream() noexcept(false) {}
33
34size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
35 if (maxBytes == 0) {
36 return 0;
37 }
38
39 KJ_DREQUIRE(minBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
40 KJ_DREQUIRE(maxBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
41
42 uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(dst);
43 uint8_t* const outEnd = reinterpret_cast<uint8_t*>(dst) + maxBytes;
44 uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes;
45
46 kj::ArrayPtr<const byte> buffer = inner.tryGetReadBuffer();
47 if (buffer.size() == 0) {
48 return 0;
49 }
50 const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
51
52#define REFRESH_BUFFER() \
53 inner.skip(buffer.size()); \
54 buffer = inner.getReadBuffer(); \
55 KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \
56 return out - reinterpret_cast<uint8_t*>(dst); \
57 } \
58 in = reinterpret_cast<const uint8_t*>(buffer.begin())
59
60#define BUFFER_END (reinterpret_cast<const uint8_t*>(buffer.end()))
61#define BUFFER_REMAINING ((size_t)(BUFFER_END - in))
62
63 for (;;) {
64 uint8_t tag;
65
66 KJ_DASSERT((out - reinterpret_cast<uint8_t*>(dst)) % sizeof(word) == 0,
67 "Output pointer should always be aligned here.");
68
69 if (BUFFER_REMAINING < 10) {
70 if (out >= outMin) {
71 // We read at least the minimum amount, so go ahead and return.
72 inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
73 return out - reinterpret_cast<uint8_t*>(dst);
74 }
75
76 if (BUFFER_REMAINING == 0) {
77 REFRESH_BUFFER();
78 continue;
79 }
80
81 // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds
82 // check on each byte.
83
84 tag = *in++;
85
86 for (uint i = 0; i < 8; i++) {
87 if (tag & (1u << i)) {
88 if (BUFFER_REMAINING == 0) {
89 REFRESH_BUFFER();
90 }
91 *out++ = *in++;
92 } else {
93 *out++ = 0;
94 }
95 }
96
97 if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) {
98 REFRESH_BUFFER();
99 }
100 } else {
101 tag = *in++;
102
103#define HANDLE_BYTE(n) \
104 { \
105 bool isNonzero = (tag & (1u << n)) != 0; \
106 *out++ = *in & (-(int8_t)isNonzero); \
107 in += isNonzero; \
108 }
109
110 HANDLE_BYTE(0);
111 HANDLE_BYTE(1);
112 HANDLE_BYTE(2);
113 HANDLE_BYTE(3);
114 HANDLE_BYTE(4);
115 HANDLE_BYTE(5);
116 HANDLE_BYTE(6);
117 HANDLE_BYTE(7);
118#undef HANDLE_BYTE
119 }
120
121 if (tag == 0) {
122 KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
123
124 uint runLength = *in++ * sizeof(word);
125
126 KJ_REQUIRE(runLength <= outEnd - out,
127 "Packed input did not end cleanly on a segment boundary.") {
128 return out - reinterpret_cast<uint8_t*>(dst);
129 }
130 memset(out, 0, runLength);
131 out += runLength;
132
133 } else if (tag == 0xffu) {
134 KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
135
136 uint runLength = *in++ * sizeof(word);
137
138 KJ_REQUIRE(runLength <= outEnd - out,
139 "Packed input did not end cleanly on a segment boundary.") {
140 return out - reinterpret_cast<uint8_t*>(dst);
141 }
142
143 uint inRemaining = BUFFER_REMAINING;
144 if (inRemaining >= runLength) {
145 // Fast path.
146 memcpy(out, in, runLength);
147 out += runLength;
148 in += runLength;
149 } else {
150 // Copy over the first buffer, then do one big read for the rest.
151 memcpy(out, in, inRemaining);
152 out += inRemaining;
153 runLength -= inRemaining;
154
155 inner.skip(buffer.size());
156 inner.read(out, runLength);
157 out += runLength;
158
159 if (out == outEnd) {
160 return maxBytes;
161 } else {
162 buffer = inner.getReadBuffer();
163 in = reinterpret_cast<const uint8_t*>(buffer.begin());
164
165 // Skip the bounds check below since we just did the same check above.
166 continue;
167 }
168 }
169 }
170
171 if (out == outEnd) {
172 inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
173 return maxBytes;
174 }
175 }
176
177 KJ_FAIL_ASSERT("Can't get here.");
178 return 0; // GCC knows KJ_FAIL_ASSERT doesn't return, but Eclipse CDT still warns...
179
180#undef REFRESH_BUFFER
181}
182
183void PackedInputStream::skip(size_t bytes) {
184 // We can't just read into buffers because buffers must end on block boundaries.
185
186 if (bytes == 0) {
187 return;
188 }
189
190 KJ_DREQUIRE(bytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
191
192 kj::ArrayPtr<const byte> buffer = inner.getReadBuffer();
193 const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
194
195#define REFRESH_BUFFER() \
196 inner.skip(buffer.size()); \
197 buffer = inner.getReadBuffer(); \
198 KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \
199 in = reinterpret_cast<const uint8_t*>(buffer.begin())
200
201 for (;;) {
202 uint8_t tag;
203
204 if (BUFFER_REMAINING < 10) {
205 if (BUFFER_REMAINING == 0) {
206 REFRESH_BUFFER();
207 continue;
208 }
209
210 // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds
211 // check on each byte.
212
213 tag = *in++;
214
215 for (uint i = 0; i < 8; i++) {
216 if (tag & (1u << i)) {
217 if (BUFFER_REMAINING == 0) {
218 REFRESH_BUFFER();
219 }
220 in++;
221 }
222 }
223 bytes -= 8;
224
225 if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) {
226 REFRESH_BUFFER();
227 }
228 } else {
229 tag = *in++;
230
231#define HANDLE_BYTE(n) \
232 in += (tag & (1u << n)) != 0
233
234 HANDLE_BYTE(0);
235 HANDLE_BYTE(1);
236 HANDLE_BYTE(2);
237 HANDLE_BYTE(3);
238 HANDLE_BYTE(4);
239 HANDLE_BYTE(5);
240 HANDLE_BYTE(6);
241 HANDLE_BYTE(7);
242#undef HANDLE_BYTE
243
244 bytes -= 8;
245 }
246
247 if (tag == 0) {
248 KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
249
250 uint runLength = *in++ * sizeof(word);
251
252 KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
253 return;
254 }
255
256 bytes -= runLength;
257
258 } else if (tag == 0xffu) {
259 KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
260
261 uint runLength = *in++ * sizeof(word);
262
263 KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
264 return;
265 }
266
267 bytes -= runLength;
268
269 uint inRemaining = BUFFER_REMAINING;
270 if (inRemaining > runLength) {
271 // Fast path.
272 in += runLength;
273 } else {
274 // Forward skip to the underlying stream.
275 runLength -= inRemaining;
276 inner.skip(buffer.size() + runLength);
277
278 if (bytes == 0) {
279 return;
280 } else {
281 buffer = inner.getReadBuffer();
282 in = reinterpret_cast<const uint8_t*>(buffer.begin());
283
284 // Skip the bounds check below since we just did the same check above.
285 continue;
286 }
287 }
288 }
289
290 if (bytes == 0) {
291 inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
292 return;
293 }
294 }
295
296 KJ_FAIL_ASSERT("Can't get here.");
297}
298
299// -------------------------------------------------------------------
300
301PackedOutputStream::PackedOutputStream(kj::BufferedOutputStream& inner)
302 : inner(inner) {}
303PackedOutputStream::~PackedOutputStream() noexcept(false) {}
304
305void PackedOutputStream::write(const void* src, size_t size) {
306 kj::ArrayPtr<byte> buffer = inner.getWriteBuffer();
307 byte slowBuffer[20];
308
309 uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(buffer.begin());
310
311 const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(src);
312 const uint8_t* const inEnd = reinterpret_cast<const uint8_t*>(src) + size;
313
314 while (in < inEnd) {
315 if (reinterpret_cast<uint8_t*>(buffer.end()) - out < 10) {
316 // Oops, we're out of space. We need at least 10 bytes for the fast path, since we don't
317 // bounds-check on every byte.
318
319 // Write what we have so far.
320 inner.write(buffer.begin(), out - reinterpret_cast<uint8_t*>(buffer.begin()));
321
322 // Use a slow buffer into which we'll encode 10 to 20 bytes. This should get us past the
323 // output stream's buffer boundary.
324 buffer = kj::arrayPtr(slowBuffer, sizeof(slowBuffer));
325 out = reinterpret_cast<uint8_t*>(buffer.begin());
326 }
327
328 uint8_t* tagPos = out++;
329
330#define HANDLE_BYTE(n) \
331 uint8_t bit##n = *in != 0; \
332 *out = *in; \
333 out += bit##n; /* out only advances if the byte was non-zero */ \
334 ++in
335
336 HANDLE_BYTE(0);
337 HANDLE_BYTE(1);
338 HANDLE_BYTE(2);
339 HANDLE_BYTE(3);
340 HANDLE_BYTE(4);
341 HANDLE_BYTE(5);
342 HANDLE_BYTE(6);
343 HANDLE_BYTE(7);
344#undef HANDLE_BYTE
345
346 uint8_t tag = (bit0 << 0) | (bit1 << 1) | (bit2 << 2) | (bit3 << 3)
347 | (bit4 << 4) | (bit5 << 5) | (bit6 << 6) | (bit7 << 7);
348 *tagPos = tag;
349
350 if (tag == 0) {
351 // An all-zero word is followed by a count of consecutive zero words (not including the
352 // first one).
353
354 // We can check a whole word at a time. (Here is where we use the assumption that
355 // `src` is word-aligned.)
356 const uint64_t* inWord = reinterpret_cast<const uint64_t*>(in);
357
358 // The count must fit it 1 byte, so limit to 255 words.
359 const uint64_t* limit = reinterpret_cast<const uint64_t*>(inEnd);
360 if (limit - inWord > 255) {
361 limit = inWord + 255;
362 }
363
364 while (inWord < limit && *inWord == 0) {
365 ++inWord;
366 }
367
368 // Write the count.
369 *out++ = inWord - reinterpret_cast<const uint64_t*>(in);
370
371 // Advance input.
372 in = reinterpret_cast<const uint8_t*>(inWord);
373
374 } else if (tag == 0xffu) {
375 // An all-nonzero word is followed by a count of consecutive uncompressed words, followed
376 // by the uncompressed words themselves.
377
378 // Count the number of consecutive words in the input which have no more than a single
379 // zero-byte. We look for at least two zeros because that's the point where our compression
380 // scheme becomes a net win.
381 // TODO(perf): Maybe look for three zeros? Compressing a two-zero word is a loss if the
382 // following word has no zeros.
383 const uint8_t* runStart = in;
384
385 const uint8_t* limit = inEnd;
386 if ((size_t)(limit - in) > 255 * sizeof(word)) {
387 limit = in + 255 * sizeof(word);
388 }
389
390 while (in < limit) {
391 // Check eight input bytes for zeros.
392 uint c = *in++ == 0;
393 c += *in++ == 0;
394 c += *in++ == 0;
395 c += *in++ == 0;
396 c += *in++ == 0;
397 c += *in++ == 0;
398 c += *in++ == 0;
399 c += *in++ == 0;
400
401 if (c >= 2) {
402 // Un-read the word with multiple zeros, since we'll want to compress that one.
403 in -= 8;
404 break;
405 }
406 }
407
408 // Write the count.
409 uint count = in - runStart;
410 *out++ = count / sizeof(word);
411
412 if (count <= reinterpret_cast<uint8_t*>(buffer.end()) - out) {
413 // There's enough space to memcpy.
414 memcpy(out, runStart, count);
415 out += count;
416 } else {
417 // Input overruns the output buffer. We'll give it to the output stream in one chunk
418 // and let it decide what to do.
419 inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin());
420 inner.write(runStart, in - runStart);
421 buffer = inner.getWriteBuffer();
422 out = reinterpret_cast<uint8_t*>(buffer.begin());
423 }
424 }
425 }
426
427 // Write whatever is left.
428 inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin());
429}
430
431} // namespace _ (private)
432
433// =======================================================================================
434
435PackedMessageReader::PackedMessageReader(
436 kj::BufferedInputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
437 : PackedInputStream(inputStream),
438 InputStreamMessageReader(static_cast<PackedInputStream&>(*this), options, scratchSpace) {}
439
440PackedMessageReader::~PackedMessageReader() noexcept(false) {}
441
442PackedFdMessageReader::PackedFdMessageReader(
443 int fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
444 : FdInputStream(fd),
445 BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)),
446 PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this),
447 options, scratchSpace) {}
448
449PackedFdMessageReader::PackedFdMessageReader(
450 kj::AutoCloseFd fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
451 : FdInputStream(kj::mv(fd)),
452 BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)),
453 PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this),
454 options, scratchSpace) {}
455
456PackedFdMessageReader::~PackedFdMessageReader() noexcept(false) {}
457
458void writePackedMessage(kj::BufferedOutputStream& output,
459 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
460 _::PackedOutputStream packedOutput(output);
461 writeMessage(packedOutput, segments);
462}
463
464void writePackedMessage(kj::OutputStream& output,
465 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
466 KJ_IF_MAYBE(bufferedOutputPtr, kj::dynamicDowncastIfAvailable<kj::BufferedOutputStream>(output)) {
467 writePackedMessage(*bufferedOutputPtr, segments);
468 } else {
469 byte buffer[8192];
470 kj::BufferedOutputStreamWrapper bufferedOutput(output, kj::arrayPtr(buffer, sizeof(buffer)));
471 writePackedMessage(bufferedOutput, segments);
472 }
473}
474
475void writePackedMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
476 kj::FdOutputStream output(fd);
477 writePackedMessage(output, segments);
478}
479
480size_t computeUnpackedSizeInWords(kj::ArrayPtr<const byte> packedBytes) {
481 const byte* ptr = packedBytes.begin();
482 const byte* end = packedBytes.end();
483
484 size_t total = 0;
485 while (ptr < end) {
486 uint tag = *ptr;
487 size_t count = kj::popCount(tag);
488 total += 1;
489 KJ_REQUIRE(end - ptr >= count, "invalid packed data");
490 ptr += count + 1;
491
492 if (tag == 0) {
493 KJ_REQUIRE(ptr < end, "invalid packed data");
494 total += *ptr++;
495 } else if (tag == 0xff) {
496 KJ_REQUIRE(ptr < end, "invalid packed data");
497 size_t words = *ptr++;
498 total += words;
499 size_t bytes = words * sizeof(word);
500 KJ_REQUIRE(end - ptr >= bytes, "invalid packed data");
501 ptr += bytes;
502 }
503 }
504
505 return total;
506}
507
508} // namespace capnp
509