虚树
/ / 阅读耗时 22 分钟 介绍虚树,一个解决树上动态规划的利器。
先来一个经典例题:洛谷P2011消耗战。
很明显是树型DP,但是奈何n过大,时间复杂度达到$O(nm)$,必然TLE,但是我们又没有什么好的解决思路。这时看到了一个条件$\displaystyle\sum_{i=1}^mk_i<=500000$,这说明无论m有多大,k的总和总是不大于500000,这个条件能不能用来优化算法?
答案是肯定的,这时就需要引入虚树这个黑科技。
所谓虚树是指将原树中与问题无关的点剔除,保留与问题有关的点及其lca,之后得到的树。经过这种处理后的树保留了点的相对关系,又去除了冗余的点,可以提高算法效率,这样得到的树就称为虚树。相对应地,原有的树称为原树。虚树能够有效利用与问题有关的点,这样k总和不大于500000的条件就很好地利用上了。
上图是例题样例图。假如现在需要处理10和6两个点,那么树可以简化成这样:
然后复杂度就降下来了,这就是虚树。
虚树的构建
现在给定原树和我们需要处理的点,如何求虚树呢?这里需要前缀知识:求LCA。只要不是暴力,选哪一个求LCA的方法均可,包括但不仅限于倍增法、tarjan算法、树链剖分。这里选择倍增法。
首先,我们需要对树上结点进行DFS排序,得到其编号DFN。
再次,对给定的点按照DFS序排序,用一个栈维护加入虚树的结点遍号,之后将点按照DFS序从小到大的顺序依次加入虚树,规则如下:
- 若栈中元素数目小于2个,则直接进栈。
- 若栈中元素数目不小于2个,则求待进栈结点与栈首结点的lca。如果lca就是栈首元素,说明待入栈结点是栈首元素的子结点,直接进栈。
- 如果lca不是栈首元素,说明栈首元素已经没有需要处理的子结点(毕竟是按照DFS序进栈的),这时构建虚树子树,构建方法是:取栈首元素和第二个元素,若两者都是lca的子结点(这个用DFS序判断),则两者连边作为虚树的树边,根据DFS序,可知第二个元素必定是栈首元素的父结点,加完边后栈首元素出栈,重复这个过程直到栈中元素少于2个或者栈首元素和第二元素不都是lca的子结点。这个过程完成后,判断lca是否是栈首元素,若否说明当前栈首元素是栈中唯一一个为lca子结点的结点,连接lca和这个结点,然后将栈首替换为lca。最后待进栈元素进栈。
- 最后,可能虚树仍然没有构建完成,最后还需要一次出栈处理。
这一步示例代码:1
2
3
4
5
6
7
8
9
10
11
12inline void insert(int x) {//建立虚树
if (top <= 1)sta[++top] = x;//数组模拟栈
else {
int lca = LCA(sta[top], x);//求lca
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])add(sta[top - 1], sta[top]), --top;
if (lca != sta[top])add(lca, sta[top]), sta[top] = lca;
sta[++top] = x;
}
}
}
最后一次出栈处理代码:1
while (top > 1)add(sta[top - 1], sta[top]), --top;//最后一次的出栈处理
这里会有一个问题:虚树边权如何定义?这个需要具体分析,本题中将点和点之间的链化成了一条边,根据题目性质,这一条边权值应该是它所连接的两点在原树中链边权的最小值。这个可以用倍增法快速求。
建完虚树之后,DP方程就很好列了(本来就好列)。规定$dp(i)$为以i点为根的子树中,去除所有目标点所需的最小代价,那么有转移方程:
$vis[x]$在i为目标结点时为1,否则为0。k为i结点的儿子数目,编号为1~k,$val(i,j)$是结点i到结点j在虚树上的连边权值。
然后本题就切掉了。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
using namespace std;
inline int read() {//读入优化
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}
unordered_set<int> ssp;//用于预处理
struct Edge {
int to, next, v;
} edge[N << 1], vedge[N];
int n, m, head[N], vhead[N], x, y, z, cnt = 1, vcnt = 1, grand[N][19], DFN[N], depth[N], DFNCNT = 1;
int op[500002], sta[500002], top, grand_min[N][19];
long long dp[N];
bool vis[N];
inline void add(int x, int y, int w) {//原树
edge[cnt].to = y, edge[cnt].next = head[x], edge[cnt].v = w, head[x] = cnt++;
}
inline void addv(int x, int y, int w) {//虚树
vedge[vcnt].to = y, vedge[vcnt].next = vhead[x], vedge[vcnt].v = w, vhead[x] = vcnt++;
}
inline int LCA(int x, int y) {//倍增LCA
if (depth[x] > depth[y])swap(x, y);
for (int i = 18; i >= 0; i--)if (depth[grand[y][i]] >= depth[x])y = grand[y][i];
if (x == y)return x;
for (int i = 18; i >= 0; i--)if (grand[x][i] != grand[y][i])x = grand[x][i], y = grand[y][i];
return grand[x][0];
}
inline int minEdge(int x, int y) {//求最小边权
int ans = 0x7fffffff;
if (depth[x] > depth[y])swap(x, y);
for (int i = 18; i >= 0; i--)if (depth[grand[y][i]] >= depth[x])ans = min(ans, grand_min[y][i]), y = grand[y][i];
return ans;
}
void DFS(int x) {
DFN[x] = DFNCNT++;
for (int i = head[x]; i; i = edge[i].next) {
if (depth[edge[i].to] == 0) {
depth[edge[i].to] = depth[x] + 1, grand[edge[i].to][0] = x, grand_min[edge[i].to][0] = edge[i].v;
for (int j = 1; j <= 18; j++)grand[edge[i].to][j] = grand[grand[edge[i].to][j - 1]][j - 1];
for (int j = 1; j <= 18; j++)
grand_min[edge[i].to][j] = min(grand_min[edge[i].to][j - 1],
grand_min[grand[edge[i].to][j - 1]][j - 1]);
DFS(edge[i].to);
}
}
}
bool cmp(int x, int y) {
return DFN[x] < DFN[y];
}
inline void insert(int x) {//建立虚树
ssp.insert(x);
if (top <= 1)sta[++top] = x;
else {
int lca = LCA(sta[top], x);
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])
addv(sta[top - 1], sta[top], minEdge(sta[top - 1], sta[top])), --top;
if (lca != sta[top])addv(lca, sta[top], minEdge(lca, sta[top])), sta[top] = lca, ssp.insert(lca);
sta[++top] = x;
}
}
}
long long DP(int x) {//DP过程
if (dp[x] != -1)return dp[x];
if (vis[x])return dp[x] = 1l << 60;
dp[x] = 0ll;
for (int i = vhead[x]; i; i = vedge[i].next)dp[x] += min(1ll * vedge[i].v, DP(vedge[i].to));
return dp[x];
}
int main() {
n = read();
for (int i = 1; i < n; i++)x = read(), y = read(), z = read(), add(x, y, z), add(y, x, z);
depth[1] = 1, DFS(1), m = read();
for (int i = 1; i <= m; i++) {
int k = read();
sta[top = 1] = 1, vcnt = 1, ssp.insert(1);
for (int j = 1; j <= k; j++)op[j] = read(), vis[op[j]] = true;
sort(op + 1, op + k + 1, cmp);
for (int j = 1; j <= k; j++)insert(op[j]);//构建虚树
while (top > 1)addv(sta[top - 1], sta[top], minEdge(sta[top - 1], sta[top])), --top;
for (int it : ssp)dp[it] = -1;
printf("%lld\n", DP(1));
for (int it : ssp)vis[it] = false, vhead[it] = 0;
ssp.clear();
}
return 0;
}
这里需要注意一个问题:不能在每一次查询中都清一遍数组,这样时间复杂度就退化了(又达到$O(nm)$)。正确解法是记录这一次查询用到了那些位置,然后再单独对这些位置做初始化。上面代码中用unordered_set实现。
例题
下面会整理例题,长期更新。不涉及题面,点击标题可跳转。
[HNOI2014]世界树
都能看出来这是虚树,但是怎么$DP$是难点。
建出虚树后,一遍DFS,用子结点更新父结点,可以找到子结点到父结点最近的议事处结点,并能记录编号。然后再来一遍DFS,用父结点更新子结点,使在结点祖先上的议事处结点信息更新过来。这里的DFS需要记录信息,本质上就是DP。
做完这一步,在虚树上的结点的归宿就全部清楚了,现在考虑不在虚树上的结点如何处理。
如果一个结点的子树不在虚树上,但它自己在虚树中,那么它的子树上结点的归宿与该结点一致,这是显然的。
对于虚树上的一条边,它对应原树上的一条链(以及上边的子树),如果边连接的两个点归宿相同,那么边对应的链上结点归宿必然可以确定。
如果边连接的两个点归宿不同,那么这条边上必然有一个分界点,使得分界点之上的归宿为第一个点,之下为第二个点,这一步可以倍增地找到。
全程都需要注意当距离相同时,找编号小的那一个。本题比较考验码力。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
using namespace std;
struct Edge {
int next, to;
} edge[N << 1], vedge[N];
int head[N], n, cnt = 1, DFN[N], DFNID, sz[N], gr[N][20], dep[N], q, xx, op[N], sta[N], top;
int vhead[N], vcnt = 1, ans[N], f[N], g[N], vis[N], TIME, op2[N];
unordered_set<int> ssp;
bool cmp(int a, int b) {
return DFN[a] < DFN[b];
}
inline void add(int x, int y) {
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}
inline void vadd(int x, int y) {
vedge[vcnt].to = y, vedge[vcnt].next = vhead[x], vhead[x] = vcnt++, ssp.insert(x), ssp.insert(y);
}
inline int LCA(int x, int y) {
if (dep[x] > dep[y])swap(x, y);
for (int i = 19; i >= 0; i--)if (dep[gr[y][i]] >= dep[x])y = gr[y][i];
if (x == y)return x;
for (int i = 19; i >= 0; i--)if (gr[x][i] != gr[y][i])x = gr[x][i], y = gr[y][i];
return gr[x][0];
}
inline void insert(int x) {
if (top <= 1)sta[++top] = x;
else {
int lca = LCA(sta[top], x);
if (lca == sta[top])sta[++top] = x;
else {
while (top > 1 && DFN[lca] <= DFN[sta[top - 1]])vadd(sta[top - 1], sta[top]), --top;
if (lca != sta[top])vadd(lca, sta[top]), sta[top] = lca;
sta[++top] = x;
}
}
}
void DFS(int x, int fa) {
DFN[x] = ++DFNID, sz[x] = 1;
for (int i = head[x]; i; i = edge[i].next) {
if (edge[i].to != fa) {
dep[edge[i].to] = dep[x] + 1, gr[edge[i].to][0] = x;
for (int j = 1; j < 20; j++)gr[edge[i].to][j] = gr[gr[edge[i].to][j - 1]][j - 1];
DFS(edge[i].to, x), sz[x] += sz[edge[i].to];
}
}
}
void DFS2(int x) {
if (vis[x] == TIME)f[x] = 0, g[x] = x;
else f[x] = 0x7fffffff;
for (int i = vhead[x]; i; i = vedge[i].next) {
DFS2(vedge[i].to);
if (f[vedge[i].to] + dep[vedge[i].to] - dep[x] < f[x])
f[x] = f[vedge[i].to] + dep[vedge[i].to] - dep[x], g[x] = g[vedge[i].to];
else if (f[vedge[i].to] + dep[vedge[i].to] - dep[x] == f[x] && g[vedge[i].to] < g[x])g[x] = g[vedge[i].to];
}
}
void DFS3(int x, int dis, int who) {
if (vis[x] == TIME) {
for (int i = vhead[x]; i; i = vedge[i].next)DFS3(vedge[i].to, dep[vedge[i].to] - dep[x], x);
} else {
if (dis < f[x])f[x] = dis, g[x] = who;
else if (dis == f[x] && who < g[x])g[x] = who;
if (f[x] < dis)dis = f[x], who = g[x];
else if (f[x] == dis && g[x] < who)who = g[x];
for (int i = vhead[x]; i; i = vedge[i].next)DFS3(vedge[i].to, dis + dep[vedge[i].to] - dep[x], who);
}
}
inline int SON(int x, int y) {
for (int i = 19; i >= 0; i--)if (dep[gr[y][i]] > dep[x])y = gr[y][i];
return y;
}
void DFS_ans(int x) {
int num = 0;
for (int i = vhead[x]; i; i = vedge[i].next) {
DFS_ans(vedge[i].to), num += sz[SON(x, vedge[i].to)];
if (g[x] == g[vedge[i].to])ans[g[x]] += sz[SON(x, vedge[i].to)] - sz[vedge[i].to];
else {
int now = vedge[i].to, from;
for (int j = 19; j >= 0; j--) {
if (dep[gr[now][j]] > dep[x] &&
f[vedge[i].to] + dep[vedge[i].to] - dep[gr[now][j]] <= f[x] + dep[gr[now][j]] - dep[x]) {
now = gr[now][j];
}
}
ans[g[x]] += sz[SON(x, now)] - sz[now];
ans[g[vedge[i].to]] += sz[from = SON(now, vedge[i].to)] - sz[vedge[i].to];
if (f[vedge[i].to] + dep[vedge[i].to] - dep[now] < f[x] + dep[now] - dep[x]) {
ans[g[vedge[i].to]] += sz[now] - sz[from];
} else if (g[vedge[i].to] < g[x])ans[g[vedge[i].to]] += sz[now] - sz[from];
else ans[g[x]] += sz[now] - sz[from];
}
}
ans[g[x]] += sz[x] - num - 1;
}
int main() {
scanf("%d", &n);
for (int i = 1, x, y; i < n; i++)scanf("%d%d", &x, &y), add(x, y), add(y, x);
dep[1] = 1, DFS(1, 0), scanf("%d", &q);
while (q--) {
++TIME, scanf("%d", &xx), vcnt = 1, ssp.clear();
for (int i = 1; i <= xx; i++)scanf("%d", op + i), ans[op2[i] = op[i]] = 0, vis[op[i]] = TIME;
sort(op + 1, op + xx + 1, cmp);
if (op[1] != 1)sta[top = 1] = 1;
else top = 0;
for (int i = 1; i <= xx; i++)insert(op[i]);
while (top > 1)vadd(sta[top - 1], sta[top]), --top;
DFS2(1), DFS3(1, 0x7fffffff, -1), DFS_ans(1);
for (int x:ssp)++ans[g[x]], vhead[x] = 0;
for (int i = 1; i <= xx; i++)printf("%d ", ans[op2[i]]);
printf("\n");
}
return 0;
}