介绍快速数论变换,本文需要快速傅里叶变换前缀知识。

NTT

        FFT可以很好地解决计算卷积的问题,但是它也有缺点:

  • 精度问题:double的使用
  • 不能取模:复数的引入

        如果给定的所有系数均为整数,最后的结果又在模意义下,那么FFT没有很大优势。这时就需要快速数论变换(Number Theoretic Transform,NTT)来解决这个问题。
        NTT的使用有一定限制,即对模数有要求,它们需要为$k2^p+1$的素数形式,比如1004535809, 998244353, 469762049这三个数。
        在理解了FFT原理后,NTT是容易理解的。在此之前,需要先了解原根的性质。
        对于一个数x,如果它存在原根g,那么总有下式成立:

        类似FFT中的单位根,NTT中用$g^{\frac {x-1} {n}}$作为单位根,可以证明它是满足FFT中的三个引理的,鉴于上面所述的三个素数原根均为3,这样就可以将FFT的板子修改如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<bits/stdc++.h>

#define N (2000005<<2)
#define MOD 998244353//模数
#define G 3//原根
using namespace std;
typedef long long ll;
int tr[N >> 1], le = 0;

inline int read() {
char e = getchar();
int s = 0, g = 0;
while (e < '-')e = getchar();
if (e == '-')g = 1, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}

ll qPow(ll a, ll b) {
ll ans = 1, x = a;
while (b > 0) {
if (b & 1)ans = ans * x % MOD;
x *= x, x %= MOD, b >>= 1;
}
return ans;
}

ll a[N], b[N];

void NTT(int l, ll *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
ll wn = qPow(G, (MOD - 1) / (mid << 1));//快速幂求单位根
if (type == -1)wn = qPow(wn, MOD - 2);//type==-1时为逆元,用费马小定理来求
for (int len = mid << 1, j = 0; j < l; j += len) {
ll w = 1;
for (int k = 0; k < mid; k++, w = w * wn % MOD) {
ll x = c[j + k], y = w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD;
c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}


int main() {
int n = read(), m = read();
ll l = 1;
for (int i = 0; i <= n; i++)a[i] = read();
for (int i = 0; i <= m; i++)b[i] = read();
while (l <= n + m)l <<= 1, ++le;
for (int i = 0; i < l; i++)tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
NTT(l, a, 1), NTT(l, b, 1);
for (int i = 0; i <= l; i++)a[i] = (long long) a[i] * b[i] % MOD;
NTT(l, a, -1), l = qPow(l, MOD - 2);//逆元
for (int i = 0; i <= n + m; i++)printf("%lld ", a[i] * l % MOD);
return 0;
}

任意模数MTT

三模数NTT

        如果模数不限形式,那么NTT便不再适用,需要采用MTT的方法。MTT可以用系数拆分FFT和三模数NTT两种方法来实现,这里介绍后者。模板题
        首先,如果可以保证最后答案不超过$10^{23}$,我们可以先找3个乘积超过$10^{23}$的三个素数,这里用上文提到的三个素数即可。对三个素数分别进行一次NTT,可以得到下面的同余式:

        这时我们可以用中国剩余定理合并这三个同余式,得到:

        然后拿c去模给定的mod即可。但是这里有一个问题就是求值过程中会爆long long,解决方法当然可以高精度,但这里有一个巧妙的方法解决这个问题。
        先用中国剩余定理合并前两个同余式,这里$m_1m_2$不会爆long long:

        上面的$m_2^{-1}$是$m_2$对$m_1$的逆元,$m_1^{-1}$为$m_1$对$m_2$的逆元。记后者为$A$,那么答案满足:

        然后有:

        即:

        因此答案满足:

        所以:

        由范围可知$s=0$,那么答案即为:

        这样就省去了高精度计算,求解这个式子即可。注意当两个数需要相乘时,可能爆long long,要采用快速乘算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<bits/stdc++.h>

#define N (2000005<<2)
#define G 3
using namespace std;
typedef long long ll;
int tr[N >> 1], le = 0;
const ll M[3] = {1004535809, 998244353, 469762049};

inline int read() {
char e = getchar();
int s = 0, g = 0;
while (e < '-')e = getchar();
if (e == '-')g = 1, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}

ll qMultiple(ll a, ll b, ll mod) {//快速乘
ll ans = 0, s = a;
while (b > 0) {
if (b & 1)ans = (ans + s) % mod;
b >>= 1, s <<= 1, s %= mod;
}
return ans;
}

ll qPow(ll a, ll b, ll mod) {//快速幂
ll ans = 1, x = a;
while (b > 0) {
if (b & 1)ans = ans * x % mod;
x *= x, x %= mod, b >>= 1;
}
return ans;
}

ll a[N], b[N], f[N], g[N], l = 1, ans[3][N];
int n, m, mod;

void NTT(int l, ll *c, int type, int s) {//NTT
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
ll wn = qPow(G, (M[s] - 1) / (mid << 1), M[s]);
if (type == -1)wn = qPow(wn, M[s] - 2, M[s]);
for (int len = mid << 1, j = 0; j < l; j += len) {
ll w = 1;
for (int k = 0; k < mid; k++, w = w * wn % M[s]) {
ll x = c[j + k], y = w * c[j + mid + k] % M[s];
c[j + k] = ((x + y) % M[s] + M[s]) % M[s];
c[j + mid + k] = ((x - y) % M[s] + M[s]) % M[s];
}
}
}
}

void NTT2(int s) {//封装NTT的操作
for (int i = 0; i < l; i++)f[i] = a[i] % M[s];
for (int i = 0; i < l; i++)g[i] = b[i] % M[s];
NTT(l, f, 1, s), NTT(l, g, 1, s);
for (int i = 0; i < l; i++)f[i] = f[i] * g[i] % M[s];
NTT(l, f, -1, s);
ll p = qPow(l, M[s] - 2, M[s]);
for (int i = 0; i <= (n + m); i++)ans[s][i] = f[i] * p % M[s];
}

int main() {
n = read(), m = read(), mod = read();
for (int i = 0; i <= n; i++)a[i] = read();
for (int i = 0; i <= m; i++)b[i] = read();
while (l <= n + m)l <<= 1, ++le;
for (int i = 0; i < l; i++)tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
for (int i = 0; i < 3; i++)NTT2(i);
ll inv1 = qPow(M[1] % M[0], M[0] - 2l, M[0]), inv2 = qPow(M[0] % M[1], M[1] - 2l, M[1]);
ll inv3 = qPow(M[0] * M[1] % M[2], M[2] - 2l, M[2]);
for (int i = 0; i < l; i++) {
ll c;
ans[0][i] = (qMultiple(ans[0][i] * M[1] % (M[0] * M[1]), inv1, M[0] * M[1])
+ qMultiple(ans[1][i] * M[0] % (M[0] * M[1]), inv2, M[0] * M[1])) % (M[0] * M[1]);
c = ans[0][i];
ans[0][i] = ((ans[2][i] - c) % M[2] + M[2]) % M[2] * inv3 % M[2];
ans[0][i] = (M[0] * M[1] % mod * ans[0][i] % mod + c % mod) % mod;
}
for (int i = 0; i <= n + m; i++)printf("%lld ", ans[0][i]);
return 0;
}

拆系数七次FFT

        拆系数法FFT也是一种重要的任意模数方法。
        我们当然可以用FFT先求出系数来,然后再取模,但是这样很明显会爆long long/double。于是就有了拆系数FFT来解决这个问题。
        拆系数FFT的思想很简单。首先假定模数为$P$,那么先求出$\sqrt P$(需要取整,下同),将两个多项式的系数按照带余除法原理写成下面的形式:

        注:这里第一个式子是$A(x)$的系数拆分,第二个是$B(x)$。
        于是两个多项式可以写成下面的形式:

        这样将两个多项式分别又拆成两个多项式,在求$A(x)B(x)$将这四个多项式两两相乘(共乘四次)。这样求完这四个多项式后再把它们加起来就是最终的答案。
        这样做的可行性是拆系数使得系数压至$\sqrt P$级别,这样这四次多项式乘积中,系数最大也不会超过$nP$。由于大多数问题中$n$在$10^5$量级,$P$在$10^9$量级,$nP$的乘积用double完全存的下。
        这样,我们对这四个多项式分别进行一次DFT,然后求出四个多项式的两两乘积(共四次),其中系数相同的(只乘一个$\sqrt P$)我们将它们合并起来,这样IDFT只需要做三次,共做了七次FFT。下面给出模板题代码(注意精度):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<bits/stdc++.h>

#define N (100005<<2)
using namespace std;
const double Pi = acos(-1.0);

struct Complex {
double a, b;

explicit Complex(double x = 0, double y = 0) : a(x), b(y) {}

Complex operator+(Complex c) { return Complex(a + c.a, b + c.b); }

Complex operator*(Complex c) { return Complex(a * c.a - b * c.b, a * c.b + b * c.a); }

Complex operator-(Complex c) { return Complex(a - c.a, b - c.b); }

Complex operator/(double s) { return Complex(a / s, b / s); }
} a[N], b[N], c[N], d[N], tmp[N];

double co[N][25], si[N][25];
int ans[N], tr[N], le = 0;

inline int read() {
char e = getchar();
int s = 0, g = 0;
while (e < '-')e = getchar();
if (e == '-')g = 1, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}

void FFT(int l, Complex *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1, p = 0; mid < l; mid <<= 1, p++) {
for (int len = mid << 1, j = 0; j < l; j += len) {
Complex w(1.0, 0);
for (int k = 0; k < mid; k++, w = Complex(co[k][p], type * si[k][p])) {
Complex x = c[j + k], y = w * c[j + mid + k];
c[j + k] = x + y, c[j + mid + k] = x - y;
}
}
}
}

int main() {
int n = read(), m = read(), p = read(), l = 1, base = sqrt(p);
for (int i = 0, x; i <= n; i++)x = read() % p, a[i].a = x / base, b[i].a = x % base;
for (int i = 0, x; i <= m; i++)x = read() % p, c[i].a = x / base, d[i].a = x % base;
while (l <= n + m)l <<= 1, ++le;
for (int i = 0, j = 1; j < l; i++, j <<= 1) {//预处理单位根,减小精度误差
for (int z = 1; z < j; z++)co[z][i] = cos(z * Pi / j), si[z][i] = sin(z * Pi / j);
}
for (int i = 0; i < l; i++)tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
FFT(l, a, 1), FFT(l, b, 1), FFT(l, c, 1), FFT(l, d, 1);//四次DFT
for (int i = 0; i < l; i++)tmp[i] = a[i] * c[i];
FFT(l, tmp, -1);
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp[i].a / l + 0.5) % p * base * base % p) % p;
for (int i = 0; i < l; i++)tmp[i] = b[i] * c[i] + a[i] * d[i];//这两次合并成一个
FFT(l, tmp, -1);
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp[i].a / l + 0.5) * base % p) % p;
for (int i = 0; i < l; i++)tmp[i] = b[i] * d[i];
FFT(l, tmp, -1);
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp[i].a / l + 0.5) % p) % p;
for (int i = 0; i <= n + m; i++)cout << ans[i] << " ";
return 0;
}

拆系数四次FFT

        这可能是最好的任意模数NTT做法,要理解这个方法,你需要至少理解七次FFT的思路。
        在快速傅里叶变换一文的末尾,曾提及myy的论文,在这篇论文中他提出了将两次FFT合并成一次的做法。这启示我们,七次FFT可以两两合并,从而减少FFT的次数,提高程序效率。
        首先一开始的四次DFT,将它们两两合并,变成两次DFT,这样就减少了两次DFT。
        减少两次之后似乎没有办法再优化了。剩余的两次DFT由于项可能是复数,不能合并,而IDFT的项同样是复数,自然也很难处理。
        myy在他的论文中提到了一个十分巧妙的做法,可以帮助我们合并三次IDFT为两次,再减少一次IDFT。
        虽然参与IDFT的序列不是纯实数,但是它们的结果显然是实数。观察这个式子:

        左侧是经过DFT后的序列,也就是我们的已知式。通过这两个式子,显然可以求出$c_k$或者$d_k$。而根据$c_k$以及$d_k$的定义,我们对它们其中任何一个进行一次IDFT,之后都可以从实部和虚部直接读出两个序列经过IDFT后的值,这是一个极其巧妙的优化。
        为了实现四次FFT的过程,定义两个函数,分别用于DFT以及IDFT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline void DFT(int l, Complex *x, Complex *y) {//x与y合并成一次,下同
for (int i = 0; i < l; i++)x[i].b = y[i].a;
FFT(l, x, 1);//只进行一次
for (int i = 0; i < l; i++)y[i].a = x[i == 0 ? 0 : l - i].a, y[i].b = -x[i == 0 ? 0 : l - i].b;
for (int i = 0; i < l; i++) {
Complex xx = x[i], yy = y[i];
x[i] = (xx + yy) / 2.0, y[i] = (xx - yy) / 2.0 * Complex(0, -1);
}
}

inline void IDFT(int l, Complex *x, Complex *y) {
for (int i = 0; i < l; i++)x[i] = x[i] + Complex(0, 1) * y[i];
FFT(l, x, -1);
for (int i = 0; i < l; i++)y[i].a = 0, y[i].a = x[i].b;
}


        这样,四次FFT就可以完美解决任意模数NTT的问题。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include<bits/stdc++.h>

#define N (100005<<2)
using namespace std;
const double Pi = acos(-1.0);

struct Complex {
double a, b;

explicit Complex(double x = 0, double y = 0) : a(x), b(y) {}

Complex operator+(Complex c) { return Complex(a + c.a, b + c.b); }

Complex operator*(Complex c) { return Complex(a * c.a - b * c.b, a * c.b + b * c.a); }

Complex operator-(Complex c) { return Complex(a - c.a, b - c.b); }

Complex operator/(double s) { return Complex(a / s, b / s); }
} a[N], b[N], c[N], d[N], tmp[N], tmp2[N];

double co[N][25], si[N][25];
int ans[N], tr[N], le = 0;

inline int read() {
char e = getchar();
int s = 0, g = 0;
while (e < '-')e = getchar();
if (e == '-')g = 1, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}

void FFT(int l, Complex *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1, p = 0; mid < l; mid <<= 1, p++) {
for (int len = mid << 1, j = 0; j < l; j += len) {
Complex w(1.0, 0);
for (int k = 0; k < mid; k++, w = Complex(co[k][p], type * si[k][p])) {
Complex x = c[j + k], y = w * c[j + mid + k];
c[j + k] = x + y, c[j + mid + k] = x - y;
}
}
}
}

inline void DFT(int l, Complex *x, Complex *y) {
for (int i = 0; i < l; i++)x[i].b = y[i].a;
FFT(l, x, 1);
for (int i = 0; i < l; i++)y[i].a = x[i == 0 ? 0 : l - i].a, y[i].b = -x[i == 0 ? 0 : l - i].b;
for (int i = 0; i < l; i++) {
Complex xx = x[i], yy = y[i];
x[i] = (xx + yy) / 2.0, y[i] = (xx - yy) / 2.0 * Complex(0, -1);
}
}

inline void IDFT(int l, Complex *x, Complex *y) {
for (int i = 0; i < l; i++)x[i] = x[i] + Complex(0, 1) * y[i];
FFT(l, x, -1);
for (int i = 0; i < l; i++)y[i].a = 0, y[i].a = x[i].b;
}

int main() {
int n = read(), m = read(), p = read(), l = 1, base = sqrt(p);
for (int i = 0, x; i <= n; i++)x = read() % p, a[i].a = x / base, b[i].a = x % base;
for (int i = 0, x; i <= m; i++)x = read() % p, c[i].a = x / base, d[i].a = x % base;
while (l <= n + m)l <<= 1, ++le;
for (int i = 0, j = 1; j < l; i++, j <<= 1) {
for (int z = 1; z < j; z++)co[z][i] = cos(z * Pi / j), si[z][i] = sin(z * Pi / j);
}
for (int i = 0; i < l; i++)tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
DFT(l, a, b), DFT(l, c, d);//这里两次FFT
for (int i = 0; i < l; i++)tmp[i] = a[i] * c[i];
for (int i = 0; i < l; i++)tmp2[i] = b[i] * c[i] + a[i] * d[i];
IDFT(l, tmp, tmp2);//第三次FFT
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp[i].a / l + 0.5) % p * base * base % p) % p;
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp2[i].a / l + 0.5) * base % p) % p;
for (int i = 0; i < l; i++)tmp[i] = b[i] * d[i];
FFT(l, tmp, -1);//最后一次FFT
for (int i = 0; i < l; i++)ans[i] = (ans[i] + (long long) (tmp[i].a / l + 0.5) % p) % p;
for (int i = 0; i <= n + m; i++)cout << ans[i] << " ";
return 0;
}