1// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2// Licensed under the MIT License:
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20// THE SOFTWARE.
21
22#include "serialize.h"
23#include "layout.h"
24#include <kj/debug.h>
25#include <exception>
26
27namespace capnp {
28
29UnalignedFlatArrayMessageReader::UnalignedFlatArrayMessageReader(
30 kj::ArrayPtr<const word> array, ReaderOptions options)
31 : MessageReader(options), end(array.end()) {
32 if (array.size() < 1) {
33 // Assume empty message.
34 return;
35 }
36
37 const _::WireValue<uint32_t>* table =
38 reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());
39
40 uint segmentCount = table[0].get() + 1;
41 size_t offset = segmentCount / 2u + 1u;
42
43 KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") {
44 return;
45 }
46
47 {
48 uint segmentSize = table[1].get();
49
50 KJ_REQUIRE(array.size() >= offset + segmentSize,
51 "Message ends prematurely in first segment.") {
52 return;
53 }
54
55 segment0 = array.slice(offset, offset + segmentSize);
56 offset += segmentSize;
57 }
58
59 if (segmentCount > 1) {
60 moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);
61
62 for (uint i = 1; i < segmentCount; i++) {
63 uint segmentSize = table[i + 1].get();
64
65 KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") {
66 moreSegments = nullptr;
67 return;
68 }
69
70 moreSegments[i - 1] = array.slice(offset, offset + segmentSize);
71 offset += segmentSize;
72 }
73 }
74
75 end = array.begin() + offset;
76}
77
78size_t expectedSizeInWordsFromPrefix(kj::ArrayPtr<const word> array) {
79 if (array.size() < 1) {
80 // All messages are at least one word.
81 return 1;
82 }
83
84 const _::WireValue<uint32_t>* table =
85 reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());
86
87 uint segmentCount = table[0].get() + 1;
88 size_t offset = segmentCount / 2u + 1u;
89
90 // If the array is too small to contain the full segment table, truncate segmentCount to just
91 // what is available.
92 segmentCount = kj::min(segmentCount, array.size() * 2 - 1u);
93
94 size_t totalSize = offset;
95 for (uint i = 0; i < segmentCount; i++) {
96 totalSize += table[i + 1].get();
97 }
98 return totalSize;
99}
100
101kj::ArrayPtr<const word> UnalignedFlatArrayMessageReader::getSegment(uint id) {
102 if (id == 0) {
103 return segment0;
104 } else if (id <= moreSegments.size()) {
105 return moreSegments[id - 1];
106 } else {
107 return nullptr;
108 }
109}
110
111kj::ArrayPtr<const word> FlatArrayMessageReader::checkAlignment(kj::ArrayPtr<const word> array) {
112 KJ_REQUIRE((uintptr_t)array.begin() % sizeof(void*) == 0,
113 "Input to FlatArrayMessageReader is not aligned. If your architecture supports unaligned "
114 "access (e.g. x86/x64/modern ARM), you may use UnalignedFlatArrayMessageReader instead, "
115 "though this may harm performance.");
116
117 return array;
118}
119
120kj::ArrayPtr<const word> initMessageBuilderFromFlatArrayCopy(
121 kj::ArrayPtr<const word> array, MessageBuilder& target, ReaderOptions options) {
122 FlatArrayMessageReader reader(array, options);
123 target.setRoot(reader.getRoot<AnyPointer>());
124 return kj::arrayPtr(reader.getEnd(), array.end());
125}
126
127kj::Array<word> messageToFlatArray(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
128 kj::Array<word> result = kj::heapArray<word>(computeSerializedSizeInWords(segments));
129
130 _::WireValue<uint32_t>* table =
131 reinterpret_cast<_::WireValue<uint32_t>*>(result.begin());
132
133 // We write the segment count - 1 because this makes the first word zero for single-segment
134 // messages, improving compression. We don't bother doing this with segment sizes because
135 // one-word segments are rare anyway.
136 table[0].set(segments.size() - 1);
137
138 for (uint i = 0; i < segments.size(); i++) {
139 table[i + 1].set(segments[i].size());
140 }
141
142 if (segments.size() % 2 == 0) {
143 // Set padding byte.
144 table[segments.size() + 1].set(0);
145 }
146
147 word* dst = result.begin() + segments.size() / 2 + 1;
148
149 for (auto& segment: segments) {
150 memcpy(dst, segment.begin(), segment.size() * sizeof(word));
151 dst += segment.size();
152 }
153
154 KJ_DASSERT(dst == result.end(), "Buffer overrun/underrun bug in code above.");
155
156 return kj::mv(result);
157}
158
159size_t computeSerializedSizeInWords(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
160 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
161
162 size_t totalSize = segments.size() / 2 + 1;
163
164 for (auto& segment: segments) {
165 totalSize += segment.size();
166 }
167
168 return totalSize;
169}
170
171// =======================================================================================
172
173InputStreamMessageReader::InputStreamMessageReader(
174 kj::InputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
175 : MessageReader(options), inputStream(inputStream), readPos(nullptr) {
176 _::WireValue<uint32_t> firstWord[2];
177
178 inputStream.read(firstWord, sizeof(firstWord));
179
180 uint segmentCount = firstWord[0].get() + 1;
181 uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get();
182
183 size_t totalWords = segment0Size;
184
185 // Reject messages with too many segments for security reasons.
186 KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") {
187 segmentCount = 1;
188 segment0Size = 1;
189 break;
190 }
191
192 // Read sizes for all segments except the first. Include padding if necessary.
193 KJ_STACK_ARRAY(_::WireValue<uint32_t>, moreSizes, segmentCount & ~1, 16, 64);
194 if (segmentCount > 1) {
195 inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]));
196 for (uint i = 0; i < segmentCount - 1; i++) {
197 totalWords += moreSizes[i].get();
198 }
199 }
200
201 // Don't accept a message which the receiver couldn't possibly traverse without hitting the
202 // traversal limit. Without this check, a malicious client could transmit a very large segment
203 // size to make the receiver allocate excessive space and possibly crash.
204 KJ_REQUIRE(totalWords <= options.traversalLimitInWords,
205 "Message is too large. To increase the limit on the receiving end, see "
206 "capnp::ReaderOptions.") {
207 segmentCount = 1;
208 segment0Size = kj::min(segment0Size, options.traversalLimitInWords);
209 totalWords = segment0Size;
210 break;
211 }
212
213 if (scratchSpace.size() < totalWords) {
214 // TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
215 // fragmentation.
216 ownedSpace = kj::heapArray<word>(totalWords);
217 scratchSpace = ownedSpace;
218 }
219
220 segment0 = scratchSpace.slice(0, segment0Size);
221
222 if (segmentCount > 1) {
223 moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);
224 size_t offset = segment0Size;
225
226 for (uint i = 0; i < segmentCount - 1; i++) {
227 uint segmentSize = moreSizes[i].get();
228 moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize);
229 offset += segmentSize;
230 }
231 }
232
233 if (segmentCount == 1) {
234 inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
235 } else if (segmentCount > 1) {
236 readPos = scratchSpace.asBytes().begin();
237 readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word));
238 }
239}
240
241InputStreamMessageReader::~InputStreamMessageReader() noexcept(false) {
242 if (readPos != nullptr) {
243 unwindDetector.catchExceptionsIfUnwinding([&]() {
244 // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
245 // valid.
246 const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
247 inputStream.skip(allEnd - readPos);
248 });
249 }
250}
251
252kj::ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
253 if (id > moreSegments.size()) {
254 return nullptr;
255 }
256
257 kj::ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1];
258
259 if (readPos != nullptr) {
260 // May need to lazily read more data.
261 const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end());
262 if (readPos < segmentEnd) {
263 // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
264 // valid.
265 const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
266 readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos);
267 }
268 }
269
270 return segment;
271}
272
273void readMessageCopy(kj::InputStream& input, MessageBuilder& target,
274 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
275 InputStreamMessageReader message(input, options, scratchSpace);
276 target.setRoot(message.getRoot<AnyPointer>());
277}
278
279// -------------------------------------------------------------------
280
281void writeMessage(kj::OutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
282 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
283
284 KJ_STACK_ARRAY(_::WireValue<uint32_t>, table, (segments.size() + 2) & ~size_t(1), 16, 64);
285
286 // We write the segment count - 1 because this makes the first word zero for single-segment
287 // messages, improving compression. We don't bother doing this with segment sizes because
288 // one-word segments are rare anyway.
289 table[0].set(segments.size() - 1);
290 for (uint i = 0; i < segments.size(); i++) {
291 table[i + 1].set(segments[i].size());
292 }
293 if (segments.size() % 2 == 0) {
294 // Set padding byte.
295 table[segments.size() + 1].set(0);
296 }
297
298 KJ_STACK_ARRAY(kj::ArrayPtr<const byte>, pieces, segments.size() + 1, 4, 32);
299 pieces[0] = table.asBytes();
300
301 for (uint i = 0; i < segments.size(); i++) {
302 pieces[i + 1] = segments[i].asBytes();
303 }
304
305 output.write(pieces);
306}
307
308// =======================================================================================
309
310StreamFdMessageReader::~StreamFdMessageReader() noexcept(false) {}
311
312void writeMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
313 kj::FdOutputStream stream(fd);
314 writeMessage(stream, segments);
315}
316
317void readMessageCopyFromFd(int fd, MessageBuilder& target,
318 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
319 kj::FdInputStream stream(fd);
320 readMessageCopy(stream, target, options, scratchSpace);
321}
322
323} // namespace capnp
324