bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

commit 0714d05592f0527d3572fb31b1b345e89e1facfc
parent a58e698cfc79a9511cf576aee2288e07b6c54e7c
Author: Robert Russell <robert@rr3.xyz>
Date:   Wed,  1 Jan 2025 19:31:37 -0800

Only karatsuba numbers of same width

This results in a performance boost, and simplifies the karatsuba
function and analysis.

This is still WIP.

Diffstat:
Mbigmul.c | 133++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
1 file changed, 87 insertions(+), 46 deletions(-)

diff --git a/bigmul.c b/bigmul.c @@ -3,6 +3,8 @@ #include <stdio.h> #include <unistd.h> +#define KARATSUBA_THRESH 32 // Best power of 2 determined via benchmarking + struct nat { usize cap; usize len; @@ -55,27 +57,32 @@ fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) { } // Precondition: m >= n -void -add(u64 *r, u64 *x, usize m, u64 *y, usize n) { - u64 c = 0; +u64 +add(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 c) { for (usize i = 0; i < n; i++) add64(&c, &r[i], x[i], y[i], c); for (usize i = n; i < m; i++) add64(&c, &r[i], x[i], c, 0); - r[m] = c; + return c; +} + +u64 +addw(u64 *r, u64 *x, usize m, u64 y) { + for (usize i = 0; i < m; i++) + add64(&y, &r[i], x[i], y, 0); + return y; } // Precondition: m >= n // TODO: sub is not commutative like add, so we need a "bus" operation // ("sub" backwards) for when m < n. -void -sub(u64 *r, u64 *x, usize m, u64 *y, usize n) { - u64 b = 0; +u64 +sub(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 b) { for (usize i = 0; i < n; i++) sub64(&b, &r[i], x[i], y[i], b); for (usize i = n; i < m; i++) sub64(&b, &r[i], x[i], b, 0); - r[m] = -b; // TODO: I don't think this makes sense for nats. + return b; } // Precondition: r does not intersect x nor y @@ -93,8 +100,9 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) { // Precondition: r does not intersect x nor y // TODO: Document precondition regarding size of scratch memory. void -karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { - /* We seek to multiply x and y, which have m and n "words" (digits of +karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) { + /* TODO: Update + * We seek to multiply x and y, which have m and n "words" (digits of * base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2) * and split x and y as * x = xh * b^k + xl and y = yh * b^k + yl. @@ -135,54 +143,85 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { * n >= 4 (instead of m >= 4), then |q| < n and |p| <= m. It therefore * suffices to separate the case m < 4 || n < 4 for the recursion basis. */ - usize maxmn = MAX(m, n); - if (m < 4 || n < 4 || maxmn < 32) { // 32 was determined via benchmarks. - mul_quadratic(r, x, m, y, n); + if (n < KARATSUBA_THRESH) { // TODO: Calibrate, and ensure necessary bounds + mul_quadratic(r, x, n, y, n); return; } - usize k = (maxmn + 1) / 2; - - // 1. Split x - usize mh = m > k ? m - k : 0, ml = MIN(k, m); - u64 *xh = x + ml, *xl = x; + // 1. Compute l, h, k, and their doubles + usize l = n / 2, ll = l * 2; + usize h = n - l, hh = h * 2; + usize k = h + 1, kk = k * 2; - // 2. Split y - usize nh = n > k ? n - k : 0, nl = MIN(k, n); - u64 *yh = y + nl, *yl = y; + // 2. Split x and y + u64 *xh = x + l, *xl = x; + u64 *yh = y + l, *yl = y; // 3. Assign blocks of memory for intermediate results. // Note that we use the output buffer r as temporary storage for p and q. // We also store u and s directly in r at the appropriate offsets, such // that p and q overlap with u and s, but that's ok, because we're done // with p and q by the time we calculate u and s. - usize pw = MIN(ml + 1, m); u64 *p = r; - usize qw = MIN(nl + 1, n); u64 *q = r + pw; - usize tw = pw + qw; u64 *t = scratch; - usize uw = ml + nl; u64 *u = r; - usize sw = mh + nh; u64 *s = r + 2 * k; + // TODO: Justify all these lengths + u64 *p = r; + u64 *q = r + k; + u64 *t = scratch0; + u64 *u = r; + u64 *s = r + ll; + u64 *scratch1 = scratch0 + kk; // 4. Arithmetic - add(p, xl, ml, xh, mh); // p = xl + xh - add(q, yl, nl, yh, nh); // q = yl + yh - karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q - karatsuba(u, xl, ml, yl, nl, scratch + tw); // u = xl * yl - for (usize i = uw; i < 2 * k; i++) r[i] = 0; // r[uw..2*k] = 0 - karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh - sub(t, t, tw, u, uw); // t -= u - sub(t, t, tw, s, sw); // t -= s - add(r + k, t, tw, r + k, k + sw); // r[k..] += t - for (usize i = tw + k + sw; i < m + n; i++) r[i] = 0; // TODO - // TODO: Prove that tw + sw <= m + n - k. + p[k] = add(p, xh, h, xl, l, 0); // p = xh + xl + q[k] = add(q, yh, h, yl, l, 0); // q = yh + yl + karatsuba(t, p, q, k, scratch1); // t = p * q + karatsuba(u, xl, yl, l, scratch1); // u = xl * yl + karatsuba(s, xh, yh, h, scratch1); // s = xh * yh + sub(t, t, kk, u, ll, 0); // t -= u (borrow out must be 0) TODO: explain + sub(t, t, kk, s, hh, 0); // t -= s (borrow out must be 0) TODO: explain + add(r + l, r + l, hh + l, t, kk, 0); // r[l..] += t (carry out must be 0) TODO: explain } void mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { + if (m < n) { + u64 *t0 = x; x = y; y = t0; // Swap x and y + usize t1 = m; m = n; n = t1; // Swap m and n + } + + if (n < KARATSUBA_THRESH) { // TODO: Calibrate. + mul_quadratic(r, x, m, y, n); + return; + } + + // TODO: Special-case m == n? + // TODO: Accept capacity of r as an argument, and use excess memory for // scratch, if it's big enough, instead of allocating. - u64 *scratch = r_eallocn((m + n) * 2, sizeof *scratch); - karatsuba(r, x, m, y, n, scratch); - free(scratch); + usize firstkk = 2 * (n - n / 2 + 1); + u64 *mem = r_eallocn(2 * n + 2 * firstkk, sizeof *mem); + u64 *prod = mem; + u64 *scratch = mem + 2 * n; + + // TODO: The control flow here kinda sucks. + // TODO: There are unnecessary copies between prod and r. Try to do it in + // "one pass", without first initializing r to 0. + memset(r, 0, (m + n) * sizeof *r); + for (;;) { + for (; m >= n; r += n, x += n, m -= n) { + karatsuba(prod, x, y, n, scratch); + add(r, prod, 2 * n, r, n, 0); + } + if (m == 0) break; + if (m < KARATSUBA_THRESH) { // TODO: Calibrate. + mul_quadratic(prod, x, m, y, n); + add(r, prod, m + n, r, n, 0); + break; + } + u64 *t0 = x; x = y; y = t0; + usize t1 = m; m = n; n = t1; + } + + free(mem); } u64 x[4096]; @@ -225,12 +264,14 @@ void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); } int main(void) { - // u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef }; - // u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc }; - // u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y)); - // u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y)); - // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]); - // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]); +/* + u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef }; + u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc }; + u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y)); + u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y)); + printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]); + printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]); +*/ for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64(); for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();