点分治
/ / 阅读耗时 34 分钟 点分治是一类解决树上路径问题的算法,下面以判断树中是否存在长度为k的路径这一问题来介绍点分治算法。点分治是树分治算法的一部分,另一部分是边分治。
下面是一棵树(边权未标出):
对于一个路径,我们可以把它归结为两类:经过根节点和不经过根节点。对于不经过根节点的路径,总可以找到一个新的根节点使之经过根节点。比如图中EBF这一路径没有经过根节点A,但是它经过了子树的根节点B。因此所有路径都可以归结到第一种(也就是经过根节点),于是可以通过不同根结点来判别各种路径。这里体现了分治的思想。
下面来看看点分治的操作:
确定树根
我们一开始得到的树通常是无根树,即使是有根树也可以发现答案与树根的选取无关。不妨选取一个点为根使得树尽量平衡,递归次数尽量少。这个点显然是树的重心。重心可以通过一遍DFS求出。关于树的重心见这篇文章。
从树根开始进行分治
这里就是点分治的核心了,先来看看代码模板:1
2
3
4
5
6
7
8
9
10void divide(int x) {
work();//第一个work
vis[x] = 1;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (vis[edge[i].to])continue;
work();//第二个work
findRoot(edge[i].to);
divide(root);
}
}
该函数中x为现在树的树根编号。vis[x]=1标记该点已经被使用,下面遍历出边。当前函数处理的是经过点x的路径,所以对于出点(也就是子树的树根),我们需要对其进行递归处理(也就是处理经过子树树根的路径),发现这与原树是一个性质的问题,因此可以递归处理:先找重心(这里就相当于对子树重新选根构造)再递归。
这样分治会不会重呢?也就是同一条路径被多次处理?答案当然是不会,因为我们的路径处理只是在以该点为根的子树上。
两个work函数做什么?这就是处理经过点x的路径的函数。此时可能会有疑问:按理说只需要开头有一个work函数处理就可以了,为什么还要对每一个子结点再work一次?后一个work函数并不是子树的递归处理,而是去重。work函数与问题的性质有关,需要自己设计。
先来讨论本问题中如何处理work函数。
本问题中我们需要找有没有长度为k的路径。那么在work函数中先求出所有点(包括本身,路径长当然是0)到根节点x的路径长,然后如果两个互异点的路径长之和等于k,那么就可以判定存在长度为k的路径。求长度需要$O(n)$,互异点枚举需要$O(n^2)$。效率太低如何解决?先在$O(nlogn)$复杂度下对路径长进行排序,然后枚举每一个点进行二分,可以降低时间复杂度为$O(nlogn)$。当然还有更好的解决方法,那就是双指针扫描,时间复杂度$O(n)$,关于该技巧见本文末。
这样就完了吗?当然不会,还有前面提到的去重操作!注意到上面给出的图中,当处理点A时,若E到A的距离和F到A的距离之和是k,我们会认为这是一条合法的路径,但事实上它有重边,并不合法,因此需要将这条非法边去除。在第一个work函数中肯定无法知晓哪些点对是合法的,哪些是非法的,因此只能通过容斥原理来将非法边去除,这就是第二个work所做的事情。
如何在第二个work中统计非法边个数呢?注意到非法点对一定在同一子树中,并且凡是来源于子树B的非法点对(x,y),它们都满足:
v(A,B)是A点与B点之间连边的权值。这样我们在B子树中找到距离和(这里的距离就是到B的距离了,需要重求)为k-2v(A,B)的点对数量,再在答案中减去就可以了。由于两个work拥有同样的功能(都是求符合条件的点对数),因此这两个work可以用一个函数去处理。
下面给出洛谷P3806点分治模板题的示例代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
using namespace std;
struct Edge {
int to, next, v;
} edge[N * 2];
int head[N], vis[N], cnt = 1, n, m, ans[105], a, b, c, root, minn, size[N], f[N], dict[N], cntDFS;
int query[105];
inline int read() {
char e = getchar();
int s = 0;
while (e < '0' || e > '9')e = getchar();
while (e >= '0' && e <= '9')s = s * 10 + e - '0', e = getchar();
return s;
}
inline void add(int x, int y, int z) {
edge[cnt].to = y, edge[cnt].v = z, edge[cnt].next = head[x], head[x] = cnt++;
}
void DFS1(int x, int fa, int num) {
size[x] = 1, f[x] = 0;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (vis[edge[i].to] || edge[i].to == fa)continue;
DFS1(edge[i].to, x, num), size[x] += size[edge[i].to], f[x] = max(f[x], size[edge[i].to]);
}
f[x] = max(f[x], num - f[x]);
if (f[x] < minn)minn = f[x], root = x;
}
inline void findRoot(int x, int num) {
minn = 0x7fffffff;
DFS1(x, 0, num);
}
void DFS2(int x, int fa, int v) {
dict[cntDFS++] = v;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (vis[edge[i].to] || edge[i].to == fa)continue;
DFS2(edge[i].to, x, v + edge[i].v);
}
}
int find1(int l, int r, int p) {//(,]
if (l + 1 > r)return -1;
if (r == l + 1) {
if (dict[r] == p)return r;
return -1;
}
int mid = (l + r) >> 1;
if (dict[mid] >= p)return find1(l, mid, p);
return find1(mid, r, p);
}
int find2(int l, int r, int p) {//[,)
if (l + 1 > r)return -1;
if (r == l + 1) {
if (dict[l] == p)return l;
return -1;
}
int mid = (l + r) >> 1;
if (dict[mid] <= p)return find2(mid, r, p);
return find2(l, mid, p);
}
inline void work(int x, int p, int w) {
cntDFS = 1, DFS2(x, 0, 0), sort(dict + 1, dict + cntDFS);
for (int i = 1; i <= m; i++) {
for (int j = 1; j < cntDFS; j++) {
int l = find1(0, j - 1, query[i] - p - dict[j]), r = find2(1, j, query[i] - p - dict[j]);
if (l != -1 && r != -1)ans[i] += (r - l + 1) * w;
}
}
}
void divide(int x) {
work(x, 0, 1), vis[x] = 1;
for (int i = head[x]; i != -1; i = edge[i].next) {
if (vis[edge[i].to])continue;
work(edge[i].to, 2 * edge[i].v, -1);
findRoot(edge[i].to, size[edge[i].to]), divide(root);
}
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++)head[i] = -1;
for (int i = 1; i < n; i++)a = read(), b = read(), c = read(), add(a, b, c), add(b, a, c);
for (int i = 1; i <= m; i++)query[i] = read();
findRoot(1, n), divide(root);
for (int i = 1; i <= m; i++) {
if (ans[i] > 0)cout << "AYE" << endl;
else cout << "NAY" << endl;
}
return 0;
}
其它重要的事
点分治是一个很有用的算法,它将很不好优化的树上点对问题$O(n^2)$的复杂度直接降到$O(nlogn)$。这是如何做到的?
在向子树分治时,我们重新选取了子树的根,将其定为重心,这样可以保证每一次分治后,问题规模减小一半。如果每一次分治处理的时间复杂度为$O(n)$,那么最终复杂度当然是$O(nlogn)$。也正是因为这个原因,如果程序中求重心的部分出现了问题,那么时间复杂度很可能会退化到$O(n^2)$。
在上文点分治中,曾提到一个很重要的divide函数。这个函数在整个过程中会调用$n$次,但是每一次分治后,这个函数处理的问题规模减半,故即使其中再执行$O(n)$的操作,也能够保证时间复杂度。这启示我们:在某一次分治的操作中,只能在该分治子树上进行$O(n)$或者$O(nlogn)$的操作。否则,复杂度会很快退化。这里提出这个问题,一来可以帮助我们选取合适的算法来进行处理,二来可以避免一些玄学TLE。
有些时候,我们可能会在divide函数中用memset来清空一些东西。要注意,这是一个非常错误的做法!memset函数会对所有值进行均等的处理,即使问题规模已经减半,那些多余的部分也被memset处理。这样做不会有什么结果上的影响,但是却导致了时间复杂度退化为$O(n^2)$。事实证明,在divide中使用memset会使复杂度严重退化,极容易导致TLE。
到这里就可以看出点分治降复杂度的本质了:通过合理地划分子树来使问题规模不断减半。
前面提到了双指针扫描法,这其实是一种技巧。
给定一串数,如何找是否存在某个数对之和为k?
先排序然后枚举二分在$O(nlogn)$下解决当然可以,但我们也可以这样做:
先取两个指针l、r分别指向第一个元素和最后一个元素,然后判断:
- 若两个指针所指元素之和小于k,l++
- 若两个指针所指元素之和大于k,r—
- 若两个指针所指元素之和等于k,找到答案
这种方法对于是否存在的问题是很高效的,但计数并不易做到。快速排序中也有双指针的应用。
动态点分治
来看看动态点分治,例题引入:这里。
这题就是给一棵树,查询点权在某一个区间中的所有点到一个特定点的路径边权和。暴力的话,强行DFS,复杂度$O(nq)$。
对于更换点的路径和问题,可以用一个优雅的树链剖分算法解决,具体来说,我们先假定1为根,然后求出所有点到它的距离$dis[x]$,显然所有点到1的路径和为:
如果更换根结点为$r$,注意到有这样一个性质:
那么路径和就是:
这里的$lca$当然都是$r$的祖先结点,但是每一个占多少不好求。一个好方法是从每一个点开始走向1号结点,记录点被覆盖的次数,就是它在求和时的比重。根据这个式子,前面两项预处理完后$O(1)$就能求,后面那一项树链剖分配合线段树维护。
如果有了点权的限制,把线段树换成可持久化线段树就行,这样本题就得以解决。
树剖的做法确实巧妙,但是对于本题来说,占空间量是惊人的,很容易MLE,考虑别的做法,那就是动态点分治。
动态点分治其实就是在点分治的基础上,多了一步操作,那就是记录分治下去的子树的重心为当前的分治中心,从而建立一棵新树。这棵新树也被称为点分树。
在点分树中,所有子结点都是下面那一个树块的重心,所以点分树的高度是$O(logn)$级别的,这给予点分树很好的性质,可以完成一些很暴力的操作。比如,给点分树上每一个结点开一个vector存它的子结点的信息,复杂度仅为$O(nlogn)$,完全可以做到。
不过,点分树已经破坏了原树的结构关系,它往往需要一些空间来记录一些信息。首先这里有一个性质:点分树上的两个点的lca必是其在原树路径上的某一个结点。这启示我们可以这样求两个点的距离:
$dis(a)$是$a$到$a$与$b$在点分树上的lca在对应原树上的路径权和。
那么建出点分树,暴力地用vector存每一个点其对应树块上的所有结点到它的距离,并来一步sort使它们按点权有序。其次,还要开一个vector存这个结点到它在点分树上所有祖先的距离。
当处理查询时,找到对应的中心结点,暴力向上跳父结点,二分出对应的点,然后求距离和。这里的一个问题是父结点会将子结点已经计入的部分重新计入一次,这样的话,给每一个结点再开一个vector存其树块上的点到父结点的距离,在操作的时候减去这个值就好了。
最终总时间复杂度为$O(qlog^2n)$,比较考验码力。
1 |
|
另一个例题。
每一次修改,指定一个点$x$和数$w$,对于其余的点$i$,增加其点权$w-dis(i,x)$。查询单点权值。
对于这个题,我们仍然是建出点分树,修改时,找到$x$,在点分树上向上跳。在这里需要对其中的点进行两次修改。当从$x$跳到祖先$f$时,根据点分树的性质易知在树$f$上的节点$i$到$x$的距离为$dis(x,f)+dis(f,i)$(注意$i$不能在$f$到$x$方向上的子树上)。 这样向上跳时,给$f$增加$dis(x,f)$的贡献。这里需要注意不能重复计算,为了做到这一点,我们利用差分思想,当在$f$增加$dis(x,f)$,要在其对应的儿子节点上减去相应的贡献,表示这一部分重复计算了。
同样地,我们还需要维护另一个量,即上文中提到的$dis(i,f)$。维护一个$cnt$变量,将$x$到点分树树根上的点的$cnt$全部加一,然后利用$cnt$的差分来计算贡献。
至于$dis$可以使用倍增法求$LCA$来求。复杂度$O(nlog^2n)$。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
using namespace std;
inline int read()
{
char e = getchar();
int s = 0;
while (e < '-') e = getchar();
while (e > '-') s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}
struct Edge
{
int next, to;
} edge[N << 1];
int n, Q, head[N], c1, sz[N], num, root, vis[N], maxn, fa[N], val[N], cnt[N], gr[N][20], dep[N];
void init_DFS(int x, int f)
{
for (int i = head[x]; i; i = edge[i].next) {
if (edge[i].to == f) continue;
dep[edge[i].to] = dep[x] + 1, gr[edge[i].to][0] = x;
for (int j = 1; j < 20; j++) gr[edge[i].to][j] = gr[gr[edge[i].to][j - 1]][j - 1];
init_DFS(edge[i].to, x);
}
}
inline int lca(int x, int y)
{
if (dep[x] < dep[y]) swap(x, y);
for (int i = 19; i >= 0; i--) {
if (dep[gr[x][i]] >= dep[y]) x = gr[x][i];
}
if (x == y) return x;
for (int i = 19; i >= 0; i--) {
if (gr[x][i] != gr[y][i]) x = gr[x][i], y = gr[y][i];
}
return gr[x][0];
}
inline int getDis(int x, int y)
{
return dep[x] + dep[y] - 2 * dep[lca(x, y)];
}
long long ans;
inline void add1(int x, int y)
{
edge[c1].next = head[x], edge[c1].to = y, head[x] = c1++;
}
void DFS(int x, int f)
{
int s = 0;
sz[x] = 1;
for (int i = head[x]; i; i = edge[i].next) {
if (!vis[edge[i].to] && edge[i].to != f) {
DFS(edge[i].to, x), s = max(s, sz[edge[i].to]), sz[x] += sz[edge[i].to];
}
}
s = max(s, num - sz[x]);
if (maxn > s) maxn = s, root = x;
}
inline void findRoot(int x, int nm)
{
maxn = 0x7fffffff, num = nm, root = 0, DFS(x, 0);
}
void DFS2(int x, int f)
{
sz[x] = 1;
for (int i = head[x]; i; i = edge[i].next) {
if (!vis[edge[i].to] && edge[i].to != f) DFS2(edge[i].to, x), sz[x] += sz[edge[i].to];
}
}
void divide(int x)
{
vis[x] = 1, DFS2(x, 0);
for (int i = head[x]; i; i = edge[i].next) {
if (!vis[edge[i].to]) findRoot(edge[i].to, sz[edge[i].to]), fa[root] = x, divide(root);
}
}
inline void solve(int x)
{
int last = -1, now = x;
while (now) {
val[now] += getDis(x, now), ++cnt[now];
if (last != -1) val[last] -= getDis(now, x);
last = now, now = fa[now];
}
}
inline int getAns(int x)
{
if (x == 0) return 0;
int now = x, last = -1, ans = 0;
while (now) {
ans += val[now] + (last == -1 ? cnt[now] : cnt[now] - cnt[last]) * getDis(now, x);
last = now, now = fa[now];
}
return ans;
}
int sumW = 0, nmb[N];
signed main()
{
int TTT = read();
while (TTT--) {
n = read(), Q = read(), c1 = 1, maxn = num = root = 0;
for (int i = 1; i <= n; i++) dep[i] = sz[i] = fa[i] = head[i] = nmb[i] = cnt[i] = vis[i] = val[i] = 0;
memset(gr, 0, sizeof(gr)), sumW = 0;
for (int i = 1, x, y; i < n; i++) x = read(), y = read(), add1(x, y), add1(y, x);
dep[1] = 1, init_DFS(1, 0);
findRoot(1, n), divide(root);
int opt;
while (Q--) {
opt = read();
int x, w;
if (opt == 1) {
x = read(), w = read(), sumW += w, solve(x);
} else if (opt == 2) {
x = read(), nmb[x] = sumW - getAns(x);
} else if (opt == 3) {
x = read(), printf("%d\n", sumW - getAns(x) - nmb[x]);
}
}
}
return 0;
}