这是一篇模板类文章,意思即:本文的代码均可以作为模板代码。代码会时不时进行更新,以保证较小的常数。对于本博客中其它关于多项式的算法均以本文代码为准。
        本文整理ACM相关的多项式有关算法,长期更新。注意,这里的模块顺序虽然保证了前置知识在前的原则,但是它们不一定有必然的逻辑关系。
        UPD:2019.8.2更新,代码效率提升。

多项式乘法

        多项式乘法运算可以说是多项式中的基本功。朴素的多项式乘法时间复杂度$O(n^2)$,通常可以使用FFT以及NTT来将时间复杂度降至$O(nlogn)$,关于这两个算法前面的文章有所介绍,见FFT以及NTT
        为了使问题更简便,这里封装一下NTT的代码。首先定义多项式类:

1
2
3
struct Pol {//多项式类
int l, op[N << 2] = {0};//l为最高项指数,N是最高项次数,通常取一个更大一点的值
};

        封装NTT以及多项式乘法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#define N 2000005
#define MOD 998244353//模数
#define G 3//原根
int tr[N << 2];//交换数组

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}


void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {//将a与b相乘并计入ans
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;//高次项系数清零
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}


        上面的板子直接拿来用就可以了。这里需要注意一个问题:multiple函数会修改a与b的值,如果不希望修改需要事先备份,这里不事先备份是为了保证更小的常数。其次,这里的多项式乘数a与b不能相同。此处的板子可以直接切掉多项式乘法模板题(见FFT)。
        我们可以从更高的层次去理解多项式乘法。注意到对于两个多项式$A(x)$,$B(x)$,它们的乘积$C(x)=A(x)B(x)$的系数满足:

        这就是卷积,也写成如下形式:

        凡是形如这样的求和式都可以用FFT/NTT优化。

多项式分治乘法

        两个多项式相乘可以使用$NTT/FFT$做到$O(nlogn)$,考虑如何求$n$个小次数多项式相乘的结果。例如求$n$个形如$ax+b$的多项式乘积。
        朴素的暴力累乘算法会导致$O(n^2logn)$,这是因为暴力累乘会导致多项式次数越来越大,复杂度逐步提高。
        考虑一种分治解法。我们将$n$个多项式划分为左右两个部分,然后求出两个部分的结果。容易发现两个部分的多项式的次数是$O(n)$的,因此复杂度为:

        复杂度大概在$O(nlog^2n)$。

        由于需要存很多中间结果,该算法的空间复杂度达到$O(nlogn)$。在这里并不建议使用递归中申请临时变量的方法来记忆多项式,而是将它们开到全局变量中,这样可以降低常数。

1
2
3
4
5
6
7
8
9
10
11
12
Pol pols[20];//临时变量池
//处理[l,r]的多项式,深度为k,结果存入pol
void solve(int l, int r, int k, Pol* pol)
{
if (l == r) {
//在这里更新多项式
return;
}
int mid = (l + r) >> 1;
solve(l, mid, k + 1, pols + k), solve(mid + 1, r, k + 1, pol);
multiple(pol, pols + k, pol);
}

多项式求逆

        多项式求逆是多项式专题中的重要问题,多项式的逆定义如下。
        对于最高次项次数为n-1的多项式$A(x)$,若存在$B(x)$使得:

        则称$B(x)$为$A(x)$在模$x^n$下的逆多项式。比如$A(x)=1+x$与$B(x)=1-x$就在模$x^2$下互为逆多项式,这是因为:

        若$B(x)$为$A(x)$的逆多项式,也记$A^{-1}(x)=B(x)$。
        易知如果$A(x)$与$B(x)$在模$x^n$下互逆,则$A(x)B(x)$的所有次数为$1$到$n-1$的项系数都为0而常数项为1。
        下面探讨逆多项式的求法。
        现在要求$A(x)$在模$x^n$下的逆多项式。如果$A(x)$最高次项次数为0,那么其逆多项式就是常数项的逆元。假设在$x^{\lceil \frac {n} {2}\rceil}$下的逆多项式已经求出,为$B’(x)$,即有:

        假设需要求的逆多项式为$B(x)$,那么有:

        也满足:

        于是有:

        两边同时去乘$B’(x)$,得到:

        两边平方:

        同时乘上$A(x)$,得到:

        立即得到:

        上面公式给出了$B(x)$与$B’(x)$的递推关系,这使得我们可以在$O(lognM)$复杂度下递推出逆多项式,$M$是多项式乘法的复杂度。
        如果要求系数在模空间下给出,并且模数恰好是NTT模数,那么我们可以用NTT来使乘法复杂度达到$O(nlogn)$,总复杂度大约就是$O(nlog^2n)$。
        注意到在相乘时,$A(x)$中次数不小于n的项是无用的,它们不会影响答案,因此可以在相乘时滤掉这些项,达到每一次相乘时最高次项次数不断下降的效果,最终复杂度就是$O(nlogn)$。
        递归版的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void inv(Pol *ans, const Pol *x, int rk) {//求x的逆,在x^rk意义下,存入ans,x不会被修改
if (rk == 1) {//递推边界
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;//高次项滤掉,零项补零
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

        附模板题代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, tmp;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)cin >> pol1.op[i];
inv(&pol2, &pol1, n);
for (int i = 0; i <= pol2.l; i++)cout << (pol2.op[i] + MOD) % MOD << " ";
return 0;
}

        和乘法逆元类似,多项式逆元有着它的应用,其中一个重要应用就是求$\displaystyle\frac {B(x)} {A(x)}$在$x^n$模意义下的多项式,答案就是$B(x)A^{-1}(x)$。

多项式对数函数

        本算法必须保证常数项系数为1。
        给定一个n-1次多项式$A(x)$,求在模$x^n$意义下的$B(x)$满足:

        这种问题可以用求导的方式解决,首先求个导:

        就是:

        之后对$B’(x)$积分就可以了。这里定义一下求导和积分函数:

1
2
3
4
5
6
7
8
9
10
inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

inline void integral(Pol *s) {
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}


        然后用上面的模板代码去求逆以及乘积即可,复杂度$O(nlogn)$,模板题全代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, tmp;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void diff(Pol *s) {//求导
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}


inline void integral(Pol *s) {//积分
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)cin >> pol1.op[i];
inv(&pol2, &pol1, n), diff(&pol1), multiple(&pol1, &pol1, &pol2), integral(&pol1);
for (int i = 0; i <= n - 1; i++)cout << (pol1.op[i] + MOD) % MOD << " ";
return 0;
}

多项式指数函数

        本算法必须保证常数项系数为0。
        可以认为是对数函数的逆运算。为了让问题更简便,封装一下求ln的函数:

1
2
3
4
5
inline void getLn(Pol *ans, const Pol *x, int rk) {//求x的ln,在x^rk意义下,存入ans,x不会被改变
inv(ans, x, rk);
for (int i = 0; i < rk; i++)tmp.op[i] = i <= x->l ? x->op[i] : 0, tmp.l = rk - 1;
diff(&tmp), multiple(ans, &tmp, ans), integral(ans), ans->l = rk - 1;
}

        给定n-1次多项式$A(x)$,求模$x^n$意义下的$B(x)$满足:

        首先两边取个$ln$,然后移项:

        左侧式中,$A(x)$是给定的,可以认为是常量,而$B(x)$是变量,那么可以将左侧看成关于$B(x)$的函数$G(B(x))$,令其对$B(x)$求导可得:

        下面考虑其它问题。既然将$B(x)$看成是$G$的自变量,那么现在要做的实质上就是求$G$的零点,但这个函数的零点显然不是那么容易求:

        考虑迭代。假设现在已经求出$F(x)$满足:

        将$G(X)$在这一点做幂级数展开:

        $X$表示一个多项式。现在要求满足$G(X)\equiv 0\pmod {x^n}$的$X$,$X$当然也满足$G(X)\equiv 0\pmod {x^{\lceil \frac {n} {2}\rceil}}$。
        根据$G(X)$的形式,当然可以推出:

        于是:

        这样只需要拿出幂级数展开的前两项就可以了,因为后面的全是0:

        这样$X$就求出来了,这里的$X$就是$B(x)$:

        这就是牛顿迭代公式,代入上面求出的导数:

        最终:

        边界条件为$n=1$时,$B(x)=1$。这样就可以写出核心迭代函数,和inv很像:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void getExp(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {//迭代边界
ans->l = 0, ans->op[0] = 1;
return;
}
getExp(ans, x, (rk + 1) >> 1), getLn(&tmp2, ans, rk);//用一个临时变量存ln
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i < rk ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++)ans->op[i] = 1ll * ans->op[i] * (1ll - tmp2.op[i] + tmp.op[i]) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}


        复杂度$O(nlogn)$(?),附模板题全代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {//多项式类
int l, op[N << 2] = {0};//l最高项指数
} pol1, pol2, tmp, tmp2;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

inline void integral(Pol *s) {
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD, c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;//高次项滤掉,零项补零
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {//将a与b相乘并计入ans
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

inline void getLn(Pol *ans, const Pol *x, int rk) {//求x的ln
inv(ans, x, rk);
for (int i = 0; i < rk; i++)tmp.op[i] = i <= x->l ? x->op[i] : 0, tmp.l = rk - 1;
diff(&tmp), multiple(ans, &tmp, ans), integral(ans), ans->l = rk - 1;
}

void getExp(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = 1;
return;
}
getExp(ans, x, (rk + 1) >> 1), getLn(&tmp2, ans, rk);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i < rk ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++)ans->op[i] = 1ll * ans->op[i] * (1ll - tmp2.op[i] + tmp.op[i]) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)cin >> pol1.op[i];
getExp(&pol2, &pol1, n);
for (int i = 0; i <= pol2.l; i++)cout << pol2.op[i] << " ";
return 0;
}

多项式快速幂

        本算法必须保证常数项系数为1。
        在上面求对数以及求指数的基础上我们可以很方便地求多项式快速幂。它适用于$A(x)$常数项为1的情况。
        现在要求一个$B(x)$满足:

        两边取对数:

        再取指数:

        然后就转化为求指数和求对数的问题了,复杂度$O(nlogn)$。这个算法支持k非常大的情况,因为k无论多大最终都会取模。模板题全代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {//多项式类
int l, op[N << 2] = {0};//l最高项指数
} pol1, pol2, tmp, tmp2;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

inline void integral(Pol *s) {
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

inline void getLn(Pol *ans, const Pol *x, int rk) {//求x的ln
inv(ans, x, rk);
for (int i = 0; i < rk; i++)tmp.op[i] = i <= x->l ? x->op[i] : 0, tmp.l = rk - 1;
diff(&tmp), multiple(ans, &tmp, ans), integral(ans), ans->l = rk - 1;
}

void getExp(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = 1;
return;
}
getExp(ans, x, (rk + 1) >> 1), getLn(&tmp2, ans, rk);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i < rk ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++)ans->op[i] = 1ll * ans->op[i] * (1ll - tmp2.op[i] + tmp.op[i]) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline int read() {
char e = getchar();
long long s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar(), s %= MOD;
return s;
}

int main() {
int n = read(), k = read();
pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)pol1.op[i] = read();
getLn(&pol2, &pol1, n);
for (int i = 0; i <= pol2.l; i++)pol2.op[i] = 1ll * pol2.op[i] * k % MOD;
getExp(&pol1, &pol2, n);
for (int i = 0; i <= pol1.l; i++)cout << (pol1.op[i] + MOD) % MOD << " ";
return 0;
}

        上面基于$ln$和$exp$的快速幂常数很大,在某些情况下会T,所以这里也补充一下倍增快速幂。
        思路很简单,就是按照普通快速幂的思路,只不过将数乘改成多项式乘法。复杂度$O(nlog^2n)$。我们可以用下面的样式去写倍增多项式快速幂。

1
2
3
4
5
6
7
8
9
10
11
while (n) {
if (n & 1) {
tmp.l = pol.l;//复制一份
for (int i = 0; i <= pol.l; i++)tmp.op[i] = pol.op[i];
multiple(&ans, &tmp, &ans);
}
for (int i = 0; i <= pol.l; i++)tmp.op[i] = pol.op[i];//继续复制
tmp.l = pol.l, multiple(&pol, &tmp, &pol);
n >>= 1;
}


        这里主要注意一下复制的问题,其余就没有什么了。这种方法在某些情况下比上面的快速幂更优。

多项式开根

        多项式开根其实就是平方的逆运算(废话),是要求一个多项式$B(x)$满足:

        其实这里两边取个$ln$就能转化为$ln$和$exp$问题了,但是这样做常数巨大,不是很好的做法。这里的解法是迭代+求逆。
        根据之前迭代的思路,我们先假定求出了一个$F(x)$满足:

        $B(x)$当然也满足:

        那么:

        两边平方并拆项:

        两边开根:

        所以:

        迭代进行这个过程就可以完成开根运算。边界条件为$n=1$时,$B(x)=1$(因为题目保证$A(x)$常数项为1)。模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, tmp, tmp2;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

void polSqrt(Pol *ans, const Pol *x, int rk) {//开根函数
if (rk == 1) {//边界条件
ans->l = 0, ans->op[0] = 1;
return;
}
polSqrt(ans, x, (rk + 1) >> 1), inv(&tmp2, ans, rk);
int l = 1, le = 0, inv2 = qPow(2, MOD - 2);
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++) {
ans->op[i] = 1ll * (tmp.op[i] + 1ll * ans->op[i] * ans->op[i]) % MOD * inv2 % MOD * tmp2.op[i] % MOD;
}
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)cin >> pol1.op[i];
polSqrt(&pol2, &pol1, n);
for (int i = 0; i <= pol2.l; i++)cout << (pol2.op[i] + MOD) % MOD << " ";
return 0;
}

拉格朗日插值

        这个部分和上面没有多大关系,但是这是一个重要方法。
        易知$n+1$个点可以确定一个最高$n$次多项式,求这个多项式的一个方法是列方程然后高斯消元,时间复杂度$O(n^3)$,拉格朗日插值法可以很好地解决这个问题,将时间复杂度降至$O(nlogn)$。
        拉格朗日插值法构造多项式其实就像一个结论,根据这$n$个点构造多项式如下:

        将这$n$个点代入上面的多项式,都可以得到相应的值,可知这个多项式是正确的,而构造这个多项式的时间复杂度显然是$O(n^2)$。
        模板题示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <bits/stdc++.h>

#define MOD 998244353
using namespace std;

int qPow(int x, int y) {
int ans = 1, sta = x;
while (y) {
if (y & 1)ans = 1ll * ans * sta % MOD;
sta = 1ll * sta * sta % MOD, y >>= 1;
}
return ans;
}

int n, k, x[2500], y[2500], ans;

int main() {
cin >> n >> k;
for (int i = 1; i <= n; i++)cin >> x[i] >> y[i];
for (int i = 1; i <= n; i++) {
int tmp = y[i];
for (int j = 1; j <= n; j++) {
if (i != j)tmp = 1ll * tmp * (k - x[j]) % MOD * qPow(x[i] - x[j], MOD - 2) % MOD;
}
ans = (ans + tmp) % MOD;
}
cout << (ans + MOD) % MOD;
return 0;
}

多项式除法

        给定一个$n$次多项式$A(x)$和一个$m$次多项式$G(x)$,求$n-m$次多项式$Q(x)$以及次数小于$m$的多项式$R(x)$满足:

        这里可以把$Q(x)$看成商,$R(x)$看成余数(模)。下面考虑如何求$Q(x)$以及$R(x)$。
        提出$A(x)$中的$x^n$,并定义这种操作为$R$。那么有:

        易知$A_R(x)$与$A(x)$的系数满足如下关系:

        其实就是将系数反置,这个操作可以$O(n)$完成。
        根据一开始的式子:

        可得:

        一步步往下推:

        然后求个逆求个乘法就求出了$Q_R(x)$,经系数反置后得到$Q(x)$。之后:

        求出$R(x)$即可。复杂度$O(nlogn)$。模板题代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>

#define MOD 998244353
#define G0 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {//多项式类
int l, op[N << 2] = {0};//l最高项指数
} F, G, pol3, tmp, tmp2;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G0, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;//高次项滤掉,零项补零
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void toR(Pol *a, const Pol *b) {//R操作
for (int i = 0; i <= b->l; i++)a->op[i] = b->op[b->l - i];
a->l = b->l;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {//将a与b相乘并计入ans
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;//高次项系数清零
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> F.l >> G.l;
for (int i = 0; i <= F.l; i++)cin >> F.op[i];
for (int i = 0; i <= G.l; i++)cin >> G.op[i];
toR(&pol3, &G), inv(&tmp2, &pol3, F.l - G.l + 1), toR(&tmp, &F);
multiple(&tmp2, &tmp, &tmp2), tmp2.l = F.l - G.l, toR(&tmp, &tmp2);
for (int i = 0; i <= tmp.l; i++)cout << (tmp.op[i] + MOD) % MOD << " ";
cout << endl;
multiple(&tmp, &tmp, &G);
for (int i = 0; i < G.l; i++)cout << (1ll * F.op[i] - tmp.op[i] + MOD) % MOD << " ";
return 0;
}

        为了方便后面的操作,封装一步求模的过程:
1
2
3
4
5
6
7
8
9
inline void mod(Pol *ans, const Pol *a, const Pol *b) {//a除以b得到余数ans
toR(&tmp3, b), inv(&tmp2, &tmp3, a->l - b->l + 1), toR(ans, a);
multiple(&tmp2, ans, &tmp2), tmp2.l = a->l - b->l, toR(ans, &tmp2);
//到这一步ans就是商
for (int i = 0; i <= b->l; i++)tmp.op[i] = b->op[i];//复制一份b
tmp.l = b->l, multiple(ans, ans, &tmp);
for (int i = 0; i < b->l; i++)ans->op[i] = (a->op[i] - ans->op[i] + MOD) % MOD;
ans->l = b->l - 1;
}

分治 FFT/NTT

        对于大多数卷积问题,我们卷一下就好了。但对于某些问题,计算后半段需要前半段的结果,这样似乎只能递推去求,导致复杂度退化。其实,对于这样的问题有一种很好的方法去解决,这就是分治思想。分治FFT/NTT就是这样一种思想的体现,它在很多地方都有应用,其中也蕴含了CDQ分治(以后文章探讨)的理念。
        来看这样一个问题:分治FFT模板
        显然,后面的$f_i$需要前面的$f_i$,这样就不能直接卷积,考虑分治。
        分治FFT/NTT的思想是将问题规模一分为二,考虑前半个问题对后半个问题的影响,然后再解决后半个问题。这里的重点在于分治时两个子问题会有单方面的影响。
        将区间$[0,n-1]$一分为二,分别求解其中的$f_i$。显然,后半区间的值对前半区间没有任何影响,但是没有前半区间的值,无法求后半区间。于是我们先递归地求前半区间的$f_i$值。当区间左右端点相同时直接结束递归,注意当$l=r=0$时赋值$f_0=1$。
        这里提一下为什么$l=r$时直接结束递归。当我们在一个区间$[l,r]$上计算时,只考虑当前这个区间对答案的影响,这是一个重要思想。当$l=r$相同时,$f_i$对自己的贡献是多少?根据题目中的式子可知它没有任何贡献,于是直接结束递归。
        在递归求右半边区间之前,先考虑左半区间对右半区间的影响。
        显然,对于右半区间的某一个点$x,x\in [mid+1,r]$,左半区间对其系数的贡献为:

        这是一个卷积,拿NTT去求,可将时间复杂度降到$O(nlogn)$,求完后将答案加到数组中就好了,之后再递归地求右半区间。
        分治的时间复杂度为$O(logn)$,每一次还需要$O(nlogn)$的NTT,总时间复杂度为$O(nlog^2n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {//多项式类
int l, op[N << 2] = {0};//l最高项指数
} pol1, tmp1, tmp2, ans;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD, c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {//将a与b相乘并计入ans
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;//高次项系数清零
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

void merge_NTT(int l, int r) {//分治NTT
if (l == r) {
if (l == 0)ans.op[0] = 1;//递归结束边界条件
return;
}
int mid = (l + r) >> 1;
merge_NTT(l, mid);//递归左半区间
for (int i = l, j = 0; i <= mid; i++, j++)tmp1.op[j] = ans.op[i];//复制一遍数组
for (int i = 0; i <= r - l; i++)tmp2.op[i] = pol1.op[i];
tmp1.l = mid - l, tmp2.l = r - l, multiple(&tmp1, &tmp1, &tmp2);//求卷积
for (int i = mid + 1, j = mid - l + 1; i <= r; i++, j++)ans.op[i] = (ans.op[i] + tmp1.op[j]) % MOD;//答案加上去
merge_NTT(mid + 1, r);//递归求右半区间
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1;
for (int i = 1; i <= n - 1; i++)cin >> pol1.op[i];
merge_NTT(0, n - 1);
for (int i = 0; i <= n - 1; i++)cout << ans.op[i] << " ";
return 0;
}


        这个题比较有趣的地方是它可以用多项式求逆直接切掉,复杂度更低($O(nlogn)$)。
        为了方便,不妨令$g_0=0$,这样就有:

        上式适用于$x>0$的情况,对于$x=0$时,$f_0=1$。
        然后构造$n-1$次多项式$F(x)$以及$G(x)$:

        那么:

        其实就是:

        很快就能得到:

        求个逆就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {//多项式类
int l, op[N << 2] = {0};//l最高项指数
} pol1, pol2, tmp;
int n;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD, c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;//高次项滤掉,零项补零
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n;
pol1.l = n - 1, pol1.op[0] = 1;
for (int i = 1; i <= pol1.l; i++)cin >> pol1.op[i];
for (int i = 1; i <= pol1.l; i++)pol1.op[i] = -pol1.op[i];
inv(&pol2, &pol1, n);
for (int i = 0; i <= pol2.l; i++)cout << (pol2.op[i] + MOD) % MOD << " ";
return 0;
}

多项式多点求值

        给定一个多项式$A(x)$求其在若干个点的值。
        考虑构造两个多项式:

        然后求$A(x)$对这两个多项式的模$R_1(x)$和$R_2(x)$。容易知道当$1\leq i\leq \lfloor\frac {n} {2} \rfloor$时,$A(x_i)=R_1(x_i)$,当$\lfloor \frac{n}{2}\rfloor < i\leq n$时,$A(x_i)=R_2(x_i)$,从而将多项式的次数降低。于是我们将一个问题转化成两个规模减半的子问题,递归求解即可。
        多项式取模的复杂度为$O(nlogn)$,故多点求值复杂度为$O(nlog^2n)$。
        小范围暴力优化:当分治到某一个小范围时(左右区间端点相差500左右),我们可以暴力求值,这样能够有效提高效率。
        对于$R_0$和$R_1$可以在之前用一步分治NTT预处理出来,用类似线段树的结构储存。模板题,本代码经过了疯狂卡常。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <bits/stdc++.h>

#pragma GCC diagnostic error "-std=c++14"
#pragma GCC target("avx")
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#define MOD 998244353
#define G0 3
#define N 64010
#define SWAP(a, b) (a^=b^=a^=b)
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol, tmp3, tmp, tmp2, tmp4, sss[25];
int ans[N], n, m, M[N], w1[22][N << 3], w2[22][N << 3];
int mem[N << 6], cnt = 1, L[N << 2], R[N << 2];

inline int read() {
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}

inline int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])SWAP(c[i], c[tr[i]]);
for (int mid = 1, p = 0; mid < l; mid <<= 1, p++) {
for (int len = mid << 1, j = 0; j < l; j += len) {
for (int k = 0, w = 1; k < mid; ++k, w = type == 1 ? w1[p][k] : w2[p][k]) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = (i <= x->l & i <= rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void toR(Pol *a, const Pol *b) {//R操作
for (int i = 0; i <= b->l; i++)a->op[i] = b->op[b->l - i];
a->l = b->l;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

inline void mod(Pol *ans, const Pol *a, Pol *b) {
toR(&tmp3, b), inv(&tmp2, &tmp3, a->l - b->l + 1), toR(ans, a);
multiple(&tmp2, ans, &tmp2), tmp2.l = a->l - b->l, toR(ans, &tmp2), multiple(ans, ans, b);
for (int i = 0; i < b->l; i++)ans->op[i] = (1ll * a->op[i] - ans->op[i] + MOD) % MOD;
ans->l = b->l - 1;
}

void solve(int l, int r, Pol *s, int k, int dep) {//求[l,r]中s的答案, k用来索引线段树结果
if (r - l <= 600) {
for (int i = l; i <= r; i++) {
int p = 1;
for (int j = 0; j <= s->l; j++)ans[i] = (1ll * p * s->op[j] + ans[i]) % MOD, p = 1ll * p * M[i] % MOD;
printf("%d\n", ans[i]);
}
} else {
for (int i = L[k << 1], j = 0; i <= R[k << 1]; i++, j++)tmp4.op[j] = mem[i];
tmp4.l = R[k << 1] - L[k << 1], mod(sss + dep, s, &tmp4), solve(l, (l + r) >> 1, dep + sss, k << 1, dep + 1);
for (int i = L[k << 1 | 1], j = 0; i <= R[k << 1 | 1]; i++, j++)tmp4.op[j] = mem[i];
tmp4.l = R[k << 1 | 1] - L[k << 1 | 1], mod(sss + dep, s, &tmp4);
solve(1 + ((l + r) >> 1), r, dep + sss, k << 1 | 1, dep + 1);
}
}

void merge_NTT(int l, int r, int k) {//类似线段树的结构,存分治NTT结果
if (l == r)L[k] = cnt, mem[cnt++] = -M[l], mem[cnt++] = 1, R[k] = cnt - 1;//就是x-xi
else {
merge_NTT(l, (l + r) >> 1, k << 1), merge_NTT(((l + r) >> 1) + 1, r, k << 1 | 1);//分治
memcpy(tmp.op, mem + L[k << 1], sizeof(int) * (R[k << 1] - L[k << 1] + 1));
memcpy(tmp2.op, mem + L[k << 1 | 1], sizeof(int) * (R[k << 1 | 1] - L[k << 1 | 1] + 1));
tmp.l = R[k << 1] - L[k << 1], tmp2.l = R[k << 1 | 1] - L[k << 1 | 1];
multiple(&tmp, &tmp, &tmp2), L[k] = cnt, R[k] = cnt + tmp.l;
memcpy(mem + cnt, tmp.op, sizeof(int) * (tmp.l + 1)), cnt += tmp.l + 1;
}
}

int main() {
for (int i = 0, p = 1; i < 19; i++, p <<= 1) {//预处理原根可以有效降低常数
w1[i][0] = w2[i][0] = 1;
int s1 = qPow(G0, (MOD - 1) / (p << 1)), s2 = qPow(s1, MOD - 2);
for (int k = 1; k < p; k++)w1[i][k] = 1ll * w1[i][k - 1] * s1 % MOD, w2[i][k] = 1ll * w2[i][k - 1] * s2 % MOD;
}
pol.l = n = read(), m = read();
for (int i = 0; i <= pol.l; i++)pol.op[i] = read();
for (int i = 1; i <= m; i++)M[i] = read();
merge_NTT(1, m, 1), solve(1, m, &pol, 1, 1);
return 0;
}

多项式快速插值

        给定$n+1$个互异点,确定一个$n$次多项式,这就是多项式快速插值。
        上面介绍了一种求多项式插值的拉格朗日算法,这个算法是$O(n^2)$的,现在考虑优化。
        先列出拉格朗日法的多项式:

        这个式子经过整理可以得到:

        考虑左边的式子,分子是一个常数,现在考虑求分母。构造一个多项式:

        这样分母就是:

        发现分子分母都趋于0,用一步洛必达法则求上式,得到:

        这样就可以用分治NTT求出$g(x)$,然后求导,再用一步多项式多点求值求出这些值。复杂度$O(nlog^2n)$。
        然后考虑分治。用$h(l,r)$表示从$l$到$r$的答案,那么我们要求$h(1,n)$。
        然后推一波式子:

        当$l=r$时,$h(l,l)=\frac {y_l}{g’(x_l)}$,是一个常数。递归求解即可,复杂度为$O(nlog^2n)$。模板题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <bits/stdc++.h>

#pragma GCC diagnostic error "-std=c++14"
#pragma GCC target("avx")
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#define MOD 998244353
#define G0 3
#define N 100500
#define SWAP(a, b) (a^=b^=a^=b)
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} A, tmp3, tmp, tmp2, tmp4, sss[25];
int ans[N], n, m, w1[22][N << 3], w2[22][N << 3], X[N], Y[N];
int mem[N << 6], cnt = 1, L[N << 2], R[N << 2];

inline int read() {
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}

inline int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])SWAP(c[i], c[tr[i]]);
for (int mid = 1, p = 0; mid < l; mid <<= 1, p++) {
for (int len = mid << 1, j = 0; j < l; j += len) {
for (int k = 0, w = 1; k < mid; ++k, w = type == 1 ? w1[p][k] : w2[p][k]) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}

inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

void inv(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = (i <= x->l & i <= rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void toR(Pol *a, const Pol *b) {
for (int i = 0; i <= b->l; i++)a->op[i] = b->op[b->l - i];
a->l = b->l;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {//将a与b相乘并计入ans
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

inline void mod(Pol *ans, const Pol *a, Pol *b) {//a除以b得到余数ans
toR(&tmp3, b), inv(&tmp2, &tmp3, a->l - b->l + 1), toR(ans, a);
multiple(&tmp2, ans, &tmp2), tmp2.l = a->l - b->l, toR(ans, &tmp2), multiple(ans, ans, b);
for (int i = 0; i < b->l; i++)ans->op[i] = (1ll * a->op[i] - ans->op[i] + MOD) % MOD;
ans->l = b->l - 1;
}

void solve(int l, int r, Pol *s, int k, int dep) {//求[l,r]中s的答案, k用来索引线段树结果
if (r - l <= 600) {
for (int i = l; i <= r; i++) {
int p = 1;
for (int j = 0; j <= s->l; j++)ans[i] = (1ll * p * s->op[j] + ans[i]) % MOD, p = 1ll * p * X[i] % MOD;
}
} else {
for (int i = L[k << 1], j = 0; i <= R[k << 1]; i++, j++)tmp4.op[j] = mem[i];
tmp4.l = R[k << 1] - L[k << 1], mod(sss + dep, s, &tmp4), solve(l, (l + r) >> 1, dep + sss, k << 1, dep + 1);
for (int i = L[k << 1 | 1], j = 0; i <= R[k << 1 | 1]; i++, j++)tmp4.op[j] = mem[i];
tmp4.l = R[k << 1 | 1] - L[k << 1 | 1], mod(sss + dep, s, &tmp4);
solve(1 + ((l + r) >> 1), r, dep + sss, k << 1 | 1, dep + 1);
}
}

void merge_NTT(int l, int r, int k) {
if (l == r)L[k] = cnt, mem[cnt++] = -X[l], mem[cnt++] = 1, R[k] = cnt - 1;
else {
merge_NTT(l, (l + r) >> 1, k << 1), merge_NTT(((l + r) >> 1) + 1, r, k << 1 | 1);
memcpy(tmp.op, mem + L[k << 1], sizeof(int) * (R[k << 1] - L[k << 1] + 1));
memcpy(tmp2.op, mem + L[k << 1 | 1], sizeof(int) * (R[k << 1 | 1] - L[k << 1 | 1] + 1));
tmp.l = R[k << 1] - L[k << 1], tmp2.l = R[k << 1 | 1] - L[k << 1 | 1];
multiple(&tmp, &tmp, &tmp2), L[k] = cnt, R[k] = cnt + tmp.l;
memcpy(mem + cnt, tmp.op, sizeof(int) * (tmp.l + 1)), cnt += tmp.l + 1;
}
}

void calc(Pol *a, int l, int r, int k, int dep) {
if (l == r)a->l = 0, a->op[0] = 1ll * Y[l] * qPow(ans[l], MOD - 2) % MOD;
else {
calc(sss + dep, l, (l + r) >> 1, k << 1, dep + 1);
memcpy(a->op, mem + L[k << 1 | 1], sizeof(int) * (R[k << 1 | 1] - L[k << 1 | 1] + 1));
a->l = R[k << 1 | 1] - L[k << 1 | 1], multiple(a, a, sss + dep);
calc(sss + dep, ((l + r) >> 1) + 1, r, k << 1 | 1, dep + 1);
memcpy(tmp.op, mem + L[k << 1], sizeof(int) * (R[k << 1] - L[k << 1] + 1));
tmp.l = R[k << 1] - L[k << 1], multiple(&tmp, &tmp, sss + dep);
for (int i = 0; i <= tmp.l; i++)a->op[i] = (1ll * a->op[i] + tmp.op[i]) % MOD;
}
}

int main() {
for (int i = 0, p = 1; i < 19; i++, p <<= 1) {
w1[i][0] = w2[i][0] = 1;
int s1 = qPow(G0, (MOD - 1) / (p << 1)), s2 = qPow(s1, MOD - 2);
for (int k = 1; k < p; k++)w1[i][k] = 1ll * w1[i][k - 1] * s1 % MOD, w2[i][k] = 1ll * w2[i][k - 1] * s2 % MOD;
}
n = read();
for (int i = 1; i <= n; i++)X[i] = read(), Y[i] = read();
merge_NTT(1, n, 1), memcpy(A.op, mem + L[1], sizeof(int) * (R[1] - L[1] + 1)), A.l = R[1] - L[1];
diff(&A), solve(1, n, &A, 1, 1), calc(&A, 1, n, 1, 1);
for (int i = 0; i <= A.l; i++)printf("%d ", (A.op[i] + MOD) % MOD);
return 0;
}

多项式三角函数

        本算法要求多项式常数项为0。
        给定一个$n$次多项式$A(x)$,求$B(x)$满足$B(x)\equiv sinA(x)\pmod{x^n}$或$B(x)\equiv cosA(x)\pmod {x^n}$。
        要解决这个问题,自然要与已知的知识相结合。考虑欧拉公式:

        然后就得到:

        于是:

        就转化成多项式指数函数的内容了。这里有一处优化,我们不用求两次多项式exp,因为可以发现两个多项式是互逆关系,只需要求出一个,另一个再求逆即可,能提升效率。
        发现这里有一个虚数单位$i$没法处理,考虑虚数的定义:

        那么就有:

        把$i$替换成这个数即可。模板题代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
#define I 86583718//虚数单位
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, A, tmp, tmp2;
int n, type;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

inline void integral(Pol *s) {
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD, c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

inline void getLn(Pol *ans, const Pol *x, int rk) {//求x的ln
inv(ans, x, rk);
for (int i = 0; i < rk; i++)tmp.op[i] = i <= x->l ? x->op[i] : 0, tmp.l = rk - 1;
diff(&tmp), multiple(ans, &tmp, ans), integral(ans), ans->l = rk - 1;
}

void getExp(Pol *ans, const Pol *x, int rk) {
if (rk == 1) {
ans->l = 0, ans->op[0] = 1;
return;
}
getExp(ans, x, (rk + 1) >> 1), getLn(&tmp2, ans, rk);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i < rk ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++)ans->op[i] = 1ll * ans->op[i] * (1ll - tmp2.op[i] + tmp.op[i]) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false);
cin >> n >> type;
for (int i = 0; i < n; i++)cin >> pol1.op[i];
A.l = pol1.l = n - 1;
for (int i = 0; i <= pol1.l; i++)pol1.op[i] = 1ll * pol1.op[i] * I % MOD;
getExp(&pol2, &pol1, n);
for (int i = 0, inv = qPow(type ? 2 : 2 * I, MOD - 2); i <= pol2.l; i++) {
A.op[i] = 1ll * pol2.op[i] * inv % MOD;
}
for (int i = 0; i <= pol1.l; i++)pol1.op[i] = pol2.op[i];
inv(&pol2, &pol1, n);
for (int i = 0, inv = qPow(type ? 2 : 2 * I, MOD - 2); i <= pol2.l; i++) {
A.op[i] = (1ll * A.op[i] + (type ? 1ll : -1ll) * pol2.op[i] * inv) % MOD;
}
for (int i = 0; i <= A.l; i++)cout << (A.op[i] + MOD) % MOD << " ";
return 0;
}

多项式反三角函数

        本算法要求常数项为0。
        有了上面的基础再做这个就很简单了。就是给定一个n次多项式$A(x)$,求$B(x)$满足$B(x)\equiv arcsinA(x)\pmod{x^n}$或$B(x)\equiv arctanA(x)\pmod{x^n}$。
        对于$B(x)\equiv arcsinA(x)\pmod{x^n}$,求个导:

        对于$B(x)\equiv arctanA(x)\pmod{x^n}$,同样可得:

        下面应该怎么做就很显然了,把板子粘过来,然后求即可。模板题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <bits/stdc++.h>

#define MOD 998244353
#define G 3
#define N 100005
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, pol3, tmp, tmp2;
int n, type;

int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}

inline void diff(Pol *s) {
for (int i = 0; i < s->l; i++)s->op[i] = 1ll * s->op[i + 1] * (i + 1) % MOD;
s->op[s->l] = 0, --s->l;
}

inline void integral(Pol *s) {
for (int i = s->l + 1; i >= 1; i--)s->op[i] = 1ll * s->op[i - 1] * qPow(i, MOD - 2) % MOD;
++s->l, s->op[0] = 0;
}

void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = ((x + y) % MOD + MOD) % MOD, c[j + mid + k] = ((x - y) % MOD + MOD) % MOD;
}
}
}
}

void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

inline void multiple(Pol *ans, Pol *a, Pol *b) {
int l = 1, le = 0;
while (l <= a->l + b->l)l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
if (i > a->l)a->op[i] = 0;
if (i > b->l)b->op[i] = 0;
}
NTT(l, a->op, 1), NTT(l, b->op, 1);
for (int i = 0; i < l; i++)a->op[i] = 1ll * a->op[i] * b->op[i] % MOD;
NTT(l, a->op, -1), l = qPow(l, MOD - 2), ans->l = a->l + b->l;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * a->op[i] * l % MOD;
}

void polSqrt(Pol *ans, const Pol *x, int rk) {//开根函数
if (rk == 1) {//边界条件
ans->l = 0, ans->op[0] = 1;
return;
}
polSqrt(ans, x, (rk + 1) >> 1), inv(&tmp2, ans, rk);
int l = 1, le = 0, inv2 = qPow(2, MOD - 2);
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++) {
ans->op[i] = 1ll * (tmp.op[i] + 1ll * ans->op[i] * ans->op[i]) % MOD * inv2 % MOD * tmp2.op[i] % MOD;
}
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}

int main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> type;
for (int i = 0; i < n; i++) {
cin >> pol1.op[i];
tmp.op[i] = pol2.op[i] = pol1.op[i];
}
tmp.l = pol2.l = pol1.l = n - 1, multiple(&pol2, &pol2, &tmp);
if (type == 0) {
for (int i = 0; i <= pol2.l; i++)pol2.op[i] = -pol2.op[i];
++pol2.op[0], diff(&pol1), polSqrt(&pol3, &pol2, n), inv(&pol2, &pol3, n);
multiple(&pol1, &pol1, &pol2), integral(&pol1);
for (int i = 0; i < n; i++)cout << (pol1.op[i] + MOD) % MOD << " ";
} else {
++pol2.op[0], inv(&pol3, &pol2, n), diff(&pol1), multiple(&pol1, &pol1, &pol3), integral(&pol1);
for (int i = 0; i < n; i++)cout << (pol1.op[i] + MOD) % MOD << " ";
}
return 0;
}