1/*
2 * This file is part of the MicroPython project, http://micropython.org/
3 *
4 * The MIT License (MIT)
5 *
6 * Copyright (c) 2013, 2014 Damien P. George
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to deal
10 * in the Software without restriction, including without limitation the rights
11 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 * copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in
16 * all copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 * THE SOFTWARE.
25 */
26
27#include <string.h>
28#include <assert.h>
29
30#include "py/mpz.h"
31
32#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
33
34#define DIG_SIZE (MPZ_DIG_SIZE)
35#define DIG_MASK ((MPZ_LONG_1 << DIG_SIZE) - 1)
36#define DIG_MSB (MPZ_LONG_1 << (DIG_SIZE - 1))
37#define DIG_BASE (MPZ_LONG_1 << DIG_SIZE)
38
39/*
40 mpz is an arbitrary precision integer type with a public API.
41
42 mpn functions act on non-negative integers represented by an array of generalised
43 digits (eg a word per digit). You also need to specify separately the length of the
44 array. There is no public API for mpn. Rather, the functions are used by mpz to
45 implement its features.
46
47 Integer values are stored little endian (first digit is first in memory).
48
49 Definition of normalise: ?
50*/
51
52STATIC size_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
53 for (--idig; idig >= oidig && *idig == 0; --idig) {
54 }
55 return idig + 1 - oidig;
56}
57
58/* compares i with j
59 returns sign(i - j)
60 assumes i, j are normalised
61*/
62STATIC int mpn_cmp(const mpz_dig_t *idig, size_t ilen, const mpz_dig_t *jdig, size_t jlen) {
63 if (ilen < jlen) {
64 return -1;
65 }
66 if (ilen > jlen) {
67 return 1;
68 }
69
70 for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
71 mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);
72 if (cmp < 0) {
73 return -1;
74 }
75 if (cmp > 0) {
76 return 1;
77 }
78 }
79
80 return 0;
81}
82
83/* computes i = j << n
84 returns number of digits in i
85 assumes enough memory in i; assumes normalised j; assumes n > 0
86 can have i, j pointing to same memory
87*/
88STATIC size_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
89 mp_uint_t n_whole = (n + DIG_SIZE - 1) / DIG_SIZE;
90 mp_uint_t n_part = n % DIG_SIZE;
91 if (n_part == 0) {
92 n_part = DIG_SIZE;
93 }
94
95 // start from the high end of the digit arrays
96 idig += jlen + n_whole - 1;
97 jdig += jlen - 1;
98
99 // shift the digits
100 mpz_dbl_dig_t d = 0;
101 for (size_t i = jlen; i > 0; i--, idig--, jdig--) {
102 d |= *jdig;
103 *idig = (d >> (DIG_SIZE - n_part)) & DIG_MASK;
104 d <<= DIG_SIZE;
105 }
106
107 // store remaining bits
108 *idig = (d >> (DIG_SIZE - n_part)) & DIG_MASK;
109 idig -= n_whole - 1;
110 memset(idig, 0, (n_whole - 1) * sizeof(mpz_dig_t));
111
112 // work out length of result
113 jlen += n_whole;
114 while (jlen != 0 && idig[jlen - 1] == 0) {
115 jlen--;
116 }
117
118 // return length of result
119 return jlen;
120}
121
122/* computes i = j >> n
123 returns number of digits in i
124 assumes enough memory in i; assumes normalised j; assumes n > 0
125 can have i, j pointing to same memory
126*/
127STATIC size_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
128 mp_uint_t n_whole = n / DIG_SIZE;
129 mp_uint_t n_part = n % DIG_SIZE;
130
131 if (n_whole >= jlen) {
132 return 0;
133 }
134
135 jdig += n_whole;
136 jlen -= n_whole;
137
138 for (size_t i = jlen; i > 0; i--, idig++, jdig++) {
139 mpz_dbl_dig_t d = *jdig;
140 if (i > 1) {
141 d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
142 }
143 d >>= n_part;
144 *idig = d & DIG_MASK;
145 }
146
147 if (idig[-1] == 0) {
148 jlen--;
149 }
150
151 return jlen;
152}
153
154/* computes i = j + k
155 returns number of digits in i
156 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
157 can have i, j, k pointing to same memory
158*/
159STATIC size_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
160 mpz_dig_t *oidig = idig;
161 mpz_dbl_dig_t carry = 0;
162
163 jlen -= klen;
164
165 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
166 carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;
167 *idig = carry & DIG_MASK;
168 carry >>= DIG_SIZE;
169 }
170
171 for (; jlen > 0; --jlen, ++idig, ++jdig) {
172 carry += *jdig;
173 *idig = carry & DIG_MASK;
174 carry >>= DIG_SIZE;
175 }
176
177 if (carry != 0) {
178 *idig++ = carry;
179 }
180
181 return idig - oidig;
182}
183
184/* computes i = j - k
185 returns number of digits in i
186 assumes enough memory in i; assumes normalised j, k; assumes j >= k
187 can have i, j, k pointing to same memory
188*/
189STATIC size_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
190 mpz_dig_t *oidig = idig;
191 mpz_dbl_dig_signed_t borrow = 0;
192
193 jlen -= klen;
194
195 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
196 borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;
197 *idig = borrow & DIG_MASK;
198 borrow >>= DIG_SIZE;
199 }
200
201 for (; jlen > 0; --jlen, ++idig, ++jdig) {
202 borrow += *jdig;
203 *idig = borrow & DIG_MASK;
204 borrow >>= DIG_SIZE;
205 }
206
207 return mpn_remove_trailing_zeros(oidig, idig);
208}
209
210#if MICROPY_OPT_MPZ_BITWISE
211
212/* computes i = j & k
213 returns number of digits in i
214 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
215 can have i, j, k pointing to same memory
216*/
217STATIC size_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t *kdig, size_t klen) {
218 mpz_dig_t *oidig = idig;
219
220 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
221 *idig = *jdig & *kdig;
222 }
223
224 return mpn_remove_trailing_zeros(oidig, idig);
225}
226
227#endif
228
229/* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
230 i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
231 i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
232 computes general form:
233 i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
234 returns number of digits in i
235 assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
236 can have i, j, k pointing to same memory
237*/
238STATIC size_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
239 mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
240 mpz_dig_t *oidig = idig;
241 mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
242 mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
243 mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
244
245 for (; jlen > 0; ++idig, ++jdig) {
246 carryj += *jdig ^ jmask;
247 carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
248 carryi += ((carryj & carryk) ^ imask) & DIG_MASK;
249 *idig = carryi & DIG_MASK;
250 carryk >>= DIG_SIZE;
251 carryj >>= DIG_SIZE;
252 carryi >>= DIG_SIZE;
253 }
254
255 if (0 != carryi) {
256 *idig++ = carryi;
257 }
258
259 return mpn_remove_trailing_zeros(oidig, idig);
260}
261
262#if MICROPY_OPT_MPZ_BITWISE
263
264/* computes i = j | k
265 returns number of digits in i
266 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
267 can have i, j, k pointing to same memory
268*/
269STATIC size_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
270 mpz_dig_t *oidig = idig;
271
272 jlen -= klen;
273
274 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
275 *idig = *jdig | *kdig;
276 }
277
278 for (; jlen > 0; --jlen, ++idig, ++jdig) {
279 *idig = *jdig;
280 }
281
282 return idig - oidig;
283}
284
285#endif
286
287/* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
288 i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
289 i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
290 computes general form:
291 i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
292 returns number of digits in i
293 assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
294 can have i, j, k pointing to same memory
295*/
296
297#if MICROPY_OPT_MPZ_BITWISE
298
299STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
300 mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
301 mpz_dig_t *oidig = idig;
302 mpz_dbl_dig_t carryi = 1;
303 mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
304 mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
305
306 for (; jlen > 0; ++idig, ++jdig) {
307 carryj += *jdig ^ jmask;
308 carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
309 carryi += ((carryj | carryk) ^ DIG_MASK) & DIG_MASK;
310 *idig = carryi & DIG_MASK;
311 carryk >>= DIG_SIZE;
312 carryj >>= DIG_SIZE;
313 carryi >>= DIG_SIZE;
314 }
315
316 // At least one of j,k must be negative so the above for-loop runs at least
317 // once. For carryi to be non-zero here it must be equal to 1 at the end of
318 // each iteration of the loop. So the accumulation of carryi must overflow
319 // each time, ie carryi += 0xff..ff. So carryj|carryk must be 0 in the
320 // DIG_MASK bits on each iteration. But considering all cases of signs of
321 // j,k one sees that this is not possible.
322 assert(carryi == 0);
323
324 return mpn_remove_trailing_zeros(oidig, idig);
325}
326
327#else
328
329STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
330 mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
331 mpz_dig_t *oidig = idig;
332 mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
333 mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
334 mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
335
336 for (; jlen > 0; ++idig, ++jdig) {
337 carryj += *jdig ^ jmask;
338 carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
339 carryi += ((carryj | carryk) ^ imask) & DIG_MASK;
340 *idig = carryi & DIG_MASK;
341 carryk >>= DIG_SIZE;
342 carryj >>= DIG_SIZE;
343 carryi >>= DIG_SIZE;
344 }
345
346 // See comment in above mpn_or_neg for why carryi must be 0.
347 assert(carryi == 0);
348
349 return mpn_remove_trailing_zeros(oidig, idig);
350}
351
352#endif
353
354#if MICROPY_OPT_MPZ_BITWISE
355
356/* computes i = j ^ k
357 returns number of digits in i
358 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
359 can have i, j, k pointing to same memory
360*/
361STATIC size_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
362 mpz_dig_t *oidig = idig;
363
364 jlen -= klen;
365
366 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
367 *idig = *jdig ^ *kdig;
368 }
369
370 for (; jlen > 0; --jlen, ++idig, ++jdig) {
371 *idig = *jdig;
372 }
373
374 return mpn_remove_trailing_zeros(oidig, idig);
375}
376
377#endif
378
379/* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
380 i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
381 i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
382 computes general form:
383 i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
384 returns number of digits in i
385 assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
386 can have i, j, k pointing to same memory
387*/
388STATIC size_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
389 mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
390 mpz_dig_t *oidig = idig;
391
392 for (; jlen > 0; ++idig, ++jdig) {
393 carryj += *jdig + DIG_MASK;
394 carryk += (--klen <= --jlen) ? (*kdig++ + DIG_MASK) : DIG_MASK;
395 carryi += (carryj ^ carryk) & DIG_MASK;
396 *idig = carryi & DIG_MASK;
397 carryk >>= DIG_SIZE;
398 carryj >>= DIG_SIZE;
399 carryi >>= DIG_SIZE;
400 }
401
402 if (0 != carryi) {
403 *idig++ = carryi;
404 }
405
406 return mpn_remove_trailing_zeros(oidig, idig);
407}
408
409/* computes i = i * d1 + d2
410 returns number of digits in i
411 assumes enough memory in i; assumes normalised i; assumes dmul != 0
412*/
413STATIC size_t mpn_mul_dig_add_dig(mpz_dig_t *idig, size_t ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
414 mpz_dig_t *oidig = idig;
415 mpz_dbl_dig_t carry = dadd;
416
417 for (; ilen > 0; --ilen, ++idig) {
418 carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
419 *idig = carry & DIG_MASK;
420 carry >>= DIG_SIZE;
421 }
422
423 if (carry != 0) {
424 *idig++ = carry;
425 }
426
427 return idig - oidig;
428}
429
430/* computes i = j * k
431 returns number of digits in i
432 assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
433 can have j, k point to same memory
434*/
435STATIC size_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mpz_dig_t *kdig, size_t klen) {
436 mpz_dig_t *oidig = idig;
437 size_t ilen = 0;
438
439 for (; klen > 0; --klen, ++idig, ++kdig) {
440 mpz_dig_t *id = idig;
441 mpz_dbl_dig_t carry = 0;
442
443 size_t jl = jlen;
444 for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
445 carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
446 *id = carry & DIG_MASK;
447 carry >>= DIG_SIZE;
448 }
449
450 if (carry != 0) {
451 *id++ = carry;
452 }
453
454 ilen = id - oidig;
455 }
456
457 return ilen;
458}
459
460/* natural_div - quo * den + new_num = old_num (ie num is replaced with rem)
461 assumes den != 0
462 assumes num_dig has enough memory to be extended by 1 digit
463 assumes quo_dig has enough memory (as many digits as num)
464 assumes quo_dig is filled with zeros
465*/
466STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_dig, size_t den_len, mpz_dig_t *quo_dig, size_t *quo_len) {
467 mpz_dig_t *orig_num_dig = num_dig;
468 mpz_dig_t *orig_quo_dig = quo_dig;
469 mpz_dig_t norm_shift = 0;
470 mpz_dbl_dig_t lead_den_digit;
471
472 // handle simple cases
473 {
474 int cmp = mpn_cmp(num_dig, *num_len, den_dig, den_len);
475 if (cmp == 0) {
476 *num_len = 0;
477 quo_dig[0] = 1;
478 *quo_len = 1;
479 return;
480 } else if (cmp < 0) {
481 // numerator remains the same
482 *quo_len = 0;
483 return;
484 }
485 }
486
487 // We need to normalise the denominator (leading bit of leading digit is 1)
488 // so that the division routine works. Since the denominator memory is
489 // read-only we do the normalisation on the fly, each time a digit of the
490 // denominator is needed. We need to know is how many bits to shift by.
491
492 // count number of leading zeros in leading digit of denominator
493 {
494 mpz_dig_t d = den_dig[den_len - 1];
495 while ((d & DIG_MSB) == 0) {
496 d <<= 1;
497 ++norm_shift;
498 }
499 }
500
501 // now need to shift numerator by same amount as denominator
502 // first, increase length of numerator in case we need more room to shift
503 num_dig[*num_len] = 0;
504 ++(*num_len);
505 for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) {
506 mpz_dig_t n = *num;
507 *num = ((n << norm_shift) | carry) & DIG_MASK;
508 carry = (mpz_dbl_dig_t)n >> (DIG_SIZE - norm_shift);
509 }
510
511 // cache the leading digit of the denominator
512 lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
513 if (den_len >= 2) {
514 lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
515 }
516
517 // point num_dig to last digit in numerator
518 num_dig += *num_len - 1;
519
520 // calculate number of digits in quotient
521 *quo_len = *num_len - den_len;
522
523 // point to last digit to store for quotient
524 quo_dig += *quo_len - 1;
525
526 // keep going while we have enough digits to divide
527 while (*num_len > den_len) {
528 mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];
529
530 // get approximate quotient
531 quo /= lead_den_digit;
532
533 // Multiply quo by den and subtract from num to get remainder.
534 // Must be careful with overflow of the borrow variable. Both
535 // borrow and low_digs are signed values and need signed right-shift,
536 // but x is unsigned and may take a full-range value.
537 const mpz_dig_t *d = den_dig;
538 mpz_dbl_dig_t d_norm = 0;
539 mpz_dbl_dig_signed_t borrow = 0;
540 for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
541 // Get the next digit in (den).
542 d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
543 // Multiply the next digit in (quo * den).
544 mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
545 // Compute the low DIG_MASK bits of the next digit in (num - quo * den)
546 mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
547 // Store the digit result for (num).
548 *n = low_digs & DIG_MASK;
549 // Compute the borrow, shifted right before summing to avoid overflow.
550 borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
551 }
552
553 // At this point we have either:
554 //
555 // 1. quo was the correct value and the most-sig-digit of num is exactly
556 // cancelled by borrow (borrow + *num_dig == 0). In this case there is
557 // nothing more to do.
558 //
559 // 2. quo was too large, we subtracted too many den from num, and the
560 // most-sig-digit of num is less than needed (borrow + *num_dig < 0).
561 // In this case we must reduce quo and add back den to num until the
562 // carry from this operation cancels out the borrow.
563 //
564 borrow += *num_dig;
565 for (; borrow != 0; --quo) {
566 d = den_dig;
567 d_norm = 0;
568 mpz_dbl_dig_t carry = 0;
569 for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
570 d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
571 carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
572 *n = carry & DIG_MASK;
573 carry >>= DIG_SIZE;
574 }
575 borrow += carry;
576 }
577
578 // store this digit of the quotient
579 *quo_dig = quo & DIG_MASK;
580 --quo_dig;
581
582 // move down to next digit of numerator
583 --num_dig;
584 --(*num_len);
585 }
586
587 // unnormalise numerator (remainder now)
588 for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
589 mpz_dig_t n = *num;
590 *num = ((n >> norm_shift) | carry) & DIG_MASK;
591 carry = (mpz_dbl_dig_t)n << (DIG_SIZE - norm_shift);
592 }
593
594 // strip trailing zeros
595
596 while (*quo_len > 0 && orig_quo_dig[*quo_len - 1] == 0) {
597 --(*quo_len);
598 }
599
600 while (*num_len > 0 && orig_num_dig[*num_len - 1] == 0) {
601 --(*num_len);
602 }
603}
604
605#define MIN_ALLOC (2)
606
607void mpz_init_zero(mpz_t *z) {
608 z->neg = 0;
609 z->fixed_dig = 0;
610 z->alloc = 0;
611 z->len = 0;
612 z->dig = NULL;
613}
614
615void mpz_init_from_int(mpz_t *z, mp_int_t val) {
616 mpz_init_zero(z);
617 mpz_set_from_int(z, val);
618}
619
620void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, size_t alloc, mp_int_t val) {
621 z->neg = 0;
622 z->fixed_dig = 1;
623 z->alloc = alloc;
624 z->len = 0;
625 z->dig = dig;
626 mpz_set_from_int(z, val);
627}
628
629void mpz_deinit(mpz_t *z) {
630 if (z != NULL && !z->fixed_dig) {
631 m_del(mpz_dig_t, z->dig, z->alloc);
632 }
633}
634
635#if 0
636these functions are unused
637
638mpz_t *mpz_zero(void) {
639 mpz_t *z = m_new_obj(mpz_t);
640 mpz_init_zero(z);
641 return z;
642}
643
644mpz_t *mpz_from_int(mp_int_t val) {
645 mpz_t *z = mpz_zero();
646 mpz_set_from_int(z, val);
647 return z;
648}
649
650mpz_t *mpz_from_ll(long long val, bool is_signed) {
651 mpz_t *z = mpz_zero();
652 mpz_set_from_ll(z, val, is_signed);
653 return z;
654}
655
656#if MICROPY_PY_BUILTINS_FLOAT
657mpz_t *mpz_from_float(mp_float_t val) {
658 mpz_t *z = mpz_zero();
659 mpz_set_from_float(z, val);
660 return z;
661}
662#endif
663
664mpz_t *mpz_from_str(const char *str, size_t len, bool neg, unsigned int base) {
665 mpz_t *z = mpz_zero();
666 mpz_set_from_str(z, str, len, neg, base);
667 return z;
668}
669#endif
670
671STATIC void mpz_free(mpz_t *z) {
672 if (z != NULL) {
673 m_del(mpz_dig_t, z->dig, z->alloc);
674 m_del_obj(mpz_t, z);
675 }
676}
677
678STATIC void mpz_need_dig(mpz_t *z, size_t need) {
679 if (need < MIN_ALLOC) {
680 need = MIN_ALLOC;
681 }
682
683 if (z->dig == NULL || z->alloc < need) {
684 // if z has fixed digit buffer there's not much we can do as the caller will
685 // be expecting a buffer with at least "need" bytes (but it shouldn't happen)
686 assert(!z->fixed_dig);
687 z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, need);
688 z->alloc = need;
689 }
690}
691
692STATIC mpz_t *mpz_clone(const mpz_t *src) {
693 assert(src->alloc != 0);
694 mpz_t *z = m_new_obj(mpz_t);
695 z->neg = src->neg;
696 z->fixed_dig = 0;
697 z->alloc = src->alloc;
698 z->len = src->len;
699 z->dig = m_new(mpz_dig_t, z->alloc);
700 memcpy(z->dig, src->dig, src->alloc * sizeof(mpz_dig_t));
701 return z;
702}
703
704/* sets dest = src
705 can have dest, src the same
706*/
707void mpz_set(mpz_t *dest, const mpz_t *src) {
708 mpz_need_dig(dest, src->len);
709 dest->neg = src->neg;
710 dest->len = src->len;
711 memcpy(dest->dig, src->dig, src->len * sizeof(mpz_dig_t));
712}
713
714void mpz_set_from_int(mpz_t *z, mp_int_t val) {
715 if (val == 0) {
716 z->len = 0;
717 return;
718 }
719
720 mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT);
721
722 mp_uint_t uval;
723 if (val < 0) {
724 z->neg = 1;
725 uval = -val;
726 } else {
727 z->neg = 0;
728 uval = val;
729 }
730
731 z->len = 0;
732 while (uval > 0) {
733 z->dig[z->len++] = uval & DIG_MASK;
734 uval >>= DIG_SIZE;
735 }
736}
737
738void mpz_set_from_ll(mpz_t *z, long long val, bool is_signed) {
739 mpz_need_dig(z, MPZ_NUM_DIG_FOR_LL);
740
741 unsigned long long uval;
742 if (is_signed && val < 0) {
743 z->neg = 1;
744 uval = -val;
745 } else {
746 z->neg = 0;
747 uval = val;
748 }
749
750 z->len = 0;
751 while (uval > 0) {
752 z->dig[z->len++] = uval & DIG_MASK;
753 uval >>= DIG_SIZE;
754 }
755}
756
757#if MICROPY_PY_BUILTINS_FLOAT
758void mpz_set_from_float(mpz_t *z, mp_float_t src) {
759 mp_float_union_t u = {src};
760 z->neg = u.p.sgn;
761 if (u.p.exp == 0) {
762 // value == 0 || value < 1
763 mpz_set_from_int(z, 0);
764 } else if (u.p.exp == ((1 << MP_FLOAT_EXP_BITS) - 1)) {
765 // u.p.frc == 0 indicates inf, else NaN
766 // should be handled by caller
767 mpz_set_from_int(z, 0);
768 } else {
769 const int adj_exp = (int)u.p.exp - MP_FLOAT_EXP_BIAS;
770 if (adj_exp < 0) {
771 // value < 1 , truncates to 0
772 mpz_set_from_int(z, 0);
773 } else if (adj_exp == 0) {
774 // 1 <= value < 2 , so truncates to 1
775 mpz_set_from_int(z, 1);
776 } else {
777 // 2 <= value
778 const int dig_cnt = (adj_exp + 1 + (DIG_SIZE - 1)) / DIG_SIZE;
779 const unsigned int rem = adj_exp % DIG_SIZE;
780 int dig_ind, shft;
781 mp_float_uint_t frc = u.p.frc | ((mp_float_uint_t)1 << MP_FLOAT_FRAC_BITS);
782
783 if (adj_exp < MP_FLOAT_FRAC_BITS) {
784 shft = 0;
785 dig_ind = 0;
786 frc >>= MP_FLOAT_FRAC_BITS - adj_exp;
787 } else {
788 shft = (rem - MP_FLOAT_FRAC_BITS) % DIG_SIZE;
789 dig_ind = (adj_exp - MP_FLOAT_FRAC_BITS) / DIG_SIZE;
790 }
791 mpz_need_dig(z, dig_cnt);
792 z->len = dig_cnt;
793 if (dig_ind != 0) {
794 memset(z->dig, 0, dig_ind * sizeof(mpz_dig_t));
795 }
796 if (shft != 0) {
797 z->dig[dig_ind++] = (frc << shft) & DIG_MASK;
798 frc >>= DIG_SIZE - shft;
799 }
800 #if DIG_SIZE < (MP_FLOAT_FRAC_BITS + 1)
801 while (dig_ind != dig_cnt) {
802 z->dig[dig_ind++] = frc & DIG_MASK;
803 frc >>= DIG_SIZE;
804 }
805 #else
806 if (dig_ind != dig_cnt) {
807 z->dig[dig_ind] = frc;
808 }
809 #endif
810 }
811 }
812}
813#endif
814
815// returns number of bytes from str that were processed
816size_t mpz_set_from_str(mpz_t *z, const char *str, size_t len, bool neg, unsigned int base) {
817 assert(base <= 36);
818
819 const char *cur = str;
820 const char *top = str + len;
821
822 mpz_need_dig(z, len * 8 / DIG_SIZE + 1);
823
824 if (neg) {
825 z->neg = 1;
826 } else {
827 z->neg = 0;
828 }
829
830 z->len = 0;
831 for (; cur < top; ++cur) { // XXX UTF8 next char
832 // mp_uint_t v = char_to_numeric(cur#); // XXX UTF8 get char
833 mp_uint_t v = *cur;
834 if ('0' <= v && v <= '9') {
835 v -= '0';
836 } else if ('A' <= v && v <= 'Z') {
837 v -= 'A' - 10;
838 } else if ('a' <= v && v <= 'z') {
839 v -= 'a' - 10;
840 } else {
841 break;
842 }
843 if (v >= base) {
844 break;
845 }
846 z->len = mpn_mul_dig_add_dig(z->dig, z->len, base, v);
847 }
848
849 return cur - str;
850}
851
852void mpz_set_from_bytes(mpz_t *z, bool big_endian, size_t len, const byte *buf) {
853 int delta = 1;
854 if (big_endian) {
855 buf += len - 1;
856 delta = -1;
857 }
858
859 mpz_need_dig(z, (len * 8 + DIG_SIZE - 1) / DIG_SIZE);
860
861 mpz_dig_t d = 0;
862 int num_bits = 0;
863 z->neg = 0;
864 z->len = 0;
865 while (len) {
866 while (len && num_bits < DIG_SIZE) {
867 d |= *buf << num_bits;
868 num_bits += 8;
869 buf += delta;
870 len--;
871 }
872 z->dig[z->len++] = d & DIG_MASK;
873 // Need this #if because it's C undefined behavior to do: uint32_t >> 32
874 #if DIG_SIZE != 8 && DIG_SIZE != 16 && DIG_SIZE != 32
875 d >>= DIG_SIZE;
876 #else
877 d = 0;
878 #endif
879 num_bits -= DIG_SIZE;
880 }
881
882 z->len = mpn_remove_trailing_zeros(z->dig, z->dig + z->len);
883}
884
885#if 0
886these functions are unused
887
888bool mpz_is_pos(const mpz_t *z) {
889 return z->len > 0 && z->neg == 0;
890}
891
892bool mpz_is_odd(const mpz_t *z) {
893 return z->len > 0 && (z->dig[0] & 1) != 0;
894}
895
896bool mpz_is_even(const mpz_t *z) {
897 return z->len == 0 || (z->dig[0] & 1) == 0;
898}
899#endif
900
901int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
902 // to catch comparison of -0 with +0
903 if (z1->len == 0 && z2->len == 0) {
904 return 0;
905 }
906 int cmp = (int)z2->neg - (int)z1->neg;
907 if (cmp != 0) {
908 return cmp;
909 }
910 cmp = mpn_cmp(z1->dig, z1->len, z2->dig, z2->len);
911 if (z1->neg != 0) {
912 cmp = -cmp;
913 }
914 return cmp;
915}
916
917#if 0
918// obsolete
919// compares mpz with an integer that fits within DIG_SIZE bits
920mp_int_t mpz_cmp_sml_int(const mpz_t *z, mp_int_t sml_int) {
921 mp_int_t cmp;
922 if (z->neg == 0) {
923 if (sml_int < 0) {
924 return 1;
925 }
926 if (sml_int == 0) {
927 if (z->len == 0) {
928 return 0;
929 }
930 return 1;
931 }
932 if (z->len == 0) {
933 return -1;
934 }
935 assert(sml_int < (1 << DIG_SIZE));
936 if (z->len != 1) {
937 return 1;
938 }
939 cmp = z->dig[0] - sml_int;
940 } else {
941 if (sml_int > 0) {
942 return -1;
943 }
944 if (sml_int == 0) {
945 if (z->len == 0) {
946 return 0;
947 }
948 return -1;
949 }
950 if (z->len == 0) {
951 return 1;
952 }
953 assert(sml_int > -(1 << DIG_SIZE));
954 if (z->len != 1) {
955 return -1;
956 }
957 cmp = -z->dig[0] - sml_int;
958 }
959 if (cmp < 0) {
960 return -1;
961 }
962 if (cmp > 0) {
963 return 1;
964 }
965 return 0;
966}
967#endif
968
969#if 0
970these functions are unused
971
972/* returns abs(z)
973*/
974mpz_t *mpz_abs(const mpz_t *z) {
975 // TODO: handle case of z->alloc=0
976 mpz_t *z2 = mpz_clone(z);
977 z2->neg = 0;
978 return z2;
979}
980
981/* returns -z
982*/
983mpz_t *mpz_neg(const mpz_t *z) {
984 // TODO: handle case of z->alloc=0
985 mpz_t *z2 = mpz_clone(z);
986 z2->neg = 1 - z2->neg;
987 return z2;
988}
989
990/* returns lhs + rhs
991 can have lhs, rhs the same
992*/
993mpz_t *mpz_add(const mpz_t *lhs, const mpz_t *rhs) {
994 mpz_t *z = mpz_zero();
995 mpz_add_inpl(z, lhs, rhs);
996 return z;
997}
998
999/* returns lhs - rhs
1000 can have lhs, rhs the same
1001*/
1002mpz_t *mpz_sub(const mpz_t *lhs, const mpz_t *rhs) {
1003 mpz_t *z = mpz_zero();
1004 mpz_sub_inpl(z, lhs, rhs);
1005 return z;
1006}
1007
1008/* returns lhs * rhs
1009 can have lhs, rhs the same
1010*/
1011mpz_t *mpz_mul(const mpz_t *lhs, const mpz_t *rhs) {
1012 mpz_t *z = mpz_zero();
1013 mpz_mul_inpl(z, lhs, rhs);
1014 return z;
1015}
1016
1017/* returns lhs ** rhs
1018 can have lhs, rhs the same
1019*/
1020mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs) {
1021 mpz_t *z = mpz_zero();
1022 mpz_pow_inpl(z, lhs, rhs);
1023 return z;
1024}
1025
1026/* computes new integers in quo and rem such that:
1027 quo * rhs + rem = lhs
1028 0 <= rem < rhs
1029 can have lhs, rhs the same
1030*/
1031void mpz_divmod(const mpz_t *lhs, const mpz_t *rhs, mpz_t **quo, mpz_t **rem) {
1032 *quo = mpz_zero();
1033 *rem = mpz_zero();
1034 mpz_divmod_inpl(*quo, *rem, lhs, rhs);
1035}
1036#endif
1037
1038/* computes dest = abs(z)
1039 can have dest, z the same
1040*/
1041void mpz_abs_inpl(mpz_t *dest, const mpz_t *z) {
1042 if (dest != z) {
1043 mpz_set(dest, z);
1044 }
1045 dest->neg = 0;
1046}
1047
1048/* computes dest = -z
1049 can have dest, z the same
1050*/
1051void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
1052 if (dest != z) {
1053 mpz_set(dest, z);
1054 }
1055 dest->neg = 1 - dest->neg;
1056}
1057
1058/* computes dest = ~z (= -z - 1)
1059 can have dest, z the same
1060*/
1061void mpz_not_inpl(mpz_t *dest, const mpz_t *z) {
1062 if (dest != z) {
1063 mpz_set(dest, z);
1064 }
1065 if (dest->len == 0) {
1066 mpz_need_dig(dest, 1);
1067 dest->dig[0] = 1;
1068 dest->len = 1;
1069 dest->neg = 1;
1070 } else if (dest->neg) {
1071 dest->neg = 0;
1072 mpz_dig_t k = 1;
1073 dest->len = mpn_sub(dest->dig, dest->dig, dest->len, &k, 1);
1074 } else {
1075 mpz_need_dig(dest, dest->len + 1);
1076 mpz_dig_t k = 1;
1077 dest->len = mpn_add(dest->dig, dest->dig, dest->len, &k, 1);
1078 dest->neg = 1;
1079 }
1080}
1081
1082/* computes dest = lhs << rhs
1083 can have dest, lhs the same
1084*/
1085void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, mp_uint_t rhs) {
1086 if (lhs->len == 0 || rhs == 0) {
1087 mpz_set(dest, lhs);
1088 } else {
1089 mpz_need_dig(dest, lhs->len + (rhs + DIG_SIZE - 1) / DIG_SIZE);
1090 dest->len = mpn_shl(dest->dig, lhs->dig, lhs->len, rhs);
1091 dest->neg = lhs->neg;
1092 }
1093}
1094
1095/* computes dest = lhs >> rhs
1096 can have dest, lhs the same
1097*/
1098void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, mp_uint_t rhs) {
1099 if (lhs->len == 0 || rhs == 0) {
1100 mpz_set(dest, lhs);
1101 } else {
1102 mpz_need_dig(dest, lhs->len);
1103 dest->len = mpn_shr(dest->dig, lhs->dig, lhs->len, rhs);
1104 dest->neg = lhs->neg;
1105 if (dest->neg) {
1106 // arithmetic shift right, rounding to negative infinity
1107 mp_uint_t n_whole = rhs / DIG_SIZE;
1108 mp_uint_t n_part = rhs % DIG_SIZE;
1109 mpz_dig_t round_up = 0;
1110 for (size_t i = 0; i < lhs->len && i < n_whole; i++) {
1111 if (lhs->dig[i] != 0) {
1112 round_up = 1;
1113 break;
1114 }
1115 }
1116 if (n_whole < lhs->len && (lhs->dig[n_whole] & ((1 << n_part) - 1)) != 0) {
1117 round_up = 1;
1118 }
1119 if (round_up) {
1120 if (dest->len == 0) {
1121 // dest == 0, so need to add 1 by hand (answer will be -1)
1122 dest->dig[0] = 1;
1123 dest->len = 1;
1124 } else {
1125 // dest > 0, so can use mpn_add to add 1
1126 dest->len = mpn_add(dest->dig, dest->dig, dest->len, &round_up, 1);
1127 }
1128 }
1129 }
1130 }
1131}
1132
1133/* computes dest = lhs + rhs
1134 can have dest, lhs, rhs the same
1135*/
1136void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1137 if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1138 const mpz_t *temp = lhs;
1139 lhs = rhs;
1140 rhs = temp;
1141 }
1142
1143 if (lhs->neg == rhs->neg) {
1144 mpz_need_dig(dest, lhs->len + 1);
1145 dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1146 } else {
1147 mpz_need_dig(dest, lhs->len);
1148 dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1149 }
1150
1151 dest->neg = lhs->neg;
1152}
1153
1154/* computes dest = lhs - rhs
1155 can have dest, lhs, rhs the same
1156*/
1157void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1158 bool neg = false;
1159
1160 if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1161 const mpz_t *temp = lhs;
1162 lhs = rhs;
1163 rhs = temp;
1164 neg = true;
1165 }
1166
1167 if (lhs->neg != rhs->neg) {
1168 mpz_need_dig(dest, lhs->len + 1);
1169 dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1170 } else {
1171 mpz_need_dig(dest, lhs->len);
1172 dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1173 }
1174
1175 if (neg) {
1176 dest->neg = 1 - lhs->neg;
1177 } else {
1178 dest->neg = lhs->neg;
1179 }
1180}
1181
1182/* computes dest = lhs & rhs
1183 can have dest, lhs, rhs the same
1184*/
1185void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1186 // make sure lhs has the most digits
1187 if (lhs->len < rhs->len) {
1188 const mpz_t *temp = lhs;
1189 lhs = rhs;
1190 rhs = temp;
1191 }
1192
1193 #if MICROPY_OPT_MPZ_BITWISE
1194
1195 if ((0 == lhs->neg) && (0 == rhs->neg)) {
1196 mpz_need_dig(dest, lhs->len);
1197 dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
1198 dest->neg = 0;
1199 } else {
1200 mpz_need_dig(dest, lhs->len + 1);
1201 dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1202 lhs->neg == rhs->neg, 0 != lhs->neg, 0 != rhs->neg);
1203 dest->neg = lhs->neg & rhs->neg;
1204 }
1205
1206 #else
1207
1208 mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1209 dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1210 (lhs->neg == rhs->neg) ? lhs->neg : 0, lhs->neg, rhs->neg);
1211 dest->neg = lhs->neg & rhs->neg;
1212
1213 #endif
1214}
1215
1216/* computes dest = lhs | rhs
1217 can have dest, lhs, rhs the same
1218*/
1219void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1220 // make sure lhs has the most digits
1221 if (lhs->len < rhs->len) {
1222 const mpz_t *temp = lhs;
1223 lhs = rhs;
1224 rhs = temp;
1225 }
1226
1227 #if MICROPY_OPT_MPZ_BITWISE
1228
1229 if ((0 == lhs->neg) && (0 == rhs->neg)) {
1230 mpz_need_dig(dest, lhs->len);
1231 dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1232 dest->neg = 0;
1233 } else {
1234 mpz_need_dig(dest, lhs->len + 1);
1235 dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1236 0 != lhs->neg, 0 != rhs->neg);
1237 dest->neg = 1;
1238 }
1239
1240 #else
1241
1242 mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1243 dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1244 (lhs->neg || rhs->neg), lhs->neg, rhs->neg);
1245 dest->neg = lhs->neg | rhs->neg;
1246
1247 #endif
1248}
1249
1250/* computes dest = lhs ^ rhs
1251 can have dest, lhs, rhs the same
1252*/
1253void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1254 // make sure lhs has the most digits
1255 if (lhs->len < rhs->len) {
1256 const mpz_t *temp = lhs;
1257 lhs = rhs;
1258 rhs = temp;
1259 }
1260
1261 #if MICROPY_OPT_MPZ_BITWISE
1262
1263 if (lhs->neg == rhs->neg) {
1264 mpz_need_dig(dest, lhs->len);
1265 if (lhs->neg == 0) {
1266 dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1267 } else {
1268 dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 0, 0, 0);
1269 }
1270 dest->neg = 0;
1271 } else {
1272 mpz_need_dig(dest, lhs->len + 1);
1273 dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 1,
1274 0 == lhs->neg, 0 == rhs->neg);
1275 dest->neg = 1;
1276 }
1277
1278 #else
1279
1280 mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1281 dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1282 (lhs->neg != rhs->neg), 0 == lhs->neg, 0 == rhs->neg);
1283 dest->neg = lhs->neg ^ rhs->neg;
1284
1285 #endif
1286}
1287
1288/* computes dest = lhs * rhs
1289 can have dest, lhs, rhs the same
1290*/
1291void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1292 if (lhs->len == 0 || rhs->len == 0) {
1293 mpz_set_from_int(dest, 0);
1294 return;
1295 }
1296
1297 mpz_t *temp = NULL;
1298 if (lhs == dest) {
1299 lhs = temp = mpz_clone(lhs);
1300 if (rhs == dest) {
1301 rhs = lhs;
1302 }
1303 } else if (rhs == dest) {
1304 rhs = temp = mpz_clone(rhs);
1305 }
1306
1307 mpz_need_dig(dest, lhs->len + rhs->len); // min mem l+r-1, max mem l+r
1308 memset(dest->dig, 0, dest->alloc * sizeof(mpz_dig_t));
1309 dest->len = mpn_mul(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1310
1311 if (lhs->neg == rhs->neg) {
1312 dest->neg = 0;
1313 } else {
1314 dest->neg = 1;
1315 }
1316
1317 mpz_free(temp);
1318}
1319
1320/* computes dest = lhs ** rhs
1321 can have dest, lhs, rhs the same
1322*/
1323void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1324 if (lhs->len == 0 || rhs->neg != 0) {
1325 mpz_set_from_int(dest, 0);
1326 return;
1327 }
1328
1329 if (rhs->len == 0) {
1330 mpz_set_from_int(dest, 1);
1331 return;
1332 }
1333
1334 mpz_t *x = mpz_clone(lhs);
1335 mpz_t *n = mpz_clone(rhs);
1336
1337 mpz_set_from_int(dest, 1);
1338
1339 while (n->len > 0) {
1340 if ((n->dig[0] & 1) != 0) {
1341 mpz_mul_inpl(dest, dest, x);
1342 }
1343 n->len = mpn_shr(n->dig, n->dig, n->len, 1);
1344 if (n->len == 0) {
1345 break;
1346 }
1347 mpz_mul_inpl(x, x, x);
1348 }
1349
1350 mpz_free(x);
1351 mpz_free(n);
1352}
1353
1354/* computes dest = (lhs ** rhs) % mod
1355 can have dest, lhs, rhs the same; mod can't be the same as dest
1356*/
1357void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t *mod) {
1358 if (lhs->len == 0 || rhs->neg != 0 || (mod->len == 1 && mod->dig[0] == 1)) {
1359 mpz_set_from_int(dest, 0);
1360 return;
1361 }
1362
1363 mpz_set_from_int(dest, 1);
1364
1365 if (rhs->len == 0) {
1366 return;
1367 }
1368
1369 mpz_t *x = mpz_clone(lhs);
1370 mpz_t *n = mpz_clone(rhs);
1371 mpz_t quo;
1372 mpz_init_zero(&quo);
1373
1374 while (n->len > 0) {
1375 if ((n->dig[0] & 1) != 0) {
1376 mpz_mul_inpl(dest, dest, x);
1377 mpz_divmod_inpl(&quo, dest, dest, mod);
1378 }
1379 n->len = mpn_shr(n->dig, n->dig, n->len, 1);
1380 if (n->len == 0) {
1381 break;
1382 }
1383 mpz_mul_inpl(x, x, x);
1384 mpz_divmod_inpl(&quo, x, x, mod);
1385 }
1386
1387 mpz_deinit(&quo);
1388 mpz_free(x);
1389 mpz_free(n);
1390}
1391
1392#if 0
1393these functions are unused
1394
1395/* computes gcd(z1, z2)
1396 based on Knuth's modified gcd algorithm (I think?)
1397 gcd(z1, z2) >= 0
1398 gcd(0, 0) = 0
1399 gcd(z, 0) = abs(z)
1400*/
1401mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2) {
1402 if (z1->len == 0) {
1403 // TODO: handle case of z2->alloc=0
1404 mpz_t *a = mpz_clone(z2);
1405 a->neg = 0;
1406 return a;
1407 } else if (z2->len == 0) {
1408 mpz_t *a = mpz_clone(z1);
1409 a->neg = 0;
1410 return a;
1411 }
1412
1413 mpz_t *a = mpz_clone(z1);
1414 mpz_t *b = mpz_clone(z2);
1415 mpz_t c;
1416 mpz_init_zero(&c);
1417 a->neg = 0;
1418 b->neg = 0;
1419
1420 for (;;) {
1421 if (mpz_cmp(a, b) < 0) {
1422 if (a->len == 0) {
1423 mpz_free(a);
1424 mpz_deinit(&c);
1425 return b;
1426 }
1427 mpz_t *t = a;
1428 a = b;
1429 b = t;
1430 }
1431 if (!(b->len >= 2 || (b->len == 1 && b->dig[0] > 1))) { // compute b > 0; could be mpz_cmp_small_int(b, 1) > 0
1432 break;
1433 }
1434 mpz_set(&c, b);
1435 do {
1436 mpz_add_inpl(&c, &c, &c);
1437 } while (mpz_cmp(&c, a) <= 0);
1438 c.len = mpn_shr(c.dig, c.dig, c.len, 1);
1439 mpz_sub_inpl(a, a, &c);
1440 }
1441
1442 mpz_deinit(&c);
1443
1444 if (b->len == 1 && b->dig[0] == 1) { // compute b == 1; could be mpz_cmp_small_int(b, 1) == 0
1445 mpz_free(a);
1446 return b;
1447 } else {
1448 mpz_free(b);
1449 return a;
1450 }
1451}
1452
1453/* computes lcm(z1, z2)
1454 = abs(z1) / gcd(z1, z2) * abs(z2)
1455 lcm(z1, z1) >= 0
1456 lcm(0, 0) = 0
1457 lcm(z, 0) = 0
1458*/
1459mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2) {
1460 if (z1->len == 0 || z2->len == 0) {
1461 return mpz_zero();
1462 }
1463
1464 mpz_t *gcd = mpz_gcd(z1, z2);
1465 mpz_t *quo = mpz_zero();
1466 mpz_t *rem = mpz_zero();
1467 mpz_divmod_inpl(quo, rem, z1, gcd);
1468 mpz_mul_inpl(rem, quo, z2);
1469 mpz_free(gcd);
1470 mpz_free(quo);
1471 rem->neg = 0;
1472 return rem;
1473}
1474#endif
1475
1476/* computes new integers in quo and rem such that:
1477 quo * rhs + rem = lhs
1478 0 <= rem < rhs
1479 can have lhs, rhs the same
1480 assumes rhs != 0 (undefined behaviour if it is)
1481*/
1482void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const mpz_t *rhs) {
1483 assert(!mpz_is_zero(rhs));
1484
1485 mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
1486 memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
1487 dest_quo->len = 0;
1488 mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
1489 mpz_set(dest_rem, lhs);
1490 mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
1491
1492 // check signs and do Python style modulo
1493 if (lhs->neg != rhs->neg) {
1494 dest_quo->neg = 1;
1495 if (!mpz_is_zero(dest_rem)) {
1496 mpz_t mpzone;
1497 mpz_init_from_int(&mpzone, -1);
1498 mpz_add_inpl(dest_quo, dest_quo, &mpzone);
1499 mpz_add_inpl(dest_rem, dest_rem, rhs);
1500 }
1501 }
1502}
1503
1504#if 0
1505these functions are unused
1506
1507/* computes floor(lhs / rhs)
1508 can have lhs, rhs the same
1509*/
1510mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs) {
1511 mpz_t *quo = mpz_zero();
1512 mpz_t rem;
1513 mpz_init_zero(&rem);
1514 mpz_divmod_inpl(quo, &rem, lhs, rhs);
1515 mpz_deinit(&rem);
1516 return quo;
1517}
1518
1519/* computes lhs % rhs ( >= 0)
1520 can have lhs, rhs the same
1521*/
1522mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs) {
1523 mpz_t quo;
1524 mpz_init_zero(&quo);
1525 mpz_t *rem = mpz_zero();
1526 mpz_divmod_inpl(&quo, rem, lhs, rhs);
1527 mpz_deinit(&quo);
1528 return rem;
1529}
1530#endif
1531
1532// must return actual int value if it fits in mp_int_t
1533mp_int_t mpz_hash(const mpz_t *z) {
1534 mp_uint_t val = 0;
1535 mpz_dig_t *d = z->dig + z->len;
1536
1537 while (d-- > z->dig) {
1538 val = (val << DIG_SIZE) | *d;
1539 }
1540
1541 if (z->neg != 0) {
1542 val = -val;
1543 }
1544
1545 return val;
1546}
1547
1548bool mpz_as_int_checked(const mpz_t *i, mp_int_t *value) {
1549 mp_uint_t val = 0;
1550 mpz_dig_t *d = i->dig + i->len;
1551
1552 while (d-- > i->dig) {
1553 if (val > (~(MP_OBJ_WORD_MSBIT_HIGH) >> DIG_SIZE)) {
1554 // will overflow
1555 return false;
1556 }
1557 val = (val << DIG_SIZE) | *d;
1558 }
1559
1560 if (i->neg != 0) {
1561 val = -val;
1562 }
1563
1564 *value = val;
1565 return true;
1566}
1567
1568bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
1569 if (i->neg != 0) {
1570 // can't represent signed values
1571 return false;
1572 }
1573
1574 mp_uint_t val = 0;
1575 mpz_dig_t *d = i->dig + i->len;
1576
1577 while (d-- > i->dig) {
1578 if (val > (~(MP_OBJ_WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {
1579 // will overflow
1580 return false;
1581 }
1582 val = (val << DIG_SIZE) | *d;
1583 }
1584
1585 *value = val;
1586 return true;
1587}
1588
1589void mpz_as_bytes(const mpz_t *z, bool big_endian, size_t len, byte *buf) {
1590 byte *b = buf;
1591 if (big_endian) {
1592 b += len;
1593 }
1594 mpz_dig_t *zdig = z->dig;
1595 int bits = 0;
1596 mpz_dbl_dig_t d = 0;
1597 mpz_dbl_dig_t carry = 1;
1598 for (size_t zlen = z->len; zlen > 0; --zlen) {
1599 bits += DIG_SIZE;
1600 d = (d << DIG_SIZE) | *zdig++;
1601 for (; bits >= 8; bits -= 8, d >>= 8) {
1602 mpz_dig_t val = d;
1603 if (z->neg) {
1604 val = (~val & 0xff) + carry;
1605 carry = val >> 8;
1606 }
1607 if (big_endian) {
1608 *--b = val;
1609 if (b == buf) {
1610 return;
1611 }
1612 } else {
1613 *b++ = val;
1614 if (b == buf + len) {
1615 return;
1616 }
1617 }
1618 }
1619 }
1620
1621 // fill remainder of buf with zero/sign extension of the integer
1622 if (big_endian) {
1623 len = b - buf;
1624 } else {
1625 len = buf + len - b;
1626 buf = b;
1627 }
1628 memset(buf, z->neg ? 0xff : 0x00, len);
1629}
1630
1631#if MICROPY_PY_BUILTINS_FLOAT
1632mp_float_t mpz_as_float(const mpz_t *i) {
1633 mp_float_t val = 0;
1634 mpz_dig_t *d = i->dig + i->len;
1635
1636 while (d-- > i->dig) {
1637 val = val * DIG_BASE + *d;
1638 }
1639
1640 if (i->neg != 0) {
1641 val = -val;
1642 }
1643
1644 return val;
1645}
1646#endif
1647
1648#if 0
1649this function is unused
1650char *mpz_as_str(const mpz_t *i, unsigned int base) {
1651 char *s = m_new(char, mp_int_format_size(mpz_max_num_bits(i), base, NULL, '\0'));
1652 mpz_as_str_inpl(i, base, NULL, 'a', '\0', s);
1653 return s;
1654}
1655#endif
1656
1657// assumes enough space in str as calculated by mp_int_format_size
1658// base must be between 2 and 32 inclusive
1659// returns length of string, not including null byte
1660size_t mpz_as_str_inpl(const mpz_t *i, unsigned int base, const char *prefix, char base_char, char comma, char *str) {
1661 assert(str != NULL);
1662 assert(2 <= base && base <= 32);
1663
1664 size_t ilen = i->len;
1665
1666 char *s = str;
1667 if (ilen == 0) {
1668 if (prefix) {
1669 while (*prefix) {
1670 *s++ = *prefix++;
1671 }
1672 }
1673 *s++ = '0';
1674 *s = '\0';
1675 return s - str;
1676 }
1677
1678 // make a copy of mpz digits, so we can do the div/mod calculation
1679 mpz_dig_t *dig = m_new(mpz_dig_t, ilen);
1680 memcpy(dig, i->dig, ilen * sizeof(mpz_dig_t));
1681
1682 // convert
1683 char *last_comma = str;
1684 bool done;
1685 do {
1686 mpz_dig_t *d = dig + ilen;
1687 mpz_dbl_dig_t a = 0;
1688
1689 // compute next remainder
1690 while (--d >= dig) {
1691 a = (a << DIG_SIZE) | *d;
1692 *d = a / base;
1693 a %= base;
1694 }
1695
1696 // convert to character
1697 a += '0';
1698 if (a > '9') {
1699 a += base_char - '9' - 1;
1700 }
1701 *s++ = a;
1702
1703 // check if number is zero
1704 done = true;
1705 for (d = dig; d < dig + ilen; ++d) {
1706 if (*d != 0) {
1707 done = false;
1708 break;
1709 }
1710 }
1711 if (comma && (s - last_comma) == 3) {
1712 *s++ = comma;
1713 last_comma = s;
1714 }
1715 }
1716 while (!done);
1717
1718 // free the copy of the digits array
1719 m_del(mpz_dig_t, dig, ilen);
1720
1721 if (prefix) {
1722 const char *p = &prefix[strlen(prefix)];
1723 while (p > prefix) {
1724 *s++ = *--p;
1725 }
1726 }
1727 if (i->neg != 0) {
1728 *s++ = '-';
1729 }
1730
1731 // reverse string
1732 for (char *u = str, *v = s - 1; u < v; ++u, --v) {
1733 char temp = *u;
1734 *u = *v;
1735 *v = temp;
1736 }
1737
1738 *s = '\0'; // null termination
1739
1740 return s - str;
1741}
1742
1743#endif // MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
1744