bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

bigmul.c (9397B)


      1 #include <rcx/all.h>
      2 #include <rcx/bench.h>
      3 #include <stdio.h>
      4 #include <unistd.h>
      5 
      6 // TODO: Add (slow) fallback for __builtin_{add,sub}cl when not using clang or
      7 // GCC 14. Maybe that should go in rcx.
      8 
      9 // This power of 2 results in the lowest run-time (on the hardware on which the
     10 // benchmarks were run). This must be at least 4 (see the comments in
     11 // karatsuba).
     12 #define KARATSUBA_THRESH 32
     13 
     14 
     15 /* ----- Wide u64 math ----- */
     16 
     17 inline void
     18 mul64(u64 *rh, u64 *rl, u64 x, u64 y) {
     19 	u128 r = (u128)x * (u128)y;
     20 	*rh = r >> 64;
     21 	*rl = r;
     22 }
     23 
     24 inline void
     25 fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) {
     26 	u64 h0, h1, h2, l;
     27 	mul64(&h0, &l, w, x);               // h0:l = w * x
     28 	l = __builtin_addcl(l, y, 0, &h1);  // h1:l = l + y
     29 	l = __builtin_addcl(l, z, 0, &h2);  // h2:l = l + z
     30 	*rh = h0 + h1 + h2;
     31 	*rl = l;
     32 }
     33 
     34 
     35 /* ----- Big nat math ----- */
     36 
     37 // Precondition: m >= n
     38 u64
     39 add(u64 *r, u64 *x, usize m, u64 *y, usize n) {
     40 	u64 c = 0;
     41 	usize i = 0;
     42 
     43 	for (; i + 3 < n; i += 4) {
     44 		r[i + 0] = __builtin_addcl(x[i + 0], y[i + 0], c, &c);
     45 		r[i + 1] = __builtin_addcl(x[i + 1], y[i + 1], c, &c);
     46 		r[i + 2] = __builtin_addcl(x[i + 2], y[i + 2], c, &c);
     47 		r[i + 3] = __builtin_addcl(x[i + 3], y[i + 3], c, &c);
     48 	}
     49 	for (; i < n; i++)
     50 		r[i] = __builtin_addcl(x[i], y[i], c, &c);
     51 
     52 	for (; i + 3 < m; i += 4) {
     53 		r[i + 0] = __builtin_addcl(x[i + 0], 0, c, &c);
     54 		r[i + 1] = __builtin_addcl(x[i + 1], 0, c, &c);
     55 		r[i + 2] = __builtin_addcl(x[i + 2], 0, c, &c);
     56 		r[i + 3] = __builtin_addcl(x[i + 3], 0, c, &c);
     57 	}
     58 	for (; i < m; i++)
     59 		r[i] = __builtin_addcl(x[i], 0, c, &c);
     60 
     61 	return c;
     62 }
     63 
     64 // Precondition: m >= n
     65 // TODO: sub is not commutative like add, so we need a "bus" operation
     66 // ("sub" backwards) for when m < n.
     67 u64
     68 sub(u64 *r, u64 *x, usize m, u64 *y, usize n) {
     69 	u64 c = 0;
     70 	usize i = 0;
     71 
     72 	for (; i + 3 < n; i += 4) {
     73 		r[i + 0] = __builtin_subcl(x[i + 0], y[i + 0], c, &c);
     74 		r[i + 1] = __builtin_subcl(x[i + 1], y[i + 1], c, &c);
     75 		r[i + 2] = __builtin_subcl(x[i + 2], y[i + 2], c, &c);
     76 		r[i + 3] = __builtin_subcl(x[i + 3], y[i + 3], c, &c);
     77 	}
     78 	for (; i < n; i++)
     79 		r[i] = __builtin_subcl(x[i], y[i], c, &c);
     80 
     81 	for (; i + 3 < m; i += 4) {
     82 		r[i + 0] = __builtin_subcl(x[i + 0], 0, c, &c);
     83 		r[i + 1] = __builtin_subcl(x[i + 1], 0, c, &c);
     84 		r[i + 2] = __builtin_subcl(x[i + 2], 0, c, &c);
     85 		r[i + 3] = __builtin_subcl(x[i + 3], 0, c, &c);
     86 	}
     87 	for (; i < m; i++)
     88 		r[i] = __builtin_subcl(x[i], 0, c, &c);
     89 
     90 	return c;
     91 }
     92 
     93 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n
     94 // Precondition: r is disjoint with x and y
     95 void
     96 fma_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
     97 	for (usize j = 0; j < n; j++) {
     98 		u64 c = 0;
     99 		for (usize i = 0; i < m; i++)
    100 			fmaa64(&c, &r[i + j], x[i], y[j], r[i + j], c);
    101 		r[m + j] = c;
    102 	}
    103 }
    104 
    105 // Precondition: capacity(r) >= 2 * n, capacity(x) = capactiy(y) = n
    106 // Precondition: r is disjoint with x and y
    107 // Precondition: capacity(tt) >= 2 * kk, where kk := 2 * (n - n / 2 + 1)
    108 void
    109 karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) {
    110 	/* We seek to multiply x and y, which each have n "words" (digits of base
    111 	 * b := 2^64), obtaining the full 2*n word product. For this, we let
    112 	 *     l := floor(n / 2)   and   h := ceil(n / 2)
    113 	 * and partition x and y into low parts xl and yl with l words and high
    114 	 * parts xh and yh with h words, such that
    115 	 *     x = xh * b^l + xl   and   y = yh * b^l + yl.
    116 	 * Then
    117 	 *     x * y = s * b^(2*l) + t * b^l + u,
    118 	 * where
    119 	 *     s := xh * yh,   t := xh * yl + xl * yh,   and   u := xl * yl.
    120 	 * Thus, we could multiply x and y by recursively evaluating the four
    121 	 * products in the definition of s, t, and u (the products involving b^i
    122 	 * are just bit shifts). However, this would result in a time complexity
    123 	 * of O(n^2), the same asymptotic performance as the naive "doubly-nested
    124 	 * for-loop" algorithm. Instead, we exploit the following identity, whose
    125 	 * significance for multiplication algorithms was first noticed by Anatoly
    126 	 * Karatsuba in 1960:
    127 	 *     t = p * q - u - s
    128 	 * where
    129 	 *     p := xl + xh   and   q := yl + yh.
    130 	 * Computing t in its latter form saves one multiplication at the expense
    131 	 * of a few additions/subtractions (which have O(n) time complexity),
    132 	 * thereby reducing the time complexity to O(n^(lg 3)).
    133 	 *
    134 	 * Let |z| denote the number of words in a number z. For well-founded
    135 	 * recusion, we need n to strictly decrease in each of the three recursive
    136 	 * calls. When we compute u = xl * yl and s = xh * yh, this is true as long
    137 	 * as n >= 2, for then |xl|,|yl| <= l < n and |yh|,|yh| <= h < n. For the
    138 	 * computation of p * q in t, on the other hand, we have
    139 	 *     |p|,|q|  = |xl + xh|,|yl + yh|   (definition of p and q)
    140 	 *             <= max(l, h) + 1         (addition adds at most 1 word)
    141 	 *              = h + 1                 (h >= l)
    142 	 *              = ceil(n / 2) + 1       (definition of h)
    143 	 * and as long as n >= 4, this quantity is strictly less than n. It
    144 	 * therefore suffices to separate the case n < 4 for the recursion
    145 	 * basis. */
    146 
    147 	// 1. Basis
    148 	if (n < KARATSUBA_THRESH) {
    149 		fma_quadratic(r, x, n, y, n);
    150 		return;
    151 	}
    152 
    153 	// 2. Compute l, h, and k, and their doubles ll, hh, and kk
    154 	// The significance of these quantities is as follows:
    155 	//   - l is the max width of xl and yl
    156 	//   - h is the max width of xh and yh
    157 	//   - k is the max width of p and q
    158 	//   - ll is the max width of u
    159 	//   - hh is the max width of s
    160 	//   - kk is the max width of t
    161 	usize l = n / 2, ll = 2 * l;
    162 	usize h = n - l, hh = 2 * h;
    163 	usize k = h + 1, kk = 2 * k;
    164 
    165 	// 3. Split x and y
    166 	u64 *xh = x + l, *xl = x;
    167 	u64 *yh = y + l, *yl = y;
    168 
    169 	// 4. Assign blocks of memory for intermediate results
    170 	// Note that we use the output buffer r as temporary storage for p and q.
    171 	// We also store u and s directly in r at the appropriate offsets (free bit
    172 	// shifts!), such that p and q overlap with u and s, but that's ok, because
    173 	// we're done with p and q by the time we calculate u and s.
    174 	u64 *p = r;
    175 	u64 *q = r + k;
    176 	u64 *t0 = tt;       // Storage for t in this invocation
    177 	u64 *t1 = tt + kk;  // Storage for t in future (recursive) invocations
    178 	u64 *u = r;
    179 	u64 *s = r + ll;
    180 
    181 	// 5. Arithmetic
    182 	p[k] = add(p, xh, h, xl, l);        // p = xh + xl
    183 	q[k] = add(q, yh, h, yl, l);        // q = yh + yl
    184 	karatsuba(t0, p, q, k, t1);         // t = p * q
    185 	karatsuba(u, xl, yl, l, t1);        // u = xl * yl
    186 	karatsuba(s, xh, yh, h, t1);        // s = xh * yh
    187 	sub(t0, t0, kk, u, ll);             // t -= u  [1]
    188 	sub(t0, t0, kk, s, hh);             // t -= s  [1]
    189 	add(r + l, r + l, hh + l, t0, kk);  // r[l..] += t  [2] [3]
    190 	// [1]: The borrow outs are guaranteed to be 0, because t0 - u - s must
    191 	//      be positive.
    192 	// [2]: The carry out is guaranteed to be 0, because the full product
    193 	//      x * y must fit in 2 * n words.
    194 	// [3]: The add precondition hh + l >= kk is satisfied here as long as
    195 	//      n >= 4, and it is, because n < 4 is the recursion basis.
    196 }
    197 
    198 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n
    199 // Precondition: r is disjoint with x and y
    200 // Precondition: m >= n
    201 void
    202 fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
    203 	// TODO: Accept capacity of r as an argument, and use excess memory for
    204 	// scratch, if it's big enough, instead of allocating.
    205 	usize firstkk = 2 * (n - n / 2 + 1);
    206 	u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch);
    207 	u64 *prod = scratch;
    208 	u64 *tt = scratch + 2 * n;
    209 
    210 	for (;;) {
    211 		for (; m >= n; r += n, x += n, m -= n) {
    212 			karatsuba(prod, x, y, n, tt);
    213 			add(r, prod, 2 * n, r, n);
    214 		}
    215 
    216 		if (m == 0) break;
    217 
    218 		if (m < KARATSUBA_THRESH) {
    219 			fma_quadratic(r, x, m, y, n);
    220 			break;
    221 		}
    222 
    223 		u64 *t0 = x; x = y; y = t0;   // Swap x and y
    224 		usize t1 = m; m = n; n = t1;  // Swap m and n
    225 	}
    226 
    227 	free(scratch);
    228 }
    229 
    230 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n
    231 // Precondition: r is disjoint with x and y
    232 void
    233 mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
    234 	if (m < n) {
    235 		u64 *t0 = x; x = y; y = t0;   // Swap x and y
    236 		usize t1 = m; m = n; n = t1;  // Swap m and n
    237 	}
    238 
    239 	memset(r, 0, (m + n) * sizeof *r);
    240 	(n < KARATSUBA_THRESH ? fma_quadratic : fma_karatsuba)(r, x, m, y, n);
    241 }
    242 
    243 
    244 /* ----- Benchmarks ----- */
    245 
    246 u64 x[4096];
    247 u64 y[4096];
    248 u64 r[8192];
    249 
    250 NOINLINE void
    251 bench_karatsuba(u64 l, u64 n) {
    252 	r_bench_start();
    253 	for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l);
    254 	r_bench_stop();
    255 }
    256 
    257 void bench_karatsuba16(u64 n)   { bench_karatsuba(16, n); }
    258 void bench_karatsuba32(u64 n)   { bench_karatsuba(32, n); }
    259 void bench_karatsuba64(u64 n)   { bench_karatsuba(64, n); }
    260 void bench_karatsuba128(u64 n)  { bench_karatsuba(128, n); }
    261 void bench_karatsuba256(u64 n)  { bench_karatsuba(256, n); }
    262 void bench_karatsuba512(u64 n)  { bench_karatsuba(512, n); }
    263 void bench_karatsuba1024(u64 n) { bench_karatsuba(1024, n); }
    264 void bench_karatsuba2048(u64 n) { bench_karatsuba(2048, n); }
    265 void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); }
    266 
    267 int
    268 main(void) {
    269 	for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64();
    270 	for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();
    271 
    272 	r_bench(bench_karatsuba16,   1000);
    273 	r_bench(bench_karatsuba32,   1000);
    274 	r_bench(bench_karatsuba64,   1000);
    275 	r_bench(bench_karatsuba128,  1000);
    276 	r_bench(bench_karatsuba256,  1000);
    277 	r_bench(bench_karatsuba512,  1000);
    278 	r_bench(bench_karatsuba1024, 1000);
    279 	r_bench(bench_karatsuba2048, 1000);
    280 	r_bench(bench_karatsuba4096, 1000);
    281 }