树链剖分是树上的重要的算法,阅读本文需要先了解线段树

        树链剖分主要解决维护并求树上最短路径和,子树和的问题,它的思想是将树划分为一条条链,再用数据结构来维护这些链。树链剖分需要保证每一个结点存在且仅存在于一条链中。下面用洛谷P3384来介绍树链剖分。
        首先需要清楚以下几个概念:

  • 结点数size:这个结点对应子树的总结点数(含自己)。
  • 父结点编号father:这个结点的父结点
  • 重儿子son:某个结点所有子结点中size最大的一个。叶子结点没有重儿子,其余结点有且仅有一个重儿子。
  • 重边:连接父结点和其重儿子的边。
  • 重链:顺次连接重边所得到的链边
  • 深度depth:结点深度
  • DFS序:结点在DFS中被首次访问的顺序编号
  • 链顶编号top:该结点所在链的顶端(depth最小的结点)标号

        下面介绍树链剖分步骤。

第一遍DFS:求father、size、son和depth

        第一遍DFS,求出基础数据,一遍DFS可以完成。

1
2
3
4
5
6
7
8
9
10
void DFS1(int x, int fa) {
int maxn = -1, r = -1;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (edge[i].to != fa) {//防止向父结点回溯
father[edge[i].to] = x, depth[edge[i].to] = depth[x] + 1, DFS1(edge[i].to, x), size[x] += size[edge[i].to];
if (size[edge[i].to] > maxn)maxn = size[edge[i].to], r = edge[i].to;//更新son
}
}
son[x] = r, size[x]++;//含自己,size要自加
}

第二遍DFS:求top、DFS序

        这一步是关键。这一步的主要目的是建立树链,规则如下:

  • 如果结点存在重儿子,则优先连接重边,继承已有的树链。
  • 所有轻儿子(除了重儿子就是轻儿子)均分别作一条新链的首端,继续向下延伸。
1
2
3
4
5
6
7
8
void DFS2(int x, int t) {//现在访问的结点编号、该结点所在树链的顶端编号
id[x] = dCnt, rk[dCnt] = x, dCnt++, top[x] = t;//构造双向映射(id和rk),进行DFS序重新编号并更新top值
if (size[x] == 1)return;//叶子结点直接return
if (son[x])DFS2(son[x], t);//存在重儿子优先选重儿子
for (int i = head[x]; i != -1; i = edge[i].next) {
if (edge[i].to != son[x] && edge[i].to != father[x])DFS2(edge[i].to, edge[i].to);//所有轻儿子作新的树链起点
}
}

        这里可能会有一个疑惑,为什么要重新编号呢?其实这才是树链剖分的目的:将每一条树链上的点编号都变成连续的,这样就可以用数据结构(比如线段树)维护它们的和。
        这一步完成后,可以发现,每一个结点(比如x)的所有子树结点编号都成为连续的了,它们的新编号区间为[id[x],id[x]+size[x]-1]。这样有利于我们用线段树解决子树修改求和问题。

最短路径的划分

        如何求两个结点之间最短路径和呢?在树链剖分后容易发现如果两个结点在同一条树链中,它们之间的结点新编号必然是连续的,这样求和和修改就转化为区间求和修改问题,用线段树维护即可。如果两个结点不在一条树链中,就需要采用下面的方法来将它们分到同一个树链中:(原题要取模,这里为了方便省略取模过程)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
inline int queL(int x, int y) {
int tx = top[x], ty = top[y], ans = 0;//找到两条链的top
while (tx != ty) {//top不同说明不在一条链上
if (depth[tx] >= depth[ty]) {//对于更深的top
ans += query(id[tx], id[x], 1, n, 1);//从top开始到这个结点的路径必然都在最短路径上,加上它的影响
x = father[tx], tx = top[x];
} else {
ans += query(id[ty], id[y], 1, n, 1);
y = father[ty], ty = top[y];
}
}
if (depth[x] <= depth[y])return ans + query(id[x], id[y], 1, n, 1);//划分到同一条链上之后直接处理
return ans + query(id[y], id[x], 1, n, 1);
}

        修改类比即可。这里可以发现树链剖分的本质是将树上的问题转化为区间上的问题。
        下面给出全部代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include<iostream>
#include<cstdio>
#include<algorithm>

#define MAX (100000+5)
using namespace std;
struct {
int to, next;
} edge[MAX * 2];
int head[MAX], father[MAX], son[MAX] = {0}, size[MAX] = {0};
int top[MAX], depth[MAX], cnt = 1, dCnt = 1, n, m, root, mod, x, y, z;
int id[MAX] = {0}, rk[MAX] = {0}, init[MAX], tree[4 * MAX], lazy[4 * MAX] = {0};

inline int read() {
char e = getchar();
while (e < '0' || e > '9')e = getchar();
int s = 0;
while (e >= '0' && e <= '9')s = s * 10 + e - '0', e = getchar();
return s;
}

inline void add(int x, int y) {
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

void DFS1(int x, int fa) {
int maxn = -1, r = -1;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (edge[i].to != fa) {
father[edge[i].to] = x, depth[edge[i].to] = depth[x] + 1, DFS1(edge[i].to, x), size[x] += size[edge[i].to];
if (size[edge[i].to] > maxn)maxn = size[edge[i].to], r = edge[i].to;
}
}
son[x] = r, size[x]++;
}

void DFS2(int x, int t) {
id[x] = dCnt, rk[dCnt] = x, dCnt++, top[x] = t;
if (size[x] == 1)return;
if (son[x])DFS2(son[x], t);
for (int i = head[x]; i != -1; i = edge[i].next) {
if (edge[i].to != son[x] && edge[i].to != father[x])DFS2(edge[i].to, edge[i].to);
}
}

void make(int l, int r, int k) {
if (l == r) {
tree[k] = init[rk[l]];
return;
}
int mid = (l + r) >> 1;
make(l, mid, 2 * k), make(mid + 1, r, 2 * k + 1);
tree[k] = (tree[2 * k] + tree[2 * k + 1]) % mod;
}

inline void down(int k, int l, int r) {
int mid = (l + r) >> 1;
tree[2 * k] = (tree[2 * k] + (mid - l + 1) * lazy[k] % mod) % mod;
tree[2 * k + 1] = (tree[2 * k + 1] + (r - mid) * lazy[k] % mod) % mod;
lazy[2 * k] += lazy[k], lazy[2 * k + 1] += lazy[k], lazy[2 * k] %= mod, lazy[2 * k + 1] %= mod, lazy[k] = 0;
}

int query(int x, int y, int l, int r, int k) {
int mid = (l + r) >> 1;
if (l < r && lazy[k] != 0)down(k, l, r);
if (x == l && y == r)return tree[k];
if (mid + 1 <= x)return query(x, y, mid + 1, r, 2 * k + 1);
if (mid >= y)return query(x, y, l, mid, 2 * k);
return (query(x, mid, l, mid, 2 * k) + query(mid + 1, y, mid + 1, r, 2 * k + 1)) % mod;
}

void change(int x, int y, int l, int r, int s, int k) {
int mid = (l + r) >> 1;
if (l < r && lazy[k] != 0)down(k, l, r);
if (x == l && y == r) {
tree[k] = (tree[k] + (r - l + 1) * s % mod) % mod;
lazy[k] += s;
return;
}
if (mid + 1 <= x)change(x, y, mid + 1, r, s, 2 * k + 1);
else if (mid >= y)change(x, y, l, mid, s, 2 * k);
else change(x, mid, l, mid, s, 2 * k), change(mid + 1, y, mid + 1, r, s, 2 * k + 1);
tree[k] = (tree[2 * k] + tree[2 * k + 1]) % mod;
}

inline void addL(int x, int y, int s) {
int tx = top[x], ty = top[y];
while (tx != ty) {
if (depth[tx] >= depth[ty]) {
change(id[tx], id[x], 1, n, s, 1);
x = father[tx], tx = top[x];
} else {
change(id[ty], id[y], 1, n, s, 1);
y = father[ty], ty = top[y];
}
}
if (depth[x] <= depth[y])change(id[x], id[y], 1, n, s, 1);
else change(id[y], id[x], 1, n, s, 1);
}

inline int queL(int x, int y) {
int tx = top[x], ty = top[y], ans = 0;
while (tx != ty) {
if (depth[tx] >= depth[ty]) {
ans += query(id[tx], id[x], 1, n, 1), ans %= mod;
x = father[tx], tx = top[x];
} else {
ans += query(id[ty], id[y], 1, n, 1), ans %= mod;
y = father[ty], ty = top[y];
}
}
if (depth[x] <= depth[y])return (ans + query(id[x], id[y], 1, n, 1)) % mod;
return (ans + query(id[y], id[x], 1, n, 1)) % mod;
}

int main() {
n = read(), m = read(), root = read(), mod = read();
for (int i = 1; i <= n; i++)init[i] = read(), head[i] = -1;
for (int i = 1; i < n; i++)x = read(), y = read(), add(x, y), add(y, x);
depth[root] = 1, DFS1(root, 0), DFS2(root, root), make(1, n, 1);
int cmd;
for (int i = 0; i < m; i++) {
cmd = read();
if (cmd == 1) {
x = read(), y = read(), z = read();
addL(x, y, z);
} else if (cmd == 2) {
x = read(), y = read();
cout << queL(x, y) << endl;
} else if (cmd == 3) {
x = read(), y = read();
change(id[x], id[x] + size[x] - 1, 1, n, y, 1);
} else {
x = read();
cout << query(id[x], id[x] + size[x] - 1, 1, n, 1) << endl;
}
}
return 0;
}

        后记:最短路径和也可以通过LCA来完成,做法是预处理每一个结点到根结点的前缀和,求出LCA后即可计算出。