[CF438EThe Child and Binary Tree]
/ / 阅读耗时 7 分钟题目描述
Our child likes computer science very much, especially he likes binary trees.
Consider the sequence of n n distinct positive integers: $c_{1},c_{2},…,c_{n}$. The child calls a vertex-weighted rooted binary tree good if and only if for every vertex $v$ , the weight of v v is in the set $\{c_{1},c_{2},…,c_{n}\}$ c. Also our child thinks that the weight of a vertex-weighted tree is the sum of all vertices’ weights.
Given an integer $m$ , can you for all $s (1\leq s\leq m)$calculate the number of good vertex-weighted rooted binary trees with weight $s$? Please, check the samples for better understanding what trees are considered different.
We only want to know the answer modulo $998244353 $( $7×17×2^{23}+1$, a prime number).
输入格式
The first line contains two integers$n,m (1<=n<=10^{5}; 1<=m<=10^{5})$ . The second line contains$ n $space-separated pairwise distinct integers $c_{1},c_{2},…,c_{n}$. $(1\leq c_{i}\leq 10^{5})$.
输出格式
Print $m$ lines, each line containing a single integer. The $i$ -th line must contain the number of good vertex-weighted rooted binary trees whose weight exactly equal to $i$ . Print the answers modulo $998244353 $.
题解
其实就是个$DP$+多项式。
构造一个多项式:
$s_i$在$i\in C$时为1,否则为0。然后令$f(x)$为权值为$x$时的答案,那么有:
就是给左右子树分配权值。然后右边是一个卷积,它是$f*f$的第$x-s_i$项:
然后发现又是一个卷积,得到:
由于$f(0)=1$,上式中$f(0)=0$,漏项(相当于没有给出递推边界),于是修改成:
解方程:
于是:
代入$0$可知取正号,然后多项式开根+求逆就完了。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
using namespace std;
int tr[N << 2];
struct Pol {
int l, op[N << 2] = {0};
} pol1, pol2, tmp, tmp2;
int n, m;
int qPow(int a, int b) {
int ans = 1, x = a;
while (b) {
if (b & 1)ans = 1ll * ans * x % MOD;
x = 1ll * x * x % MOD, b >>= 1;
}
return ans;
}
void NTT(int l, int *c, int type) {
for (int i = 0; i < l; i++)if (i < tr[i])swap(c[i], c[tr[i]]);
for (int mid = 1; mid < l; mid <<= 1) {
int wn = qPow(G, (MOD - 1) / (mid << 1));
if (type == -1)wn = qPow(wn, MOD - 2);
for (int len = mid << 1, j = 0; j < l; j += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % MOD) {
int x = c[j + k], y = 1ll * w * c[j + mid + k] % MOD;
c[j + k] = (1ll * x + y) % MOD, c[j + mid + k] = (1ll * x - y) % MOD;
}
}
}
}
void inv(Pol *ans, const Pol *x, int rk) {//求逆
if (rk == 1) {
ans->l = 0, ans->op[0] = qPow(x->op[0], MOD - 2);
return;
}
inv(ans, x, (rk + 1) >> 1);
int l = 1, le = 0;
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1);
for (int i = 0; i < l; i++)
ans->op[i] = (2ll * ans->op[i] % MOD - (1ll * tmp.op[i] * ans->op[i] % MOD) * 1ll * ans->op[i] % MOD) % MOD;
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}
void polSqrt(Pol *ans, const Pol *x, int rk) {//开根函数
if (rk == 1) {//边界条件
ans->l = 0, ans->op[0] = 1;
return;
}
polSqrt(ans, x, (rk + 1) >> 1), inv(&tmp2, ans, rk);
int l = 1, le = 0, inv2 = qPow(2, MOD - 2);
while (l <= (rk << 1))l <<= 1, ++le;
for (int i = 0; i < l; i++) {
tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (le - 1));
tmp.op[i] = tmp.op[i] = i <= min(x->l, rk - 1) ? x->op[i] : 0;
ans->op[i] = i <= ans->l ? ans->op[i] : 0;
tmp2.op[i] = i <= tmp2.l ? tmp2.op[i] : 0;
}
NTT(l, tmp.op, 1), NTT(l, ans->op, 1), NTT(l, tmp2.op, 1);
for (int i = 0; i < l; i++) {
ans->op[i] = 1ll * (tmp.op[i] + 1ll * ans->op[i] * ans->op[i]) % MOD * inv2 % MOD * tmp2.op[i] % MOD;
}
NTT(l, ans->op, -1), l = qPow(l, MOD - 2), ans->l = rk - 1;
for (int i = 0; i <= ans->l; i++)ans->op[i] = 1ll * ans->op[i] * l % MOD;
}
int main() {
ios::sync_with_stdio(false);
cin >> n >> m;
pol1.l = 100001;
for (int i = 1, x; i <= n; i++) {
cin >> x;
pol1.op[x] = -4;
}
pol1.op[0] = 1;
polSqrt(&pol2, &pol1, 100001);
pol2.op[0]++;
inv(&pol1, &pol2, 100001);
for (int i = 0; i <= pol1.l; i++)pol1.op[i] = 2ll * pol1.op[i] % MOD;
for (int i = 1; i <= m; i++)cout << (1ll * pol1.op[i] + MOD) % MOD << endl;
return 0;
}