bigmul.c (9397B)
1 #include <rcx/all.h> 2 #include <rcx/bench.h> 3 #include <stdio.h> 4 #include <unistd.h> 5 6 // TODO: Add (slow) fallback for __builtin_{add,sub}cl when not using clang or 7 // GCC 14. Maybe that should go in rcx. 8 9 // This power of 2 results in the lowest run-time (on the hardware on which the 10 // benchmarks were run). This must be at least 4 (see the comments in 11 // karatsuba). 12 #define KARATSUBA_THRESH 32 13 14 15 /* ----- Wide u64 math ----- */ 16 17 inline void 18 mul64(u64 *rh, u64 *rl, u64 x, u64 y) { 19 u128 r = (u128)x * (u128)y; 20 *rh = r >> 64; 21 *rl = r; 22 } 23 24 inline void 25 fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) { 26 u64 h0, h1, h2, l; 27 mul64(&h0, &l, w, x); // h0:l = w * x 28 l = __builtin_addcl(l, y, 0, &h1); // h1:l = l + y 29 l = __builtin_addcl(l, z, 0, &h2); // h2:l = l + z 30 *rh = h0 + h1 + h2; 31 *rl = l; 32 } 33 34 35 /* ----- Big nat math ----- */ 36 37 // Precondition: m >= n 38 u64 39 add(u64 *r, u64 *x, usize m, u64 *y, usize n) { 40 u64 c = 0; 41 usize i = 0; 42 43 for (; i + 3 < n; i += 4) { 44 r[i + 0] = __builtin_addcl(x[i + 0], y[i + 0], c, &c); 45 r[i + 1] = __builtin_addcl(x[i + 1], y[i + 1], c, &c); 46 r[i + 2] = __builtin_addcl(x[i + 2], y[i + 2], c, &c); 47 r[i + 3] = __builtin_addcl(x[i + 3], y[i + 3], c, &c); 48 } 49 for (; i < n; i++) 50 r[i] = __builtin_addcl(x[i], y[i], c, &c); 51 52 for (; i + 3 < m; i += 4) { 53 r[i + 0] = __builtin_addcl(x[i + 0], 0, c, &c); 54 r[i + 1] = __builtin_addcl(x[i + 1], 0, c, &c); 55 r[i + 2] = __builtin_addcl(x[i + 2], 0, c, &c); 56 r[i + 3] = __builtin_addcl(x[i + 3], 0, c, &c); 57 } 58 for (; i < m; i++) 59 r[i] = __builtin_addcl(x[i], 0, c, &c); 60 61 return c; 62 } 63 64 // Precondition: m >= n 65 // TODO: sub is not commutative like add, so we need a "bus" operation 66 // ("sub" backwards) for when m < n. 67 u64 68 sub(u64 *r, u64 *x, usize m, u64 *y, usize n) { 69 u64 c = 0; 70 usize i = 0; 71 72 for (; i + 3 < n; i += 4) { 73 r[i + 0] = __builtin_subcl(x[i + 0], y[i + 0], c, &c); 74 r[i + 1] = __builtin_subcl(x[i + 1], y[i + 1], c, &c); 75 r[i + 2] = __builtin_subcl(x[i + 2], y[i + 2], c, &c); 76 r[i + 3] = __builtin_subcl(x[i + 3], y[i + 3], c, &c); 77 } 78 for (; i < n; i++) 79 r[i] = __builtin_subcl(x[i], y[i], c, &c); 80 81 for (; i + 3 < m; i += 4) { 82 r[i + 0] = __builtin_subcl(x[i + 0], 0, c, &c); 83 r[i + 1] = __builtin_subcl(x[i + 1], 0, c, &c); 84 r[i + 2] = __builtin_subcl(x[i + 2], 0, c, &c); 85 r[i + 3] = __builtin_subcl(x[i + 3], 0, c, &c); 86 } 87 for (; i < m; i++) 88 r[i] = __builtin_subcl(x[i], 0, c, &c); 89 90 return c; 91 } 92 93 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n 94 // Precondition: r is disjoint with x and y 95 void 96 fma_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) { 97 for (usize j = 0; j < n; j++) { 98 u64 c = 0; 99 for (usize i = 0; i < m; i++) 100 fmaa64(&c, &r[i + j], x[i], y[j], r[i + j], c); 101 r[m + j] = c; 102 } 103 } 104 105 // Precondition: capacity(r) >= 2 * n, capacity(x) = capactiy(y) = n 106 // Precondition: r is disjoint with x and y 107 // Precondition: capacity(tt) >= 2 * kk, where kk := 2 * (n - n / 2 + 1) 108 void 109 karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) { 110 /* We seek to multiply x and y, which each have n "words" (digits of base 111 * b := 2^64), obtaining the full 2*n word product. For this, we let 112 * l := floor(n / 2) and h := ceil(n / 2) 113 * and partition x and y into low parts xl and yl with l words and high 114 * parts xh and yh with h words, such that 115 * x = xh * b^l + xl and y = yh * b^l + yl. 116 * Then 117 * x * y = s * b^(2*l) + t * b^l + u, 118 * where 119 * s := xh * yh, t := xh * yl + xl * yh, and u := xl * yl. 120 * Thus, we could multiply x and y by recursively evaluating the four 121 * products in the definition of s, t, and u (the products involving b^i 122 * are just bit shifts). However, this would result in a time complexity 123 * of O(n^2), the same asymptotic performance as the naive "doubly-nested 124 * for-loop" algorithm. Instead, we exploit the following identity, whose 125 * significance for multiplication algorithms was first noticed by Anatoly 126 * Karatsuba in 1960: 127 * t = p * q - u - s 128 * where 129 * p := xl + xh and q := yl + yh. 130 * Computing t in its latter form saves one multiplication at the expense 131 * of a few additions/subtractions (which have O(n) time complexity), 132 * thereby reducing the time complexity to O(n^(lg 3)). 133 * 134 * Let |z| denote the number of words in a number z. For well-founded 135 * recusion, we need n to strictly decrease in each of the three recursive 136 * calls. When we compute u = xl * yl and s = xh * yh, this is true as long 137 * as n >= 2, for then |xl|,|yl| <= l < n and |yh|,|yh| <= h < n. For the 138 * computation of p * q in t, on the other hand, we have 139 * |p|,|q| = |xl + xh|,|yl + yh| (definition of p and q) 140 * <= max(l, h) + 1 (addition adds at most 1 word) 141 * = h + 1 (h >= l) 142 * = ceil(n / 2) + 1 (definition of h) 143 * and as long as n >= 4, this quantity is strictly less than n. It 144 * therefore suffices to separate the case n < 4 for the recursion 145 * basis. */ 146 147 // 1. Basis 148 if (n < KARATSUBA_THRESH) { 149 fma_quadratic(r, x, n, y, n); 150 return; 151 } 152 153 // 2. Compute l, h, and k, and their doubles ll, hh, and kk 154 // The significance of these quantities is as follows: 155 // - l is the max width of xl and yl 156 // - h is the max width of xh and yh 157 // - k is the max width of p and q 158 // - ll is the max width of u 159 // - hh is the max width of s 160 // - kk is the max width of t 161 usize l = n / 2, ll = 2 * l; 162 usize h = n - l, hh = 2 * h; 163 usize k = h + 1, kk = 2 * k; 164 165 // 3. Split x and y 166 u64 *xh = x + l, *xl = x; 167 u64 *yh = y + l, *yl = y; 168 169 // 4. Assign blocks of memory for intermediate results 170 // Note that we use the output buffer r as temporary storage for p and q. 171 // We also store u and s directly in r at the appropriate offsets (free bit 172 // shifts!), such that p and q overlap with u and s, but that's ok, because 173 // we're done with p and q by the time we calculate u and s. 174 u64 *p = r; 175 u64 *q = r + k; 176 u64 *t0 = tt; // Storage for t in this invocation 177 u64 *t1 = tt + kk; // Storage for t in future (recursive) invocations 178 u64 *u = r; 179 u64 *s = r + ll; 180 181 // 5. Arithmetic 182 p[k] = add(p, xh, h, xl, l); // p = xh + xl 183 q[k] = add(q, yh, h, yl, l); // q = yh + yl 184 karatsuba(t0, p, q, k, t1); // t = p * q 185 karatsuba(u, xl, yl, l, t1); // u = xl * yl 186 karatsuba(s, xh, yh, h, t1); // s = xh * yh 187 sub(t0, t0, kk, u, ll); // t -= u [1] 188 sub(t0, t0, kk, s, hh); // t -= s [1] 189 add(r + l, r + l, hh + l, t0, kk); // r[l..] += t [2] [3] 190 // [1]: The borrow outs are guaranteed to be 0, because t0 - u - s must 191 // be positive. 192 // [2]: The carry out is guaranteed to be 0, because the full product 193 // x * y must fit in 2 * n words. 194 // [3]: The add precondition hh + l >= kk is satisfied here as long as 195 // n >= 4, and it is, because n < 4 is the recursion basis. 196 } 197 198 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n 199 // Precondition: r is disjoint with x and y 200 // Precondition: m >= n 201 void 202 fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { 203 // TODO: Accept capacity of r as an argument, and use excess memory for 204 // scratch, if it's big enough, instead of allocating. 205 usize firstkk = 2 * (n - n / 2 + 1); 206 u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch); 207 u64 *prod = scratch; 208 u64 *tt = scratch + 2 * n; 209 210 for (;;) { 211 for (; m >= n; r += n, x += n, m -= n) { 212 karatsuba(prod, x, y, n, tt); 213 add(r, prod, 2 * n, r, n); 214 } 215 216 if (m == 0) break; 217 218 if (m < KARATSUBA_THRESH) { 219 fma_quadratic(r, x, m, y, n); 220 break; 221 } 222 223 u64 *t0 = x; x = y; y = t0; // Swap x and y 224 usize t1 = m; m = n; n = t1; // Swap m and n 225 } 226 227 free(scratch); 228 } 229 230 // Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n 231 // Precondition: r is disjoint with x and y 232 void 233 mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { 234 if (m < n) { 235 u64 *t0 = x; x = y; y = t0; // Swap x and y 236 usize t1 = m; m = n; n = t1; // Swap m and n 237 } 238 239 memset(r, 0, (m + n) * sizeof *r); 240 (n < KARATSUBA_THRESH ? fma_quadratic : fma_karatsuba)(r, x, m, y, n); 241 } 242 243 244 /* ----- Benchmarks ----- */ 245 246 u64 x[4096]; 247 u64 y[4096]; 248 u64 r[8192]; 249 250 NOINLINE void 251 bench_karatsuba(u64 l, u64 n) { 252 r_bench_start(); 253 for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l); 254 r_bench_stop(); 255 } 256 257 void bench_karatsuba16(u64 n) { bench_karatsuba(16, n); } 258 void bench_karatsuba32(u64 n) { bench_karatsuba(32, n); } 259 void bench_karatsuba64(u64 n) { bench_karatsuba(64, n); } 260 void bench_karatsuba128(u64 n) { bench_karatsuba(128, n); } 261 void bench_karatsuba256(u64 n) { bench_karatsuba(256, n); } 262 void bench_karatsuba512(u64 n) { bench_karatsuba(512, n); } 263 void bench_karatsuba1024(u64 n) { bench_karatsuba(1024, n); } 264 void bench_karatsuba2048(u64 n) { bench_karatsuba(2048, n); } 265 void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); } 266 267 int 268 main(void) { 269 for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64(); 270 for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64(); 271 272 r_bench(bench_karatsuba16, 1000); 273 r_bench(bench_karatsuba32, 1000); 274 r_bench(bench_karatsuba64, 1000); 275 r_bench(bench_karatsuba128, 1000); 276 r_bench(bench_karatsuba256, 1000); 277 r_bench(bench_karatsuba512, 1000); 278 r_bench(bench_karatsuba1024, 1000); 279 r_bench(bench_karatsuba2048, 1000); 280 r_bench(bench_karatsuba4096, 1000); 281 }