LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mul_toom53.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../include/lammp/impl/toom_interp.h"
8
9/*
10Evaluate in: 0, +1, -1, +2, -2, 1/2, +inf
11
12 <-s-><--n--><--n--><--n--><--n-->
13 |a4-|--a3--|--a2--|--a1--|--a0--|
14 |--b2|--b1--|--b0--|
15 <-t--><--n--><--n-->
16
17 v0 = a0 * b0 # A(0)*B(0)
18 v1 = ( a0+ a1+ a2+ a3+ a4)*( b0+ b1+ b2) # A(1)*B(1) ah <= 4 bh <= 2
19 vm1 = ( a0- a1+ a2- a3+ a4)*( b0- b1+ b2) # A(-1)*B(-1) |ah| <= 2 bh <= 1
20 v2 = ( a0+2a1+4a2+8a3+16a4)*( b0+2b1+4b2) # A(2)*B(2) ah <= 30 bh <= 6
21 vm2 = ( a0-2a1+4a2-8a3+16a4)*( b0-2b1+4b2) # A(2)*B(2) -9<=ah<=20 -1<=bh<=4
22 vh = (16a0+8a1+4a2+2a3+ a4)*(4b0+2b1+ b2) # A(1/2)*B(1/2) ah <= 30 bh <= 6
23 vinf= a4 * b2 # A(inf)*B(inf)
24*/
25
26void lmmp_mul_toom53_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb) {
27 lmmp_param_assert(9 * na <= 20 * nb);
28 lmmp_param_assert(5 * nb <= 3 * na);
29 mp_size_t n, s, t;
30 mp_limb_t cy;
31 mp_ptr gp;
32 mp_ptr as1, asm1, as2, asm2, ash;
33 mp_ptr bs1, bsm1, bs2, bsm2, bsh;
34 mp_ptr tmp;
35 enum toom7_flags flags;
37
38#define a0 numa
39#define a1 (numa + n)
40#define a2 (numa + 2 * n)
41#define a3 (numa + 3 * n)
42#define a4 (numa + 4 * n)
43#define b0 numb
44#define b1 (numb + n)
45#define b2 (numb + 2 * n)
46
47 n = 1 + (3 * na >= 5 * nb ? (na - 1) / (mp_size_t)5 : (nb - 1) / (mp_size_t)3);
48 mp_ptr restrict scratch = SALLOC_TYPE(10 * (n + 1), mp_limb_t);
49 s = na - 4 * n;
50 t = nb - 2 * n;
51
52 tmp = SALLOC_TYPE(10 * (n + 1), mp_limb_t);
53 as1 = tmp;
54 tmp += n + 1;
55 asm1 = tmp;
56 tmp += n + 1;
57 as2 = tmp;
58 tmp += n + 1;
59 asm2 = tmp;
60 tmp += n + 1;
61 ash = tmp;
62 tmp += n + 1;
63 bs1 = tmp;
64 tmp += n + 1;
65 bsm1 = tmp;
66 tmp += n + 1;
67 bs2 = tmp;
68 tmp += n + 1;
69 bsm2 = tmp;
70 tmp += n + 1;
71 bsh = tmp;
72 tmp += n + 1;
73
74 gp = dst;
75
76 /* Compute as1 and asm1. */
77 flags = (enum toom7_flags)(toom7_w3_neg & lmmp_toom_eval_pm1_(as1, asm1, 4, numa, n, s, gp));
78
79 /* Compute as2 and asm2. */
80 flags = (enum toom7_flags)(flags | (toom7_w1_neg & lmmp_toom_eval_pm2_(as2, asm2, 4, numa, n, s, gp)));
81
82 /* Compute ash = 16 a0 + 8 a1 + 4 a2 + 2 a3 + a4
83 = 2*(2*(2*(2*a0 + a1) + a2) + a3) + a4 */
84 cy = lmmp_addshl1_n_(ash, a1, a0, n);
85 cy = 2 * cy + lmmp_addshl1_n_(ash, a2, ash, n);
86 cy = 2 * cy + lmmp_addshl1_n_(ash, a3, ash, n);
87 if (s < n) {
88 mp_limb_t cy2;
89 cy2 = lmmp_addshl1_n_(ash, a4, ash, s);
90 ash[n] = 2 * cy + lmmp_shl_(ash + s, ash + s, n - s, 1);
91 lmmp_inc_1(ash + s, cy2);
92 } else
93 ash[n] = 2 * cy + lmmp_addshl1_n_(ash, a4, ash, n);
94
95
96 /* Compute bs1 and bsm1. */
97 bs1[n] = lmmp_add_(bs1, b0, n, b2, t); /* b0 + b2 */
98 if (bs1[n] == 0 && lmmp_cmp_(bs1, b1, n) < 0) {
99 bs1[n] = lmmp_add_n_sub_n_(bs1, bsm1, b1, bs1, n) >> 1;
100 bsm1[n] = 0;
101 flags = (enum toom7_flags)(flags ^ toom7_w3_neg);
102 } else {
103 cy = lmmp_add_n_sub_n_(bs1, bsm1, bs1, b1, n);
104 bsm1[n] = bs1[n] - (cy & 1);
105 bs1[n] += (cy >> 1);
106 }
107
108 /* Compute bs2 and bsm2. */
109
110 cy = lmmp_shl_(gp, b2, t, 2);
111 bs2[n] = lmmp_add_(bs2, b0, n, gp, t);
112 lmmp_inc_1(bs2 + t, cy);
113
114 gp[n] = lmmp_shl_(gp, b1, n, 1);
115
116 if (lmmp_cmp_(bs2, gp, n + 1) < 0) {
117 lmmp_add_n_sub_n_(bs2, bsm2, gp, bs2, n + 1);
118 flags = (enum toom7_flags)(flags ^ toom7_w1_neg);
119 } else {
120 lmmp_add_n_sub_n_(bs2, bsm2, bs2, gp, n + 1);
121 }
122
123 /* Compute bsh = 4 b0 + 2 b1 + b2 = 2*(2*b0 + b1)+b2. */
124
125 cy = lmmp_addshl1_n_(bsh, b1, b0, n);
126 if (t < n) {
127 mp_limb_t cy2;
128 cy2 = lmmp_addshl1_n_(bsh, b2, bsh, t);
129 bsh[n] = 2 * cy + lmmp_shl_(bsh + t, bsh + t, n - t, 1);
130 lmmp_inc_1(bsh + t, cy2);
131 } else
132 bsh[n] = 2 * cy + lmmp_addshl1_n_(bsh, b2, bsh, n);
133
134 lmmp_debug_assert(as1[n] <= 4);
135 lmmp_debug_assert(bs1[n] <= 2);
136 lmmp_debug_assert(asm1[n] <= 2);
137 lmmp_debug_assert(bsm1[n] <= 1);
138 lmmp_debug_assert(as2[n] <= 30);
139 lmmp_debug_assert(bs2[n] <= 6);
140 lmmp_debug_assert(asm2[n] <= 20);
141 lmmp_debug_assert(bsm2[n] <= 4);
142 lmmp_debug_assert(ash[n] <= 30);
143 lmmp_debug_assert(bsh[n] <= 6);
144
145#define v0 dst /* 2n */
146#define v1 (dst + 2 * n) /* 2n+1 */
147#define vinf (dst + 6 * n) /* s+t */
148#define v2 scratch /* 2n+1 */
149#define vm2 (scratch + 2 * n + 1) /* 2n+1 */
150#define vh (scratch + 4 * n + 2) /* 2n+1 */
151#define vm1 (scratch + 6 * n + 3) /* 2n+1 */
152#define scratch_out (scratch + 8 * n + 4) /* 2n+1 */
153 /* Total scratch need: 10*n+5 */
154
155 /* Must be in allocation order, as they overwrite one limb beyond
156 * 2n+1. */
157 lmmp_mul_n_(v2, as2, bs2, n + 1); /* v2, 2n+1 limbs */
158 lmmp_mul_n_(vm2, asm2, bsm2, n + 1); /* vm2, 2n+1 limbs */
159 lmmp_mul_n_(vh, ash, bsh, n + 1); /* vh, 2n+1 limbs */
160
161 /* vm1, 2n+1 limbs */
162 vm1[2 * n] = 0;
163 lmmp_mul_n_(vm1, asm1, bsm1, n + ((asm1[n] | bsm1[n]) != 0));
164
165
166 /* v1, 2n+1 limbs */
167 v1[2 * n] = 0;
168 lmmp_mul_n_(v1, as1, bs1, n + ((as1[n] | bs1[n]) != 0));
169
170
171 lmmp_mul_n_(v0, a0, b0, n); /* v0, 2n limbs */
172
173 /* vinf, s+t limbs */
174 if (s > t)
175 lmmp_mul_(vinf, a4, s, b2, t);
176 else
177 lmmp_mul_(vinf, b2, t, a4, s);
178
179 lmmp_toom_interp7_(dst, n, flags, vm2, vm1, v2, vh, s + t, scratch_out);
180
182}
#define scratch
mp_limb_t * mp_ptr
Definition lmmp.h:215
uint64_t mp_size_t
Definition lmmp.h:212
#define lmmp_debug_assert(x)
Definition lmmp.h:387
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_param_assert(x)
Definition lmmp.h:398
static mp_limb_t lmmp_add_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数加法静态内联函数 [dst,na]=[numa,na]+[numb,nb]
Definition lmmpn.h:1058
static int lmmp_cmp_(mp_srcptr numa, mp_srcptr numb, mp_size_t n)
大数比较函数(内联)
Definition lmmpn.h:1004
void lmmp_mul_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
不等长大数乘法操作 [dst,na+nb] = [numa,na] * [numb,nb]
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_limb_t lmmp_addshl1_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
加法结合左移1位操作 [dst,n] = [numa,n] + ([numb,n] << 1)
Definition shl.c:56
mp_limb_t lmmp_shl_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充0
Definition shl.c:9
mp_limb_t lmmp_add_n_sub_n_(mp_ptr dsta, mp_ptr dstb, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
同时执行n位加法和减法 ([dsta,n],[dstb,n]) = ([numa,n]+[numb,n],[numa,n]-[numb,n])
Definition add_n_sub_n.c:10
#define lmmp_inc_1(p, inc)
大数加指定值宏(预期无进位)
Definition lmmpn.h:958
#define bsm1
#define asm1
#define bs1
#define bs2
#define as2
#define asm2
#define bsm2
#define as1
void lmmp_mul_toom53_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
Definition mul_toom53.c:26
#define b0
#define v0
#define a4
#define a3
#define b1
#define v2
#define vm1
#define scratch_out
#define vh
#define a2
#define a0
#define a1
#define b2
#define vinf
#define v1
#define vm2
#define SALLOC_TYPE(n, type)
Definition tmp_alloc.h:87
#define TEMP_S_DECL
Definition tmp_alloc.h:76
#define TEMP_S_FREE
Definition tmp_alloc.h:105
int lmmp_toom_eval_pm2_(mp_ptr xp2, mp_ptr xm2, unsigned k, mp_srcptr xp, mp_size_t n, mp_size_t hn, mp_ptr tp)
通用高阶 Toom 求值:k次多项式在 x = +2 和 x = -2 处求值
toom7_flags
Definition toom_interp.h:27
@ toom7_w1_neg
Definition toom_interp.h:27
@ toom7_w3_neg
Definition toom_interp.h:27
void lmmp_toom_interp7_(mp_ptr dst, mp_size_t n, enum toom7_flags flags, mp_ptr w1, mp_ptr w3, mp_ptr w4, mp_ptr w5, mp_size_t w6n, mp_ptr tp)
Toom插值计算(7点插值):用于Toom-44、Toom-53、Toom-62 乘法算法
int lmmp_toom_eval_pm1_(mp_ptr xp1, mp_ptr xm1, unsigned k, mp_srcptr xp, mp_size_t n, mp_size_t hn, mp_ptr tp)
通用高阶 Toom 求值:k次多项式在 x = +1 和 x = -1 处求值