LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mullo.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../include/lammp/impl/tmp_alloc.h"
8#include "../../include/lammp/lmmpn.h"
9#include "../../include/lammp/impl/mparam.h"
10
12 lmmp_param_assert(n > 0);
13 mp_size_t hn = lmmp_fft_next_size_((n + n + 1) >> 1);
14 lmmp_assert(n + n > hn);
15 mp_ptr tp = ALLOC_TYPE(hn + 1, mp_limb_t);
16
17 mp_srcptr amodm = numa;
18 mp_size_t nam = n;
19 if (n > hn) {
20 /*
21 Z = B^hb - 1
22 amodm = a mod Z
23 */
24 if (lmmp_add_(scratch, numa, hn, numa + hn, n - hn))
26 amodm = scratch;
27 nam = hn;
28 }
29 lmmp_mul_mersenne_(scratch, hn, amodm, nam, numb, n);
30
31 mp_srcptr amodp = numa;
32 mp_size_t nap = n;
33 if (n > hn) {
34 /*
35 Z = B^hp - 1
36 amodp = a mod Z
37 */
38 tp[hn] = 0;
39 if (lmmp_sub_(tp, numa, hn, numa + hn, n - hn))
40 lmmp_inc(tp);
41 amodp = tp;
42 nap = hn + 1;
43 }
44 lmmp_mul_fermat_(tp, hn, amodp, nap, numb, n);
45
47 cy <<= LIMB_BITS - 1;
48 scratch[hn - 1] += cy;
49 if (scratch[hn - 1] < cy)
51
52 if (n == hn) {
53 cy = tp[hn] + lmmp_sub_n_(scratch + hn, scratch, tp, hn);
54 // cy==1 means [tp,hn+1]!=0, then [dst,hn]!=0
55 // cy==2 is impossible since [tp,hn+1] is normalized.
56 // so the following dec won't overflow.
58 } else {
59 mp_size_t n2 = 2 * n;
60 cy = lmmp_sub_n_(scratch + hn, scratch, tp, n2 - hn);
61 cy = tp[hn] + lmmp_sub_nc_(tp + n2 - hn, scratch + n2 - hn, tp + n2 - hn, 2 * hn - n2, cy);
62 cy = lmmp_sub_1_(scratch, scratch, n2, cy);
63 }
65 lmmp_copy(dst, scratch, n);
66}
67
68/*
69 <---t---><---m--->
70 |--a1---|---a0---|
71 |--b1---|---b0---|
72
73 ,
74 |\
75 | \
76 | \
77 +-----,
78 | |
79 | |\
80 | | \
81 | | \
82 +-----+---`
83 ^ m ^ t ^
84
85 此算法是一种不平衡分块的算法,朴素的想法是计算平衡分块,计算一次完整的乘法,然后两次递归的调用此函数计算低位,
86 事实上,我们也可以不平衡的分块,以减少递归深度,具体分析如下:
87 取a和b的低位一定宽度为m,高位宽度为t,则有:
88 计算一次完整的平衡乘法m,然后递归调用计算mullo,长度为t
89 复杂度模型:
90 ML(n) = 2*ML(a*n) + M((1-a)*n)
91 其中ML为mullo的复杂度,M为mul_n的复杂度
92 我们可以假定 M(n)=O(n^e) 即多项式复杂度
93 则有:
94 ML(n) = C(a) * n^e
95 C(a) = a^e / (1-2*(1-a)^e)
96 我们希望C(a)尽可能小,即希望ML(n)尽可能小,则有:
97 a_opt = 1 - 2^(-1/(e-1))
98 e=log(3)/log(2) [Toom-2] -> a ~= 0.694
99 e=log(5)/log(3) [Toom-3] -> a ~= 0.775
100 e=log(7)/log(4) [Toom-4] -> a ~= 0.820
101 e=log(11)/log(6) [Toom-6] -> a ~= 0.871
102 e=log(15)/log(8) [Toom-8] -> a ~= 0.899
103*/
104
105#define MUL_TOOM66_THRESHOLD MUL_FFT_THRESHOLD
106#define MUL_TOOM88_THRESHOLD 5621
107
109 mp_ptr restrict dst,
110 mp_srcptr restrict numa,
111 mp_srcptr restrict numb,
112 mp_ptr restrict tp,
113 mp_size_t n
114) {
115 if (n < MULLO_BASECASE_THRESHOLD) {
116 lmmp_mul_1_(dst, numa, n, numb[0]);
117 for (mp_size_t i = 1; i < n; ++i) {
118 lmmp_mul_1_(tp, numa, n - i, numb[i]);
119 lmmp_add_n_(dst + i, dst + i, tp, n - i);
120 }
121 return;
122 } else {
123 mp_size_t m, t;
124 if (n < MUL_TOOM33_THRESHOLD) {
125 m = 25 * n / 36;
126 } else if (n < MUL_TOOM44_THRESHOLD) {
127 m = 31 * n / 40;
128 } else if (n < MUL_TOOM66_THRESHOLD) {
129 m = 32 * n / 39;
130 } else if (n < MUL_TOOM88_THRESHOLD) {
131 m = 27 * n / 31;
132 } else {
133 m = 9 * n / 10;
134 }
135 t = n - m;
136
137#define a0 (numa)
138#define a1 (numa + m)
139#define b0 (numb)
140#define b1 (numb + m)
141#define c0 (dst)
142#define c1 (dst + m)
143#define lo1 (tp) // [tp, 2*t]
144#define tp1 (tp + 2 * t) // [tp+2*t, 2*t]
145#define lo2 (tp + 2 * t) // [tp+2*t, 2*t]
146#define tp2 (tp + 4 * t) // [tp+2*t, 2*t]
147 lmmp_mul_n_(tp, a0, b0, m);
148 lmmp_copy(c0, tp, n);
149 lmmp_mullo_dc_(lo1, a1, b0, tp1, t);
150 lmmp_mullo_dc_(lo2, a0, b1, tp2, t);
151 lmmp_add_n_(c1, c1, lo1, t);
152 lmmp_add_n_(c1, c1, lo2, t);
153 return;
154 }
155}
156
157void lmmp_sqrlo_dc_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_ptr restrict tp, mp_size_t n) {
158 if (n < MULLO_BASECASE_THRESHOLD) {
159 lmmp_mul_1_(dst, numa, n, numa[0]);
160 for (mp_size_t i = 1; i < n; ++i) {
161 lmmp_mul_1_(tp, numa, n - i, numa[i]);
162 lmmp_add_n_(dst + i, dst + i, tp, n - i);
163 }
164 return;
165 } else {
166 mp_size_t m, t;
167 if (n < MUL_TOOM33_THRESHOLD) {
168 m = 25 * n / 36;
169 } else if (n < MUL_TOOM44_THRESHOLD) {
170 m = 31 * n / 40;
171 } else if (n < MUL_TOOM66_THRESHOLD) {
172 m = 32 * n / 39;
173 } else if (n < MUL_TOOM88_THRESHOLD) {
174 m = 27 * n / 31;
175 } else {
176 m = 9 * n / 10;
177 }
178 t = n - m;
179#define a0 (numa)
180#define a1 (numa + m)
181#define c0 (dst)
182#define c1 (dst + m)
183#define lo1 (tp) // [tp, 2*t]
184#define tp1 (tp + 2 * t) // [tp+2*t, 2*t]
185 lmmp_sqr_(tp, a0, m);
186 lmmp_copy(c0, tp, n);
187 lmmp_mullo_dc_(lo1, a0, a1, tp1, t);
188 lmmp_addshl1_n_(c1, c1, lo1, t);
189 }
190}
191
192void lmmp_mullo_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_size_t n) {
193 lmmp_param_assert(n > 0);
194 if (n < MULLO_DC_THRESHOLD) {
195 if (numa == numb) {
196 TEMP_DECL;
197 mp_ptr restrict tp = TALLOC_TYPE(2 * n, mp_limb_t);
198 lmmp_sqrlo_dc_(dst, numa, tp, n);
199 TEMP_FREE;
200 return;
201 }
202 TEMP_DECL;
203 mp_ptr restrict tp = TALLOC_TYPE(2 * n, mp_limb_t);
204 lmmp_mullo_dc_(dst, numa, numb, tp, n);
205 TEMP_FREE;
206 return;
207 } else {
208 TEMP_DECL;
209 mp_ptr restrict tp = TALLOC_TYPE(2 * n, mp_limb_t);
210 lmmp_mullo_fft_(dst, numa, numb, n, tp);
211 TEMP_FREE;
212 return;
213 }
214}
#define scratch
mp_limb_t * mp_ptr
Definition lmmp.h:215
#define lmmp_copy(dst, src, n)
Definition lmmp.h:364
uint64_t mp_size_t
Definition lmmp.h:212
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
void lmmp_free(void *ptr)
内存释放函数(调用lmmp_heap_free_fn)
Definition memory.c:204
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_assert(x)
Definition lmmp.h:370
#define LIMB_BITS
Definition lmmp.h:221
#define lmmp_param_assert(x)
Definition lmmp.h:398
void lmmp_mul_mersenne_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
梅森数模乘法 [dst,rn] = [numa,na]*[numb,nb] mod B^rn-1
Definition mul_fft.c:752
static mp_limb_t lmmp_add_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数加法静态内联函数 [dst,na]=[numa,na]+[numb,nb]
Definition lmmpn.h:1058
mp_limb_t lmmp_shr1add_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带进位加法后右移1位 [dst,n] = ([numa,n] + [numb,n] + c) >> 1
Definition shr.c:79
#define lmmp_inc(p)
大数加1宏(预期无进位)
Definition lmmpn.h:946
void lmmp_sqr_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数平方操作 [dst,2*na] = [numa,na]^2
Definition sqr.c:10
void lmmp_mul_fermat_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
费马数模乘法 [dst,rn+1]=[numa,na]*[numb,nb] mod B^rn+1
Definition mul_fft.c:677
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_limb_t lmmp_addshl1_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
加法结合左移1位操作 [dst,n] = [numa,n] + ([numb,n] << 1)
Definition shl.c:56
mp_size_t lmmp_fft_next_size_(mp_size_t n)
计算满足 >=n 的最小费马/梅森乘法可行尺寸
Definition mul_fft.c:84
static mp_limb_t lmmp_sub_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数减法静态内联函数 [dst,na]=[numa,na]-[numb,nb]
Definition lmmpn.h:1072
mp_limb_t lmmp_mul_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数乘以单limb操作 [dst,na] = [numa,na] * x
#define lmmp_dec_1(p, dec)
大数减指定值宏(预期无借位)
Definition lmmpn.h:985
static mp_limb_t lmmp_sub_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数减单精度数静态内联函数 [dst,na]=[numa,na]-x
Definition lmmpn.h:1122
mp_limb_t lmmp_sub_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无借位的n位减法 [dst,n] = [numa,n] - [numb,n]
Definition sub_n.c:70
mp_limb_t lmmp_sub_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带借位的n位减法 [dst,n] = [numa,n] - [numb,n] - c
Definition sub_n.c:9
mp_limb_t lmmp_add_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无进位的n位加法 [dst,n] = [numa,n] + [numb,n]
Definition add_n.c:71
#define MULLO_BASECASE_THRESHOLD
Definition mparam.h:57
#define MUL_TOOM33_THRESHOLD
Definition mparam.h:50
#define MULLO_DC_THRESHOLD
Definition mparam.h:59
#define MUL_TOOM44_THRESHOLD
Definition mparam.h:52
#define tp
void lmmp_mullo_dc_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_ptr restrict tp, mp_size_t n)
Definition mullo.c:108
#define lo2
#define b0
#define MUL_TOOM66_THRESHOLD
Definition mullo.c:105
#define b1
void lmmp_mullo_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_size_t n)
Definition mullo.c:192
#define tp2
#define c1
#define tp1
void lmmp_sqrlo_dc_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_ptr restrict tp, mp_size_t n)
Definition mullo.c:157
#define a0
#define a1
void lmmp_mullo_fft_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_ptr scratch)
低位FFT乘法 [dst,n] = [numa,n] * [numb,n] mod B^n
Definition mullo.c:11
#define c0
#define MUL_TOOM88_THRESHOLD
Definition mullo.c:106
#define lo1
#define TEMP_DECL
Definition tmp_alloc.h:72
#define ALLOC_TYPE(n, type)
Definition tmp_alloc.h:112
#define TEMP_FREE
Definition tmp_alloc.h:93
#define TALLOC_TYPE(n, type)
Definition tmp_alloc.h:91