1// Copyright (c) 2017 Cloudflare, Inc.; Sandstorm Development Group, Inc.; and contributors
2// Licensed under the MIT License:
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20// THE SOFTWARE.
21
22#include "encoding.h"
23#include "vector.h"
24#include "debug.h"
25
26namespace kj {
27
28namespace {
29
30#define GOTO_ERROR_IF(cond) if (KJ_UNLIKELY(cond)) goto error
31
32inline void addChar32(Vector<char16_t>& vec, char32_t u) {
33 // Encode as surrogate pair.
34 u -= 0x10000;
35 vec.add(0xd800 | (u >> 10));
36 vec.add(0xdc00 | (u & 0x03ff));
37}
38
39inline void addChar32(Vector<char32_t>& vec, char32_t u) {
40 vec.add(u);
41}
42
43template <typename T>
44EncodingResult<Array<T>> encodeUtf(ArrayPtr<const char> text, bool nulTerminate) {
45 Vector<T> result(text.size() + nulTerminate);
46 bool hadErrors = false;
47
48 size_t i = 0;
49 while (i < text.size()) {
50 byte c = text[i++];
51 if (c < 0x80) {
52 // 0xxxxxxx -- ASCII
53 result.add(c);
54 continue;
55 } else if (KJ_UNLIKELY(c < 0xc0)) {
56 // 10xxxxxx -- malformed continuation byte
57 goto error;
58 } else if (c < 0xe0) {
59 // 110xxxxx -- 2-byte
60 byte c2;
61 GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i;
62 char16_t u = (static_cast<char16_t>(c & 0x1f) << 6)
63 | (static_cast<char16_t>(c2 & 0x3f) );
64
65 // Disallow overlong sequence.
66 GOTO_ERROR_IF(u < 0x80);
67
68 result.add(u);
69 continue;
70 } else if (c < 0xf0) {
71 // 1110xxxx -- 3-byte
72 byte c2, c3;
73 GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i;
74 GOTO_ERROR_IF(i == text.size() || ((c3 = text[i]) & 0xc0) != 0x80); ++i;
75 char16_t u = (static_cast<char16_t>(c & 0x0f) << 12)
76 | (static_cast<char16_t>(c2 & 0x3f) << 6)
77 | (static_cast<char16_t>(c3 & 0x3f) );
78
79 // Disallow overlong sequence.
80 GOTO_ERROR_IF(u < 0x0800);
81
82 // Flag surrogate pair code points as errors, but allow them through.
83 if (KJ_UNLIKELY((u & 0xf800) == 0xd800)) {
84 if (result.size() > 0 &&
85 (u & 0xfc00) == 0xdc00 &&
86 (result.back() & 0xfc00) == 0xd800) {
87 // Whoops, the *previous* character was also an invalid surrogate, and if we add this
88 // one too, they'll form a valid surrogate pair. If we allowed this, then it would mean
89 // invalid UTF-8 round-tripped to UTF-16 and back could actually change meaning entirely.
90 // OTOH, the reason we allow dangling surrogates is to allow invalid UTF-16 to round-trip
91 // to UTF-8 without loss, but if the original UTF-16 had a valid surrogate pair, it would
92 // have been encoded as a valid single UTF-8 codepoint, not as separate UTF-8 codepoints
93 // for each surrogate.
94 goto error;
95 }
96
97 hadErrors = true;
98 }
99
100 result.add(u);
101 continue;
102 } else if (c < 0xf8) {
103 // 11110xxx -- 4-byte
104 byte c2, c3, c4;
105 GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i;
106 GOTO_ERROR_IF(i == text.size() || ((c3 = text[i]) & 0xc0) != 0x80); ++i;
107 GOTO_ERROR_IF(i == text.size() || ((c4 = text[i]) & 0xc0) != 0x80); ++i;
108 char32_t u = (static_cast<char32_t>(c & 0x07) << 18)
109 | (static_cast<char32_t>(c2 & 0x3f) << 12)
110 | (static_cast<char32_t>(c3 & 0x3f) << 6)
111 | (static_cast<char32_t>(c4 & 0x3f) );
112
113 // Disallow overlong sequence.
114 GOTO_ERROR_IF(u < 0x10000);
115
116 // Unicode ends at U+10FFFF
117 GOTO_ERROR_IF(u >= 0x110000);
118
119 addChar32(result, u);
120 continue;
121 } else {
122 // 5-byte and 6-byte sequences are not legal as they'd result in codepoints outside the
123 // range of Unicode.
124 goto error;
125 }
126
127 error:
128 result.add(0xfffd);
129 hadErrors = true;
130 // Ignore all continuation bytes.
131 while (i < text.size() && (text[i] & 0xc0) == 0x80) {
132 ++i;
133 }
134 }
135
136 if (nulTerminate) result.add(0);
137
138 return { result.releaseAsArray(), hadErrors };
139}
140
141} // namespace
142
143EncodingResult<Array<char16_t>> encodeUtf16(ArrayPtr<const char> text, bool nulTerminate) {
144 return encodeUtf<char16_t>(text, nulTerminate);
145}
146
147EncodingResult<Array<char32_t>> encodeUtf32(ArrayPtr<const char> text, bool nulTerminate) {
148 return encodeUtf<char32_t>(text, nulTerminate);
149}
150
151EncodingResult<String> decodeUtf16(ArrayPtr<const char16_t> utf16) {
152 Vector<char> result(utf16.size() + 1);
153 bool hadErrors = false;
154
155 size_t i = 0;
156 while (i < utf16.size()) {
157 char16_t u = utf16[i++];
158
159 if (u < 0x80) {
160 result.add(u);
161 continue;
162 } else if (u < 0x0800) {
163 result.addAll<std::initializer_list<char>>({
164 static_cast<char>(((u >> 6) ) | 0xc0),
165 static_cast<char>(((u ) & 0x3f) | 0x80)
166 });
167 continue;
168 } else if ((u & 0xf800) == 0xd800) {
169 // surrogate pair
170 char16_t u2;
171 if (KJ_UNLIKELY(i == utf16.size() // missing second half
172 || (u & 0x0400) != 0 // first half in wrong range
173 || ((u2 = utf16[i]) & 0xfc00) != 0xdc00)) { // second half in wrong range
174 hadErrors = true;
175 goto threeByte;
176 }
177 ++i;
178
179 char32_t u32 = (((u & 0x03ff) << 10) | (u2 & 0x03ff)) + 0x10000;
180 result.addAll<std::initializer_list<char>>({
181 static_cast<char>(((u32 >> 18) ) | 0xf0),
182 static_cast<char>(((u32 >> 12) & 0x3f) | 0x80),
183 static_cast<char>(((u32 >> 6) & 0x3f) | 0x80),
184 static_cast<char>(((u32 ) & 0x3f) | 0x80)
185 });
186 continue;
187 } else {
188 threeByte:
189 result.addAll<std::initializer_list<char>>({
190 static_cast<char>(((u >> 12) ) | 0xe0),
191 static_cast<char>(((u >> 6) & 0x3f) | 0x80),
192 static_cast<char>(((u ) & 0x3f) | 0x80)
193 });
194 continue;
195 }
196 }
197
198 result.add(0);
199 return { String(result.releaseAsArray()), hadErrors };
200}
201
202EncodingResult<String> decodeUtf32(ArrayPtr<const char32_t> utf16) {
203 Vector<char> result(utf16.size() + 1);
204 bool hadErrors = false;
205
206 size_t i = 0;
207 while (i < utf16.size()) {
208 char32_t u = utf16[i++];
209
210 if (u < 0x80) {
211 result.add(u);
212 continue;
213 } else if (u < 0x0800) {
214 result.addAll<std::initializer_list<char>>({
215 static_cast<char>(((u >> 6) ) | 0xc0),
216 static_cast<char>(((u ) & 0x3f) | 0x80)
217 });
218 continue;
219 } else if (u < 0x10000) {
220 if (KJ_UNLIKELY((u & 0xfffff800) == 0xd800)) {
221 // no surrogates allowed in utf-32
222 hadErrors = true;
223 }
224 result.addAll<std::initializer_list<char>>({
225 static_cast<char>(((u >> 12) ) | 0xe0),
226 static_cast<char>(((u >> 6) & 0x3f) | 0x80),
227 static_cast<char>(((u ) & 0x3f) | 0x80)
228 });
229 continue;
230 } else {
231 GOTO_ERROR_IF(u >= 0x110000); // outside Unicode range
232 result.addAll<std::initializer_list<char>>({
233 static_cast<char>(((u >> 18) ) | 0xf0),
234 static_cast<char>(((u >> 12) & 0x3f) | 0x80),
235 static_cast<char>(((u >> 6) & 0x3f) | 0x80),
236 static_cast<char>(((u ) & 0x3f) | 0x80)
237 });
238 continue;
239 }
240
241 error:
242 result.addAll(StringPtr(u8"\ufffd"));
243 hadErrors = true;
244 }
245
246 result.add(0);
247 return { String(result.releaseAsArray()), hadErrors };
248}
249
250namespace {
251
252#if __GNUC__ >= 8 && !__clang__
253// GCC 8's new class-memaccess warning rightly dislikes the following hacks, but we're really sure
254// we want to allow them so disable the warning.
255#pragma GCC diagnostic push
256#pragma GCC diagnostic ignored "-Wclass-memaccess"
257#endif
258
259template <typename To, typename From>
260Array<To> coerceTo(Array<From>&& array) {
261 static_assert(sizeof(To) == sizeof(From), "incompatible coercion");
262 Array<wchar_t> result;
263 memcpy(&result, &array, sizeof(array));
264 memset(&array, 0, sizeof(array));
265 return result;
266}
267
268template <typename To, typename From>
269ArrayPtr<To> coerceTo(ArrayPtr<From> array) {
270 static_assert(sizeof(To) == sizeof(From), "incompatible coercion");
271 return arrayPtr(reinterpret_cast<To*>(array.begin()), array.size());
272}
273
274template <typename To, typename From>
275EncodingResult<Array<To>> coerceTo(EncodingResult<Array<From>>&& result) {
276 return { coerceTo<To>(Array<From>(kj::mv(result))), result.hadErrors };
277}
278
279#if __GNUC__ >= 8 && !__clang__
280#pragma GCC diagnostic pop
281#endif
282
283template <size_t s>
284struct WideConverter;
285
286template <>
287struct WideConverter<sizeof(char)> {
288 typedef char Type;
289
290 static EncodingResult<Array<char>> encode(ArrayPtr<const char> text, bool nulTerminate) {
291 auto result = heapArray<char>(text.size() + nulTerminate);
292 memcpy(result.begin(), text.begin(), text.size());
293 if (nulTerminate) result.back() = 0;
294 return { kj::mv(result), false };
295 }
296
297 static EncodingResult<kj::String> decode(ArrayPtr<const char> text) {
298 return { kj::heapString(text), false };
299 }
300};
301
302template <>
303struct WideConverter<sizeof(char16_t)> {
304 typedef char16_t Type;
305
306 static inline EncodingResult<Array<char16_t>> encode(
307 ArrayPtr<const char> text, bool nulTerminate) {
308 return encodeUtf16(text, nulTerminate);
309 }
310
311 static inline EncodingResult<kj::String> decode(ArrayPtr<const char16_t> text) {
312 return decodeUtf16(text);
313 }
314};
315
316template <>
317struct WideConverter<sizeof(char32_t)> {
318 typedef char32_t Type;
319
320 static inline EncodingResult<Array<char32_t>> encode(
321 ArrayPtr<const char> text, bool nulTerminate) {
322 return encodeUtf32(text, nulTerminate);
323 }
324
325 static inline EncodingResult<kj::String> decode(ArrayPtr<const char32_t> text) {
326 return decodeUtf32(text);
327 }
328};
329
330} // namespace
331
332EncodingResult<Array<wchar_t>> encodeWideString(ArrayPtr<const char> text, bool nulTerminate) {
333 return coerceTo<wchar_t>(WideConverter<sizeof(wchar_t)>::encode(text, nulTerminate));
334}
335EncodingResult<String> decodeWideString(ArrayPtr<const wchar_t> wide) {
336 using Converter = WideConverter<sizeof(wchar_t)>;
337 return Converter::decode(coerceTo<const Converter::Type>(wide));
338}
339
340// =======================================================================================
341
342namespace {
343
344const char HEX_DIGITS[] = "0123456789abcdef";
345// Maps integer in the range [0,16) to a hex digit.
346
347const char HEX_DIGITS_URI[] = "0123456789ABCDEF";
348// RFC 3986 section 2.1 says "For consistency, URI producers and normalizers should use uppercase
349// hexadecimal digits for all percent-encodings.
350
351static Maybe<uint> tryFromHexDigit(char c) {
352 if ('0' <= c && c <= '9') {
353 return c - '0';
354 } else if ('a' <= c && c <= 'f') {
355 return c - ('a' - 10);
356 } else if ('A' <= c && c <= 'F') {
357 return c - ('A' - 10);
358 } else {
359 return nullptr;
360 }
361}
362
363static Maybe<uint> tryFromOctDigit(char c) {
364 if ('0' <= c && c <= '7') {
365 return c - '0';
366 } else {
367 return nullptr;
368 }
369}
370
371} // namespace
372
373String encodeHex(ArrayPtr<const byte> input) {
374 return strArray(KJ_MAP(b, input) {
375 return heapArray<char>({HEX_DIGITS[b/16], HEX_DIGITS[b%16]});
376 }, "");
377}
378
379EncodingResult<Array<byte>> decodeHex(ArrayPtr<const char> text) {
380 auto result = heapArray<byte>(text.size() / 2);
381 bool hadErrors = text.size() % 2;
382
383 for (auto i: kj::indices(result)) {
384 byte b = 0;
385 KJ_IF_MAYBE(d1, tryFromHexDigit(text[i*2])) {
386 b = *d1 << 4;
387 } else {
388 hadErrors = true;
389 }
390 KJ_IF_MAYBE(d2, tryFromHexDigit(text[i*2+1])) {
391 b |= *d2;
392 } else {
393 hadErrors = true;
394 }
395 result[i] = b;
396 }
397
398 return { kj::mv(result), hadErrors };
399}
400
401String encodeUriComponent(ArrayPtr<const byte> bytes) {
402 Vector<char> result(bytes.size() + 1);
403 for (byte b: bytes) {
404 if (('A' <= b && b <= 'Z') ||
405 ('a' <= b && b <= 'z') ||
406 ('0' <= b && b <= '9') ||
407 b == '-' || b == '_' || b == '.' || b == '!' || b == '~' || b == '*' || b == '\'' ||
408 b == '(' || b == ')') {
409 result.add(b);
410 } else {
411 result.add('%');
412 result.add(HEX_DIGITS_URI[b/16]);
413 result.add(HEX_DIGITS_URI[b%16]);
414 }
415 }
416 result.add('\0');
417 return String(result.releaseAsArray());
418}
419
420String encodeUriFragment(ArrayPtr<const byte> bytes) {
421 Vector<char> result(bytes.size() + 1);
422 for (byte b: bytes) {
423 if (('?' <= b && b <= '_') || // covers A-Z
424 ('a' <= b && b <= '~') || // covers a-z
425 ('&' <= b && b <= ';') || // covers 0-9
426 b == '!' || b == '=' || b == '#' || b == '$') {
427 result.add(b);
428 } else {
429 result.add('%');
430 result.add(HEX_DIGITS_URI[b/16]);
431 result.add(HEX_DIGITS_URI[b%16]);
432 }
433 }
434 result.add('\0');
435 return String(result.releaseAsArray());
436}
437
438String encodeUriPath(ArrayPtr<const byte> bytes) {
439 Vector<char> result(bytes.size() + 1);
440 for (byte b: bytes) {
441 if (('@' <= b && b <= '[') || // covers A-Z
442 ('a' <= b && b <= 'z') ||
443 ('0' <= b && b <= ';') || // covers 0-9
444 ('&' <= b && b <= '.') ||
445 b == '_' || b == '!' || b == '=' || b == ']' ||
446 b == '^' || b == '|' || b == '~' || b == '$') {
447 result.add(b);
448 } else {
449 result.add('%');
450 result.add(HEX_DIGITS_URI[b/16]);
451 result.add(HEX_DIGITS_URI[b%16]);
452 }
453 }
454 result.add('\0');
455 return String(result.releaseAsArray());
456}
457
458String encodeUriUserInfo(ArrayPtr<const byte> bytes) {
459 Vector<char> result(bytes.size() + 1);
460 for (byte b: bytes) {
461 if (('A' <= b && b <= 'Z') ||
462 ('a' <= b && b <= 'z') ||
463 ('0' <= b && b <= '9') ||
464 ('&' <= b && b <= '.') ||
465 b == '_' || b == '!' || b == '~' || b == '$') {
466 result.add(b);
467 } else {
468 result.add('%');
469 result.add(HEX_DIGITS_URI[b/16]);
470 result.add(HEX_DIGITS_URI[b%16]);
471 }
472 }
473 result.add('\0');
474 return String(result.releaseAsArray());
475}
476
477String encodeWwwForm(ArrayPtr<const byte> bytes) {
478 Vector<char> result(bytes.size() + 1);
479 for (byte b: bytes) {
480 if (('A' <= b && b <= 'Z') ||
481 ('a' <= b && b <= 'z') ||
482 ('0' <= b && b <= '9') ||
483 b == '-' || b == '_' || b == '.' || b == '*') {
484 result.add(b);
485 } else if (b == ' ') {
486 result.add('+');
487 } else {
488 result.add('%');
489 result.add(HEX_DIGITS_URI[b/16]);
490 result.add(HEX_DIGITS_URI[b%16]);
491 }
492 }
493 result.add('\0');
494 return String(result.releaseAsArray());
495}
496
497EncodingResult<Array<byte>> decodeBinaryUriComponent(
498 ArrayPtr<const char> text, DecodeUriOptions options) {
499 Vector<byte> result(text.size() + options.nulTerminate);
500 bool hadErrors = false;
501
502 const char* ptr = text.begin();
503 const char* end = text.end();
504 while (ptr < end) {
505 if (*ptr == '%') {
506 ++ptr;
507
508 if (ptr == end) {
509 hadErrors = true;
510 } else KJ_IF_MAYBE(d1, tryFromHexDigit(*ptr)) {
511 byte b = *d1;
512 ++ptr;
513 if (ptr == end) {
514 hadErrors = true;
515 } else KJ_IF_MAYBE(d2, tryFromHexDigit(*ptr)) {
516 b = (b << 4) | *d2;
517 ++ptr;
518 } else {
519 hadErrors = true;
520 }
521 result.add(b);
522 } else {
523 hadErrors = true;
524 }
525 } else if (options.plusToSpace && *ptr == '+') {
526 ++ptr;
527 result.add(' ');
528 } else {
529 result.add(*ptr++);
530 }
531 }
532
533 if (options.nulTerminate) result.add(0);
534 return { result.releaseAsArray(), hadErrors };
535}
536
537// =======================================================================================
538
539String encodeCEscape(ArrayPtr<const byte> bytes) {
540 Vector<char> escaped(bytes.size());
541
542 for (byte b: bytes) {
543 switch (b) {
544 case '\a': escaped.addAll(StringPtr("\\a")); break;
545 case '\b': escaped.addAll(StringPtr("\\b")); break;
546 case '\f': escaped.addAll(StringPtr("\\f")); break;
547 case '\n': escaped.addAll(StringPtr("\\n")); break;
548 case '\r': escaped.addAll(StringPtr("\\r")); break;
549 case '\t': escaped.addAll(StringPtr("\\t")); break;
550 case '\v': escaped.addAll(StringPtr("\\v")); break;
551 case '\'': escaped.addAll(StringPtr("\\\'")); break;
552 case '\"': escaped.addAll(StringPtr("\\\"")); break;
553 case '\\': escaped.addAll(StringPtr("\\\\")); break;
554 default:
555 if (b < 0x20 || b == 0x7f) {
556 // Use octal escape, not hex, because hex escapes technically have no length limit and
557 // so can create ambiguity with subsequent characters.
558 escaped.add('\\');
559 escaped.add(HEX_DIGITS[b / 64]);
560 escaped.add(HEX_DIGITS[(b / 8) % 8]);
561 escaped.add(HEX_DIGITS[b % 8]);
562 } else {
563 escaped.add(b);
564 }
565 break;
566 }
567 }
568
569 escaped.add(0);
570 return String(escaped.releaseAsArray());
571}
572
573EncodingResult<Array<byte>> decodeBinaryCEscape(ArrayPtr<const char> text, bool nulTerminate) {
574 Vector<byte> result(text.size() + nulTerminate);
575 bool hadErrors = false;
576
577 size_t i = 0;
578 while (i < text.size()) {
579 char c = text[i++];
580 if (c == '\\') {
581 if (i == text.size()) {
582 hadErrors = true;
583 continue;
584 }
585 char c2 = text[i++];
586 switch (c2) {
587 case 'a' : result.add('\a'); break;
588 case 'b' : result.add('\b'); break;
589 case 'f' : result.add('\f'); break;
590 case 'n' : result.add('\n'); break;
591 case 'r' : result.add('\r'); break;
592 case 't' : result.add('\t'); break;
593 case 'v' : result.add('\v'); break;
594 case '\'': result.add('\''); break;
595 case '\"': result.add('\"'); break;
596 case '\\': result.add('\\'); break;
597
598 case '0':
599 case '1':
600 case '2':
601 case '3':
602 case '4':
603 case '5':
604 case '6':
605 case '7': {
606 uint value = c2 - '0';
607 for (uint j = 0; j < 2 && i < text.size(); j++) {
608 KJ_IF_MAYBE(d, tryFromOctDigit(text[i])) {
609 ++i;
610 value = (value << 3) | *d;
611 } else {
612 break;
613 }
614 }
615 if (value >= 0x100) hadErrors = true;
616 result.add(value);
617 break;
618 }
619
620 case 'x': {
621 uint value = 0;
622 while (i < text.size()) {
623 KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) {
624 ++i;
625 value = (value << 4) | *d;
626 } else {
627 break;
628 }
629 }
630 if (value >= 0x100) hadErrors = true;
631 result.add(value);
632 break;
633 }
634
635 case 'u': {
636 char16_t value = 0;
637 for (uint j = 0; j < 4; j++) {
638 if (i == text.size()) {
639 hadErrors = true;
640 break;
641 } else KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) {
642 ++i;
643 value = (value << 4) | *d;
644 } else {
645 hadErrors = true;
646 break;
647 }
648 }
649 auto utf = decodeUtf16(arrayPtr(&value, 1));
650 if (utf.hadErrors) hadErrors = true;
651 result.addAll(utf.asBytes());
652 break;
653 }
654
655 case 'U': {
656 char32_t value = 0;
657 for (uint j = 0; j < 8; j++) {
658 if (i == text.size()) {
659 hadErrors = true;
660 break;
661 } else KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) {
662 ++i;
663 value = (value << 4) | *d;
664 } else {
665 hadErrors = true;
666 break;
667 }
668 }
669 auto utf = decodeUtf32(arrayPtr(&value, 1));
670 if (utf.hadErrors) hadErrors = true;
671 result.addAll(utf.asBytes());
672 break;
673 }
674
675 default:
676 result.add(c2);
677 }
678 } else {
679 result.add(c);
680 }
681 }
682
683 if (nulTerminate) result.add(0);
684 return { result.releaseAsArray(), hadErrors };
685}
686
687// =======================================================================================
688// This code is derived from libb64 which has been placed in the public domain.
689// For details, see http://sourceforge.net/projects/libb64
690
691// -------------------------------------------------------------------
692// Encoder
693
694namespace {
695
696typedef enum {
697 step_A, step_B, step_C
698} base64_encodestep;
699
700typedef struct {
701 base64_encodestep step;
702 char result;
703 int stepcount;
704} base64_encodestate;
705
706const int CHARS_PER_LINE = 72;
707
708void base64_init_encodestate(base64_encodestate* state_in) {
709 state_in->step = step_A;
710 state_in->result = 0;
711 state_in->stepcount = 0;
712}
713
714char base64_encode_value(char value_in) {
715 static const char* encoding = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
716 if (value_in > 63) return '=';
717 return encoding[(int)value_in];
718}
719
720int base64_encode_block(const char* plaintext_in, int length_in,
721 char* code_out, base64_encodestate* state_in, bool breakLines) {
722 const char* plainchar = plaintext_in;
723 const char* const plaintextend = plaintext_in + length_in;
724 char* codechar = code_out;
725 char result;
726 char fragment;
727
728 result = state_in->result;
729
730 switch (state_in->step) {
731 while (1) {
732 case step_A:
733 if (plainchar == plaintextend) {
734 state_in->result = result;
735 state_in->step = step_A;
736 return codechar - code_out;
737 }
738 fragment = *plainchar++;
739 result = (fragment & 0x0fc) >> 2;
740 *codechar++ = base64_encode_value(result);
741 result = (fragment & 0x003) << 4;
742 // fallthrough
743 case step_B:
744 if (plainchar == plaintextend) {
745 state_in->result = result;
746 state_in->step = step_B;
747 return codechar - code_out;
748 }
749 fragment = *plainchar++;
750 result |= (fragment & 0x0f0) >> 4;
751 *codechar++ = base64_encode_value(result);
752 result = (fragment & 0x00f) << 2;
753 // fallthrough
754 case step_C:
755 if (plainchar == plaintextend) {
756 state_in->result = result;
757 state_in->step = step_C;
758 return codechar - code_out;
759 }
760 fragment = *plainchar++;
761 result |= (fragment & 0x0c0) >> 6;
762 *codechar++ = base64_encode_value(result);
763 result = (fragment & 0x03f) >> 0;
764 *codechar++ = base64_encode_value(result);
765
766 ++(state_in->stepcount);
767 if (breakLines && state_in->stepcount == CHARS_PER_LINE/4) {
768 *codechar++ = '\n';
769 state_in->stepcount = 0;
770 }
771 }
772 }
773 /* control should not reach here */
774 return codechar - code_out;
775}
776
777int base64_encode_blockend(char* code_out, base64_encodestate* state_in, bool breakLines) {
778 char* codechar = code_out;
779
780 switch (state_in->step) {
781 case step_B:
782 *codechar++ = base64_encode_value(state_in->result);
783 *codechar++ = '=';
784 *codechar++ = '=';
785 ++state_in->stepcount;
786 break;
787 case step_C:
788 *codechar++ = base64_encode_value(state_in->result);
789 *codechar++ = '=';
790 ++state_in->stepcount;
791 break;
792 case step_A:
793 break;
794 }
795 if (breakLines && state_in->stepcount > 0) {
796 *codechar++ = '\n';
797 }
798
799 return codechar - code_out;
800}
801
802} // namespace
803
804String encodeBase64(ArrayPtr<const byte> input, bool breakLines) {
805 /* set up a destination buffer large enough to hold the encoded data */
806 // equivalent to ceil(input.size() / 3) * 4
807 auto numChars = (input.size() + 2) / 3 * 4;
808 if (breakLines) {
809 // Add space for newline characters.
810 uint lineCount = numChars / CHARS_PER_LINE;
811 if (numChars % CHARS_PER_LINE > 0) {
812 // Partial line.
813 ++lineCount;
814 }
815 numChars = numChars + lineCount;
816 }
817 auto output = heapString(numChars);
818 /* keep track of our encoded position */
819 char* c = output.begin();
820 /* store the number of bytes encoded by a single call */
821 int cnt = 0;
822 size_t total = 0;
823 /* we need an encoder state */
824 base64_encodestate s;
825
826 /*---------- START ENCODING ----------*/
827 /* initialise the encoder state */
828 base64_init_encodestate(&s);
829 /* gather data from the input and send it to the output */
830 cnt = base64_encode_block((const char *)input.begin(), input.size(), c, &s, breakLines);
831 c += cnt;
832 total += cnt;
833
834 /* since we have encoded the entire input string, we know that
835 there is no more input data; finalise the encoding */
836 cnt = base64_encode_blockend(c, &s, breakLines);
837 c += cnt;
838 total += cnt;
839 /*---------- STOP ENCODING ----------*/
840
841 KJ_ASSERT(total == output.size(), total, output.size());
842
843 return output;
844}
845
846// -------------------------------------------------------------------
847// Decoder
848
849namespace {
850
851typedef enum {
852 step_a, step_b, step_c, step_d
853} base64_decodestep;
854
855typedef struct {
856 bool hadErrors = false;
857 size_t nPaddingBytesSeen = 0;
858 // Output state. `nPaddingBytesSeen` is not guaranteed to be correct if `hadErrors` is true. It is
859 // included in the state purely to preserve the streaming capability of the algorithm while still
860 // checking for errors correctly (consider chunk 1 = "abc=", chunk 2 = "d").
861
862 base64_decodestep step = step_a;
863 char plainchar = 0;
864} base64_decodestate;
865
866int base64_decode_value(char value_in) {
867 // Returns either the fragment value or: -1 on whitespace, -2 on padding, -3 on invalid input.
868 //
869 // Note that the original libb64 implementation used -1 for invalid input, -2 on padding -- this
870 // new scheme allows for some simpler error checks in steps A and B.
871
872 static const signed char decoding[] = {
873 -3,-3,-3,-3,-3,-3,-3,-3, -3,-1,-1,-3,-1,-1,-3,-3,
874 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
875 -1,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,62,-3,-3,-3,63,
876 52,53,54,55,56,57,58,59, 60,61,-3,-3,-3,-2,-3,-3,
877 -3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,
878 15,16,17,18,19,20,21,22, 23,24,25,-3,-3,-3,-3,-3,
879 -3,26,27,28,29,30,31,32, 33,34,35,36,37,38,39,40,
880 41,42,43,44,45,46,47,48, 49,50,51,-3,-3,-3,-3,-3,
881
882 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
883 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
884 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
885 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
886 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
887 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
888 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
889 -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3,
890 };
891 static_assert(sizeof(decoding) == 256, "base64 decoding table size error");
892 return decoding[(unsigned char)value_in];
893}
894
895int base64_decode_block(const char* code_in, const int length_in,
896 char* plaintext_out, base64_decodestate* state_in) {
897 const char* codechar = code_in;
898 char* plainchar = plaintext_out;
899 signed char fragment;
900
901 if (state_in->step != step_a) {
902 *plainchar = state_in->plainchar;
903 }
904
905#define ERROR_IF(predicate) state_in->hadErrors = state_in->hadErrors || (predicate)
906
907 switch (state_in->step)
908 {
909 while (1)
910 {
911 case step_a:
912 do {
913 if (codechar == code_in+length_in) {
914 state_in->step = step_a;
915 state_in->plainchar = '\0';
916 return plainchar - plaintext_out;
917 }
918 fragment = (signed char)base64_decode_value(*codechar++);
919 // It is an error to see invalid or padding bytes in step A.
920 ERROR_IF(fragment < -1);
921 } while (fragment < 0);
922 *plainchar = (fragment & 0x03f) << 2;
923 // fallthrough
924 case step_b:
925 do {
926 if (codechar == code_in+length_in) {
927 state_in->step = step_b;
928 state_in->plainchar = *plainchar;
929 // It is always an error to suspend from step B, because we don't have enough bits yet.
930 // TODO(someday): This actually breaks the streaming use case, if base64_decode_block() is
931 // to be called multiple times. We'll fix it if we ever care to support streaming.
932 state_in->hadErrors = true;
933 return plainchar - plaintext_out;
934 }
935 fragment = (signed char)base64_decode_value(*codechar++);
936 // It is an error to see invalid or padding bytes in step B.
937 ERROR_IF(fragment < -1);
938 } while (fragment < 0);
939 *plainchar++ |= (fragment & 0x030) >> 4;
940 *plainchar = (fragment & 0x00f) << 4;
941 // fallthrough
942 case step_c:
943 do {
944 if (codechar == code_in+length_in) {
945 state_in->step = step_c;
946 state_in->plainchar = *plainchar;
947 // It is an error to complete from step C if we have seen incomplete padding.
948 // TODO(someday): This actually breaks the streaming use case, if base64_decode_block() is
949 // to be called multiple times. We'll fix it if we ever care to support streaming.
950 ERROR_IF(state_in->nPaddingBytesSeen == 1);
951 return plainchar - plaintext_out;
952 }
953 fragment = (signed char)base64_decode_value(*codechar++);
954 // It is an error to see invalid bytes or more than two padding bytes in step C.
955 ERROR_IF(fragment < -2 || (fragment == -2 && ++state_in->nPaddingBytesSeen > 2));
956 } while (fragment < 0);
957 // It is an error to continue from step C after having seen any padding.
958 ERROR_IF(state_in->nPaddingBytesSeen > 0);
959 *plainchar++ |= (fragment & 0x03c) >> 2;
960 *plainchar = (fragment & 0x003) << 6;
961 // fallthrough
962 case step_d:
963 do {
964 if (codechar == code_in+length_in) {
965 state_in->step = step_d;
966 state_in->plainchar = *plainchar;
967 return plainchar - plaintext_out;
968 }
969 fragment = (signed char)base64_decode_value(*codechar++);
970 // It is an error to see invalid bytes or more than one padding byte in step D.
971 ERROR_IF(fragment < -2 || (fragment == -2 && ++state_in->nPaddingBytesSeen > 1));
972 } while (fragment < 0);
973 // It is an error to continue from step D after having seen padding bytes.
974 ERROR_IF(state_in->nPaddingBytesSeen > 0);
975 *plainchar++ |= (fragment & 0x03f);
976 }
977 }
978
979#undef ERROR_IF
980
981 /* control should not reach here */
982 return plainchar - plaintext_out;
983}
984
985} // namespace
986
987EncodingResult<Array<byte>> decodeBase64(ArrayPtr<const char> input) {
988 base64_decodestate state;
989
990 auto output = heapArray<byte>((input.size() * 6 + 7) / 8);
991
992 size_t n = base64_decode_block(input.begin(), input.size(),
993 reinterpret_cast<char*>(output.begin()), &state);
994
995 if (n < output.size()) {
996 auto copy = heapArray<byte>(n);
997 memcpy(copy.begin(), output.begin(), n);
998 output = kj::mv(copy);
999 }
1000
1001 return EncodingResult<Array<byte>>(kj::mv(output), state.hadErrors);
1002}
1003
1004} // namespace kj
1005