LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mul_fft.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../include/lammp/impl/mparam.h"
8#include "../../include/lammp/impl/tmp_alloc.h"
9#include "../../include/lammp/lmmpn.h"
10
11// ((mp_size_t)3 << (2 * (n) - 5)) + 1 是预计算的阈值,n是对应的k值
12#define _FFT_TABLE_ENTRY(n) {((mp_size_t)3 << (2 * (n) - 5)) + 1, (n)}
13#define _FFT_TABLE_ENTRY4(n) \
14 _FFT_TABLE_ENTRY(n), _FFT_TABLE_ENTRY((n) + 1), _FFT_TABLE_ENTRY((n) + 2), _FFT_TABLE_ENTRY((n) + 3)
15
16// best_k_(next_size_(n)) = best_k_(n)
17// table[i+1][0]-1 必须是 2^(table[i][1]-LOG2_LIMB_BITS) 的整数倍
18// LOG2_LIMB_BITS:每个 limb 的比特数的2对数,为 log2(64) = 6
19static const mp_size_t lmmp_fft_table_[][2] = {
20 {0, 6},
21 {1597, 7},
22 {1655, 6},
23 {1917, 7},
24 {3447, 8},
25 {3565, 7},
26 {3831, 8},
27 {7661, 9},
28 {8145, 8},
29 {8685, 9},
30 {14289, 10},
31 {16289, 9},
32 {20433, 10},
33 {24481, 9},
34 {26577, 10},
35 {28593, 11},
36 {32545, 10},
37 {57249, 11},
38 {65313, 10},
39 {73633, 11},
40 {98081, 12},
41 {130625, 11},
42 {196385, 12},
43 {261697, 11},
44 {294689, 12},
45 {392769, 13},
46 {523265, 12},
47 {654913, 11},
48 {917281, 13},
49 {1047553, 11},
50 {1600001, 12},
51 {1834561, 14},
52 {2095105, 12},
57 {(mp_size_t)-1, 127}};
58
59typedef struct {
60 mp_ptr temp_coef; // 用于数据交换的临时系数数组
61 mp_size_t lenw; // 系数的机器字(limb)长度
62 mp_ssize_t maxdepth; // 内存栈的最大深度(已分配的层数)
63 mp_ssize_t tempdepth; // 内存栈的当前深度(正在使用的层数)
64 void* mem[16]; // 存储16层内存块的指针
65 mp_size_t memsize[16]; // 存储每层内存块的大小(以字节为单位)
67
68/**
69 * @brief 查找对于 m>=n 的模 B^m+1 FFT运算的最优k值
70 * @param n - 输入的机器字长度
71 * @return 最优的k值
72 */
74 mp_size_t k = 0;
75 while (n >= lmmp_fft_table_[k + 1][0]) ++k;
76 return lmmp_fft_table_[k][1];
77}
78
79/**
80 * @brief 计算FFT运算所需的最小规整化长度(向上取整到2^k的倍数)
81 * @param n - 原始长度
82 * @return 规整后的长度(为2^k的倍数)
83 */
88 n = (((n - 1) >> k) + 1) << k;
89 return n;
90}
91
92/**
93 * @brief FFT内存栈的分配/释放接口
94 * @param ms - 内存栈结构体栈帧
95 * @param size - 分配大小(字节),size=0表示释放当前层内存
96 * @return 分配成功:返回mp_ptr*;释放:返回0
97 */
98static void* lmmp_fft_memstack_(fft_memstack* ms, mp_size_t size) {
99 if (size) {
100 if (++ms->tempdepth > ms->maxdepth) {
101 ms->mem[++ms->maxdepth] = lmmp_alloc(size);
102 ms->memsize[ms->maxdepth] = size;
103 }
104 lmmp_debug_assert(ms->memsize[ms->tempdepth] == size);
105 return ms->mem[ms->tempdepth];
106 } else {
107 if (--ms->tempdepth < 0) {
108 for (mp_size_t i = 0; i <= (mp_size_t)(ms->maxdepth); ++i) lmmp_free(ms->mem[i]);
109 ms->maxdepth = -1;
110 }
111 return 0;
112 }
113}
114
115/**
116 * @brief [dst,lenw+1] = [(bit*)numa+bitoffset, bits]
117 * @param dst - 输出系数数组(长度lenw+1)
118 * @param numa - 输入大数指针
119 * @param bitoffset - 起始比特偏移量(>=0)
120 * @param bits - 提取的比特数(0 < bits <= LIMB_BITS*lenw)
121 * @param lenw - 输出系数的机器字长度
122 * @warning bitoffset>=0, 0<bits<=LIMB_BITS*lenw, sep(dst,numa)
123 */
124static void lmmp_fft_extract_coef_(mp_ptr dst, mp_srcptr numa, mp_size_t bitoffset, mp_size_t bits, mp_size_t lenw) {
125 // shr = 机器字内的比特偏移(0~LIMB_BITS-1)
126 // offset = 起始机器字的索引
127 mp_size_t shr = bitoffset & (LIMB_BITS - 1), offset = bitoffset / LIMB_BITS;
128
129 mp_size_t lena = (bitoffset + bits - 1) / LIMB_BITS - offset + 1, endp = (bits - 1) / LIMB_BITS;
130
131 if (shr)
132 lmmp_shr_(dst, numa + offset, lena, shr);
133 else
134 lmmp_copy(dst, numa + offset, lena);
135
136 dst[endp] &= LIMB_MAX >> (-bits & (LIMB_BITS - 1));
137
138 lmmp_zero(dst + endp + 1, lenw - endp);
139}
140
141/**
142 * @brief 对模 2^n+1 的系数执行左移操作
143 * @param ms - 内存栈结构体指针
144 * @param coef - 输入输出系数数组指针(指针的指针,用于交换内存)
145 * @param shl - 左移的比特数(0<shl<2*n)
146 * @warning n = ms->lenw * LIMB_BITS
147 * *coef 已伪归一化(mod 2^n+1)
148 * ms->temp_coef 至少有 lenw+1 个机器字
149 */
150static void lmmp_fft_shl_coef_(fft_memstack* ms, mp_ptr* coef, mp_size_t shl) {
151 mp_size_t l = ms->lenw; // 系数的机器字长度
152 mp_size_t w = shl / LIMB_BITS; // 左移的机器字数量
153 shl &= LIMB_BITS - 1; // 剩余的比特偏移(0~LIMB_BITS-1)
154 mp_ptr src = *coef; // 源系数数组
155 mp_ptr dst = ms->temp_coef; // 目标临时数组
156 mp_limb_t cc, rd; // 进位变量(cc=carry, rd=read)
157
158 if (w >= l) {
159 w -= l;
160 if (shl) {
161 lmmp_shl_(dst, src + l - w, w + 1, shl);
162 rd = dst[w];
163 cc = lmmp_shlnot_(dst + w, src, l - w, shl);
164 } else {
165 if (w)
166 lmmp_copy(dst, src + l - w, w);
167 rd = src[l];
168 lmmp_not_(dst + w, src, l - w);
169 cc = 0;
170 }
171 dst[l] = 0;
172 ++cc;
173 lmmp_inc_1(dst, cc);
174
175 if (++rd == 0)
176 lmmp_inc(dst + w + 1);
177 else
178 lmmp_inc_1(dst + w, rd);
179 } else {
180 if (shl) {
181 lmmp_shlnot_(dst, src + l - w, w + 1, shl);
182 rd = ~dst[w];
183 cc = lmmp_shl_(dst + w, src, l - w, shl);
184 } else {
185 if (w)
186 lmmp_not_(dst, src + l - w, w);
187 rd = src[l];
188
189 lmmp_copy(dst + w, src, l - w);
190 cc = 0;
191 }
192 dst[l] = 2;
193 lmmp_inc_1(dst, 3);
194 lmmp_dec_1(dst, cc);
195
196 if (++rd == 0)
197 lmmp_dec(dst + w + 1);
198 else
199 lmmp_dec_1(dst + w, rd);
200
201 cc = dst[l];
202 dst[l] = dst[0] < cc;
203 lmmp_dec_1(dst, cc - dst[l]);
204 }
205
206 ms->temp_coef = src;
207 *coef = dst;
208}
209
210/**
211 * @brief 对模 2^n+1 的系数执行右移操作
212 * 右移shr位 = 左移(2n - shr)位(mod 2^n+1的循环特性)
213 * @param ms - 内存栈结构体指针
214 * @param coef - 输入输出系数数组指针
215 * @param shr - 右移的比特数(0 < shr < 2*n)
216 */
217static void lmmp_fft_shr_coef_(fft_memstack* ms, mp_ptr* coef, mp_size_t shr) {
218 lmmp_fft_shl_coef_(ms, coef, 2 * ms->lenw * LIMB_BITS - shr);
219}
220
221/**
222 * @brief FFT蝶形运算(Butterfly Operation)
223 * (a,b) = (a + b, (a-b) << w ) mod 2^n+1
224 * a=[coef[0],ms->lenw+1], b=[coef[wing],ms->lenw+1], n=ms->lenw * LIMB_BITS
225 * @param ms - 内存栈结构体指针
226 * @param coef - 系数数组指针数组(coef[0]=a, coef[wing]=b)
227 * @param wing - b的索引
228 * @param w - 左移的比特数(0<=w<n)
229 * @warning n = ms->lenw * LIMB_BITS
230 * a,b 均已伪归一化(mod 2^n+1)
231 * ms->temp_coef 有至少 lenw + 1 个字长
232 */
233static void lmmp_fft_bfy_(fft_memstack* ms, mp_ptr* coef, mp_size_t wing, mp_size_t w) {
234 mp_ptr numa = coef[0]; // 系数a
235 mp_ptr numb = coef[wing]; // 系数b
236 mp_ptr numc = ms->temp_coef; // 临时数组(存储a-b<<w)
237 mp_size_t shl = w & (LIMB_BITS - 1); // 比特级左移量
238 w /= LIMB_BITS; // 机器字级左移量
239 mp_size_t l = ms->lenw; // 系数长度(机器字)
240
241 mp_slimb_t acyo = 0, scyo = 0, ch;
242 mp_limb_t shlcyo = 0, chp = 0, chn = 0;
243
244 for (mp_size_t off = 0; off < l - w; off += PART_SIZE) {
245 mp_size_t cursize = LMMP_MIN(l - w - off, PART_SIZE);
246 scyo = lmmp_sub_nc_(numc + w + off, numa + off, numb + off, cursize, scyo);
247 acyo = lmmp_add_nc_(numa + off, numa + off, numb + off, cursize, acyo);
248 if (shl)
249 shlcyo = lmmp_shl_c_(numc + w + off, numc + w + off, cursize, shl, shlcyo);
250 }
251
252 ch = shlcyo + (-scyo << shl);
253 if (ch > 0)
254 chp = ch;
255 else
256 chn = -ch;
257
258 scyo = 0;
259 shlcyo = 0;
260
261 for (mp_size_t off = l - w; off < l; off += PART_SIZE) {
262 mp_size_t cursize = LMMP_MIN(l - off, PART_SIZE);
263 scyo = lmmp_sub_nc_(numc + off - (l - w), numb + off, numa + off, cursize, scyo);
264 acyo = lmmp_add_nc_(numa + off, numa + off, numb + off, cursize, acyo);
265 if (shl)
266 shlcyo = lmmp_shl_c_(numc + off - (l - w), numc + off - (l - w), cursize, shl, shlcyo);
267 }
268
269 numc[w] += shlcyo; // 左移进位加到numc[w]
270 scyo = -scyo + numb[l] - numa[l]; // 调整借位(包含最高位)
271 acyo += numa[l] + numb[l]; // 调整进位(包含最高位)
272
273 numa[l] = numa[0] < (mp_limb_t)(acyo);
274 lmmp_dec_1(numa, acyo - numa[l]);
275
276 numc[l] = 1;
277 ++chn;
278 if (scyo > 0)
279 lmmp_inc_1(numc + w, scyo << shl);
280 else if (scyo < 0) {
281 if (scyo == -2 && shl == LIMB_BITS - 1)
282 lmmp_dec(numc + w + 1);
283 else
284 lmmp_dec_1(numc + w, -scyo << shl);
285 }
286 chp += numc[l];
287
288 if (chn >= chp) {
289 numc[l] = 0;
290 lmmp_inc_1(numc, chn - chp);
291 } else {
292 chp -= chn;
293 numc[l] = numc[0] < chp;
294 lmmp_dec_1(numc, chp - numc[l]);
295 }
296
297 coef[wing] = numc;
298 ms->temp_coef = numb;
299}
300
301/**
302 * @brief FFT蝶形运算(Butterfly Operation)
303 * (a,b) = (a+(b>>w), a-(b>>w)) mod 2^n+1
304 * a=[coef[0],ms->lenw+1], b=[coef[wing],ms->lenw+1], n=ms->lenw * LIMB_BITS
305 * @param ms - 内存栈结构体指针
306 * @param coef - 系数数组指针数组(coef[0]=a, coef[wing]=b)
307 * @param wing - b的索引
308 * @param w - 左移的比特数(0<=w<n)
309 * @warning n = ms->lenw * LIMB_BITS
310 * a,b 均已伪归一化(mod 2^n+1)
311 * ms->temp_coef 有至少 lenw + 1 个字长
312 */
313static void lmmp_ifft_bfy_(fft_memstack* ms, mp_ptr* coef, mp_size_t wing, mp_size_t w) {
314 mp_ptr numa = coef[0]; // 系数a
315 mp_ptr numb = coef[wing]; // 系数b
316 mp_ptr numc = ms->temp_coef; // 临时数组(存储a-(b>>w))
317 mp_size_t shr = w & (LIMB_BITS - 1); // 比特级右移量
318 w /= LIMB_BITS; // 机器字级右移量
319 mp_size_t l = ms->lenw; // 系数长度
320
321 mp_slimb_t bcyo = 0, acyo = 0, ah;
322 mp_limb_t shrcyo = shr ? numb[0] << (LIMB_BITS - shr) : 0;
323
324 for (mp_size_t off = l - w; off < l; off += PART_SIZE) {
325 mp_size_t cursize = LMMP_MIN(l - off, PART_SIZE);
326 if (shr)
327 lmmp_shr_c_(numb + off - (l - w), numb + off - (l - w), cursize, shr,
328 numb[off - (l - w) + cursize] << (LIMB_BITS - shr));
329 bcyo = lmmp_add_nc_(numc + off, numa + off, numb + off - (l - w), cursize, bcyo);
330 acyo = lmmp_sub_nc_(numa + off, numa + off, numb + off - (l - w), cursize, acyo);
331 }
332
333 for (mp_size_t off = 0; off < l - w; off += PART_SIZE) {
334 mp_size_t cursize = LMMP_MIN(l - w - off, PART_SIZE);
335 if (shr)
336 lmmp_shr_c_(numb + w + off, numb + w + off, cursize, shr, numb[off + w + cursize] << (LIMB_BITS - shr));
337 bcyo = lmmp_sub_nc_(numc + off, numa + off, numb + w + off, cursize, bcyo);
338 acyo = lmmp_add_nc_(numa + off, numa + off, numb + w + off, cursize, acyo);
339 }
340
341 acyo += numb[l] >> shr;
342 bcyo = -bcyo - (numb[l] >> shr);
343
344 acyo -= numa[l - w - 1] < shrcyo;
345 numa[l - w - 1] -= shrcyo;
346 numc[l - w - 1] += shrcyo;
347 bcyo += numc[l - w - 1] < shrcyo;
348
349 ah = numa[l];
350
351 numa[l] += 1;
352 if (w == 0)
353 numa[l] += acyo;
354 else {
355 if (acyo < 0)
356 lmmp_dec(numa + l - w);
357 else
358 lmmp_inc_1(numa + l - w, acyo);
359 }
360 acyo = numa[l] - 1;
361 if (acyo < 0) {
362 numa[l] = 0;
363 lmmp_inc(numa);
364 } else {
365 numa[l] = numa[0] < (mp_limb_t)acyo;
366 lmmp_dec_1(numa, acyo - numa[l]);
367 }
368
369 numc[l] = ah + 2;
370 if (w == 0)
371 numc[l] += bcyo;
372 else {
373 if (bcyo > 0)
374 lmmp_inc(numc + l - w);
375 else
376 lmmp_dec_1(numc + l - w, -bcyo);
377 }
378 bcyo = numc[l] - 2;
379 if (bcyo <= 0) {
380 numc[l] = 0;
381 lmmp_inc_1(numc, -bcyo);
382 } else {
383 numc[l] = numc[0] < (mp_limb_t)bcyo;
384 lmmp_dec_1(numc, bcyo - numc[l]);
385 }
386
387 coef[wing] = numc;
388 ms->temp_coef = numb;
389}
390
391/**
392 * @brief FFT递归函数
393 * @param ms - 内存栈结构体指针
394 * @param coef - 系数数组指针数组
395 * @param dis - 索引步长
396 * @param k - FFT层数(递归深度)
397 * @param w - 每次蝶形运算的移位基数
398 * @param w0 - 初始移位偏移
399 */
401 if (k == 1)
402 lmmp_fft_bfy_(ms, coef, dis, w0);
403 else {
404 k -= 2;
405 mp_size_t Kq = dis << k;
406 for (mp_size_t i = 0; i < Kq; i += dis) {
407 lmmp_fft_bfy_(ms, coef + i, 2 * Kq, i * w + w0);
408 lmmp_fft_bfy_(ms, coef + i + Kq, 2 * Kq, (i + Kq) * w + w0);
409 lmmp_fft_bfy_(ms, coef + i, Kq, 2 * (i * w + w0));
410 lmmp_fft_bfy_(ms, coef + i + Kq * 2, Kq, 2 * (i * w + w0));
411 }
412 if (k > 0) {
413 lmmp_fft_b1_(ms, coef, dis, k, 4 * w, 4 * w0);
414 lmmp_fft_b1_(ms, coef + Kq, dis, k, 4 * w, 4 * w0);
415 lmmp_fft_b1_(ms, coef + Kq * 2, dis, k, 4 * w, 4 * w0);
416 lmmp_fft_b1_(ms, coef + Kq * 3, dis, k, 4 * w, 4 * w0);
417 }
418 }
419}
420
421static void lmmp_fft_4_(fft_memstack* ms, mp_ptr* coef, mp_size_t k, mp_size_t w) {
422 if (k == 1)
423 lmmp_fft_bfy_(ms, coef, 1, 0);
424 else {
425 k -= 2;
426 mp_size_t Kq = ((mp_size_t)1) << k;
427 for (mp_size_t i = 0; i < Kq; ++i) {
428 lmmp_fft_bfy_(ms, coef + i, Kq * 2, i * w);
429 lmmp_fft_bfy_(ms, coef + i + Kq, Kq * 2, (i + Kq) * w);
430 lmmp_fft_bfy_(ms, coef + i, Kq, 2 * i * w);
431 lmmp_fft_bfy_(ms, coef + i + 2 * Kq, Kq, 2 * i * w);
432 }
433 if (k > 0) {
434 lmmp_fft_4_(ms, coef, k, w * 4);
435 lmmp_fft_4_(ms, coef + Kq, k, w * 4);
436 lmmp_fft_4_(ms, coef + 2 * Kq, k, w * 4);
437 lmmp_fft_4_(ms, coef + 3 * Kq, k, w * 4);
438 }
439 }
440}
441
442static void lmmp_fft_(fft_memstack* ms, mp_ptr* coef, mp_size_t k, mp_size_t w) {
443 mp_size_t k1 = k >> 1; // k1 = k/2(右移1位等价于除以2)
444 k -= k1; // k = k - k1(剩余层数)
445 mp_size_t Kp = ((mp_size_t)1) << k; // Kp = 2^k
446 mp_size_t Kq = ((mp_size_t)1) << k1; // Kq = 2^k1
447
448 for (mp_size_t i = 0; i < Kp; ++i) lmmp_fft_b1_(ms, coef + i, Kp, k1, w, i * w);
449
450 for (mp_size_t i = 0; i < Kq; ++i) lmmp_fft_4_(ms, coef + Kp * i, k, w * Kq);
451}
452
454 if (k == 1)
455 lmmp_ifft_bfy_(ms, coef, dis, w0);
456 else {
457 k -= 2;
458 mp_size_t Kq = dis << k;
459 if (k > 0) {
460 lmmp_ifft_b1_(ms, coef, dis, k, 4 * w, 4 * w0);
461 lmmp_ifft_b1_(ms, coef + Kq, dis, k, 4 * w, 4 * w0);
462 lmmp_ifft_b1_(ms, coef + Kq * 2, dis, k, 4 * w, 4 * w0);
463 lmmp_ifft_b1_(ms, coef + Kq * 3, dis, k, 4 * w, 4 * w0);
464 }
465 for (mp_size_t i = 0; i < Kq; i += dis) {
466 lmmp_ifft_bfy_(ms, coef + i, Kq, 2 * (i * w + w0));
467 lmmp_ifft_bfy_(ms, coef + i + Kq * 2, Kq, 2 * (i * w + w0));
468 lmmp_ifft_bfy_(ms, coef + i, 2 * Kq, i * w + w0);
469 lmmp_ifft_bfy_(ms, coef + i + Kq, 2 * Kq, (i + Kq) * w + w0);
470 }
471 }
472}
473
475 if (k == 1)
476 lmmp_ifft_bfy_(ms, coef, 1, 0);
477 else {
478 k -= 2;
479 mp_size_t Kq = ((mp_size_t)1) << k;
480 if (k > 0) {
481 lmmp_ifft_4_(ms, coef, k, w * 4);
482 lmmp_ifft_4_(ms, coef + Kq, k, w * 4);
483 lmmp_ifft_4_(ms, coef + 2 * Kq, k, w * 4);
484 lmmp_ifft_4_(ms, coef + 3 * Kq, k, w * 4);
485 }
486 for (mp_size_t i = 0; i < Kq; ++i) {
487 lmmp_ifft_bfy_(ms, coef + i, Kq, 2 * i * w);
488 lmmp_ifft_bfy_(ms, coef + i + 2 * Kq, Kq, 2 * i * w);
489 lmmp_ifft_bfy_(ms, coef + i, Kq * 2, i * w);
490 lmmp_ifft_bfy_(ms, coef + i + Kq, Kq * 2, (i + Kq) * w);
491 }
492 }
493}
494
495static void lmmp_ifft_(fft_memstack* ms, mp_ptr* coef, mp_size_t k, mp_size_t w) {
496 mp_size_t k1 = k >> 1; // k1 = k/2
497 k -= k1; // k = k - k1
498 mp_size_t Kp = ((mp_size_t)1) << k; // Kp = 2^k
499 mp_size_t Kq = ((mp_size_t)1) << k1; // Kq = 2^k1
500
501 for (mp_size_t i = 0; i < Kq; ++i) lmmp_ifft_4_(ms, coef + Kp * i, k, w * Kq);
502
503 for (mp_size_t i = 0; i < Kp; ++i) lmmp_ifft_b1_(ms, coef + i, Kp, k1, w, i * w);
504}
505
506/**
507 * @brief 费马变换 模 B^n+1 乘法的结果合并
508 * @param ms - 内存栈结构体指针
509 * @param dst - 输出结果数组
510 * @param pfca - FFT系数数组指针数组
511 * @param K - FFT块数(2^k)
512 * @param k - FFT层数
513 * @param n - 系数总比特数
514 * @param M - 每个块的比特数
515 * @param rn - 结果长度(机器字)
516 */
518 fft_memstack* ms,
519 mp_ptr dst,
520 mp_ptr* pfca,
521 mp_size_t K,
522 mp_size_t k,
523 mp_size_t n,
524 mp_size_t M,
525 mp_size_t rn
526) {
527 mp_size_t rhead = 0, nlen = ms->lenw + 1;
528 mp_slimb_t borrow = 0, maxc = 0;
529
530 for (mp_size_t i = 0; i < K; ++i) {
531 lmmp_fft_shr_coef_(ms, pfca + i, (i * n >> k) + k);
532 mp_ptr nums = pfca[i];
533
534 if (nums[nlen - 1]) {
535 lmmp_dec(nums);
536 --nums[nlen - 1];
537 }
538 if (nums[nlen - 1] == 0 && nums[nlen - 2] >> (LIMB_BITS - 1)) {
539 lmmp_dec(nums);
540 --nums[nlen - 1];
541 }
542
543 if (borrow) {
544 mp_size_t brshift = borrow - 1 + n - M;
545 mp_size_t bshl = brshift & (LIMB_BITS - 1);
546 brshift /= LIMB_BITS;
547 --nums[nlen - 1];
548 lmmp_dec_1(nums + brshift, (mp_limb_t)1 << bshl);
549 ++nums[nlen - 1];
550 }
551 borrow = -nums[nlen - 1];
552 nums[nlen - 1] = 0;
553
554 mp_size_t roffset = i * M;
555 mp_size_t shl = roffset & (LIMB_BITS - 1);
556 roffset /= LIMB_BITS;
557
558 if (shl)
559 lmmp_shl_(nums, nums, nlen, shl);
560
561 if (i == 0) {
562 lmmp_copy(dst, nums, nlen);
563 rhead = nlen;
564 } else if (roffset + nlen <= rn) {
565 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
566 rhead = roffset + nlen;
567 } else {
568 maxc += lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
569 maxc -= lmmp_sub_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
570 rhead = rn;
571 }
572 }
573
574 if (borrow) {
575 mp_size_t cyshift = borrow - 1 + n - M;
576 mp_size_t cshl = cyshift & (LIMB_BITS - 1);
577 cyshift /= LIMB_BITS;
578 maxc += lmmp_add_1_(dst + cyshift, dst + cyshift, rn - cyshift, (mp_limb_t)1 << cshl);
579 }
580
581 if (maxc > 0) {
582 dst[rn] = dst[0] < (mp_limb_t)maxc;
583 lmmp_dec_1(dst, maxc - dst[rn]);
584 } else {
585 dst[rn] = 0;
586 lmmp_inc_1(dst, -maxc);
587 }
588}
589
590/**
591 * @brief 费马变换乘法递归函数(核心乘法逻辑)
592 * @param ms - 内存栈结构体指针
593 * @param pc1 - 第一个数的FFT系数数组指针数组
594 * @param pc2 - 第二个数的FFT系数数组指针数组
595 * @param K0 - FFT块数
596 * @warning K0>0
597 * 所有系数均已伪归一化(mod B^lenw+1)
598 * nsqr=1表示乘法,nsqr=0表示平方
599 */
601 int nsqr = pc1 != pc2; // 判断是否为平方运算
602 mp_ptr push_temp_coef = ms->temp_coef;
603 mp_size_t rn = ms->lenw; // 当前系数长度
604
605 // 小于阈值则不使用FFT
606 if (rn < MUL_FFT_MODF_THRESHOLD) {
607 mp_ptr temp_mul = (mp_ptr)lmmp_fft_memstack_(ms, (rn + 1) * 2 * LIMB_BYTES);
608 for (mp_size_t i = 0; i < K0; ++i) {
609 if (nsqr)
610 lmmp_mul_n_(temp_mul, pc1[i], pc2[i], rn + 1);
611 else
612 lmmp_sqr_(temp_mul, pc1[i], rn + 1);
613
614 // 模 B^rn+1 归一化:temp_mul - temp_mul[rn ...]
615 mp_limb_t maxc = lmmp_sub_n_(pc1[i], temp_mul, temp_mul + rn, rn) + temp_mul[rn * 2];
616 pc1[i][rn] = 0;
617 lmmp_inc_1(pc1[i], maxc);
618 }
619 lmmp_fft_memstack_(ms, 0);
620 } else {
621 mp_size_t N = rn * LIMB_BITS; // 总比特数
622 mp_size_t k = lmmp_fft_best_k_(rn); // 最优FFT层数
623 mp_size_t K = ((mp_size_t)1) << k; // FFT块数(2^k)
624 lmmp_debug_assert(!(N & (K - 1)));
625 mp_size_t M = N >> k; // 每个块的比特数(N/K)
626 mp_size_t n = 2 * M + k + 2; // 扩展系数长度(保证归一化)
627
628 // 规整n:必须是LIMB_BITS和K的整数倍
629 n = (n + LIMB_BITS - 1) & (-LIMB_BITS); // 向上取整到LIMB_BITS的倍数
630 n = (((n - 1) >> k) + 1) << k; // 向上取整到K的倍数
631
632 ms->lenw = n / LIMB_BITS;
633 mp_size_t nlen = ms->lenw + 1;
634
635 ms->temp_coef = (mp_ptr)lmmp_fft_memstack_(ms, (((nlen + 1) << (k + nsqr)) + nlen) * LIMB_BYTES);
636 mp_ptr *pfca = (mp_ptr*)(ms->temp_coef + nlen), *pfcb = pfca;
637 for (mp_size_t i = 0; i < K; ++i) pfca[i] = (mp_ptr)(pfca + K) + i * nlen;
638 if (nsqr) {
639 pfcb += (nlen + 1) << k;
640 for (mp_size_t i = 0; i < K; ++i) pfcb[i] = (mp_ptr)(pfcb + K) + i * nlen;
641 }
642
643 for (mp_size_t j = 0; j < K0; ++j) {
644 mp_ptr numa = pc1[j];
645 mp_ptr numb = pc2[j];
646
647 for (mp_size_t i = 0; i < K; ++i) {
648 lmmp_fft_extract_coef_(pfca[i], numa, M * i, M + (i == K - 1), ms->lenw);
649 if (i > 0)
650 lmmp_fft_shl_coef_(ms, pfca + i, i * n >> k);
651 }
652 lmmp_fft_(ms, pfca, k, n >> (k - 1));
653
654 if (nsqr) {
655 for (mp_size_t i = 0; i < K; ++i) {
656 lmmp_fft_extract_coef_(pfcb[i], numb, M * i, M + (i == K - 1), ms->lenw);
657 if (i > 0)
658 lmmp_fft_shl_coef_(ms, pfcb + i, i * n >> k);
659 }
660 lmmp_fft_(ms, pfcb, k, n >> (k - 1));
661 }
662
663 // dot product
664 lmmp_mul_fermat_recurse_(ms, pfca, pfcb, K);
665
666 lmmp_ifft_(ms, pfca, k, n >> (k - 1));
667
668 lmmp_mul_fermat_recombine_(ms, numa, pfca, K, k, n, M, rn);
669 }
670 lmmp_fft_memstack_(ms, 0);
671 }
672
673 ms->temp_coef = push_temp_coef;
674 ms->lenw = rn;
675}
676
678 int nsqr = numa != numb || na != nb; // 判断是否为平方运算
679 mp_size_t N = rn * LIMB_BITS; // 结果总比特数
680 mp_size_t k = lmmp_fft_best_k_(rn); // 最优FFT层数
681 mp_size_t K = ((mp_size_t)1) << k; // FFT块数(2^k)
682 lmmp_debug_assert(!(N & (K - 1)));
683 mp_size_t M = N >> k; // 每个块的比特数
684 mp_size_t n = 2 * M + k + 2; // 扩展系数长度
685
686 n = (n + LIMB_BITS - 1) & (-LIMB_BITS);
687 n = (((n - 1) >> k) + 1) << k;
688
689 // 初始化内存栈
690 fft_memstack msr;
691 msr.maxdepth = -1;
692 msr.tempdepth = -1;
693 msr.lenw = n / LIMB_BITS; // 系数长度(机器字)
694 mp_size_t nlen = msr.lenw + 1; // 系数总长度
695
696 msr.temp_coef = (mp_ptr)lmmp_fft_memstack_(&msr, (((nlen + 1) << (k + nsqr)) + nlen) * LIMB_BYTES);
697
698 mp_ptr *pfca = (mp_ptr*)(msr.temp_coef + nlen), *pfcb = pfca;
699 mp_size_t narest = na * LIMB_BITS, nbrest = nb * LIMB_BITS;
700
701 for (mp_size_t i = 0; i < K; ++i) {
702 mp_size_t coeflen;
703 pfca[i] = (mp_ptr)(pfca + K) + i * nlen;
704 if (narest > 0) {
705 coeflen = M + (i == K - 1);
706 coeflen = LMMP_MIN(narest, coeflen);
707 narest -= coeflen;
708 lmmp_fft_extract_coef_(pfca[i], numa, M * i, coeflen, msr.lenw);
709 // 非第一个块:左移补偿
710 if (i > 0)
711 lmmp_fft_shl_coef_(&msr, pfca + i, i * n >> k);
712 } else {
713 lmmp_zero(pfca[i], nlen);
714 }
715 }
716 lmmp_fft_(&msr, pfca, k, n >> (k - 1));
717
718 if (nsqr) {
719 pfcb += (nlen + 1) << k;
720 for (mp_size_t i = 0; i < K; ++i) {
721 mp_size_t coeflen;
722 pfcb[i] = (mp_ptr)(pfcb + K) + i * nlen;
723 if (nbrest > 0) {
724 coeflen = M + (i == K - 1);
725 coeflen = LMMP_MIN(nbrest, coeflen);
726 nbrest -= coeflen;
727 lmmp_fft_extract_coef_(pfcb[i], numb, M * i, coeflen, msr.lenw);
728 if (i > 0)
729 lmmp_fft_shl_coef_(&msr, pfcb + i, i * n >> k);
730 } else {
731 lmmp_zero(pfcb[i], nlen);
732 }
733 }
734 lmmp_fft_(&msr, pfcb, k, n >> (k - 1));
735 }
736
737 lmmp_mul_fermat_recurse_(&msr, pfca, pfcb, K);
738
739 lmmp_ifft_(&msr, pfca, k, n >> (k - 1));
740
741 lmmp_mul_fermat_recombine_(&msr, dst, pfca, K, k, n, M, rn);
742
743 // 处理模 B^rn+1 的溢出
744 if (dst[rn] && !lmmp_zero_q_(dst, rn)) {
745 dst[rn] = 0;
746 lmmp_dec(dst);
747 }
748
749 lmmp_fft_memstack_(&msr, 0);
750}
751
753 int nsqr = numa != numb || na != nb; // 判断是否为平方运算
754 mp_size_t N = rn * LIMB_BITS; // 结果总比特数
755 mp_size_t k = lmmp_fft_best_k_(rn); // 最优FFT层数
756 mp_size_t K = ((mp_size_t)1) << k; // FFT块数(2^k)
757 // 断言:N必须是K的整数倍
758 lmmp_debug_assert(!(N & (K - 1)));
759 mp_size_t M = N >> k; // 每个块的比特数
760 mp_size_t n = 2 * M + k; // 扩展系数长度(梅森数比费马数少2)
761
762 // 规整n:必须是LIMB_BITS和K/2的整数倍
763 n = (n + LIMB_BITS - 1) & (-LIMB_BITS);
764 n = (((n - 1) >> (k - 1)) + 1) << (k - 1);
765
766 // 初始化内存栈
767 fft_memstack msr;
768 msr.maxdepth = -1;
769 msr.tempdepth = -1;
770 msr.lenw = n / LIMB_BITS; // 系数长度(机器字)
771 mp_size_t nlen = msr.lenw + 1; // 系数总长度
772
773 msr.temp_coef = (mp_ptr)lmmp_fft_memstack_(&msr, (((nlen + 1) << (k + nsqr)) + nlen) * LIMB_BYTES);
774
775 mp_ptr *pfca = (mp_ptr*)(msr.temp_coef + nlen), *pfcb = pfca;
776 mp_size_t narest = na * LIMB_BITS, nbrest = nb * LIMB_BITS;
777
778 for (mp_size_t i = 0; i < K; ++i) {
779 mp_size_t coeflen;
780 pfca[i] = (mp_ptr)(pfca + K) + i * nlen;
781 if (narest > 0) {
782 coeflen = LMMP_MIN(narest, M);
783 narest -= coeflen;
784 lmmp_fft_extract_coef_(pfca[i], numa, M * i, coeflen, msr.lenw);
785 } else {
786 lmmp_zero(pfca[i], nlen);
787 }
788 }
789 lmmp_fft_(&msr, pfca, k, n >> (k - 1));
790
791 if (nsqr) {
792 pfcb += (nlen + 1) << k;
793 for (mp_size_t i = 0; i < K; ++i) {
794 mp_size_t coeflen;
795 pfcb[i] = (mp_ptr)(pfcb + K) + i * nlen;
796 if (nbrest > 0) {
797 coeflen = LMMP_MIN(nbrest, M);
798 nbrest -= coeflen;
799 lmmp_fft_extract_coef_(pfcb[i], numb, M * i, coeflen, msr.lenw);
800 } else {
801 lmmp_zero(pfcb[i], nlen);
802 }
803 }
804 lmmp_fft_(&msr, pfcb, k, n >> (k - 1));
805 }
806
807 lmmp_mul_fermat_recurse_(&msr, pfca, pfcb, K);
808
809 lmmp_ifft_(&msr, pfca, k, n >> (k - 1));
810
811 mp_size_t rhead = 0, maxc = 0;
812 for (mp_size_t i = 0; i < K; ++i) {
813 lmmp_fft_shr_coef_(&msr, pfca + i, k);
814 mp_ptr nums = pfca[i];
815
816 if (nums[nlen - 1]) {
817 lmmp_dec(nums);
818 lmmp_debug_assert(nums[nlen - 1] == 1);
819 nums[nlen - 1] = 0;
820 }
821
822 mp_size_t roffset = i * M;
823 mp_size_t shl = roffset & (LIMB_BITS - 1);
824 roffset /= LIMB_BITS;
825
826 if (shl)
827 lmmp_shl_(nums, nums, nlen, shl);
828
829 if (i == 0) {
830 lmmp_copy(dst, nums, nlen);
831 rhead = nlen;
832 } else if (roffset + nlen <= rn) {
833 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
834 rhead = roffset + nlen;
835 } else {
836 maxc += lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
837 maxc += lmmp_add_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
838 rhead = rn;
839 }
840 }
841
842 if (!lmmp_add_1_(dst, dst, rn, 1 + maxc))
843 lmmp_dec(dst);
844
845 lmmp_fft_memstack_(&msr, 0);
846}
847
848typedef struct {
853 int fermat_flag; // 是否分配了费马内存
854 int mersenne_flag; // 是否分配了梅森内存
855} fft_cache;
856
857static inline void lmmp_mul_fft_cache_free_(fft_cache* GH) {
858 if (GH->fermat_flag)
860 if (GH->mersenne_flag)
862}
863
865 mp_ptr dst,
866 mp_size_t rn,
867 mp_srcptr numa,
868 mp_size_t na,
869 mp_srcptr numb,
870 mp_size_t nb,
871 fft_cache* GH
872) {
873 int nsqr = numa != numb || na != nb; // 1为非平方,0为平方
874 lmmp_assert(nsqr);
875 mp_size_t N = rn * LIMB_BITS; // 结果总比特数
876 mp_size_t k = lmmp_fft_best_k_(rn); // 最优FFT层数
877 mp_size_t K = ((mp_size_t)1) << k; // FFT块数(2^k)
878 lmmp_debug_assert(!(N & (K - 1)));
879 mp_size_t M = N >> k; // 每个块的比特数
880 mp_size_t n = 2 * M + k + 2; // 扩展系数长度
881
882 n = (n + LIMB_BITS - 1) & (-LIMB_BITS);
883 n = (((n - 1) >> k) + 1) << k;
884
885 fft_memstack* bmsr = NULL;
886 fft_memstack amsr;
887 amsr.maxdepth = -1;
888 amsr.tempdepth = -1;
889 amsr.lenw = n / LIMB_BITS; // 系数长度(机器字)
890 mp_size_t nlen = amsr.lenw + 1; // 系数总长度
891 mp_size_t a_size = (((nlen + 1) << (k)) + nlen) * LIMB_BYTES;
892 mp_size_t b_size = (((nlen + 1) << (k)) + nlen) * LIMB_BYTES;
893 amsr.temp_coef = (mp_ptr)lmmp_fft_memstack_(&amsr, a_size);
894
895 mp_ptr* pfca = (mp_ptr*)(amsr.temp_coef + nlen);
896 mp_ptr* pfcb = NULL;
897
898 if (GH->fermat_flag) {
899 bmsr = &GH->msr_fermat;
900 bmsr->lenw = n / LIMB_BITS;
901 pfcb = (mp_ptr*)(GH->temp_coef_fermat + nlen);
902 } else {
903 bmsr = &GH->msr_fermat;
904 bmsr->maxdepth = -1;
905 bmsr->tempdepth = -1;
906 bmsr->lenw = n / LIMB_BITS;
907 bmsr->temp_coef = (mp_ptr)lmmp_fft_memstack_(bmsr, b_size);
908 GH->temp_coef_fermat = bmsr->temp_coef;
909 pfcb = (mp_ptr*)(bmsr->temp_coef + nlen);
910 }
911
912 mp_size_t narest = na * LIMB_BITS, nbrest = nb * LIMB_BITS;
913 for (mp_size_t i = 0; i < K; ++i) {
914 mp_size_t coeflen;
915 pfca[i] = (mp_ptr)(pfca + K) + i * nlen;
916 if (narest > 0) {
917 coeflen = M + (i == K - 1);
918 coeflen = LMMP_MIN(narest, coeflen);
919 narest -= coeflen;
920 lmmp_fft_extract_coef_(pfca[i], numa, M * i, coeflen, amsr.lenw);
921 if (i > 0)
922 lmmp_fft_shl_coef_(&amsr, pfca + i, i * n >> k);
923 } else {
924 lmmp_zero(pfca[i], nlen);
925 }
926 }
927 lmmp_fft_(&amsr, pfca, k, n >> (k - 1));
928
929 if (!GH->fermat_flag) {
930 GH->fermat_flag = 1;
931 for (mp_size_t i = 0; i < K; ++i) {
932 mp_size_t coeflen;
933 pfcb[i] = (mp_ptr)(pfcb + K) + i * nlen;
934 if (nbrest > 0) {
935 coeflen = M + (i == K - 1);
936 coeflen = LMMP_MIN(nbrest, coeflen);
937 nbrest -= coeflen;
938 lmmp_fft_extract_coef_(pfcb[i], numb, M * i, coeflen, bmsr->lenw);
939 if (i > 0)
940 lmmp_fft_shl_coef_(bmsr, pfcb + i, i * n >> k);
941 } else {
942 lmmp_zero(pfcb[i], nlen);
943 }
944 }
945 lmmp_fft_(bmsr, pfcb, k, n >> (k - 1));
946 }
947
948 lmmp_mul_fermat_recurse_(&amsr, pfca, pfcb, K);
949
950 lmmp_ifft_(&amsr, pfca, k, n >> (k - 1));
951
952 lmmp_mul_fermat_recombine_(&amsr, dst, pfca, K, k, n, M, rn);
953
954 if (dst[rn] && !lmmp_zero_q_(dst, rn)) {
955 dst[rn] = 0;
956 lmmp_dec(dst);
957 }
958
959 lmmp_fft_memstack_(&amsr, 0);
960}
961
963 mp_ptr dst,
964 mp_size_t rn,
965 mp_srcptr numa,
966 mp_size_t na,
967 mp_srcptr numb,
968 mp_size_t nb,
969 fft_cache* GH
970) {
971 int nsqr = numa != numb || na != nb; // 1为非平方,0为平方
972 lmmp_assert(nsqr);
973 mp_size_t N = rn * LIMB_BITS; // 结果总比特数
974 mp_size_t k = lmmp_fft_best_k_(rn); // 最优FFT层数
975 mp_size_t K = ((mp_size_t)1) << k; // FFT块数(2^k)
976 // 断言:N必须是K的整数倍
977 lmmp_debug_assert(!(N & (K - 1)));
978 mp_size_t M = N >> k; // 每个块的比特数
979 mp_size_t n = 2 * M + k; // 扩展系数长度(梅森数比费马数少2)
980
981 // 规整n:必须是LIMB_BITS和K/2的整数倍
982 n = (n + LIMB_BITS - 1) & (-LIMB_BITS);
983 n = (((n - 1) >> (k - 1)) + 1) << (k - 1);
984
985 // 初始化内存栈
986 fft_memstack* bmsr = NULL;
987 fft_memstack amsr;
988 amsr.maxdepth = -1;
989 amsr.tempdepth = -1;
990 amsr.lenw = n / LIMB_BITS; // 系数长度(机器字)
991 mp_size_t nlen = amsr.lenw + 1; // 系数总长度
992 mp_size_t a_size = (((nlen + 1) << (k)) + nlen) * LIMB_BYTES;
993 mp_size_t b_size = (((nlen + 1) << (k)) + nlen) * LIMB_BYTES;
994 amsr.temp_coef = (mp_ptr)lmmp_fft_memstack_(&amsr, a_size);
995
996 mp_ptr* pfca = (mp_ptr*)(amsr.temp_coef + nlen);
997 mp_ptr* pfcb = NULL;
998
999 if (GH->mersenne_flag) {
1000 bmsr = &GH->msr_mersenne;
1001 bmsr->lenw = n / LIMB_BITS;
1002 pfcb = (mp_ptr*)(GH->temp_coef_mersenne + nlen);
1003 } else {
1004 bmsr = &GH->msr_mersenne;
1005 bmsr->maxdepth = -1;
1006 bmsr->tempdepth = -1;
1007 bmsr->lenw = n / LIMB_BITS;
1008 bmsr->temp_coef = (mp_ptr)lmmp_fft_memstack_(bmsr, b_size);
1009 GH->temp_coef_mersenne = bmsr->temp_coef;
1010 pfcb = (mp_ptr*)(bmsr->temp_coef + nlen);
1011 }
1012
1013 mp_size_t narest = na * LIMB_BITS, nbrest = nb * LIMB_BITS;
1014
1015 for (mp_size_t i = 0; i < K; ++i) {
1016 mp_size_t coeflen;
1017 pfca[i] = (mp_ptr)(pfca + K) + i * nlen;
1018 if (narest > 0) {
1019 coeflen = LMMP_MIN(narest, M);
1020 narest -= coeflen;
1021 lmmp_fft_extract_coef_(pfca[i], numa, M * i, coeflen, amsr.lenw);
1022 } else {
1023 lmmp_zero(pfca[i], nlen);
1024 }
1025 }
1026 lmmp_fft_(&amsr, pfca, k, n >> (k - 1));
1027
1028 if (!GH->mersenne_flag) {
1029 GH->mersenne_flag = 1;
1030 for (mp_size_t i = 0; i < K; ++i) {
1031 mp_size_t coeflen;
1032 pfcb[i] = (mp_ptr)(pfcb + K) + i * nlen;
1033 if (nbrest > 0) {
1034 coeflen = LMMP_MIN(nbrest, M);
1035 nbrest -= coeflen;
1036 lmmp_fft_extract_coef_(pfcb[i], numb, M * i, coeflen, bmsr->lenw);
1037 } else {
1038 lmmp_zero(pfcb[i], nlen);
1039 }
1040 }
1041 lmmp_fft_(bmsr, pfcb, k, n >> (k - 1));
1042 }
1043
1044 lmmp_mul_fermat_recurse_(&amsr, pfca, pfcb, K);
1045
1046 lmmp_ifft_(&amsr, pfca, k, n >> (k - 1));
1047
1048 mp_size_t rhead = 0, maxc = 0;
1049 for (mp_size_t i = 0; i < K; ++i) {
1050 lmmp_fft_shr_coef_(&amsr, pfca + i, k);
1051 mp_ptr nums = pfca[i];
1052
1053 if (nums[nlen - 1]) {
1054 lmmp_dec(nums);
1055 lmmp_debug_assert(nums[nlen - 1] == 1);
1056 nums[nlen - 1] = 0;
1057 }
1058
1059 mp_size_t roffset = i * M;
1060 mp_size_t shl = roffset & (LIMB_BITS - 1);
1061 roffset /= LIMB_BITS;
1062
1063 if (shl)
1064 lmmp_shl_(nums, nums, nlen, shl);
1065
1066 if (i == 0) {
1067 lmmp_copy(dst, nums, nlen);
1068 rhead = nlen;
1069 } else if (roffset + nlen <= rn) {
1070 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
1071 rhead = roffset + nlen;
1072 } else {
1073 maxc += lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
1074 maxc += lmmp_add_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
1075 rhead = rn;
1076 }
1077 }
1078
1079 if (!lmmp_add_1_(dst, dst, rn, 1 + maxc))
1080 lmmp_dec(dst);
1081
1082 lmmp_fft_memstack_(&amsr, 0);
1083}
1084
1086 lmmp_param_assert(na > 0 && nb > 0);
1087 lmmp_param_assert(na >= nb);
1088 mp_size_t hn = lmmp_fft_next_size_((na + nb + 1) >> 1);
1089 lmmp_assert(na + nb > hn);
1090 mp_ptr tp = ALLOC_TYPE(hn + 1, mp_limb_t);
1091
1092 mp_srcptr amodm = numa;
1093 mp_size_t nam = na;
1094 if (na > hn) {
1095 /*
1096 Z = B^hb - 1
1097 amodm = a mod Z
1098 */
1099 if (lmmp_add_(dst, numa, hn, numa + hn, na - hn))
1100 lmmp_inc(dst);
1101 amodm = dst;
1102 nam = hn;
1103 }
1104 lmmp_mul_mersenne_(dst, hn, amodm, nam, numb, nb);
1105
1106 mp_srcptr amodp = numa;
1107 mp_size_t nap = na;
1108 if (na > hn) {
1109 /*
1110 Z = B^hp - 1
1111 amodp = a mod Z
1112 */
1113 tp[hn] = 0;
1114 if (lmmp_sub_(tp, numa, hn, numa + hn, na - hn))
1115 lmmp_inc(tp);
1116 amodp = tp;
1117 nap = hn + 1;
1118 }
1119 lmmp_mul_fermat_(tp, hn, amodp, nap, numb, nb);
1120
1121 mp_limb_t cy = lmmp_shr1add_nc_(dst, dst, tp, hn, tp[hn]);
1122 cy <<= LIMB_BITS - 1;
1123 dst[hn - 1] += cy;
1124 if (dst[hn - 1] < cy)
1125 lmmp_inc(dst);
1126
1127 if (na + nb == 2 * hn) {
1128 cy = tp[hn] + lmmp_sub_n_(dst + hn, dst, tp, hn);
1129 // cy==1 means [tp,hn+1]!=0, then [dst,hn]!=0
1130 // cy==2 is impossible since [tp,hn+1] is normalized.
1131 // so the following dec won't overflow.
1132 lmmp_dec_1(dst, cy);
1133 } else {
1134 cy = lmmp_sub_n_(dst + hn, dst, tp, na + nb - hn);
1135 cy = tp[hn] + lmmp_sub_nc_(tp + na + nb - hn, dst + na + nb - hn, tp + na + nb - hn, 2 * hn - (na + nb), cy);
1136 cy = lmmp_sub_1_(dst, dst, na + nb, cy);
1137 }
1138 lmmp_free(tp);
1139}
1140
1142 mp_ptr dst,
1143 mp_size_t hn,
1144 mp_srcptr numa,
1145 mp_size_t na,
1146 mp_srcptr numb,
1147 mp_size_t nb,
1148 fft_cache* GH
1149) {
1150 lmmp_param_assert(na > 0 && nb > 0);
1151 lmmp_param_assert(na >= nb);
1152 lmmp_assert(na + nb > hn);
1153 mp_ptr tp = ALLOC_TYPE(hn + 1, mp_limb_t);
1154
1155 mp_srcptr amodm = numa;
1156 mp_size_t nam = na;
1157 if (na > hn) {
1158 /*
1159 Z = B^hb - 1
1160 amodm = a mod Z
1161 */
1162 if (lmmp_add_(dst, numa, hn, numa + hn, na - hn))
1163 lmmp_inc(dst);
1164 amodm = dst;
1165 nam = hn;
1166 }
1167 lmmp_mul_mersenne_single_(dst, hn, amodm, nam, numb, nb, GH);
1168
1169 mp_srcptr amodp = numa;
1170 mp_size_t nap = na;
1171 if (na > hn) {
1172 /*
1173 Z = B^hp - 1
1174 amodp = a mod Z
1175 */
1176 tp[hn] = 0;
1177 if (lmmp_sub_(tp, numa, hn, numa + hn, na - hn))
1178 lmmp_inc(tp);
1179 amodp = tp;
1180 nap = hn + 1;
1181 }
1182 lmmp_mul_fermat_single_(tp, hn, amodp, nap, numb, nb, GH);
1183
1184 mp_limb_t cy = lmmp_shr1add_nc_(dst, dst, tp, hn, tp[hn]);
1185 cy <<= LIMB_BITS - 1;
1186 dst[hn - 1] += cy;
1187 if (dst[hn - 1] < cy)
1188 lmmp_inc(dst);
1189
1190 if (na + nb == 2 * hn) {
1191 cy = tp[hn] + lmmp_sub_n_(dst + hn, dst, tp, hn);
1192 // cy==1 means [tp,hn+1]!=0, then [dst,hn]!=0
1193 // cy==2 is impossible since [tp,hn+1] is normalized.
1194 // so the following dec won't overflow.
1195 lmmp_dec_1(dst, cy);
1196 } else {
1197 cy = lmmp_sub_n_(dst + hn, dst, tp, na + nb - hn);
1198 cy = tp[hn] + lmmp_sub_nc_(tp + na + nb - hn, dst + na + nb - hn, tp + na + nb - hn, 2 * hn - (na + nb), cy);
1199 cy = lmmp_sub_1_(dst, dst, na + nb, cy);
1200 }
1201 lmmp_free(tp);
1202}
1203
1205 mp_ptr restrict dst,
1206 mp_srcptr restrict numa,
1207 mp_size_t na,
1208 mp_srcptr restrict numb,
1209 mp_size_t nb
1210) {
1211 lmmp_param_assert(na >= 3 * nb);
1212 mp_ptr restrict ws = ALLOC_TYPE(nb, mp_limb_t);
1213 mp_size_t sna = 3 * nb;
1214 mp_size_t hn = lmmp_fft_next_size_((sna + nb + 1) >> 1);
1215 sna = (hn << 1) - 1 - nb;
1216 fft_cache GH = {.mersenne_flag = 0, .fermat_flag = 0};
1217 lmmp_mul_fft_cache_(dst, hn, numa, sna, numb, nb, &GH);
1218 dst += sna;
1219 numa += sna;
1220 na -= sna;
1221 lmmp_copy(ws, dst, nb);
1222 while (na >= sna) {
1223 lmmp_mul_fft_cache_(dst, hn, numa, sna, numb, nb, &GH);
1224 if (lmmp_add_n_(dst, dst, ws, nb))
1225 lmmp_inc(dst + nb);
1226 dst += sna;
1227 numa += sna;
1228 na -= sna;
1229 lmmp_copy(ws, dst, nb);
1230 }
1232 // remaining na < sna
1233 if (na >= nb)
1234 lmmp_mul_(dst, numa, na, numb, nb);
1235 else if (na > 0)
1236 lmmp_mul_(dst, numb, nb, numa, na);
1237 else // na == 0
1238 lmmp_zero(dst, nb);
1239 if (lmmp_add_n_(dst, dst, ws, nb))
1240 lmmp_inc(dst + nb);
1241 lmmp_free(ws);
1242}
#define k
mp_limb_t * mp_ptr
Definition lmmp.h:215
#define lmmp_copy(dst, src, n)
Definition lmmp.h:364
#define lmmp_zero(dst, n)
Definition lmmp.h:366
uint64_t mp_size_t
Definition lmmp.h:212
int64_t mp_slimb_t
Definition lmmp.h:213
#define lmmp_debug_assert(x)
Definition lmmp.h:387
void * lmmp_alloc(size_t size)
内存分配函数(调用lmmp_heap_alloc_fn)
Definition memory.c:166
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
#define LIMB_MAX
Definition lmmp.h:224
void lmmp_free(void *ptr)
内存释放函数(调用lmmp_heap_free_fn)
Definition memory.c:204
int64_t mp_ssize_t
Definition lmmp.h:214
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_assert(x)
Definition lmmp.h:370
#define LMMP_MIN(l, o)
Definition lmmp.h:348
#define LIMB_BITS
Definition lmmp.h:221
#define LOG2_LIMB_BITS
Definition lmmp.h:223
#define lmmp_param_assert(x)
Definition lmmp.h:398
mp_limb_t lmmp_shlnot_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
左移后按位取反操作 [dst,na] = ~([numa,na] << shl),dst的低shl位填充1
static mp_limb_t lmmp_add_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数加法静态内联函数 [dst,na]=[numa,na]+[numb,nb]
Definition lmmpn.h:1058
mp_limb_t lmmp_shr1add_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带进位加法后右移1位 [dst,n] = ([numa,n] + [numb,n] + c) >> 1
Definition shr.c:79
mp_limb_t lmmp_shr_c_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shr, mp_limb_t c)
带进位的大数右移操作 [dst,na] = [numa,na]>>shr,dst的高shr位填充c的高shr位
Definition shr.c:30
#define lmmp_dec(p)
大数减1宏(预期无借位)
Definition lmmpn.h:973
static mp_limb_t lmmp_add_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数加单精度数静态内联函数 [dst,na]=[numa,na]+x
Definition lmmpn.h:1111
#define lmmp_inc(p)
大数加1宏(预期无进位)
Definition lmmpn.h:946
mp_limb_t lmmp_shr_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shr)
大数右移操作 [dst,na] = [numa,na] >> shr,dst的高shr位填充0
Definition shr.c:9
void lmmp_mul_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
不等长大数乘法操作 [dst,na+nb] = [numa,na] * [numb,nb]
void lmmp_sqr_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数平方操作 [dst,2*na] = [numa,na]^2
Definition sqr.c:10
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_limb_t lmmp_shl_c_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl, mp_limb_t c)
带进位的大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充c的低shl位
Definition shl.c:32
mp_limb_t lmmp_add_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带进位的n位加法 [dst,n] = [numa,n] + [numb,n] + c
Definition add_n.c:9
mp_limb_t lmmp_shl_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充0
Definition shl.c:9
static mp_limb_t lmmp_sub_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数减法静态内联函数 [dst,na]=[numa,na]-[numb,nb]
Definition lmmpn.h:1072
#define lmmp_dec_1(p, dec)
大数减指定值宏(预期无借位)
Definition lmmpn.h:985
static mp_limb_t lmmp_sub_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数减单精度数静态内联函数 [dst,na]=[numa,na]-x
Definition lmmpn.h:1122
void lmmp_not_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数按位取反操作 [dst,na] = ~[numa,na] (对每个limb执行按位非操作)
mp_limb_t lmmp_sub_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无借位的n位减法 [dst,n] = [numa,n] - [numb,n]
Definition sub_n.c:70
#define lmmp_inc_1(p, inc)
大数加指定值宏(预期无进位)
Definition lmmpn.h:958
mp_limb_t lmmp_sub_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带借位的n位减法 [dst,n] = [numa,n] - [numb,n] - c
Definition sub_n.c:9
mp_limb_t lmmp_add_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无进位的n位加法 [dst,n] = [numa,n] + [numb,n]
Definition add_n.c:71
static int lmmp_zero_q_(mp_srcptr p, mp_size_t n)
大数判零函数(内联)
Definition lmmpn.h:1027
#define MUL_FFT_MODF_THRESHOLD
Definition mparam.h:65
#define PART_SIZE
Definition mparam.h:89
#define LIMB_BYTES
Definition mparam.h:85
static void lmmp_fft_shr_coef_(fft_memstack *ms, mp_ptr *coef, mp_size_t shr)
对模 2^n+1 的系数执行右移操作 右移shr位 = 左移(2n - shr)位(mod 2^n+1的循环特性)
Definition mul_fft.c:217
void lmmp_mul_mersenne_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
梅森数模乘法 [dst,rn] = [numa,na]*[numb,nb] mod B^rn-1
Definition mul_fft.c:752
mp_ptr temp_coef
Definition mul_fft.c:60
mp_ssize_t maxdepth
Definition mul_fft.c:62
static void lmmp_mul_fermat_recurse_(fft_memstack *ms, mp_ptr *pc1, mp_ptr *pc2, mp_size_t K0)
费马变换乘法递归函数(核心乘法逻辑)
Definition mul_fft.c:600
static void lmmp_mul_mersenne_single_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
Definition mul_fft.c:962
static void lmmp_ifft_bfy_(fft_memstack *ms, mp_ptr *coef, mp_size_t wing, mp_size_t w)
FFT蝶形运算(Butterfly Operation) (a,b) = (a+(b>>w), a-(b>>w)) mod 2^n+1 a=[coef[0],ms->lenw+1],...
Definition mul_fft.c:313
static void lmmp_fft_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
Definition mul_fft.c:442
#define _FFT_TABLE_ENTRY4(n)
Definition mul_fft.c:13
void lmmp_mul_fft_unbalance_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
Definition mul_fft.c:1204
mp_size_t memsize[16]
Definition mul_fft.c:65
static void * lmmp_fft_memstack_(fft_memstack *ms, mp_size_t size)
FFT内存栈的分配/释放接口
Definition mul_fft.c:98
static void lmmp_fft_shl_coef_(fft_memstack *ms, mp_ptr *coef, mp_size_t shl)
对模 2^n+1 的系数执行左移操作
Definition mul_fft.c:150
void * mem[16]
Definition mul_fft.c:64
mp_ptr temp_coef_mersenne
Definition mul_fft.c:852
static void lmmp_mul_fft_cache_(mp_ptr dst, mp_size_t hn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
Definition mul_fft.c:1141
static void lmmp_mul_fermat_recombine_(fft_memstack *ms, mp_ptr dst, mp_ptr *pfca, mp_size_t K, mp_size_t k, mp_size_t n, mp_size_t M, mp_size_t rn)
费马变换 模 B^n+1 乘法的结果合并
Definition mul_fft.c:517
static void lmmp_ifft_b1_(fft_memstack *ms, mp_ptr *coef, mp_size_t dis, mp_size_t k, mp_size_t w, mp_size_t w0)
Definition mul_fft.c:453
void lmmp_mul_fermat_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
费马数模乘法 [dst,rn+1]=[numa,na]*[numb,nb] mod B^rn+1
Definition mul_fft.c:677
static void lmmp_mul_fft_cache_free_(fft_cache *GH)
Definition mul_fft.c:857
static void lmmp_ifft_4_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
Definition mul_fft.c:474
static void lmmp_fft_bfy_(fft_memstack *ms, mp_ptr *coef, mp_size_t wing, mp_size_t w)
FFT蝶形运算(Butterfly Operation) (a,b) = (a + b, (a-b) << w ) mod 2^n+1 a=[coef[0],ms->lenw+1],...
Definition mul_fft.c:233
mp_size_t lmmp_fft_next_size_(mp_size_t n)
计算FFT运算所需的最小规整化长度(向上取整到2^k的倍数)
Definition mul_fft.c:84
static mp_size_t lmmp_fft_best_k_(mp_size_t n)
查找对于 m>=n 的模 B^m+1 FFT运算的最优k值
Definition mul_fft.c:73
static void lmmp_ifft_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
Definition mul_fft.c:495
static void lmmp_fft_extract_coef_(mp_ptr dst, mp_srcptr numa, mp_size_t bitoffset, mp_size_t bits, mp_size_t lenw)
[dst,lenw+1] = [(bit*)numa+bitoffset, bits]
Definition mul_fft.c:124
mp_ssize_t tempdepth
Definition mul_fft.c:63
static const mp_size_t lmmp_fft_table_[][2]
Definition mul_fft.c:19
int fermat_flag
Definition mul_fft.c:853
int mersenne_flag
Definition mul_fft.c:854
fft_memstack msr_mersenne
Definition mul_fft.c:850
fft_memstack msr_fermat
Definition mul_fft.c:849
void lmmp_mul_fft_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
FFT乘法运算 [dst,na+nb] = [numa,na] * [numb,nb]
Definition mul_fft.c:1085
mp_size_t lenw
Definition mul_fft.c:61
static void lmmp_fft_b1_(fft_memstack *ms, mp_ptr *coef, mp_size_t dis, mp_size_t k, mp_size_t w, mp_size_t w0)
FFT递归函数
Definition mul_fft.c:400
static void lmmp_fft_4_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
Definition mul_fft.c:421
mp_ptr temp_coef_fermat
Definition mul_fft.c:851
static void lmmp_mul_fermat_single_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
Definition mul_fft.c:864
#define tp
#define w0
#define ALLOC_TYPE(n, type)
Definition tmp_alloc.h:112