快速阶乘算法
/ / 阅读耗时 13 分钟 介绍一种可以在$O(\sqrt nlogn)$复杂度下求$n! mod P$的算法。阅读本文需要前缀知识:任意模数NTT以及拉格朗日插值法。
目的很简单,就是求$n! mod P$,$P$是一个大质数。
考虑构造一个多项式$f(x)$如下:
然后把这个多项式扩展到二维:
现在令$B=\lfloor\sqrt n\rfloor$,那么$n!$就可以表示成$f(B,0)f(B,B)\cdots f(B,(B-1)B))$的乘积,然后再乘上剩余的项,也就是:
现在对$f(d,x)$做倍增,主要有下面两个操作:
将d乘2
已知:求出:
将d加1
已知:求出:
有了这两个操作我们可以在$O(logn)$的操作中从$d=1$倍增到任意一个值,那么倍增到$B$就很容易了。比如说我们需要倍增到5(二进制表示为101),那么只需要先乘2,再乘2,再加1,这三次操作即可。
将d加一
这一步比较容易,对于$0\leq i\leq d$,显然有:
而新的一项可以暴力求:
进行一次$O(n)$的操作即可完成。
将d乘2
易知$f(2d,iB)=f(d,iB)f(d,iB+d)$。
那么求出来下面这些序列:
以及:
就可以求出所需的序列了,想一想我们现在知道什么序列:
令$h(x)=f(d,xB)$,那么就是已知:
和要求的第一个序列很像,但是少一半。其实本质上就是已知上面的式子,求:
然后就得到了第一个所需的序列,再想一想怎么获得第二个序列,其实就是已知第一个待求序列,求:
发现两个问题有一个共性,就是已知:
求:
用一步拉格朗日插值算法:
提出累积中的式子:
进一步化式子可以得到:
后面的求和式是一个卷积,可以直接拿出来用任意模数NTT求。这里利用了两个乘积项中存在$i$和$n-i$,两者和为$n$可以用来卷积,但是前者$n-i$可能为负,需要将这个多项式整体右移$k$位,然后再求。
卷积完成之后发现前面还乘了一个系数,它是一段连续的数的积,并且数目总是相同的,可以用一个双指针维护这个乘积,在$O(nlogn)$下就可以求出所有需要的累积。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
using namespace std;
const double Pi = acos(-1.0);
struct Complex {
double a, b;
explicit Complex(double x = 0, double y = 0) : a(x), b(y) {}
Complex operator+(Complex c) { return Complex(a + c.a, b + c.b); }
Complex operator*(Complex c) { return Complex(a * c.a - b * c.b, a * c.b + b * c.a); }
Complex operator-(Complex c) { return Complex(a - c.a, b - c.b); }
Complex operator/(double s) { return Complex(a / s, b / s); }
} a[N], b[N], c[N], d[N], tmp[N], tmp2[N];
double co[N][25], si[N][25];
int ans[N], tr[N], x[N], y[N], fac[N], inv[N], h[N], f[N], h2[N], P;
inline int qPow(int x, int y, int mod) {
int ans = 1, sta = x;
while (y) {
if (y & 1)ans = 1ll * ans * sta % mod;
sta = 1ll * sta * sta % mod, y >>= 1;
}
return ans;
}
void FFT(int l, Complex *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1, p = 0; mid < l; mid <<= 1, p++) {
for (int len = mid << 1, j = 0; j < l; j += len) {
Complex w(1.0, 0);
for (int k = 0; k < mid; k++, w = Complex(co[k][p], type * si[k][p])) {
Complex x = c[j + k], y = w * c[j + mid + k];
c[j + k] = x + y, c[j + mid + k] = x - y;
}
}
}
}
inline void DFT(int l, Complex *x, Complex *y) {
for (int i = 0; i < l; i++)x[i].b = y[i].a;
FFT(l, x, 1);
for (int i = 0; i < l; i++)y[i].a = x[i == 0 ? 0 : l - i].a, y[i].b = -x[i == 0 ? 0 : l - i].b;
for (int i = 0; i < l; i++) {
Complex xx = x[i], yy = y[i];
x[i] = (xx + yy) / 2.0, y[i] = (xx - yy) / 2.0 * Complex(0, -1);
}
}
inline void IDFT(int l, Complex *x, Complex *y) {
for (int i = 0; i < l; i++)x[i] = x[i] + Complex(0, 1) * y[i];
FFT(l, x, -1);
for (int i = 0; i < l; i++)y[i].a = 0, y[i].a = x[i].b;
}
inline void init(int mod) {
for (int i = 0, j = 1; j < (1 << 18); i++, j <<= 1) {
for (int z = 1; z < j; z++)co[z][i] = cos(z * Pi / j), si[z][i] = sin(z * Pi / j);
}
fac[0] = 1, inv[0] = 1;
for (int i = 1; i <= 262154; i++)fac[i] = 1ll * i * fac[i - 1] % mod, inv[i] = qPow(fac[i], mod - 2, mod);
}
inline void MTT(const int *aa, const int *bb, int *anss, int nn, int mm, int mod) {
int base = sqrt(mod), l = 1, le = 0;
for (int i = 0, x; i <= nn; i++)x = aa[i] % mod, a[i].a = x / base, b[i].a = x % base;
for (int i = 0, x; i <= mm; i++)x = bb[i] % mod, c[i].a = x / base, d[i].a = x % base;
while (l <= nn + mm)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > nn)a[i].a = b[i].a = 0;//清空
if (i > mm)c[i].a = d[i].a = 0;
a[i].b = b[i].b = c[i].b = d[i].b = 0, anss[i] = 0;
}
DFT(l, a, b), DFT(l, c, d);
for (int i = 0; i < l; i++)tmp[i] = a[i] * c[i];
for (int i = 0; i < l; i++)tmp2[i] = b[i] * c[i] + a[i] * d[i];
IDFT(l, tmp, tmp2);
for (int i = 0; i < l; i++)anss[i] = (anss[i] + (long long) (tmp[i].a / l + 0.5) % mod * base * base % mod) % mod;
for (int i = 0; i < l; i++)anss[i] = (anss[i] + (long long) (tmp2[i].a / l + 0.5) * base % mod) % mod;
for (int i = 0; i < l; i++)tmp[i] = b[i] * d[i];
FFT(l, tmp, -1);
for (int i = 0; i < l; i++)anss[i] = (anss[i] + (long long) (tmp[i].a / l + 0.5) % mod) % mod;
}
//上面封装了MTT
inline void mul2(int d, int mod, int base) {//乘2
int ss = 1ll * d * qPow(base, mod - 2, mod) % mod, tmp;
for (int i = 0; i <= (d << 1); i++) {
h[i] = f[i];
x[i] = qPow(d + 1 + i - d, mod - 2, mod);
if (i <= d)y[i] = 1ll * h[i] * inv[i] % mod * inv[d - i] % mod * ((d - i) % 2 ? -1 : 1);
else y[i] = 0;
y[i] = (1ll * y[i] + mod) % mod;
}
MTT(x, y, ans, d << 1, d << 1, mod), tmp = 1;
for (int i = 1; i <= d + 1; i++)tmp = 1ll * tmp * i % mod;
for (int i = 0, l = 1, r = d + 1; i <= d; ++i, ++l, ++r) {
h[i + d + 1] = 1ll * ans[i + d] * tmp % mod;
tmp = 1ll * tmp * qPow(l, mod - 2, mod) % mod, tmp = 1ll * tmp * (r + 1) % mod;
}
for (int i = 0; i <= (d << 2); i++) {
x[i] = qPow(ss + i - (d << 1), mod - 2, mod);
if (i <= (d << 1)) {
y[i] = 1ll * h[i] * inv[i] % mod * inv[(d << 1) - i] % mod * ((2 * d - i) % 2 ? -1 : 1);
} else y[i] = 0;
y[i] = (1ll * y[i] + mod) % mod;
}
MTT(x, y, ans, d << 2, d << 2, mod), tmp = 1;
for (int i = ss - 2 * d; i <= ss; i++)tmp = 1ll * tmp * i % mod;
for (int i = 0, l = ss - 2 * d, r = ss; i <= 2 * d; ++i, ++l, ++r) {
h2[i] = 1ll * ans[i + (d << 1)] * tmp % mod;
tmp = 1ll * tmp * qPow(l, mod - 2, mod) % mod, tmp = 1ll * tmp * (r + 1) % mod;
}
for (int i = 0; i <= (d << 1); i++)f[i] = 1ll * h[i] * h2[i] % mod;
}
inline void solve(int base, int mod) {
int d = 1, hb = 0;
for (int i = base; i; i >>= 1)++hb;
f[0] = 1, f[1] = base + 1;
for (int i = hb - 2; i >= 0; i--) {
mul2(d, mod, base), d <<= 1;//乘2
if (base & (1 << i)) {//加一
for (int j = 0; j <= d; j++)f[j] = 1ll * f[j] * (1ll * j * base % mod + d + 1) % mod;
f[d + 1] = 1;
for (int j = 1; j <= d + 1; j++)f[d + 1] = 1ll * f[d + 1] * (1ll * d * base % mod + base + j) % mod;
++d;
}
}
}
int main() {
int n, p, base, tot = 1;
cin >> n >> p;
base = sqrt(n), init(p), solve(base, p);
for (int i = 0; i < base; i++)tot = 1ll * tot * f[i] % p;
for (int i = base * base + 1; i <= n; i++)tot = 1ll * tot * i % p;
cout << (1ll * tot + p) % p;
return 0;
}