1/* Copyright (c) 2018, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15#include <openssl/hrss.h>
16
17#include <assert.h>
18#include <stdio.h>
19#include <stdlib.h>
20
21#include <openssl/bn.h>
22#include <openssl/cpu.h>
23#include <openssl/hmac.h>
24#include <openssl/mem.h>
25#include <openssl/sha.h>
26
27#if defined(OPENSSL_X86) || defined(OPENSSL_X86_64)
28#include <emmintrin.h>
29#endif
30
31#if (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
32 (defined(__ARM_NEON__) || defined(__ARM_NEON))
33#include <arm_neon.h>
34#endif
35
36#if defined(_MSC_VER)
37#define RESTRICT
38#else
39#define RESTRICT restrict
40#endif
41
42#include "../internal.h"
43#include "internal.h"
44
45// This is an implementation of [HRSS], but with a KEM transformation based on
46// [SXY]. The primary references are:
47
48// HRSS: https://eprint.iacr.org/2017/667.pdf
49// HRSSNIST:
50// https://csrc.nist.gov/CSRC/media/Projects/Post-Quantum-Cryptography/documents/round-1/submissions/NTRU_HRSS_KEM.zip
51// SXY: https://eprint.iacr.org/2017/1005.pdf
52// NTRUTN14:
53// https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf
54// NTRUCOMP:
55// https://eprint.iacr.org/2018/1174
56
57
58// Vector operations.
59//
60// A couple of functions in this file can use vector operations to meaningful
61// effect. If we're building for a target that has a supported vector unit,
62// |HRSS_HAVE_VECTOR_UNIT| will be defined and |vec_t| will be typedefed to a
63// 128-bit vector. The following functions abstract over the differences between
64// NEON and SSE2 for implementing some vector operations.
65
66// TODO: MSVC can likely also be made to work with vector operations.
67#if ((defined(__SSE__) && defined(OPENSSL_X86)) || defined(OPENSSL_X86_64)) && \
68 (defined(__clang__) || !defined(_MSC_VER))
69
70#define HRSS_HAVE_VECTOR_UNIT
71typedef __m128i vec_t;
72
73// vec_capable returns one iff the current platform supports SSE2.
74static int vec_capable(void) {
75#if defined(__SSE2__)
76 return 1;
77#else
78 int has_sse2 = (OPENSSL_ia32cap_P[0] & (1 << 26)) != 0;
79 return has_sse2;
80#endif
81}
82
83// vec_add performs a pair-wise addition of four uint16s from |a| and |b|.
84static inline vec_t vec_add(vec_t a, vec_t b) { return _mm_add_epi16(a, b); }
85
86// vec_sub performs a pair-wise subtraction of four uint16s from |a| and |b|.
87static inline vec_t vec_sub(vec_t a, vec_t b) { return _mm_sub_epi16(a, b); }
88
89// vec_mul multiplies each uint16_t in |a| by |b| and returns the resulting
90// vector.
91static inline vec_t vec_mul(vec_t a, uint16_t b) {
92 return _mm_mullo_epi16(a, _mm_set1_epi16(b));
93}
94
95// vec_fma multiplies each uint16_t in |b| by |c|, adds the result to |a|, and
96// returns the resulting vector.
97static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
98 return _mm_add_epi16(a, _mm_mullo_epi16(b, _mm_set1_epi16(c)));
99}
100
101// vec3_rshift_word right-shifts the 24 uint16_t's in |v| by one uint16.
102static inline void vec3_rshift_word(vec_t v[3]) {
103 // Intel's left and right shifting is backwards compared to the order in
104 // memory because they're based on little-endian order of words (and not just
105 // bytes). So the shifts in this function will be backwards from what one
106 // might expect.
107 const __m128i carry0 = _mm_srli_si128(v[0], 14);
108 v[0] = _mm_slli_si128(v[0], 2);
109
110 const __m128i carry1 = _mm_srli_si128(v[1], 14);
111 v[1] = _mm_slli_si128(v[1], 2);
112 v[1] |= carry0;
113
114 v[2] = _mm_slli_si128(v[2], 2);
115 v[2] |= carry1;
116}
117
118// vec4_rshift_word right-shifts the 32 uint16_t's in |v| by one uint16.
119static inline void vec4_rshift_word(vec_t v[4]) {
120 // Intel's left and right shifting is backwards compared to the order in
121 // memory because they're based on little-endian order of words (and not just
122 // bytes). So the shifts in this function will be backwards from what one
123 // might expect.
124 const __m128i carry0 = _mm_srli_si128(v[0], 14);
125 v[0] = _mm_slli_si128(v[0], 2);
126
127 const __m128i carry1 = _mm_srli_si128(v[1], 14);
128 v[1] = _mm_slli_si128(v[1], 2);
129 v[1] |= carry0;
130
131 const __m128i carry2 = _mm_srli_si128(v[2], 14);
132 v[2] = _mm_slli_si128(v[2], 2);
133 v[2] |= carry1;
134
135 v[3] = _mm_slli_si128(v[3], 2);
136 v[3] |= carry2;
137}
138
139// vec_merge_3_5 takes the final three uint16_t's from |left|, appends the first
140// five from |right|, and returns the resulting vector.
141static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
142 return _mm_srli_si128(left, 10) | _mm_slli_si128(right, 6);
143}
144
145// poly3_vec_lshift1 left-shifts the 768 bits in |a_s|, and in |a_a|, by one
146// bit.
147static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
148 vec_t carry_s = {0};
149 vec_t carry_a = {0};
150
151 for (int i = 0; i < 6; i++) {
152 vec_t next_carry_s = _mm_srli_epi64(a_s[i], 63);
153 a_s[i] = _mm_slli_epi64(a_s[i], 1);
154 a_s[i] |= _mm_slli_si128(next_carry_s, 8);
155 a_s[i] |= carry_s;
156 carry_s = _mm_srli_si128(next_carry_s, 8);
157
158 vec_t next_carry_a = _mm_srli_epi64(a_a[i], 63);
159 a_a[i] = _mm_slli_epi64(a_a[i], 1);
160 a_a[i] |= _mm_slli_si128(next_carry_a, 8);
161 a_a[i] |= carry_a;
162 carry_a = _mm_srli_si128(next_carry_a, 8);
163 }
164}
165
166// poly3_vec_rshift1 right-shifts the 768 bits in |a_s|, and in |a_a|, by one
167// bit.
168static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
169 vec_t carry_s = {0};
170 vec_t carry_a = {0};
171
172 for (int i = 5; i >= 0; i--) {
173 const vec_t next_carry_s = _mm_slli_epi64(a_s[i], 63);
174 a_s[i] = _mm_srli_epi64(a_s[i], 1);
175 a_s[i] |= _mm_srli_si128(next_carry_s, 8);
176 a_s[i] |= carry_s;
177 carry_s = _mm_slli_si128(next_carry_s, 8);
178
179 const vec_t next_carry_a = _mm_slli_epi64(a_a[i], 63);
180 a_a[i] = _mm_srli_epi64(a_a[i], 1);
181 a_a[i] |= _mm_srli_si128(next_carry_a, 8);
182 a_a[i] |= carry_a;
183 carry_a = _mm_slli_si128(next_carry_a, 8);
184 }
185}
186
187// vec_broadcast_bit duplicates the least-significant bit in |a| to all bits in
188// a vector and returns the result.
189static inline vec_t vec_broadcast_bit(vec_t a) {
190 return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63), 31),
191 0b01010101);
192}
193
194// vec_broadcast_bit15 duplicates the most-significant bit of the first word in
195// |a| to all bits in a vector and returns the result.
196static inline vec_t vec_broadcast_bit15(vec_t a) {
197 return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63 - 15), 31),
198 0b01010101);
199}
200
201// vec_get_word returns the |i|th uint16_t in |v|. (This is a macro because the
202// compiler requires that |i| be a compile-time constant.)
203#define vec_get_word(v, i) _mm_extract_epi16(v, i)
204
205#elif (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
206 (defined(__ARM_NEON__) || defined(__ARM_NEON))
207
208#define HRSS_HAVE_VECTOR_UNIT
209typedef uint16x8_t vec_t;
210
211// These functions perform the same actions as the SSE2 function of the same
212// name, above.
213
214static int vec_capable(void) { return CRYPTO_is_NEON_capable(); }
215
216static inline vec_t vec_add(vec_t a, vec_t b) { return a + b; }
217
218static inline vec_t vec_sub(vec_t a, vec_t b) { return a - b; }
219
220static inline vec_t vec_mul(vec_t a, uint16_t b) { return vmulq_n_u16(a, b); }
221
222static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
223 return vmlaq_n_u16(a, b, c);
224}
225
226static inline void vec3_rshift_word(vec_t v[3]) {
227 const uint16x8_t kZero = {0};
228 v[2] = vextq_u16(v[1], v[2], 7);
229 v[1] = vextq_u16(v[0], v[1], 7);
230 v[0] = vextq_u16(kZero, v[0], 7);
231}
232
233static inline void vec4_rshift_word(vec_t v[4]) {
234 const uint16x8_t kZero = {0};
235 v[3] = vextq_u16(v[2], v[3], 7);
236 v[2] = vextq_u16(v[1], v[2], 7);
237 v[1] = vextq_u16(v[0], v[1], 7);
238 v[0] = vextq_u16(kZero, v[0], 7);
239}
240
241static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
242 return vextq_u16(left, right, 5);
243}
244
245static inline uint16_t vec_get_word(vec_t v, unsigned i) {
246 return v[i];
247}
248
249#if !defined(OPENSSL_AARCH64)
250
251static inline vec_t vec_broadcast_bit(vec_t a) {
252 a = (vec_t)vshrq_n_s16(((int16x8_t)a) << 15, 15);
253 return vdupq_lane_u16(vget_low_u16(a), 0);
254}
255
256static inline vec_t vec_broadcast_bit15(vec_t a) {
257 a = (vec_t)vshrq_n_s16((int16x8_t)a, 15);
258 return vdupq_lane_u16(vget_low_u16(a), 0);
259}
260
261static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
262 vec_t carry_s = {0};
263 vec_t carry_a = {0};
264 const vec_t kZero = {0};
265
266 for (int i = 0; i < 6; i++) {
267 vec_t next_carry_s = a_s[i] >> 15;
268 a_s[i] <<= 1;
269 a_s[i] |= vextq_u16(kZero, next_carry_s, 7);
270 a_s[i] |= carry_s;
271 carry_s = vextq_u16(next_carry_s, kZero, 7);
272
273 vec_t next_carry_a = a_a[i] >> 15;
274 a_a[i] <<= 1;
275 a_a[i] |= vextq_u16(kZero, next_carry_a, 7);
276 a_a[i] |= carry_a;
277 carry_a = vextq_u16(next_carry_a, kZero, 7);
278 }
279}
280
281static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
282 vec_t carry_s = {0};
283 vec_t carry_a = {0};
284 const vec_t kZero = {0};
285
286 for (int i = 5; i >= 0; i--) {
287 vec_t next_carry_s = a_s[i] << 15;
288 a_s[i] >>= 1;
289 a_s[i] |= vextq_u16(next_carry_s, kZero, 1);
290 a_s[i] |= carry_s;
291 carry_s = vextq_u16(kZero, next_carry_s, 1);
292
293 vec_t next_carry_a = a_a[i] << 15;
294 a_a[i] >>= 1;
295 a_a[i] |= vextq_u16(next_carry_a, kZero, 1);
296 a_a[i] |= carry_a;
297 carry_a = vextq_u16(kZero, next_carry_a, 1);
298 }
299}
300
301#endif // !OPENSSL_AARCH64
302
303#endif // (ARM || AARCH64) && NEON
304
305// Polynomials in this scheme have N terms.
306// #define N 701
307
308// Underlying data types and arithmetic operations.
309// ------------------------------------------------
310
311// Binary polynomials.
312
313// poly2 represents a degree-N polynomial over GF(2). The words are in little-
314// endian order, i.e. the coefficient of x^0 is the LSB of the first word. The
315// final word is only partially used since N is not a multiple of the word size.
316
317// Defined in internal.h:
318// struct poly2 {
319// crypto_word_t v[WORDS_PER_POLY];
320// };
321
322OPENSSL_UNUSED static void hexdump(const void *void_in, size_t len) {
323 const uint8_t *in = (const uint8_t *)void_in;
324 for (size_t i = 0; i < len; i++) {
325 printf("%02x", in[i]);
326 }
327 printf("\n");
328}
329
330static void poly2_zero(struct poly2 *p) {
331 OPENSSL_memset(&p->v[0], 0, sizeof(crypto_word_t) * WORDS_PER_POLY);
332}
333
334// poly2_cmov sets |out| to |in| iff |mov| is all ones.
335static void poly2_cmov(struct poly2 *out, const struct poly2 *in,
336 crypto_word_t mov) {
337 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
338 out->v[i] = (out->v[i] & ~mov) | (in->v[i] & mov);
339 }
340}
341
342// poly2_rotr_words performs a right-rotate on |in|, writing the result to
343// |out|. The shift count, |bits|, must be a non-zero multiple of the word size.
344static void poly2_rotr_words(struct poly2 *out, const struct poly2 *in,
345 size_t bits) {
346 assert(bits >= BITS_PER_WORD && bits % BITS_PER_WORD == 0);
347 assert(out != in);
348
349 const size_t start = bits / BITS_PER_WORD;
350 const size_t n = (N - bits) / BITS_PER_WORD;
351
352 // The rotate is by a whole number of words so the first few words are easy:
353 // just move them down.
354 for (size_t i = 0; i < n; i++) {
355 out->v[i] = in->v[start + i];
356 }
357
358 // Since the last word is only partially filled, however, the remainder needs
359 // shifting and merging of words to take care of that.
360 crypto_word_t carry = in->v[WORDS_PER_POLY - 1];
361
362 for (size_t i = 0; i < start; i++) {
363 out->v[n + i] = carry | in->v[i] << BITS_IN_LAST_WORD;
364 carry = in->v[i] >> (BITS_PER_WORD - BITS_IN_LAST_WORD);
365 }
366
367 out->v[WORDS_PER_POLY - 1] = carry;
368}
369
370// poly2_rotr_bits performs a right-rotate on |in|, writing the result to |out|.
371// The shift count, |bits|, must be a power of two that is less than
372// |BITS_PER_WORD|.
373static void poly2_rotr_bits(struct poly2 *out, const struct poly2 *in,
374 size_t bits) {
375 assert(bits <= BITS_PER_WORD / 2);
376 assert(bits != 0);
377 assert((bits & (bits - 1)) == 0);
378 assert(out != in);
379
380 // BITS_PER_WORD/2 is the greatest legal value of |bits|. If
381 // |BITS_IN_LAST_WORD| is smaller than this then the code below doesn't work
382 // because more than the last word needs to carry down in the previous one and
383 // so on.
384 OPENSSL_STATIC_ASSERT(
385 BITS_IN_LAST_WORD >= BITS_PER_WORD / 2,
386 "there are more carry bits than fit in BITS_IN_LAST_WORD");
387
388 crypto_word_t carry = in->v[WORDS_PER_POLY - 1] << (BITS_PER_WORD - bits);
389
390 for (size_t i = WORDS_PER_POLY - 2; i < WORDS_PER_POLY; i--) {
391 out->v[i] = carry | in->v[i] >> bits;
392 carry = in->v[i] << (BITS_PER_WORD - bits);
393 }
394
395 crypto_word_t last_word = carry >> (BITS_PER_WORD - BITS_IN_LAST_WORD) |
396 in->v[WORDS_PER_POLY - 1] >> bits;
397 last_word &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
398 out->v[WORDS_PER_POLY - 1] = last_word;
399}
400
401// HRSS_poly2_rotr_consttime right-rotates |p| by |bits| in constant-time.
402void HRSS_poly2_rotr_consttime(struct poly2 *p, size_t bits) {
403 assert(bits <= N);
404 assert(p->v[WORDS_PER_POLY-1] >> BITS_IN_LAST_WORD == 0);
405
406 // Constant-time rotation is implemented by calculating the rotations of
407 // powers-of-two bits and throwing away the unneeded values. 2^9 (i.e. 512) is
408 // the largest power-of-two shift that we need to consider because 2^10 > N.
409#define HRSS_POLY2_MAX_SHIFT 9
410 size_t shift = HRSS_POLY2_MAX_SHIFT;
411 OPENSSL_STATIC_ASSERT((1 << (HRSS_POLY2_MAX_SHIFT + 1)) > N,
412 "maximum shift is too small");
413 OPENSSL_STATIC_ASSERT((1 << HRSS_POLY2_MAX_SHIFT) <= N,
414 "maximum shift is too large");
415 struct poly2 shifted;
416
417 for (; (UINT64_C(1) << shift) >= BITS_PER_WORD; shift--) {
418 poly2_rotr_words(&shifted, p, UINT64_C(1) << shift);
419 poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
420 }
421
422 for (; shift < HRSS_POLY2_MAX_SHIFT; shift--) {
423 poly2_rotr_bits(&shifted, p, UINT64_C(1) << shift);
424 poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
425 }
426#undef HRSS_POLY2_MAX_SHIFT
427}
428
429// poly2_cswap exchanges the values of |a| and |b| if |swap| is all ones.
430static void poly2_cswap(struct poly2 *a, struct poly2 *b, crypto_word_t swap) {
431 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
432 const crypto_word_t sum = swap & (a->v[i] ^ b->v[i]);
433 a->v[i] ^= sum;
434 b->v[i] ^= sum;
435 }
436}
437
438// poly2_fmadd sets |out| to |out| + |in| * m, where m is either
439// |CONSTTIME_TRUE_W| or |CONSTTIME_FALSE_W|.
440static void poly2_fmadd(struct poly2 *out, const struct poly2 *in,
441 crypto_word_t m) {
442 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
443 out->v[i] ^= in->v[i] & m;
444 }
445}
446
447// poly2_lshift1 left-shifts |p| by one bit.
448static void poly2_lshift1(struct poly2 *p) {
449 crypto_word_t carry = 0;
450 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
451 const crypto_word_t next_carry = p->v[i] >> (BITS_PER_WORD - 1);
452 p->v[i] <<= 1;
453 p->v[i] |= carry;
454 carry = next_carry;
455 }
456}
457
458// poly2_rshift1 right-shifts |p| by one bit.
459static void poly2_rshift1(struct poly2 *p) {
460 crypto_word_t carry = 0;
461 for (size_t i = WORDS_PER_POLY - 1; i < WORDS_PER_POLY; i--) {
462 const crypto_word_t next_carry = p->v[i] & 1;
463 p->v[i] >>= 1;
464 p->v[i] |= carry << (BITS_PER_WORD - 1);
465 carry = next_carry;
466 }
467}
468
469// poly2_clear_top_bits clears the bits in the final word that are only for
470// alignment.
471static void poly2_clear_top_bits(struct poly2 *p) {
472 p->v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
473}
474
475// poly2_top_bits_are_clear returns one iff the extra bits in the final words of
476// |p| are zero.
477static int poly2_top_bits_are_clear(const struct poly2 *p) {
478 return (p->v[WORDS_PER_POLY - 1] &
479 ~((UINT64_C(1) << BITS_IN_LAST_WORD) - 1)) == 0;
480}
481
482// Ternary polynomials.
483
484// poly3 represents a degree-N polynomial over GF(3). Each coefficient is
485// bitsliced across the |s| and |a| arrays, like this:
486//
487// s | a | value
488// -----------------
489// 0 | 0 | 0
490// 0 | 1 | 1
491// 1 | 1 | -1 (aka 2)
492// 1 | 0 | <invalid>
493//
494// ('s' is for sign, and 'a' is the absolute value.)
495//
496// Once bitsliced as such, the following circuits can be used to implement
497// addition and multiplication mod 3:
498//
499// (s3, a3) = (s1, a1) × (s2, a2)
500// a3 = a1 ∧ a2
501// s3 = (s1 ⊕ s2) ∧ a3
502//
503// (s3, a3) = (s1, a1) + (s2, a2)
504// t = s1 ⊕ a2
505// s3 = t ∧ (s2 ⊕ a1)
506// a3 = (a1 ⊕ a2) ∨ (t ⊕ s2)
507//
508// (s3, a3) = (s1, a1) - (s2, a2)
509// t = a1 ⊕ a2
510// s3 = (s1 ⊕ a2) ∧ (t ⊕ s2)
511// a3 = t ∨ (s1 ⊕ s2)
512//
513// Negating a value just involves XORing s by a.
514//
515// struct poly3 {
516// struct poly2 s, a;
517// };
518
519OPENSSL_UNUSED static void poly3_print(const struct poly3 *in) {
520 struct poly3 p;
521 OPENSSL_memcpy(&p, in, sizeof(p));
522 p.s.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
523 p.a.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
524
525 printf("{[");
526 for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
527 if (i) {
528 printf(" ");
529 }
530 printf(BN_HEX_FMT2, p.s.v[i]);
531 }
532 printf("] [");
533 for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
534 if (i) {
535 printf(" ");
536 }
537 printf(BN_HEX_FMT2, p.a.v[i]);
538 }
539 printf("]}\n");
540}
541
542static void poly3_zero(struct poly3 *p) {
543 poly2_zero(&p->s);
544 poly2_zero(&p->a);
545}
546
547// poly3_word_mul sets (|out_s|, |out_a) to (|s1|, |a1|) × (|s2|, |a2|).
548static void poly3_word_mul(crypto_word_t *out_s, crypto_word_t *out_a,
549 const crypto_word_t s1, const crypto_word_t a1,
550 const crypto_word_t s2, const crypto_word_t a2) {
551 *out_a = a1 & a2;
552 *out_s = (s1 ^ s2) & *out_a;
553}
554
555// poly3_word_add sets (|out_s|, |out_a|) to (|s1|, |a1|) + (|s2|, |a2|).
556static void poly3_word_add(crypto_word_t *out_s, crypto_word_t *out_a,
557 const crypto_word_t s1, const crypto_word_t a1,
558 const crypto_word_t s2, const crypto_word_t a2) {
559 const crypto_word_t t = s1 ^ a2;
560 *out_s = t & (s2 ^ a1);
561 *out_a = (a1 ^ a2) | (t ^ s2);
562}
563
564// poly3_word_sub sets (|out_s|, |out_a|) to (|s1|, |a1|) - (|s2|, |a2|).
565static void poly3_word_sub(crypto_word_t *out_s, crypto_word_t *out_a,
566 const crypto_word_t s1, const crypto_word_t a1,
567 const crypto_word_t s2, const crypto_word_t a2) {
568 const crypto_word_t t = a1 ^ a2;
569 *out_s = (s1 ^ a2) & (t ^ s2);
570 *out_a = t | (s1 ^ s2);
571}
572
573// lsb_to_all replicates the least-significant bit of |v| to all bits of the
574// word. This is used in bit-slicing operations to make a vector from a fixed
575// value.
576static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
577
578// poly3_mul_const sets |p| to |p|×m, where m = (ms, ma).
579static void poly3_mul_const(struct poly3 *p, crypto_word_t ms,
580 crypto_word_t ma) {
581 ms = lsb_to_all(ms);
582 ma = lsb_to_all(ma);
583
584 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
585 poly3_word_mul(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], ms, ma);
586 }
587}
588
589// poly3_rotr_consttime right-rotates |p| by |bits| in constant-time.
590static void poly3_rotr_consttime(struct poly3 *p, size_t bits) {
591 assert(bits <= N);
592 HRSS_poly2_rotr_consttime(&p->s, bits);
593 HRSS_poly2_rotr_consttime(&p->a, bits);
594}
595
596// poly3_fmadd sets |out| to |out| - |in|×m, where m is (ms, ma).
597static void poly3_fmsub(struct poly3 *RESTRICT out,
598 const struct poly3 *RESTRICT in, crypto_word_t ms,
599 crypto_word_t ma) {
600 crypto_word_t product_s, product_a;
601 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
602 poly3_word_mul(&product_s, &product_a, in->s.v[i], in->a.v[i], ms, ma);
603 poly3_word_sub(&out->s.v[i], &out->a.v[i], out->s.v[i], out->a.v[i],
604 product_s, product_a);
605 }
606}
607
608// final_bit_to_all replicates the bit in the final position of the last word to
609// all the bits in the word.
610static crypto_word_t final_bit_to_all(crypto_word_t v) {
611 return lsb_to_all(v >> (BITS_IN_LAST_WORD - 1));
612}
613
614// poly3_top_bits_are_clear returns one iff the extra bits in the final words of
615// |p| are zero.
616OPENSSL_UNUSED static int poly3_top_bits_are_clear(const struct poly3 *p) {
617 return poly2_top_bits_are_clear(&p->s) && poly2_top_bits_are_clear(&p->a);
618}
619
620// poly3_mod_phiN reduces |p| by Φ(N).
621static void poly3_mod_phiN(struct poly3 *p) {
622 // In order to reduce by Φ(N) we subtract by the value of the greatest
623 // coefficient.
624 const crypto_word_t factor_s = final_bit_to_all(p->s.v[WORDS_PER_POLY - 1]);
625 const crypto_word_t factor_a = final_bit_to_all(p->a.v[WORDS_PER_POLY - 1]);
626
627 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
628 poly3_word_sub(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], factor_s,
629 factor_a);
630 }
631
632 poly2_clear_top_bits(&p->s);
633 poly2_clear_top_bits(&p->a);
634}
635
636static void poly3_cswap(struct poly3 *a, struct poly3 *b, crypto_word_t swap) {
637 poly2_cswap(&a->s, &b->s, swap);
638 poly2_cswap(&a->a, &b->a, swap);
639}
640
641static void poly3_lshift1(struct poly3 *p) {
642 poly2_lshift1(&p->s);
643 poly2_lshift1(&p->a);
644}
645
646static void poly3_rshift1(struct poly3 *p) {
647 poly2_rshift1(&p->s);
648 poly2_rshift1(&p->a);
649}
650
651// poly3_span represents a pointer into a poly3.
652struct poly3_span {
653 crypto_word_t *s;
654 crypto_word_t *a;
655};
656
657// poly3_span_add adds |n| words of values from |a| and |b| and writes the
658// result to |out|.
659static void poly3_span_add(const struct poly3_span *out,
660 const struct poly3_span *a,
661 const struct poly3_span *b, size_t n) {
662 for (size_t i = 0; i < n; i++) {
663 poly3_word_add(&out->s[i], &out->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
664 }
665}
666
667// poly3_span_sub subtracts |n| words of |b| from |n| words of |a|.
668static void poly3_span_sub(const struct poly3_span *a,
669 const struct poly3_span *b, size_t n) {
670 for (size_t i = 0; i < n; i++) {
671 poly3_word_sub(&a->s[i], &a->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
672 }
673}
674
675// poly3_mul_aux is a recursive function that multiplies |n| words from |a| and
676// |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements of
677// |scratch| and the function recurses, except if |n| == 1, when |scratch| isn't
678// used and the recursion stops. For |n| in {11, 22}, the transitive total
679// amount of |scratch| needed happens to be 2n+2.
680static void poly3_mul_aux(const struct poly3_span *out,
681 const struct poly3_span *scratch,
682 const struct poly3_span *a,
683 const struct poly3_span *b, size_t n) {
684 if (n == 1) {
685 crypto_word_t r_s_low = 0, r_s_high = 0, r_a_low = 0, r_a_high = 0;
686 crypto_word_t b_s = b->s[0], b_a = b->a[0];
687 const crypto_word_t a_s = a->s[0], a_a = a->a[0];
688
689 for (size_t i = 0; i < BITS_PER_WORD; i++) {
690 // Multiply (s, a) by the next value from (b_s, b_a).
691 crypto_word_t m_s, m_a;
692 poly3_word_mul(&m_s, &m_a, a_s, a_a, lsb_to_all(b_s), lsb_to_all(b_a));
693 b_s >>= 1;
694 b_a >>= 1;
695
696 if (i == 0) {
697 // Special case otherwise the code tries to shift by BITS_PER_WORD
698 // below, which is undefined.
699 r_s_low = m_s;
700 r_a_low = m_a;
701 continue;
702 }
703
704 // Shift the multiplication result to the correct position.
705 const crypto_word_t m_s_low = m_s << i;
706 const crypto_word_t m_s_high = m_s >> (BITS_PER_WORD - i);
707 const crypto_word_t m_a_low = m_a << i;
708 const crypto_word_t m_a_high = m_a >> (BITS_PER_WORD - i);
709
710 // Add into the result.
711 poly3_word_add(&r_s_low, &r_a_low, r_s_low, r_a_low, m_s_low, m_a_low);
712 poly3_word_add(&r_s_high, &r_a_high, r_s_high, r_a_high, m_s_high,
713 m_a_high);
714 }
715
716 out->s[0] = r_s_low;
717 out->s[1] = r_s_high;
718 out->a[0] = r_a_low;
719 out->a[1] = r_a_high;
720 return;
721 }
722
723 // Karatsuba multiplication.
724 // https://en.wikipedia.org/wiki/Karatsuba_algorithm
725
726 // When |n| is odd, the two "halves" will have different lengths. The first
727 // is always the smaller.
728 const size_t low_len = n / 2;
729 const size_t high_len = n - low_len;
730 const struct poly3_span a_high = {&a->s[low_len], &a->a[low_len]};
731 const struct poly3_span b_high = {&b->s[low_len], &b->a[low_len]};
732
733 // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
734 // half.
735 const struct poly3_span a_cross_sum = *out;
736 const struct poly3_span b_cross_sum = {&out->s[high_len], &out->a[high_len]};
737 poly3_span_add(&a_cross_sum, a, &a_high, low_len);
738 poly3_span_add(&b_cross_sum, b, &b_high, low_len);
739 if (high_len != low_len) {
740 a_cross_sum.s[low_len] = a_high.s[low_len];
741 a_cross_sum.a[low_len] = a_high.a[low_len];
742 b_cross_sum.s[low_len] = b_high.s[low_len];
743 b_cross_sum.a[low_len] = b_high.a[low_len];
744 }
745
746 const struct poly3_span child_scratch = {&scratch->s[2 * high_len],
747 &scratch->a[2 * high_len]};
748 const struct poly3_span out_mid = {&out->s[low_len], &out->a[low_len]};
749 const struct poly3_span out_high = {&out->s[2 * low_len],
750 &out->a[2 * low_len]};
751
752 // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
753 poly3_mul_aux(scratch, &child_scratch, &a_cross_sum, &b_cross_sum, high_len);
754 // Calculate a_1 × b_1.
755 poly3_mul_aux(&out_high, &child_scratch, &a_high, &b_high, high_len);
756 // Calculate a_0 × b_0.
757 poly3_mul_aux(out, &child_scratch, a, b, low_len);
758
759 // Subtract those last two products from the first.
760 poly3_span_sub(scratch, out, low_len * 2);
761 poly3_span_sub(scratch, &out_high, high_len * 2);
762
763 // Add the middle product into the output.
764 poly3_span_add(&out_mid, &out_mid, scratch, high_len * 2);
765}
766
767// HRSS_poly3_mul sets |*out| to |x|×|y| mod Φ(N).
768void HRSS_poly3_mul(struct poly3 *out, const struct poly3 *x,
769 const struct poly3 *y) {
770 crypto_word_t prod_s[WORDS_PER_POLY * 2];
771 crypto_word_t prod_a[WORDS_PER_POLY * 2];
772 crypto_word_t scratch_s[WORDS_PER_POLY * 2 + 2];
773 crypto_word_t scratch_a[WORDS_PER_POLY * 2 + 2];
774 const struct poly3_span prod_span = {prod_s, prod_a};
775 const struct poly3_span scratch_span = {scratch_s, scratch_a};
776 const struct poly3_span x_span = {(crypto_word_t *)x->s.v,
777 (crypto_word_t *)x->a.v};
778 const struct poly3_span y_span = {(crypto_word_t *)y->s.v,
779 (crypto_word_t *)y->a.v};
780
781 poly3_mul_aux(&prod_span, &scratch_span, &x_span, &y_span, WORDS_PER_POLY);
782
783 // |prod| needs to be reduced mod (𝑥^n - 1), which just involves adding the
784 // upper-half to the lower-half. However, N is 701, which isn't a multiple of
785 // BITS_PER_WORD, so the upper-half vectors all have to be shifted before
786 // being added to the lower-half.
787 for (size_t i = 0; i < WORDS_PER_POLY; i++) {
788 crypto_word_t v_s = prod_s[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
789 v_s |= prod_s[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
790 crypto_word_t v_a = prod_a[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
791 v_a |= prod_a[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
792
793 poly3_word_add(&out->s.v[i], &out->a.v[i], prod_s[i], prod_a[i], v_s, v_a);
794 }
795
796 poly3_mod_phiN(out);
797}
798
799#if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
800
801// poly3_vec_cswap swaps (|a_s|, |a_a|) and (|b_s|, |b_a|) if |swap| is
802// |0xff..ff|. Otherwise, |swap| must be zero.
803static inline void poly3_vec_cswap(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
804 vec_t b_a[6], const vec_t swap) {
805 for (int i = 0; i < 6; i++) {
806 const vec_t sum_s = swap & (a_s[i] ^ b_s[i]);
807 a_s[i] ^= sum_s;
808 b_s[i] ^= sum_s;
809
810 const vec_t sum_a = swap & (a_a[i] ^ b_a[i]);
811 a_a[i] ^= sum_a;
812 b_a[i] ^= sum_a;
813 }
814}
815
816// poly3_vec_fmsub subtracts (|ms|, |ma|) × (|b_s|, |b_a|) from (|a_s|, |a_a|).
817static inline void poly3_vec_fmsub(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
818 vec_t b_a[6], const vec_t ms,
819 const vec_t ma) {
820 for (int i = 0; i < 6; i++) {
821 // See the bitslice formula, above.
822 const vec_t s = b_s[i];
823 const vec_t a = b_a[i];
824 const vec_t product_a = a & ma;
825 const vec_t product_s = (s ^ ms) & product_a;
826
827 const vec_t out_s = a_s[i];
828 const vec_t out_a = a_a[i];
829 const vec_t t = out_a ^ product_a;
830 a_s[i] = (out_s ^ product_a) & (t ^ product_s);
831 a_a[i] = t | (out_s ^ product_s);
832 }
833}
834
835// poly3_invert_vec sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
836// Φ(N).
837static void poly3_invert_vec(struct poly3 *out, const struct poly3 *in) {
838 // See the comment in |HRSS_poly3_invert| about this algorithm. In addition to
839 // the changes described there, this implementation attempts to use vector
840 // registers to speed up the computation. Even non-poly3 variables are held in
841 // vectors where possible to minimise the amount of data movement between
842 // the vector and general-purpose registers.
843
844 vec_t b_s[6], b_a[6], c_s[6], c_a[6], f_s[6], f_a[6], g_s[6], g_a[6];
845 const vec_t kZero = {0};
846 const vec_t kOne = {1};
847 static const uint8_t kOneBytes[sizeof(vec_t)] = {1};
848 static const uint8_t kBottomSixtyOne[sizeof(vec_t)] = {
849 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1f};
850
851 memset(b_s, 0, sizeof(b_s));
852 memcpy(b_a, kOneBytes, sizeof(kOneBytes));
853 memset(&b_a[1], 0, 5 * sizeof(vec_t));
854
855 memset(c_s, 0, sizeof(c_s));
856 memset(c_a, 0, sizeof(c_a));
857
858 f_s[5] = kZero;
859 memcpy(f_s, in->s.v, WORDS_PER_POLY * sizeof(crypto_word_t));
860 f_a[5] = kZero;
861 memcpy(f_a, in->a.v, WORDS_PER_POLY * sizeof(crypto_word_t));
862
863 // Set g to all ones.
864 memset(g_s, 0, sizeof(g_s));
865 memset(g_a, 0xff, 5 * sizeof(vec_t));
866 memcpy(&g_a[5], kBottomSixtyOne, sizeof(kBottomSixtyOne));
867
868 vec_t deg_f = {N - 1}, deg_g = {N - 1}, rotation = kZero;
869 vec_t k = kOne;
870 vec_t f0s = {0}, f0a = {0};
871 vec_t still_going;
872 memset(&still_going, 0xff, sizeof(still_going));
873
874 for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
875 const vec_t s_a = vec_broadcast_bit(still_going & (f_a[0] & g_a[0]));
876 const vec_t s_s =
877 vec_broadcast_bit(still_going & ((f_s[0] ^ g_s[0]) & s_a));
878 const vec_t should_swap =
879 (s_s | s_a) & vec_broadcast_bit15(deg_f - deg_g);
880
881 poly3_vec_cswap(f_s, f_a, g_s, g_a, should_swap);
882 poly3_vec_fmsub(f_s, f_a, g_s, g_a, s_s, s_a);
883 poly3_vec_rshift1(f_s, f_a);
884
885 poly3_vec_cswap(b_s, b_a, c_s, c_a, should_swap);
886 poly3_vec_fmsub(b_s, b_a, c_s, c_a, s_s, s_a);
887 poly3_vec_lshift1(c_s, c_a);
888
889 const vec_t deg_sum = should_swap & (deg_f ^ deg_g);
890 deg_f ^= deg_sum;
891 deg_g ^= deg_sum;
892
893 deg_f -= kOne;
894 still_going &= ~vec_broadcast_bit15(deg_f - kOne);
895
896 const vec_t f0_is_nonzero = vec_broadcast_bit(f_s[0] | f_a[0]);
897 // |f0_is_nonzero| implies |still_going|.
898 rotation ^= f0_is_nonzero & (k ^ rotation);
899 k += kOne;
900
901 const vec_t f0s_sum = f0_is_nonzero & (f_s[0] ^ f0s);
902 f0s ^= f0s_sum;
903 const vec_t f0a_sum = f0_is_nonzero & (f_a[0] ^ f0a);
904 f0a ^= f0a_sum;
905 }
906
907 crypto_word_t rotation_word = vec_get_word(rotation, 0);
908 rotation_word -= N & constant_time_lt_w(N, rotation_word);
909 memcpy(out->s.v, b_s, WORDS_PER_POLY * sizeof(crypto_word_t));
910 memcpy(out->a.v, b_a, WORDS_PER_POLY * sizeof(crypto_word_t));
911 assert(poly3_top_bits_are_clear(out));
912 poly3_rotr_consttime(out, rotation_word);
913 poly3_mul_const(out, vec_get_word(f0s, 0), vec_get_word(f0a, 0));
914 poly3_mod_phiN(out);
915}
916
917#endif // HRSS_HAVE_VECTOR_UNIT
918
919// HRSS_poly3_invert sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
920// Φ(N).
921void HRSS_poly3_invert(struct poly3 *out, const struct poly3 *in) {
922 // The vector version of this function seems slightly slower on AArch64, but
923 // is useful on ARMv7 and x86-64.
924#if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
925 if (vec_capable()) {
926 poly3_invert_vec(out, in);
927 return;
928 }
929#endif
930
931 // This algorithm mostly follows algorithm 10 in the paper. Some changes:
932 // 1) k should start at zero, not one. In the code below k is omitted and
933 // the loop counter, |i|, is used instead.
934 // 2) The rotation count is conditionally updated to handle trailing zero
935 // coefficients.
936 // The best explanation for why it works is in the "Why it works" section of
937 // [NTRUTN14].
938
939 struct poly3 c, f, g;
940 OPENSSL_memcpy(&f, in, sizeof(f));
941
942 // Set g to all ones.
943 OPENSSL_memset(&g.s, 0, sizeof(struct poly2));
944 OPENSSL_memset(&g.a, 0xff, sizeof(struct poly2));
945 g.a.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
946
947 struct poly3 *b = out;
948 poly3_zero(b);
949 poly3_zero(&c);
950 // Set b to one.
951 b->a.v[0] = 1;
952
953 crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
954 crypto_word_t f0s = 0, f0a = 0;
955 crypto_word_t still_going = CONSTTIME_TRUE_W;
956
957 for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
958 const crypto_word_t s_a = lsb_to_all(
959 still_going & (f.a.v[0] & g.a.v[0]));
960 const crypto_word_t s_s = lsb_to_all(
961 still_going & ((f.s.v[0] ^ g.s.v[0]) & s_a));
962 const crypto_word_t should_swap =
963 (s_s | s_a) & constant_time_lt_w(deg_f, deg_g);
964
965 poly3_cswap(&f, &g, should_swap);
966 poly3_cswap(b, &c, should_swap);
967
968 const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
969 deg_f ^= deg_sum;
970 deg_g ^= deg_sum;
971 assert(deg_g >= 1);
972
973 poly3_fmsub(&f, &g, s_s, s_a);
974 poly3_fmsub(b, &c, s_s, s_a);
975 poly3_rshift1(&f);
976 poly3_lshift1(&c);
977
978 deg_f--;
979 const crypto_word_t f0_is_nonzero =
980 lsb_to_all(f.s.v[0]) | lsb_to_all(f.a.v[0]);
981 // |f0_is_nonzero| implies |still_going|.
982 assert(!(f0_is_nonzero && !still_going));
983 still_going &= ~constant_time_is_zero_w(deg_f);
984
985 rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
986 f0s = constant_time_select_w(f0_is_nonzero, f.s.v[0], f0s);
987 f0a = constant_time_select_w(f0_is_nonzero, f.a.v[0], f0a);
988 }
989
990 rotation++;
991 rotation -= N & constant_time_lt_w(N, rotation);
992 assert(poly3_top_bits_are_clear(out));
993 poly3_rotr_consttime(out, rotation);
994 poly3_mul_const(out, f0s, f0a);
995 poly3_mod_phiN(out);
996}
997
998// Polynomials in Q.
999
1000// Coefficients are reduced mod Q. (Q is clearly not prime, therefore the
1001// coefficients do not form a field.)
1002#define Q 8192
1003
1004// VECS_PER_POLY is the number of 128-bit vectors needed to represent a
1005// polynomial.
1006#define COEFFICIENTS_PER_VEC (sizeof(vec_t) / sizeof(uint16_t))
1007#define VECS_PER_POLY ((N + COEFFICIENTS_PER_VEC - 1) / COEFFICIENTS_PER_VEC)
1008
1009// poly represents a polynomial with coefficients mod Q. Note that, while Q is a
1010// power of two, this does not operate in GF(Q). That would be a binary field
1011// but this is simply mod Q. Thus the coefficients are not a field.
1012//
1013// Coefficients are ordered little-endian, thus the coefficient of x^0 is the
1014// first element of the array.
1015struct poly {
1016#if defined(HRSS_HAVE_VECTOR_UNIT)
1017 union {
1018 // N + 3 = 704, which is a multiple of 64 and thus aligns things, esp for
1019 // the vector code.
1020 uint16_t v[N + 3];
1021 vec_t vectors[VECS_PER_POLY];
1022 };
1023#else
1024 // Even if !HRSS_HAVE_VECTOR_UNIT, external assembly may be called that
1025 // requires alignment.
1026 alignas(16) uint16_t v[N + 3];
1027#endif
1028};
1029
1030OPENSSL_UNUSED static void poly_print(const struct poly *p) {
1031 printf("[");
1032 for (unsigned i = 0; i < N; i++) {
1033 if (i) {
1034 printf(" ");
1035 }
1036 printf("%d", p->v[i]);
1037 }
1038 printf("]\n");
1039}
1040
1041#if defined(HRSS_HAVE_VECTOR_UNIT)
1042
1043// poly_mul_vec_aux is a recursive function that multiplies |n| words from |a|
1044// and |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements
1045// of |scratch| and the function recurses, except if |n| < 3, when |scratch|
1046// isn't used and the recursion stops. If |n| == |VECS_PER_POLY| then |scratch|
1047// needs 172 elements.
1048static void poly_mul_vec_aux(vec_t *restrict out, vec_t *restrict scratch,
1049 const vec_t *restrict a, const vec_t *restrict b,
1050 const size_t n) {
1051 // In [HRSS], the technique they used for polynomial multiplication is
1052 // described: they start with Toom-4 at the top level and then two layers of
1053 // Karatsuba. Karatsuba is a specific instance of the general Toom–Cook
1054 // decomposition, which splits an input n-ways and produces 2n-1
1055 // multiplications of those parts. So, starting with 704 coefficients (rounded
1056 // up from 701 to have more factors of two), Toom-4 gives seven
1057 // multiplications of degree-174 polynomials. Each round of Karatsuba (which
1058 // is Toom-2) increases the number of multiplications by a factor of three
1059 // while halving the size of the values being multiplied. So two rounds gives
1060 // 63 multiplications of degree-44 polynomials. Then they (I think) form
1061 // vectors by gathering all 63 coefficients of each power together, for each
1062 // input, and doing more rounds of Karatsuba on the vectors until they bottom-
1063 // out somewhere with schoolbook multiplication.
1064 //
1065 // I tried something like that for NEON. NEON vectors are 128 bits so hold
1066 // eight coefficients. I wrote a function that did Karatsuba on eight
1067 // multiplications at the same time, using such vectors, and a Go script that
1068 // decomposed from degree-704, with Karatsuba in non-transposed form, until it
1069 // reached multiplications of degree-44. It batched up those 81
1070 // multiplications into lots of eight with a single one left over (which was
1071 // handled directly).
1072 //
1073 // It worked, but it was significantly slower than the dumb algorithm used
1074 // below. Potentially that was because I misunderstood how [HRSS] did it, or
1075 // because Clang is bad at generating good code from NEON intrinsics on ARMv7.
1076 // (Which is true: the code generated by Clang for the below is pretty crap.)
1077 //
1078 // This algorithm is much simpler. It just does Karatsuba decomposition all
1079 // the way down and never transposes. When it gets down to degree-16 or
1080 // degree-24 values, they are multiplied using schoolbook multiplication and
1081 // vector intrinsics. The vector operations form each of the eight phase-
1082 // shifts of one of the inputs, point-wise multiply, and then add into the
1083 // result at the correct place. This means that 33% (degree-16) or 25%
1084 // (degree-24) of the multiplies and adds are wasted, but it does ok.
1085 if (n == 2) {
1086 vec_t result[4];
1087 vec_t vec_a[3];
1088 static const vec_t kZero = {0};
1089 vec_a[0] = a[0];
1090 vec_a[1] = a[1];
1091 vec_a[2] = kZero;
1092
1093 result[0] = vec_mul(vec_a[0], vec_get_word(b[0], 0));
1094 result[1] = vec_mul(vec_a[1], vec_get_word(b[0], 0));
1095
1096 result[1] = vec_fma(result[1], vec_a[0], vec_get_word(b[1], 0));
1097 result[2] = vec_mul(vec_a[1], vec_get_word(b[1], 0));
1098 result[3] = kZero;
1099
1100 vec3_rshift_word(vec_a);
1101
1102#define BLOCK(x, y) \
1103 do { \
1104 result[x + 0] = \
1105 vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1106 result[x + 1] = \
1107 vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1108 result[x + 2] = \
1109 vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
1110 } while (0)
1111
1112 BLOCK(0, 1);
1113 BLOCK(1, 9);
1114
1115 vec3_rshift_word(vec_a);
1116
1117 BLOCK(0, 2);
1118 BLOCK(1, 10);
1119
1120 vec3_rshift_word(vec_a);
1121
1122 BLOCK(0, 3);
1123 BLOCK(1, 11);
1124
1125 vec3_rshift_word(vec_a);
1126
1127 BLOCK(0, 4);
1128 BLOCK(1, 12);
1129
1130 vec3_rshift_word(vec_a);
1131
1132 BLOCK(0, 5);
1133 BLOCK(1, 13);
1134
1135 vec3_rshift_word(vec_a);
1136
1137 BLOCK(0, 6);
1138 BLOCK(1, 14);
1139
1140 vec3_rshift_word(vec_a);
1141
1142 BLOCK(0, 7);
1143 BLOCK(1, 15);
1144
1145#undef BLOCK
1146
1147 memcpy(out, result, sizeof(result));
1148 return;
1149 }
1150
1151 if (n == 3) {
1152 vec_t result[6];
1153 vec_t vec_a[4];
1154 static const vec_t kZero = {0};
1155 vec_a[0] = a[0];
1156 vec_a[1] = a[1];
1157 vec_a[2] = a[2];
1158 vec_a[3] = kZero;
1159
1160 result[0] = vec_mul(a[0], vec_get_word(b[0], 0));
1161 result[1] = vec_mul(a[1], vec_get_word(b[0], 0));
1162 result[2] = vec_mul(a[2], vec_get_word(b[0], 0));
1163
1164#define BLOCK_PRE(x, y) \
1165 do { \
1166 result[x + 0] = \
1167 vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1168 result[x + 1] = \
1169 vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1170 result[x + 2] = vec_mul(vec_a[2], vec_get_word(b[y / 8], y % 8)); \
1171 } while (0)
1172
1173 BLOCK_PRE(1, 8);
1174 BLOCK_PRE(2, 16);
1175
1176 result[5] = kZero;
1177
1178 vec4_rshift_word(vec_a);
1179
1180#define BLOCK(x, y) \
1181 do { \
1182 result[x + 0] = \
1183 vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1184 result[x + 1] = \
1185 vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1186 result[x + 2] = \
1187 vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
1188 result[x + 3] = \
1189 vec_fma(result[x + 3], vec_a[3], vec_get_word(b[y / 8], y % 8)); \
1190 } while (0)
1191
1192 BLOCK(0, 1);
1193 BLOCK(1, 9);
1194 BLOCK(2, 17);
1195
1196 vec4_rshift_word(vec_a);
1197
1198 BLOCK(0, 2);
1199 BLOCK(1, 10);
1200 BLOCK(2, 18);
1201
1202 vec4_rshift_word(vec_a);
1203
1204 BLOCK(0, 3);
1205 BLOCK(1, 11);
1206 BLOCK(2, 19);
1207
1208 vec4_rshift_word(vec_a);
1209
1210 BLOCK(0, 4);
1211 BLOCK(1, 12);
1212 BLOCK(2, 20);
1213
1214 vec4_rshift_word(vec_a);
1215
1216 BLOCK(0, 5);
1217 BLOCK(1, 13);
1218 BLOCK(2, 21);
1219
1220 vec4_rshift_word(vec_a);
1221
1222 BLOCK(0, 6);
1223 BLOCK(1, 14);
1224 BLOCK(2, 22);
1225
1226 vec4_rshift_word(vec_a);
1227
1228 BLOCK(0, 7);
1229 BLOCK(1, 15);
1230 BLOCK(2, 23);
1231
1232#undef BLOCK
1233#undef BLOCK_PRE
1234
1235 memcpy(out, result, sizeof(result));
1236
1237 return;
1238 }
1239
1240 // Karatsuba multiplication.
1241 // https://en.wikipedia.org/wiki/Karatsuba_algorithm
1242
1243 // When |n| is odd, the two "halves" will have different lengths. The first is
1244 // always the smaller.
1245 const size_t low_len = n / 2;
1246 const size_t high_len = n - low_len;
1247 const vec_t *a_high = &a[low_len];
1248 const vec_t *b_high = &b[low_len];
1249
1250 // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
1251 // half.
1252 for (size_t i = 0; i < low_len; i++) {
1253 out[i] = vec_add(a_high[i], a[i]);
1254 out[high_len + i] = vec_add(b_high[i], b[i]);
1255 }
1256 if (high_len != low_len) {
1257 out[low_len] = a_high[low_len];
1258 out[high_len + low_len] = b_high[low_len];
1259 }
1260
1261 vec_t *const child_scratch = &scratch[2 * high_len];
1262 // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
1263 poly_mul_vec_aux(scratch, child_scratch, out, &out[high_len], high_len);
1264 // Calculate a_1 × b_1.
1265 poly_mul_vec_aux(&out[low_len * 2], child_scratch, a_high, b_high, high_len);
1266 // Calculate a_0 × b_0.
1267 poly_mul_vec_aux(out, child_scratch, a, b, low_len);
1268
1269 // Subtract those last two products from the first.
1270 for (size_t i = 0; i < low_len * 2; i++) {
1271 scratch[i] = vec_sub(scratch[i], vec_add(out[i], out[low_len * 2 + i]));
1272 }
1273 if (low_len != high_len) {
1274 scratch[low_len * 2] = vec_sub(scratch[low_len * 2], out[low_len * 4]);
1275 scratch[low_len * 2 + 1] =
1276 vec_sub(scratch[low_len * 2 + 1], out[low_len * 4 + 1]);
1277 }
1278
1279 // Add the middle product into the output.
1280 for (size_t i = 0; i < high_len * 2; i++) {
1281 out[low_len + i] = vec_add(out[low_len + i], scratch[i]);
1282 }
1283}
1284
1285// poly_mul_vec sets |*out| to |x|×|y| mod (𝑥^n - 1).
1286static void poly_mul_vec(struct poly *out, const struct poly *x,
1287 const struct poly *y) {
1288 OPENSSL_memset((uint16_t *)&x->v[N], 0, 3 * sizeof(uint16_t));
1289 OPENSSL_memset((uint16_t *)&y->v[N], 0, 3 * sizeof(uint16_t));
1290
1291 OPENSSL_STATIC_ASSERT(sizeof(out->v) == sizeof(vec_t) * VECS_PER_POLY,
1292 "struct poly is the wrong size");
1293 OPENSSL_STATIC_ASSERT(alignof(struct poly) == alignof(vec_t),
1294 "struct poly has incorrect alignment");
1295
1296 vec_t prod[VECS_PER_POLY * 2];
1297 vec_t scratch[172];
1298 poly_mul_vec_aux(prod, scratch, x->vectors, y->vectors, VECS_PER_POLY);
1299
1300 // |prod| needs to be reduced mod (𝑥^n - 1), which just involves adding the
1301 // upper-half to the lower-half. However, N is 701, which isn't a multiple of
1302 // the vector size, so the upper-half vectors all have to be shifted before
1303 // being added to the lower-half.
1304 vec_t *out_vecs = (vec_t *)out->v;
1305
1306 for (size_t i = 0; i < VECS_PER_POLY; i++) {
1307 const vec_t prev = prod[VECS_PER_POLY - 1 + i];
1308 const vec_t this = prod[VECS_PER_POLY + i];
1309 out_vecs[i] = vec_add(prod[i], vec_merge_3_5(prev, this));
1310 }
1311
1312 OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
1313}
1314
1315#endif // HRSS_HAVE_VECTOR_UNIT
1316
1317// poly_mul_novec_aux writes the product of |a| and |b| to |out|, using
1318// |scratch| as scratch space. It'll use Karatsuba if the inputs are large
1319// enough to warrant it. Each call uses 2*ceil(n/2) elements of |scratch| and
1320// the function recurses, except if |n| < 64, when |scratch| isn't used and the
1321// recursion stops. If |n| == |N| then |scratch| needs 1318 elements.
1322static void poly_mul_novec_aux(uint16_t *out, uint16_t *scratch,
1323 const uint16_t *a, const uint16_t *b, size_t n) {
1324 static const size_t kSchoolbookLimit = 64;
1325 if (n < kSchoolbookLimit) {
1326 OPENSSL_memset(out, 0, sizeof(uint16_t) * n * 2);
1327 for (size_t i = 0; i < n; i++) {
1328 for (size_t j = 0; j < n; j++) {
1329 out[i + j] += (unsigned) a[i] * b[j];
1330 }
1331 }
1332
1333 return;
1334 }
1335
1336 // Karatsuba multiplication.
1337 // https://en.wikipedia.org/wiki/Karatsuba_algorithm
1338
1339 // When |n| is odd, the two "halves" will have different lengths. The
1340 // first is always the smaller.
1341 const size_t low_len = n / 2;
1342 const size_t high_len = n - low_len;
1343 const uint16_t *const a_high = &a[low_len];
1344 const uint16_t *const b_high = &b[low_len];
1345
1346 for (size_t i = 0; i < low_len; i++) {
1347 out[i] = a_high[i] + a[i];
1348 out[high_len + i] = b_high[i] + b[i];
1349 }
1350 if (high_len != low_len) {
1351 out[low_len] = a_high[low_len];
1352 out[high_len + low_len] = b_high[low_len];
1353 }
1354
1355 uint16_t *const child_scratch = &scratch[2 * high_len];
1356 poly_mul_novec_aux(scratch, child_scratch, out, &out[high_len], high_len);
1357 poly_mul_novec_aux(&out[low_len * 2], child_scratch, a_high, b_high,
1358 high_len);
1359 poly_mul_novec_aux(out, child_scratch, a, b, low_len);
1360
1361 for (size_t i = 0; i < low_len * 2; i++) {
1362 scratch[i] -= out[i] + out[low_len * 2 + i];
1363 }
1364 if (low_len != high_len) {
1365 scratch[low_len * 2] -= out[low_len * 4];
1366 assert(out[low_len * 4 + 1] == 0);
1367 }
1368
1369 for (size_t i = 0; i < high_len * 2; i++) {
1370 out[low_len + i] += scratch[i];
1371 }
1372}
1373
1374// poly_mul_novec sets |*out| to |x|×|y| mod (𝑥^n - 1).
1375static void poly_mul_novec(struct poly *out, const struct poly *x,
1376 const struct poly *y) {
1377 uint16_t prod[2 * N];
1378 uint16_t scratch[1318];
1379 poly_mul_novec_aux(prod, scratch, x->v, y->v, N);
1380
1381 for (size_t i = 0; i < N; i++) {
1382 out->v[i] = prod[i] + prod[i + N];
1383 }
1384 OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
1385}
1386
1387static void poly_mul(struct poly *r, const struct poly *a,
1388 const struct poly *b) {
1389#if defined(POLY_RQ_MUL_ASM)
1390 const int has_avx2 = (OPENSSL_ia32cap_P[2] & (1 << 5)) != 0;
1391 if (has_avx2) {
1392 poly_Rq_mul(r->v, a->v, b->v);
1393 return;
1394 }
1395#endif
1396
1397#if defined(HRSS_HAVE_VECTOR_UNIT)
1398 if (vec_capable()) {
1399 poly_mul_vec(r, a, b);
1400 return;
1401 }
1402#endif
1403
1404 // Fallback, non-vector case.
1405 poly_mul_novec(r, a, b);
1406}
1407
1408// poly_mul_x_minus_1 sets |p| to |p|×(𝑥 - 1) mod (𝑥^n - 1).
1409static void poly_mul_x_minus_1(struct poly *p) {
1410 // Multiplying by (𝑥 - 1) means negating each coefficient and adding in
1411 // the value of the previous one.
1412 const uint16_t orig_final_coefficient = p->v[N - 1];
1413
1414 for (size_t i = N - 1; i > 0; i--) {
1415 p->v[i] = p->v[i - 1] - p->v[i];
1416 }
1417 p->v[0] = orig_final_coefficient - p->v[0];
1418}
1419
1420// poly_mod_phiN sets |p| to |p| mod Φ(N).
1421static void poly_mod_phiN(struct poly *p) {
1422 const uint16_t coeff700 = p->v[N - 1];
1423
1424 for (unsigned i = 0; i < N; i++) {
1425 p->v[i] -= coeff700;
1426 }
1427}
1428
1429// poly_clamp reduces each coefficient mod Q.
1430static void poly_clamp(struct poly *p) {
1431 for (unsigned i = 0; i < N; i++) {
1432 p->v[i] &= Q - 1;
1433 }
1434}
1435
1436
1437// Conversion functions
1438// --------------------
1439
1440// poly2_from_poly sets |*out| to |in| mod 2.
1441static void poly2_from_poly(struct poly2 *out, const struct poly *in) {
1442 crypto_word_t *words = out->v;
1443 unsigned shift = 0;
1444 crypto_word_t word = 0;
1445
1446 for (unsigned i = 0; i < N; i++) {
1447 word >>= 1;
1448 word |= (crypto_word_t)(in->v[i] & 1) << (BITS_PER_WORD - 1);
1449 shift++;
1450
1451 if (shift == BITS_PER_WORD) {
1452 *words = word;
1453 words++;
1454 word = 0;
1455 shift = 0;
1456 }
1457 }
1458
1459 word >>= BITS_PER_WORD - shift;
1460 *words = word;
1461}
1462
1463// mod3 treats |a| as a signed number and returns |a| mod 3.
1464static uint16_t mod3(int16_t a) {
1465 const int16_t q = ((int32_t)a * 21845) >> 16;
1466 int16_t ret = a - 3 * q;
1467 // At this point, |ret| is in {0, 1, 2, 3} and that needs to be mapped to {0,
1468 // 1, 2, 0}.
1469 return ret & ((ret & (ret >> 1)) - 1);
1470}
1471
1472// poly3_from_poly sets |*out| to |in|.
1473static void poly3_from_poly(struct poly3 *out, const struct poly *in) {
1474 crypto_word_t *words_s = out->s.v;
1475 crypto_word_t *words_a = out->a.v;
1476 crypto_word_t s = 0;
1477 crypto_word_t a = 0;
1478 unsigned shift = 0;
1479
1480 for (unsigned i = 0; i < N; i++) {
1481 // This duplicates the 13th bit upwards to the top of the uint16,
1482 // essentially treating it as a sign bit and converting into a signed int16.
1483 // The signed value is reduced mod 3, yielding {0, 1, 2}.
1484 const uint16_t v = mod3((int16_t)(in->v[i] << 3) >> 3);
1485 s >>= 1;
1486 const crypto_word_t s_bit = (crypto_word_t)(v & 2) << (BITS_PER_WORD - 2);
1487 s |= s_bit;
1488 a >>= 1;
1489 a |= s_bit | (crypto_word_t)(v & 1) << (BITS_PER_WORD - 1);
1490 shift++;
1491
1492 if (shift == BITS_PER_WORD) {
1493 *words_s = s;
1494 words_s++;
1495 *words_a = a;
1496 words_a++;
1497 s = a = 0;
1498 shift = 0;
1499 }
1500 }
1501
1502 s >>= BITS_PER_WORD - shift;
1503 a >>= BITS_PER_WORD - shift;
1504 *words_s = s;
1505 *words_a = a;
1506}
1507
1508// poly3_from_poly_checked sets |*out| to |in|, which has coefficients in {0, 1,
1509// Q-1}. It returns a mask indicating whether all coefficients were found to be
1510// in that set.
1511static crypto_word_t poly3_from_poly_checked(struct poly3 *out,
1512 const struct poly *in) {
1513 crypto_word_t *words_s = out->s.v;
1514 crypto_word_t *words_a = out->a.v;
1515 crypto_word_t s = 0;
1516 crypto_word_t a = 0;
1517 unsigned shift = 0;
1518 crypto_word_t ok = CONSTTIME_TRUE_W;
1519
1520 for (unsigned i = 0; i < N; i++) {
1521 const uint16_t v = in->v[i];
1522 // Maps {0, 1, Q-1} to {0, 1, 2}.
1523 uint16_t mod3 = v & 3;
1524 mod3 ^= mod3 >> 1;
1525 const uint16_t expected = (uint16_t)((~((mod3 >> 1) - 1)) | mod3) % Q;
1526 ok &= constant_time_eq_w(v, expected);
1527
1528 s >>= 1;
1529 const crypto_word_t s_bit = (crypto_word_t)(mod3 & 2)
1530 << (BITS_PER_WORD - 2);
1531 s |= s_bit;
1532 a >>= 1;
1533 a |= s_bit | (crypto_word_t)(mod3 & 1) << (BITS_PER_WORD - 1);
1534 shift++;
1535
1536 if (shift == BITS_PER_WORD) {
1537 *words_s = s;
1538 words_s++;
1539 *words_a = a;
1540 words_a++;
1541 s = a = 0;
1542 shift = 0;
1543 }
1544 }
1545
1546 s >>= BITS_PER_WORD - shift;
1547 a >>= BITS_PER_WORD - shift;
1548 *words_s = s;
1549 *words_a = a;
1550
1551 return ok;
1552}
1553
1554static void poly_from_poly2(struct poly *out, const struct poly2 *in) {
1555 const crypto_word_t *words = in->v;
1556 unsigned shift = 0;
1557 crypto_word_t word = *words;
1558
1559 for (unsigned i = 0; i < N; i++) {
1560 out->v[i] = word & 1;
1561 word >>= 1;
1562 shift++;
1563
1564 if (shift == BITS_PER_WORD) {
1565 words++;
1566 word = *words;
1567 shift = 0;
1568 }
1569 }
1570}
1571
1572static void poly_from_poly3(struct poly *out, const struct poly3 *in) {
1573 const crypto_word_t *words_s = in->s.v;
1574 const crypto_word_t *words_a = in->a.v;
1575 crypto_word_t word_s = ~(*words_s);
1576 crypto_word_t word_a = *words_a;
1577 unsigned shift = 0;
1578
1579 for (unsigned i = 0; i < N; i++) {
1580 out->v[i] = (uint16_t)(word_s & 1) - 1;
1581 out->v[i] |= word_a & 1;
1582 word_s >>= 1;
1583 word_a >>= 1;
1584 shift++;
1585
1586 if (shift == BITS_PER_WORD) {
1587 words_s++;
1588 words_a++;
1589 word_s = ~(*words_s);
1590 word_a = *words_a;
1591 shift = 0;
1592 }
1593 }
1594}
1595
1596// Polynomial inversion
1597// --------------------
1598
1599// poly_invert_mod2 sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod
1600// Φ(N)), all mod 2. This isn't useful in itself, but is part of doing inversion
1601// mod Q.
1602static void poly_invert_mod2(struct poly *out, const struct poly *in) {
1603 // This algorithm follows algorithm 10 in the paper. (Although, in contrast to
1604 // the paper, k should start at zero, not one, and the rotation count is needs
1605 // to handle trailing zero coefficients.) The best explanation for why it
1606 // works is in the "Why it works" section of [NTRUTN14].
1607
1608 struct poly2 b, c, f, g;
1609 poly2_from_poly(&f, in);
1610 OPENSSL_memset(&b, 0, sizeof(b));
1611 b.v[0] = 1;
1612 OPENSSL_memset(&c, 0, sizeof(c));
1613
1614 // Set g to all ones.
1615 OPENSSL_memset(&g, 0xff, sizeof(struct poly2));
1616 g.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
1617
1618 crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
1619 crypto_word_t still_going = CONSTTIME_TRUE_W;
1620
1621 for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
1622 const crypto_word_t s = still_going & lsb_to_all(f.v[0]);
1623 const crypto_word_t should_swap = s & constant_time_lt_w(deg_f, deg_g);
1624 poly2_cswap(&f, &g, should_swap);
1625 poly2_cswap(&b, &c, should_swap);
1626 const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
1627 deg_f ^= deg_sum;
1628 deg_g ^= deg_sum;
1629 assert(deg_g >= 1);
1630 poly2_fmadd(&f, &g, s);
1631 poly2_fmadd(&b, &c, s);
1632
1633 poly2_rshift1(&f);
1634 poly2_lshift1(&c);
1635
1636 deg_f--;
1637 const crypto_word_t f0_is_nonzero = lsb_to_all(f.v[0]);
1638 // |f0_is_nonzero| implies |still_going|.
1639 assert(!(f0_is_nonzero && !still_going));
1640 rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
1641 still_going &= ~constant_time_is_zero_w(deg_f);
1642 }
1643
1644 rotation++;
1645 rotation -= N & constant_time_lt_w(N, rotation);
1646 assert(poly2_top_bits_are_clear(&b));
1647 HRSS_poly2_rotr_consttime(&b, rotation);
1648 poly_from_poly2(out, &b);
1649}
1650
1651// poly_invert sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod Φ(N)).
1652static void poly_invert(struct poly *out, const struct poly *in) {
1653 // Inversion mod Q, which is done based on the result of inverting mod
1654 // 2. See [NTRUTN14] paper, bottom of page two.
1655 struct poly a, *b, tmp;
1656
1657 // a = -in.
1658 for (unsigned i = 0; i < N; i++) {
1659 a.v[i] = -in->v[i];
1660 }
1661
1662 // b = in^-1 mod 2.
1663 b = out;
1664 poly_invert_mod2(b, in);
1665
1666 // We are working mod Q=2**13 and we need to iterate ceil(log_2(13))
1667 // times, which is four.
1668 for (unsigned i = 0; i < 4; i++) {
1669 poly_mul(&tmp, &a, b);
1670 tmp.v[0] += 2;
1671 poly_mul(b, b, &tmp);
1672 }
1673}
1674
1675// Marshal and unmarshal functions for various basic types.
1676// --------------------------------------------------------
1677
1678#define POLY_BYTES 1138
1679
1680// poly_marshal serialises all but the final coefficient of |in| to |out|.
1681static void poly_marshal(uint8_t out[POLY_BYTES], const struct poly *in) {
1682 const uint16_t *p = in->v;
1683
1684 for (size_t i = 0; i < N / 8; i++) {
1685 out[0] = p[0];
1686 out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
1687 out[2] = p[1] >> 3;
1688 out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
1689 out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
1690 out[5] = p[3] >> 1;
1691 out[6] = (0xf & (p[3] >> 9)) | ((p[4] & 0x0f) << 4);
1692 out[7] = p[4] >> 4;
1693 out[8] = (1 & (p[4] >> 12)) | ((p[5] & 0x7f) << 1);
1694 out[9] = (0x3f & (p[5] >> 7)) | ((p[6] & 0x03) << 6);
1695 out[10] = p[6] >> 2;
1696 out[11] = (7 & (p[6] >> 10)) | ((p[7] & 0x1f) << 3);
1697 out[12] = p[7] >> 5;
1698
1699 p += 8;
1700 out += 13;
1701 }
1702
1703 // There are four remaining values.
1704 out[0] = p[0];
1705 out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
1706 out[2] = p[1] >> 3;
1707 out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
1708 out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
1709 out[5] = p[3] >> 1;
1710 out[6] = 0xf & (p[3] >> 9);
1711}
1712
1713// poly_unmarshal parses the output of |poly_marshal| and sets |out| such that
1714// all but the final coefficients match, and the final coefficient is calculated
1715// such that evaluating |out| at one results in zero. It returns one on success
1716// or zero if |in| is an invalid encoding.
1717static int poly_unmarshal(struct poly *out, const uint8_t in[POLY_BYTES]) {
1718 uint16_t *p = out->v;
1719
1720 for (size_t i = 0; i < N / 8; i++) {
1721 p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
1722 p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
1723 (uint16_t)(in[3] & 3) << 11;
1724 p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
1725 p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
1726 (uint16_t)(in[6] & 0xf) << 9;
1727 p[4] = (uint16_t)(in[6] >> 4) | (uint16_t)(in[7]) << 4 |
1728 (uint16_t)(in[8] & 1) << 12;
1729 p[5] = (uint16_t)(in[8] >> 1) | (uint16_t)(in[9] & 0x3f) << 7;
1730 p[6] = (uint16_t)(in[9] >> 6) | (uint16_t)(in[10]) << 2 |
1731 (uint16_t)(in[11] & 7) << 10;
1732 p[7] = (uint16_t)(in[11] >> 3) | (uint16_t)(in[12]) << 5;
1733
1734 p += 8;
1735 in += 13;
1736 }
1737
1738 // There are four coefficients remaining.
1739 p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
1740 p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
1741 (uint16_t)(in[3] & 3) << 11;
1742 p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
1743 p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
1744 (uint16_t)(in[6] & 0xf) << 9;
1745
1746 for (unsigned i = 0; i < N - 1; i++) {
1747 out->v[i] = (int16_t)(out->v[i] << 3) >> 3;
1748 }
1749
1750 // There are four unused bits in the last byte. We require them to be zero.
1751 if ((in[6] & 0xf0) != 0) {
1752 return 0;
1753 }
1754
1755 // Set the final coefficient as specifed in [HRSSNIST] 1.9.2 step 6.
1756 uint32_t sum = 0;
1757 for (size_t i = 0; i < N - 1; i++) {
1758 sum += out->v[i];
1759 }
1760
1761 out->v[N - 1] = (uint16_t)(0u - sum);
1762
1763 return 1;
1764}
1765
1766// mod3_from_modQ maps {0, 1, Q-1, 65535} -> {0, 1, 2, 2}. Note that |v| may
1767// have an invalid value when processing attacker-controlled inputs.
1768static uint16_t mod3_from_modQ(uint16_t v) {
1769 v &= 3;
1770 return v ^ (v >> 1);
1771}
1772
1773// poly_marshal_mod3 marshals |in| to |out| where the coefficients of |in| are
1774// all in {0, 1, Q-1, 65535} and |in| is mod Φ(N). (Note that coefficients may
1775// have invalid values when processing attacker-controlled inputs.)
1776static void poly_marshal_mod3(uint8_t out[HRSS_POLY3_BYTES],
1777 const struct poly *in) {
1778 const uint16_t *coeffs = in->v;
1779
1780 // Only 700 coefficients are marshaled because in[700] must be zero.
1781 assert(coeffs[N-1] == 0);
1782
1783 for (size_t i = 0; i < HRSS_POLY3_BYTES; i++) {
1784 const uint16_t coeffs0 = mod3_from_modQ(coeffs[0]);
1785 const uint16_t coeffs1 = mod3_from_modQ(coeffs[1]);
1786 const uint16_t coeffs2 = mod3_from_modQ(coeffs[2]);
1787 const uint16_t coeffs3 = mod3_from_modQ(coeffs[3]);
1788 const uint16_t coeffs4 = mod3_from_modQ(coeffs[4]);
1789 out[i] = coeffs0 + coeffs1 * 3 + coeffs2 * 9 + coeffs3 * 27 + coeffs4 * 81;
1790 coeffs += 5;
1791 }
1792}
1793
1794// HRSS-specific functions
1795// -----------------------
1796
1797// poly_short_sample samples a vector of values in {0xffff (i.e. -1), 0, 1}.
1798// This is the same action as the algorithm in [HRSSNIST] section 1.8.1, but
1799// with HRSS-SXY the sampling algorithm is now a private detail of the
1800// implementation (previously it had to match between two parties). This
1801// function uses that freedom to implement a flatter distribution of values.
1802static void poly_short_sample(struct poly *out,
1803 const uint8_t in[HRSS_SAMPLE_BYTES]) {
1804 OPENSSL_STATIC_ASSERT(HRSS_SAMPLE_BYTES == N - 1,
1805 "HRSS_SAMPLE_BYTES incorrect");
1806 for (size_t i = 0; i < N - 1; i++) {
1807 uint16_t v = mod3(in[i]);
1808 // Map {0, 1, 2} -> {0, 1, 0xffff}
1809 v |= ((v >> 1) ^ 1) - 1;
1810 out->v[i] = v;
1811 }
1812 out->v[N - 1] = 0;
1813}
1814
1815// poly_short_sample_plus performs the T+ sample as defined in [HRSSNIST],
1816// section 1.8.2.
1817static void poly_short_sample_plus(struct poly *out,
1818 const uint8_t in[HRSS_SAMPLE_BYTES]) {
1819 poly_short_sample(out, in);
1820
1821 // sum (and the product in the for loop) will overflow. But that's fine
1822 // because |sum| is bound by +/- (N-2), and N < 2^15 so it works out.
1823 uint16_t sum = 0;
1824 for (unsigned i = 0; i < N - 2; i++) {
1825 sum += (unsigned) out->v[i] * out->v[i + 1];
1826 }
1827
1828 // If the sum is negative, flip the sign of even-positioned coefficients. (See
1829 // page 8 of [HRSS].)
1830 sum = ((int16_t) sum) >> 15;
1831 const uint16_t scale = sum | (~sum & 1);
1832 for (unsigned i = 0; i < N; i += 2) {
1833 out->v[i] = (unsigned) out->v[i] * scale;
1834 }
1835}
1836
1837// poly_lift computes the function discussed in [HRSS], appendix B.
1838static void poly_lift(struct poly *out, const struct poly *a) {
1839 // We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the
1840 // Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime).
1841
1842 // 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up:
1843 //
1844 // R.<x> = PolynomialRing(GF(3)…)
1845 // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n))
1846 // list(inv)[:15]
1847 // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]
1848 //
1849 // This three-element pattern of coefficients repeats for the whole
1850 // polynomial.
1851 //
1852 // Next define the overbar operator such that z̅ = z[0] +
1853 // reverse(z[1:]). (Index zero of a polynomial here is the coefficient
1854 // of the constant term. So index one is the coefficient of 𝑥 and so
1855 // on.)
1856 //
1857 // A less odd way to define this is to see that z̅ negates the indexes,
1858 // so z̅[0] = z[-0], z̅[1] = z[-1] and so on.
1859 //
1860 // The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = <v,
1861 // z̅>, vz[1] = <v, 𝑥z̅>, …. (Where <a, b> is the inner product: the sum
1862 // of the point-wise products.) Although we calculated the inverse mod
1863 // Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end.
1864 // (That's because (𝑥^N - 1) is a multiple of Φ(N).)
1865 //
1866 // When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation
1867 // of the list of coefficients.
1868 //
1869 // Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like:
1870 //
1871 // def reverse(xs):
1872 // suffix = list(xs[1:])
1873 // suffix.reverse()
1874 // return [xs[0]] + suffix
1875 //
1876 // def rotate(xs):
1877 // return [xs[-1]] + xs[:-1]
1878 //
1879 // zoverbar = reverse(list(inv) + [0])
1880 // xzoverbar = rotate(reverse(list(inv) + [0]))
1881 // x2zoverbar = rotate(rotate(reverse(list(inv) + [0])))
1882 //
1883 // zoverbar[:15]
1884 // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
1885 // xzoverbar[:15]
1886 // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
1887 // x2zoverbar[:15]
1888 // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
1889 //
1890 // (For a formula for z̅, see lemma two of appendix B.)
1891 //
1892 // After the first three elements have been taken care of, all then have
1893 // a repeating three-element cycle. The next value (𝑥^3z̅) involves
1894 // three rotations of the first pattern, thus the three-element cycle
1895 // lines up. However, the discontinuity in the first three elements
1896 // obviously moves to a different position. Consider the difference
1897 // between 𝑥^3z̅ and z̅:
1898 //
1899 // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15]
1900 // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1901 //
1902 // This pattern of differences is the same for all elements, although it
1903 // obviously moves right with the rotations.
1904 //
1905 // From this, we reach algorithm eight of appendix B.
1906
1907 // Handle the first three elements of the inner products.
1908 out->v[0] = a->v[0] + a->v[2];
1909 out->v[1] = a->v[1];
1910 out->v[2] = -a->v[0] + a->v[2];
1911
1912 // s0, s1, s2 are added into out->v[0], out->v[1], and out->v[2],
1913 // respectively. We do not compute s1 because it's just -(s0 + s1).
1914 uint16_t s0 = 0, s2 = 0;
1915 for (size_t i = 3; i < 699; i += 3) {
1916 s0 += -a->v[i] + a->v[i + 2];
1917 // s1 += a->v[i] - a->v[i + 1];
1918 s2 += a->v[i + 1] - a->v[i + 2];
1919 }
1920
1921 // Handle the fact that the three-element pattern doesn't fill the
1922 // polynomial exactly (since 701 isn't a multiple of three).
1923 s0 -= a->v[699];
1924 // s1 += a->v[699] - a->v[700];
1925 s2 += a->v[700];
1926
1927 // Note that s0 + s1 + s2 = 0.
1928 out->v[0] += s0;
1929 out->v[1] -= (s0 + s2); // = s1
1930 out->v[2] += s2;
1931
1932 // Calculate the remaining inner products by taking advantage of the
1933 // fact that the pattern repeats every three cycles and the pattern of
1934 // differences moves with the rotation.
1935 for (size_t i = 3; i < N; i++) {
1936 out->v[i] = (out->v[i - 3] - (a->v[i - 2] + a->v[i - 1] + a->v[i]));
1937 }
1938
1939 // Reduce mod Φ(N) by subtracting a multiple of out[700] from every
1940 // element and convert to mod Q. (See above about adding twice as
1941 // subtraction.)
1942 const crypto_word_t v = out->v[700];
1943 for (unsigned i = 0; i < N; i++) {
1944 const uint16_t vi_mod3 = mod3(out->v[i] - v);
1945 // Map {0, 1, 2} to {0, 1, 0xffff}.
1946 out->v[i] = (~((vi_mod3 >> 1) - 1)) | vi_mod3;
1947 }
1948
1949 poly_mul_x_minus_1(out);
1950}
1951
1952struct public_key {
1953 struct poly ph;
1954};
1955
1956struct private_key {
1957 struct poly3 f, f_inverse;
1958 struct poly ph_inverse;
1959 uint8_t hmac_key[32];
1960};
1961
1962// public_key_from_external converts an external public key pointer into an
1963// internal one. Externally the alignment is only specified to be eight bytes
1964// but we need 16-byte alignment. We could annotate the external struct with
1965// that alignment but we can only assume that malloced pointers are 8-byte
1966// aligned in any case. (Even if the underlying malloc returns values with
1967// 16-byte alignment, |OPENSSL_malloc| will store an 8-byte size prefix and mess
1968// that up.)
1969static struct public_key *public_key_from_external(
1970 struct HRSS_public_key *ext) {
1971 OPENSSL_STATIC_ASSERT(
1972 sizeof(struct HRSS_public_key) >= sizeof(struct public_key) + 15,
1973 "HRSS public key too small");
1974
1975 uintptr_t p = (uintptr_t)ext;
1976 p = (p + 15) & ~15;
1977 return (struct public_key *)p;
1978}
1979
1980// private_key_from_external does the same thing as |public_key_from_external|,
1981// but for private keys. See the comment on that function about alignment
1982// issues.
1983static struct private_key *private_key_from_external(
1984 struct HRSS_private_key *ext) {
1985 OPENSSL_STATIC_ASSERT(
1986 sizeof(struct HRSS_private_key) >= sizeof(struct private_key) + 15,
1987 "HRSS private key too small");
1988
1989 uintptr_t p = (uintptr_t)ext;
1990 p = (p + 15) & ~15;
1991 return (struct private_key *)p;
1992}
1993
1994void HRSS_generate_key(
1995 struct HRSS_public_key *out_pub, struct HRSS_private_key *out_priv,
1996 const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES + 32]) {
1997 struct public_key *pub = public_key_from_external(out_pub);
1998 struct private_key *priv = private_key_from_external(out_priv);
1999
2000 OPENSSL_memcpy(priv->hmac_key, in + 2 * HRSS_SAMPLE_BYTES,
2001 sizeof(priv->hmac_key));
2002
2003 struct poly f;
2004 poly_short_sample_plus(&f, in);
2005 poly3_from_poly(&priv->f, &f);
2006 HRSS_poly3_invert(&priv->f_inverse, &priv->f);
2007
2008 // pg_phi1 is p (i.e. 3) × g × Φ(1) (i.e. 𝑥-1).
2009 struct poly pg_phi1;
2010 poly_short_sample_plus(&pg_phi1, in + HRSS_SAMPLE_BYTES);
2011 for (unsigned i = 0; i < N; i++) {
2012 pg_phi1.v[i] *= 3;
2013 }
2014 poly_mul_x_minus_1(&pg_phi1);
2015
2016 struct poly pfg_phi1;
2017 poly_mul(&pfg_phi1, &f, &pg_phi1);
2018
2019 struct poly pfg_phi1_inverse;
2020 poly_invert(&pfg_phi1_inverse, &pfg_phi1);
2021
2022 poly_mul(&pub->ph, &pfg_phi1_inverse, &pg_phi1);
2023 poly_mul(&pub->ph, &pub->ph, &pg_phi1);
2024 poly_clamp(&pub->ph);
2025
2026 poly_mul(&priv->ph_inverse, &pfg_phi1_inverse, &f);
2027 poly_mul(&priv->ph_inverse, &priv->ph_inverse, &f);
2028 poly_clamp(&priv->ph_inverse);
2029}
2030
2031static const char kSharedKey[] = "shared key";
2032
2033void HRSS_encap(uint8_t out_ciphertext[POLY_BYTES],
2034 uint8_t out_shared_key[32],
2035 const struct HRSS_public_key *in_pub,
2036 const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES]) {
2037 const struct public_key *pub =
2038 public_key_from_external((struct HRSS_public_key *)in_pub);
2039 struct poly m, r, m_lifted;
2040 poly_short_sample(&m, in);
2041 poly_short_sample(&r, in + HRSS_SAMPLE_BYTES);
2042 poly_lift(&m_lifted, &m);
2043
2044 struct poly prh_plus_m;
2045 poly_mul(&prh_plus_m, &r, &pub->ph);
2046 for (unsigned i = 0; i < N; i++) {
2047 prh_plus_m.v[i] += m_lifted.v[i];
2048 }
2049
2050 poly_marshal(out_ciphertext, &prh_plus_m);
2051
2052 uint8_t m_bytes[HRSS_POLY3_BYTES], r_bytes[HRSS_POLY3_BYTES];
2053 poly_marshal_mod3(m_bytes, &m);
2054 poly_marshal_mod3(r_bytes, &r);
2055
2056 SHA256_CTX hash_ctx;
2057 SHA256_Init(&hash_ctx);
2058 SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
2059 SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
2060 SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
2061 SHA256_Update(&hash_ctx, out_ciphertext, POLY_BYTES);
2062 SHA256_Final(out_shared_key, &hash_ctx);
2063}
2064
2065void HRSS_decap(uint8_t out_shared_key[HRSS_KEY_BYTES],
2066 const struct HRSS_private_key *in_priv,
2067 const uint8_t *ciphertext, size_t ciphertext_len) {
2068 const struct private_key *priv =
2069 private_key_from_external((struct HRSS_private_key *)in_priv);
2070
2071 // This is HMAC, expanded inline rather than using the |HMAC| function so that
2072 // we can avoid dealing with possible allocation failures and so keep this
2073 // function infallible.
2074 uint8_t masked_key[SHA256_CBLOCK];
2075 OPENSSL_STATIC_ASSERT(sizeof(priv->hmac_key) <= sizeof(masked_key),
2076 "HRSS HMAC key larger than SHA-256 block size");
2077 for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
2078 masked_key[i] = priv->hmac_key[i] ^ 0x36;
2079 }
2080 OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x36,
2081 sizeof(masked_key) - sizeof(priv->hmac_key));
2082
2083 SHA256_CTX hash_ctx;
2084 SHA256_Init(&hash_ctx);
2085 SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
2086 SHA256_Update(&hash_ctx, ciphertext, ciphertext_len);
2087 uint8_t inner_digest[SHA256_DIGEST_LENGTH];
2088 SHA256_Final(inner_digest, &hash_ctx);
2089
2090 for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
2091 masked_key[i] ^= (0x5c ^ 0x36);
2092 }
2093 OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x5c,
2094 sizeof(masked_key) - sizeof(priv->hmac_key));
2095
2096 SHA256_Init(&hash_ctx);
2097 SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
2098 SHA256_Update(&hash_ctx, inner_digest, sizeof(inner_digest));
2099 OPENSSL_STATIC_ASSERT(HRSS_KEY_BYTES == SHA256_DIGEST_LENGTH,
2100 "HRSS shared key length incorrect");
2101 SHA256_Final(out_shared_key, &hash_ctx);
2102
2103 struct poly c;
2104 // If the ciphertext is publicly invalid then a random shared key is still
2105 // returned to simply the logic of the caller, but this path is not constant
2106 // time.
2107 if (ciphertext_len != HRSS_CIPHERTEXT_BYTES ||
2108 !poly_unmarshal(&c, ciphertext)) {
2109 return;
2110 }
2111
2112 struct poly f, cf;
2113 struct poly3 cf3, m3;
2114 poly_from_poly3(&f, &priv->f);
2115 poly_mul(&cf, &c, &f);
2116 poly3_from_poly(&cf3, &cf);
2117 // Note that cf3 is not reduced mod Φ(N). That reduction is deferred.
2118 HRSS_poly3_mul(&m3, &cf3, &priv->f_inverse);
2119
2120 struct poly m, m_lifted;
2121 poly_from_poly3(&m, &m3);
2122 poly_lift(&m_lifted, &m);
2123
2124 struct poly r;
2125 for (unsigned i = 0; i < N; i++) {
2126 r.v[i] = c.v[i] - m_lifted.v[i];
2127 }
2128 poly_mul(&r, &r, &priv->ph_inverse);
2129 poly_mod_phiN(&r);
2130 poly_clamp(&r);
2131
2132 struct poly3 r3;
2133 crypto_word_t ok = poly3_from_poly_checked(&r3, &r);
2134
2135 // [NTRUCOMP] section 5.1 includes ReEnc2 and a proof that it's valid. Rather
2136 // than do an expensive |poly_mul|, it rebuilds |c'| from |c - lift(m)|
2137 // (called |b|) with:
2138 // t = (−b(1)/N) mod Q
2139 // c' = b + tΦ(N) + lift(m) mod Q
2140 //
2141 // When polynomials are transmitted, the final coefficient is omitted and
2142 // |poly_unmarshal| sets it such that f(1) == 0. Thus c(1) == 0. Also,
2143 // |poly_lift| multiplies the result by (x-1) and therefore evaluating a
2144 // lifted polynomial at 1 is also zero. Thus lift(m)(1) == 0 and so
2145 // (c - lift(m))(1) == 0.
2146 //
2147 // Although we defer the reduction above, |b| is conceptually reduced mod
2148 // Φ(N). In order to do that reduction one subtracts |c[N-1]| from every
2149 // coefficient. Therefore b(1) = -c[N-1]×N. The value of |t|, above, then is
2150 // just recovering |c[N-1]|, and adding tΦ(N) is simply undoing the reduction.
2151 // Therefore b + tΦ(N) + lift(m) = c by construction and we don't need to
2152 // recover |c| at all so long as we do the checks in
2153 // |poly3_from_poly_checked|.
2154 //
2155 // The |poly_marshal| here then is just confirming that |poly_unmarshal| is
2156 // strict and could be omitted.
2157
2158 uint8_t expected_ciphertext[HRSS_CIPHERTEXT_BYTES];
2159 OPENSSL_STATIC_ASSERT(HRSS_CIPHERTEXT_BYTES == POLY_BYTES,
2160 "ciphertext is the wrong size");
2161 assert(ciphertext_len == sizeof(expected_ciphertext));
2162 poly_marshal(expected_ciphertext, &c);
2163
2164 uint8_t m_bytes[HRSS_POLY3_BYTES];
2165 uint8_t r_bytes[HRSS_POLY3_BYTES];
2166 poly_marshal_mod3(m_bytes, &m);
2167 poly_marshal_mod3(r_bytes, &r);
2168
2169 ok &= constant_time_is_zero_w(CRYPTO_memcmp(ciphertext, expected_ciphertext,
2170 sizeof(expected_ciphertext)));
2171
2172 uint8_t shared_key[32];
2173 SHA256_Init(&hash_ctx);
2174 SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
2175 SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
2176 SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
2177 SHA256_Update(&hash_ctx, expected_ciphertext, sizeof(expected_ciphertext));
2178 SHA256_Final(shared_key, &hash_ctx);
2179
2180 for (unsigned i = 0; i < sizeof(shared_key); i++) {
2181 out_shared_key[i] =
2182 constant_time_select_8(ok, shared_key[i], out_shared_key[i]);
2183 }
2184}
2185
2186void HRSS_marshal_public_key(uint8_t out[HRSS_PUBLIC_KEY_BYTES],
2187 const struct HRSS_public_key *in_pub) {
2188 const struct public_key *pub =
2189 public_key_from_external((struct HRSS_public_key *)in_pub);
2190 poly_marshal(out, &pub->ph);
2191}
2192
2193int HRSS_parse_public_key(struct HRSS_public_key *out,
2194 const uint8_t in[HRSS_PUBLIC_KEY_BYTES]) {
2195 struct public_key *pub = public_key_from_external(out);
2196 if (!poly_unmarshal(&pub->ph, in)) {
2197 return 0;
2198 }
2199 OPENSSL_memset(&pub->ph.v[N], 0, 3 * sizeof(uint16_t));
2200 return 1;
2201}
2202