这两天打比赛打的贼烂,我还是菜啊。这里记一个比赛原题,一个树型DP。

题目描述

        You have been given a weighted tree which has n nodes. Each node has a value vi and each edge has a length wi. You start at node 1, and can move a distance of d at most. How many values you can gather at most? For each node you can pick the value once no matter how many times you pass it.

输入输出格式

输入格式

        The first line of the input contains one integer $n (1 ≤ n ≤ 100)$.
        The second line contains n numbers, each of them denotes the value of one node: $v_i (0 ≤ v_i ≤ 2)$.
        For the next n-1 lines, each line contains 3 numbers $(ai, bi, wi)$, representing an edge between ai and bi of length $w_i (1 ≤ a_i, b_i ≤ n, 1 ≤ w_i ≤ 10^4)$.
        The next line has a number q, representing the total number of queries$ (0 ≤ q ≤ 100000)$. The next $q$ lines, each line contains a number $d$ representing the distance you can move at most.

输出格式

        For each query output the maximum values you can gather in one line.

输入输出样例

Sample input

3
0 1 1
1 2 5
1 3 3
3
3
10
11

Sample output

1
1
2

题解

        注意到点权是0~2,并且没有给$d$的范围,那么我们可以从点权上入手,考虑树型DP。
        令$DP(x,y,z)$表示在以$x$为根的树上,获得至少$y$的点权,并且在$z$状态下($z$只能取0或1,$z=0$时表示不用返回根,否则需要返回根)时至少要移动的距离。
        为了使问题更简便,我仍然使用那个粗鄙的方法:多叉树转二叉树。引入一些新的点,使原树转化为二叉树,新加入的边边权为0,新加入的点点权也为0,容易知道两棵树是等价的。转化为二叉树可以使状态转移方程更容易列出,时间复杂度没变(只是常数大了),这就是我们转换的目的。
        下面考虑转移。
        任何时候如果根结点的点权$v_x \geq y$,则$DP(x,y,0/1)=0$,这是很好理解的。
        若根结点只有一个儿子,则有:

        这里$ch[0]$是$x$的儿子,$w_0$则是边权。这个方程也是很好理解的。
        两个儿子的时候就不是那么容易转移,考虑对两个儿子结点进行点权分配。首先全部分配给左儿子,在$z$取值的两种情况下有以下的转移:

        只给另一个儿子分配类比即可。如果两个儿子都分配,且给第一个儿子$i$的权值,那么在$z=0$时转移为:

        $z=1$时就是:

        取上面的最小值转移到相应的$DP$值,不成立时取无穷大。对于询问直接二分找答案就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>

#define inf (1<<27)
using namespace std;
struct Edge {
int next, to, v;
} edge[500], edge2[50000];
int head[500], cnt = 1, n, q, op[50000], dp[50000][250][2], now, head2[50000], cnt2 = 1, ss;

inline void add(int x, int y, int z) {
edge[cnt].to = y, edge[cnt].next = head[x], edge[cnt].v = z, head[x] = cnt++;
}

inline void add2(int x, int y, int z) {
edge2[cnt2].to = y, edge2[cnt2].next = head2[x], edge2[cnt2].v = z, head2[x] = cnt2++;
}

int DP(int rt, int dd, int isOK) {
if (dp[rt][dd][isOK] != -1)return dp[rt][dd][isOK];
int ch[2], now = 0, v[2];
for (int i = head2[rt]; i; i = edge2[i].next)ch[now] = edge2[i].to, v[now++] = edge2[i].v;
if (now == 0)return dp[rt][dd][isOK] = (dd <= op[rt] ? 0 : inf);
if (now == 1)return dp[rt][dd][isOK] = (dd <= op[rt] ? 0 : DP(ch[0], dd - op[rt], isOK) + v[0] * (isOK + 1));
dp[rt][dd][isOK] = inf;
if (op[rt] >= dd)return dp[rt][dd][isOK] = 0;
for (int i = 0; i <= dd - op[rt]; i++) {
if (i == 0)dp[rt][dd][isOK] = min(dp[rt][dd][isOK], DP(ch[1], dd - op[rt], isOK) + v[1] * (isOK + 1));
else if (i == dd - op[rt])
dp[rt][dd][isOK] = min(dp[rt][dd][isOK], DP(ch[0], dd - op[rt], isOK) + v[0] * (isOK + 1));
else if (!isOK)
dp[rt][dd][isOK] = min(min(DP(ch[0], i, true) + DP(ch[1], dd - op[rt] - i, false) + 2 * v[0] + v[1],
DP(ch[0], i, false) + DP(ch[1], dd - op[rt] - i, true) + v[0] + 2 * v[1]),
dp[rt][dd][isOK]);
else
dp[rt][dd][isOK] = min(DP(ch[0], i, true) + DP(ch[1], dd - op[rt] - i, true) + 2 * v[0] + 2 * v[1],
dp[rt][dd][isOK]);
}
return dp[rt][dd][isOK];
}

void DFS(int x, int pre) {
int pp = 0, ss = x;
for (int i = head[x]; i; i = edge[i].next) {
if (edge[i].to == pre)continue;
if (pp == 0)add2(ss, edge[i].to, edge[i].v), ++pp;
else if (edge[i].next == 0)add2(ss, edge[i].to, edge[i].v);
else add2(ss, now, 0), add2(now, edge[i].to, edge[i].v), ss = now++;
}
for (int i = head[x]; i; i = edge[i].next)if (edge[i].to != pre)DFS(edge[i].to, x);
}

int main() {
cin >> n, memset(dp, -1, sizeof(dp));
now = n + 1;
for (int i = 1; i <= n; i++) {
cin >> op[i];
ss += op[i];
}
for (int i = 1, x, y, z; i < n; i++) {
cin >> x >> y >> z;
add(x, y, z), add(y, x, z);
}
DFS(1, 0), cin >> q;
while (q--) {
int x;
cin >> x;
int l = 0, r = ss + 1;//[,)
while (l < r) {
if (r == l + 1) {
cout << l << endl;
break;
}
int mid = (l + r) >> 1;
if (DP(1, mid, 0) <= x)l = mid;
else r = mid;
}
}
return 0;
}