介绍虚树,一个解决树上动态规划的利器。

        先来一个经典例题:洛谷P2011消耗战
        很明显是树型DP,但是奈何n过大,时间复杂度达到$O(nm)$,必然TLE,但是我们又没有什么好的解决思路。这时看到了一个条件$\displaystyle\sum_{i=1}^mk_i<=500000$,这说明无论m有多大,k的总和总是不大于500000,这个条件能不能用来优化算法?
        答案是肯定的,这时就需要引入虚树这个黑科技。
        所谓虚树是指将原树中与问题无关的点剔除,保留与问题有关的点及其lca,之后得到的树。经过这种处理后的树保留了点的相对关系,又去除了冗余的点,可以提高算法效率,这样得到的树就称为虚树。相对应地,原有的树称为原树。虚树能够有效利用与问题有关的点,这样k总和不大于500000的条件就很好地利用上了。

        上图是例题样例图。假如现在需要处理10和6两个点,那么树可以简化成这样:

        然后复杂度就降下来了,这就是虚树。

虚树的构建

        现在给定原树和我们需要处理的点,如何求虚树呢?这里需要前缀知识:求LCA。只要不是暴力,选哪一个求LCA的方法均可,包括但不仅限于倍增法、tarjan算法、树链剖分。这里选择倍增法。
        首先,我们需要对树上结点进行DFS排序,得到其编号DFN。
        再次,对给定的点按照DFS序排序,用一个栈维护加入虚树的结点遍号,之后将点按照DFS序从小到大的顺序依次加入虚树,规则如下:

  • 若栈中元素数目小于2个,则直接进栈。
  • 若栈中元素数目不小于2个,则求待进栈结点与栈首结点的lca。如果lca就是栈首元素,说明待入栈结点是栈首元素的子结点,直接进栈。
  • 如果lca不是栈首元素,说明栈首元素已经没有需要处理的子结点(毕竟是按照DFS序进栈的),这时构建虚树子树,构建方法是:取栈首元素和第二个元素,若两者都是lca的子结点(这个用DFS序判断),则两者连边作为虚树的树边,根据DFS序,可知第二个元素必定是栈首元素的父结点,加完边后栈首元素出栈,重复这个过程直到栈中元素少于2个或者栈首元素和第二元素不都是lca的子结点。这个过程完成后,判断lca是否是栈首元素,若否说明当前栈首元素是栈中唯一一个为lca子结点的结点,连接lca和这个结点,然后将栈首替换为lca。最后待进栈元素进栈。
  • 最后,可能虚树仍然没有构建完成,最后还需要一次出栈处理。

        这一步示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
inline void insert(int x) {//建立虚树
if (top <= 1)sta[++top] = x;//数组模拟栈
else {
int lca = LCA(sta[top], x);//求lca
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])add(sta[top - 1], sta[top]), --top;
if (lca != sta[top])add(lca, sta[top]), sta[top] = lca;
sta[++top] = x;
}
}
}

        最后一次出栈处理代码:
1
while (top > 1)add(sta[top - 1], sta[top]), --top;//最后一次的出栈处理

        这里会有一个问题:虚树边权如何定义?这个需要具体分析,本题中将点和点之间的链化成了一条边,根据题目性质,这一条边权值应该是它所连接的两点在原树中链边权的最小值。这个可以用倍增法快速求。
        建完虚树之后,DP方程就很好列了(本来就好列)。规定$dp(i)$为以i点为根的子树中,去除所有目标点所需的最小代价,那么有转移方程:

        $vis[x]$在i为目标结点时为1,否则为0。k为i结点的儿子数目,编号为1~k,$val(i,j)$是结点i到结点j在虚树上的连边权值。
        然后本题就切掉了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include<bits/stdc++.h>

#define N 250005
using namespace std;

inline int read() {//读入优化
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}

unordered_set<int> ssp;//用于预处理
struct Edge {
int to, next, v;
} edge[N << 1], vedge[N];
int n, m, head[N], vhead[N], x, y, z, cnt = 1, vcnt = 1, grand[N][19], DFN[N], depth[N], DFNCNT = 1;
int op[500002], sta[500002], top, grand_min[N][19];
long long dp[N];
bool vis[N];

inline void add(int x, int y, int w) {//原树
edge[cnt].to = y, edge[cnt].next = head[x], edge[cnt].v = w, head[x] = cnt++;
}

inline void addv(int x, int y, int w) {//虚树
vedge[vcnt].to = y, vedge[vcnt].next = vhead[x], vedge[vcnt].v = w, vhead[x] = vcnt++;
}

inline int LCA(int x, int y) {//倍增LCA
if (depth[x] > depth[y])swap(x, y);
for (int i = 18; i >= 0; i--)if (depth[grand[y][i]] >= depth[x])y = grand[y][i];
if (x == y)return x;
for (int i = 18; i >= 0; i--)if (grand[x][i] != grand[y][i])x = grand[x][i], y = grand[y][i];
return grand[x][0];
}

inline int minEdge(int x, int y) {//求最小边权
int ans = 0x7fffffff;
if (depth[x] > depth[y])swap(x, y);
for (int i = 18; i >= 0; i--)if (depth[grand[y][i]] >= depth[x])ans = min(ans, grand_min[y][i]), y = grand[y][i];
return ans;
}

void DFS(int x) {
DFN[x] = DFNCNT++;
for (int i = head[x]; i; i = edge[i].next) {
if (depth[edge[i].to] == 0) {
depth[edge[i].to] = depth[x] + 1, grand[edge[i].to][0] = x, grand_min[edge[i].to][0] = edge[i].v;
for (int j = 1; j <= 18; j++)grand[edge[i].to][j] = grand[grand[edge[i].to][j - 1]][j - 1];
for (int j = 1; j <= 18; j++)
grand_min[edge[i].to][j] = min(grand_min[edge[i].to][j - 1],
grand_min[grand[edge[i].to][j - 1]][j - 1]);
DFS(edge[i].to);
}
}
}

bool cmp(int x, int y) {
return DFN[x] < DFN[y];
}

inline void insert(int x) {//建立虚树
ssp.insert(x);
if (top <= 1)sta[++top] = x;
else {
int lca = LCA(sta[top], x);
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])
addv(sta[top - 1], sta[top], minEdge(sta[top - 1], sta[top])), --top;
if (lca != sta[top])addv(lca, sta[top], minEdge(lca, sta[top])), sta[top] = lca, ssp.insert(lca);
sta[++top] = x;
}
}
}

long long DP(int x) {//DP过程
if (dp[x] != -1)return dp[x];
if (vis[x])return dp[x] = 1l << 60;
dp[x] = 0ll;
for (int i = vhead[x]; i; i = vedge[i].next)dp[x] += min(1ll * vedge[i].v, DP(vedge[i].to));
return dp[x];
}

int main() {
n = read();
for (int i = 1; i < n; i++)x = read(), y = read(), z = read(), add(x, y, z), add(y, x, z);
depth[1] = 1, DFS(1), m = read();
for (int i = 1; i <= m; i++) {
int k = read();
sta[top = 1] = 1, vcnt = 1, ssp.insert(1);
for (int j = 1; j <= k; j++)op[j] = read(), vis[op[j]] = true;
sort(op + 1, op + k + 1, cmp);
for (int j = 1; j <= k; j++)insert(op[j]);//构建虚树
while (top > 1)addv(sta[top - 1], sta[top], minEdge(sta[top - 1], sta[top])), --top;
for (int it : ssp)dp[it] = -1;
printf("%lld\n", DP(1));
for (int it : ssp)vis[it] = false, vhead[it] = 0;
ssp.clear();
}
return 0;
}

        这里需要注意一个问题:不能在每一次查询中都清一遍数组,这样时间复杂度就退化了(又达到$O(nm)$)。正确解法是记录这一次查询用到了那些位置,然后再单独对这些位置做初始化。上面代码中用unordered_set实现。

例题

        下面会整理例题,长期更新。不涉及题面,点击标题可跳转。

[HNOI2014]世界树

        都能看出来这是虚树,但是怎么$DP$是难点。
        建出虚树后,一遍DFS,用子结点更新父结点,可以找到子结点到父结点最近的议事处结点,并能记录编号。然后再来一遍DFS,用父结点更新子结点,使在结点祖先上的议事处结点信息更新过来。这里的DFS需要记录信息,本质上就是DP。
        做完这一步,在虚树上的结点的归宿就全部清楚了,现在考虑不在虚树上的结点如何处理。
        如果一个结点的子树不在虚树上,但它自己在虚树中,那么它的子树上结点的归宿与该结点一致,这是显然的。
        对于虚树上的一条边,它对应原树上的一条链(以及上边的子树),如果边连接的两个点归宿相同,那么边对应的链上结点归宿必然可以确定。
        如果边连接的两个点归宿不同,那么这条边上必然有一个分界点,使得分界点之上的归宿为第一个点,之下为第二个点,这一步可以倍增地找到。
        全程都需要注意当距离相同时,找编号小的那一个。本题比较考验码力。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <bits/stdc++.h>

#define N 300005
using namespace std;
struct Edge {
int next, to;
} edge[N << 1], vedge[N];
int head[N], n, cnt = 1, DFN[N], DFNID, sz[N], gr[N][20], dep[N], q, xx, op[N], sta[N], top;
int vhead[N], vcnt = 1, ans[N], f[N], g[N], vis[N], TIME, op2[N];
unordered_set<int> ssp;

bool cmp(int a, int b) {
return DFN[a] < DFN[b];
}

inline void add(int x, int y) {
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

inline void vadd(int x, int y) {
vedge[vcnt].to = y, vedge[vcnt].next = vhead[x], vhead[x] = vcnt++, ssp.insert(x), ssp.insert(y);
}

inline int LCA(int x, int y) {
if (dep[x] > dep[y])swap(x, y);
for (int i = 19; i >= 0; i--)if (dep[gr[y][i]] >= dep[x])y = gr[y][i];
if (x == y)return x;
for (int i = 19; i >= 0; i--)if (gr[x][i] != gr[y][i])x = gr[x][i], y = gr[y][i];
return gr[x][0];
}

inline void insert(int x) {
if (top <= 1)sta[++top] = x;
else {
int lca = LCA(sta[top], x);
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])vadd(sta[top - 1], sta[top]), --top;
if (lca != sta[top])vadd(lca, sta[top]), sta[top] = lca;
sta[++top] = x;
}
}
}

void DFS(int x, int fa) {
DFN[x] = ++DFNID, sz[x] = 1;
for (int i = head[x]; i; i = edge[i].next) {
if (edge[i].to != fa) {
dep[edge[i].to] = dep[x] + 1, gr[edge[i].to][0] = x;
for (int j = 1; j < 20; j++)gr[edge[i].to][j] = gr[gr[edge[i].to][j - 1]][j - 1];
DFS(edge[i].to, x), sz[x] += sz[edge[i].to];
}
}
}

void DFS2(int x) {
if (vis[x] == TIME)f[x] = 0, g[x] = x;
else f[x] = 0x7fffffff;
for (int i = vhead[x]; i; i = vedge[i].next) {
DFS2(vedge[i].to);
if (f[vedge[i].to] + dep[vedge[i].to] - dep[x] < f[x])
f[x] = f[vedge[i].to] + dep[vedge[i].to] - dep[x], g[x] = g[vedge[i].to];
else if (f[vedge[i].to] + dep[vedge[i].to] - dep[x] == f[x] && g[vedge[i].to] < g[x])g[x] = g[vedge[i].to];
}
}

void DFS3(int x, int dis, int who) {
if (vis[x] == TIME) {
for (int i = vhead[x]; i; i = vedge[i].next)DFS3(vedge[i].to, dep[vedge[i].to] - dep[x], x);
} else {
if (dis < f[x])f[x] = dis, g[x] = who;
else if (dis == f[x] && who < g[x])g[x] = who;
if (f[x] < dis)dis = f[x], who = g[x];
else if (f[x] == dis && g[x] < who)who = g[x];
for (int i = vhead[x]; i; i = vedge[i].next)DFS3(vedge[i].to, dis + dep[vedge[i].to] - dep[x], who);
}
}

inline int SON(int x, int y) {
for (int i = 19; i >= 0; i--)if (dep[gr[y][i]] > dep[x])y = gr[y][i];
return y;
}

void DFS_ans(int x) {
int num = 0;
for (int i = vhead[x]; i; i = vedge[i].next) {
DFS_ans(vedge[i].to), num += sz[SON(x, vedge[i].to)];
if (g[x] == g[vedge[i].to])ans[g[x]] += sz[SON(x, vedge[i].to)] - sz[vedge[i].to];
else {
int now = vedge[i].to, from;
for (int j = 19; j >= 0; j--) {
if (dep[gr[now][j]] > dep[x] &&
f[vedge[i].to] + dep[vedge[i].to] - dep[gr[now][j]] <= f[x] + dep[gr[now][j]] - dep[x]) {
now = gr[now][j];
}
}
ans[g[x]] += sz[SON(x, now)] - sz[now];
ans[g[vedge[i].to]] += sz[from = SON(now, vedge[i].to)] - sz[vedge[i].to];
if (f[vedge[i].to] + dep[vedge[i].to] - dep[now] < f[x] + dep[now] - dep[x]) {
ans[g[vedge[i].to]] += sz[now] - sz[from];
} else if (g[vedge[i].to] < g[x])ans[g[vedge[i].to]] += sz[now] - sz[from];
else ans[g[x]] += sz[now] - sz[from];
}

}
ans[g[x]] += sz[x] - num - 1;
}

int main() {
scanf("%d", &n);
for (int i = 1, x, y; i < n; i++)scanf("%d%d", &x, &y), add(x, y), add(y, x);
dep[1] = 1, DFS(1, 0), scanf("%d", &q);
while (q--) {
++TIME, scanf("%d", &xx), vcnt = 1, ssp.clear();
for (int i = 1; i <= xx; i++)scanf("%d", op + i), ans[op2[i] = op[i]] = 0, vis[op[i]] = TIME;
sort(op + 1, op + xx + 1, cmp);
if (op[1] != 1)sta[top = 1] = 1;
else top = 0;
for (int i = 1; i <= xx; i++)insert(op[i]);
while (top > 1)vadd(sta[top - 1], sta[top]), --top;
DFS2(1), DFS3(1, 0x7fffffff, -1), DFS_ans(1);
for (int x:ssp)++ans[g[x]], vhead[x] = 0;
for (int i = 1; i <= xx; i++)printf("%d ", ans[op2[i]]);
printf("\n");
}
return 0;
}