初探CDQ分治与整体二分
/ / 阅读耗时 37 分钟 一种神奇的分治思想,可以顶替高级数据结构,不过需要离线。
CDQ分治
CDQ分治的本质思想是将问题分为两个子问题,并用其中一个子问题更新另一个子问题的答案。
二维偏序
首先需要先引入二维偏序问题。
给定$n$个二元组$(a_i,b_i)$,求对于每一个组,有多少二元组$(a_j,b_j)$满足条件:
当然可以每一次都扫一遍,复杂度为$O(n^2)$,有没有更快的方法?这里的CDQ分治便可以在$O(nlogn)$下解决这个问题。
回顾一下归并排序求逆序对的步骤。那里得到逆序本质上就是给定了一些二元组,只不过第一维是这个数在数组中的下标,然后求满足条件:
这样的$(a_j,b_j)$数量,把它们都加起来就是逆序对数量。
现在在看二维偏序,其实就是逆序对的拓展情况而已,它当然也能用分治方法解决,这就是CDQ分治。
我们首先先按第一维a排个序,然后再用归并排序排第二维,同时统计答案。代码计较简单,就不放了。
这里主要说一个树状数组的做法,这个方法很常用。一开始仍然是按照第一维排序,然后,用一个树状数组维护第二维的信息,用求前缀和的方式统计答案。模板题:HDU1541 Stars。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
using namespace std;
struct P {
int x, y;
bool operator<(P s) {
return x < s.x;
}
} p[150005];
int tree[330100], n, now = 1, ans[150005];
inline void add(int x) {
for (; x <= 330001; x += (x & -x))++tree[x];
}
inline int sum(int x) {
int s = 0;
for (; x >= 1; x -= (x & -x))s += tree[x];
return s;
}
int main() {
while (scanf("%d", &n) != EOF) {
memset(tree, 0, sizeof(tree));
memset(ans, 0, sizeof(ans));
for (int i = 1; i <= n; i++)scanf("%d%d", &p[i].x, &p[i].y), ++p[i].y;
sort(p + 1, p + n + 1), now = 1;
for (int i = 1; i <= n; i++) {
while (now <= n && p[now].x <= p[i].x)add(p[now].y), ++now;
++ans[sum(p[i].y) - 1];
}
for (int i = 1; i <= n; i++)printf("%d\n", ans[i - 1]);
}
return 0;
}
二维偏序能在很多问题上得到应用,比如说这里有这样的一个例子。给定很多区间,每次有一个区间查询,求包含在这个查询区间中的给定区间有多少个。这个问题可以转化为二维偏序问题,离线用树状数组可解。
二维偏序也可以在线做,需要用到多链扩展的主席树。
三维偏序
有了二维偏序的基础,三维偏序自然就可以带出,同样地,三维偏序也可以用CDQ分治去搞,其实任意维数理论上都可以用CDQ。模板题。
方法还是比较容易的。首先按照第一维排序,从而保证了第一维是有序的,进而按照二维偏序的方法对第二维跑归并排序,二维偏序和三维偏序仅在这里统计答案时有一些不同。在统计答案时,由于有第三维的存在,需要用一个树状数组来统计。
假设归并排序时左右指针分别为$i$和$j$,那么有下面两种情况:
- $b_i\leq b_j$此时说明$i$这一位可以对$b_j$后的元素产生贡献,将$c_i$加入树状数组。
- $b_i>b_j$说明此时所有满足第二维不大于$j$这一位的元素均已加入树状数组,直接求前缀和统计答案。
看起来比较容易,但还没有完成。由于当两个三元组相同时,它们也是互为偏序的,这里的分治统计显然无法完全计入这种情况。因此在执行算法之前,需要先去重并统计数量,然后再求偏序。
复杂度$O(nlognlogk)$。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
using namespace std;
struct N {
int a, b, c, cnt, ans;
} nd1[100005], nd2[100005], tmp[100005];
int tree[200005], n, k, m, ans[100005];
bool cmp(N a, N b) {
if (a.a != b.a)return a.a < b.a;
if (a.b != b.b)return a.b < b.b;
return a.c < b.c;
}
inline void add(int x, int p) {
for (; x <= k; x += x & -x)tree[x] += p;
}
inline int sum(int x) {
int s = 0;
for (; x >= 1; x -= x & -x)s += tree[x];
return s;
}
void CDQ(int l, int r) {
if (l >= r)return;
int mid = (l + r) >> 1;
CDQ(l, mid), CDQ(mid + 1, r);
int a = l, b = mid + 1, c = l;
while (a <= mid && b <= r) {
if (nd2[a].b <= nd2[b].b)add(nd2[a].c, nd2[a].cnt), tmp[c++] = nd2[a++];
else nd2[b].ans += sum(nd2[b].c), tmp[c++] = nd2[b++];
}
if (b == r + 1)while (a <= mid)tmp[c++] = nd2[a++];
else while (b <= r)nd2[b].ans += sum(nd2[b].c), tmp[c++] = nd2[b++];
a = l, b = mid + 1;
while (a <= mid && b <= r) {
if (nd2[a].b <= nd2[b].b)add(nd2[a].c, -nd2[a].cnt), a++;
else b++;
}
for (int i = l; i <= r; i++)nd2[i] = tmp[i];
}
int main() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++)scanf("%d%d%d", &nd1[i].a, &nd1[i].b, &nd1[i].c);
sort(nd1 + 1, nd1 + n + 1, cmp);
for (int i = 1, num = 0; i <= n; i++, m = num) {
if (i == 1 || nd1[i].a != nd1[i - 1].a || nd1[i].b != nd1[i - 1].b || nd1[i].c != nd1[i - 1].c) {
++num, ++nd2[num].cnt;
nd2[num].a = nd1[i].a;
nd2[num].b = nd1[i].b;
nd2[num].c = nd1[i].c;
} else ++nd2[num].cnt;
}
CDQ(1, m);
for (int i = 1; i <= m; i++)ans[nd2[i].ans + nd2[i].cnt - 1] += nd2[i].cnt;
for (int i = 0; i < n; i++)printf("%d\n", ans[i]);
return 0;
}
整体二分
下面是一种神奇的离线二分方法,在离线时可以吊锤高级数据结构。
先想这么一个问题:如何求静态区间第k小?
很明显,这可以用主席树直接切掉。但我们就是要没事找事,用离线的思维解决静态区间第k小。
如果只有一次查询,我们当然可以二分答案。对于一个给定的答案,遍历一遍区间,然后记录不大于这个答案的数的数量,判断与k的大小关系,接下来再二分。这种做法的时间复杂度为$O(nlogn)$,其实和排一个序的复杂度一致。对于$q$次询问,复杂度达到$O(qnlogn)$,是无法承受的。
为什么上面的二分方法效率过低?注意到每一次二分答案时,都只对一次查询有效,并且判断答案可行性时也是如此。这样每一次二分和判定都只对一次查询有效,效率自然低下,如果二分和判定能对很多次查询均有效,那么效率就能大大提升,这就是整体二分的思想。
考虑到程序主要在判定过程中耗时,那么可以考虑这样一种做法:
将所有的查询放到一个数组中,然后二分一个答案,遍历原区间数组,用树状数组标记不大于这个答案的位置。下一步遍历查询,找出每一个查询对应区间中被标记位置的数量,从而判断这个答案对该询问是过大还是过小。最后将两种类型的查询分成两类,继续分治下去,直到找到答案。
这种做法保证了一次二分以及判定对所有的查询都有效,并且使答案域减半,而且判定时只多消耗了树状数组的一个$log$的复杂度,看起来可行。但是,仔细想一想就会发现,二分下去的时候,虽然答案域的规模减半,保证了递归深度为$O(logn)$级别,但是每一次都需要遍历一整个原区间数组,这个问题的规模丝毫没有下降,最终的时间复杂度将达到$O(n^2logn)$。
现在问题的关键是降低遍历原区间数组的问题规模。进一步思考可以发现,在答案域不断减半时,根本无需遍历那些不在答案域中的元素,这样就可以使遍历区间数组的问题规模随着答案域降低,时间复杂度达到理想的$O(nlog^2n)$。
很显然,在现有的处理下,不可能只遍历那些在答案域中的元素,还需要一些额外的处理。
我们将原数组的值看成$n$个赋值操作,将它们对应的下标和值也放到查询的数组中,和查询一起二分。每当二分出一个答案并且将查询分组时,我们将对应的赋值操作也进行分组。在判定时,不再遍历原数组,而是直接在查询数组中找到相应的操作,然后更新树状数组,这样便可以巧妙地降低判定时遍历数组的问题规模。
附上模板题的整体二分代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
using namespace std;
inline int read() {
char e = getchar();
int s = 0, g = 0;
while (e < '-')e = getchar();
if (e == '-')g = 1, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}
struct Q {
int l, r, k, id, ans, rk;
} q[N << 1], tmp1[N << 1], tmp2[N << 1];
int n, m, ans[N], tmp[N], sv, tree[N], cnt;
inline void add(int x, int y) {
for (; x <= n; x += x & -x)tree[x] += y;
}
inline int sum(int x) {
int s = 0;
for (; x >= 1; x -= x & -x)s += tree[x];
return s;
}
void solve(int l, int r, int ql, int qr) {
if (l == r) {
for (int i = ql; i <= qr; i++)if (q[i].rk)q[i].ans = l;
return;
} else if (ql > qr)return;
int mid = (l + r) >> 1, a = 1, b = 1, c = ql;
for (int i = ql; i <= qr; i++) {
if (q[i].rk == 0) {
if (q[i].l <= mid)add(q[i].r, 1), tmp1[a++] = q[i];
else tmp2[b++] = q[i];
}
}
for (int i = ql; i <= qr; i++) {
if (q[i].rk) {
if (sum(q[i].r) - sum(q[i].l - 1) < q[i].k)q[i].k -= sum(q[i].r) - sum(q[i].l - 1), tmp2[b++] = q[i];
else tmp1[a++] = q[i];
}
}
for (int i = ql; i <= qr; i++) {
if (q[i].rk == 0)if (q[i].l <= mid)add(q[i].r, -1);
}
for (int i = 1; i < a; i++)q[c++] = tmp1[i];
for (int i = 1; i < b; i++)q[c++] = tmp2[i];
solve(l, mid, ql, ql + a - 2), solve(mid + 1, r, ql + a - 1, qr);
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++)q[++cnt].l = tmp[i] = read(), q[cnt].r = i;
sort(tmp + 1, tmp + n + 1), sv = unique(tmp + 1, tmp + n + 1) - tmp - 1;
for (int i = 1; i <= n; i++)q[i].l = lower_bound(tmp + 1, tmp + sv + 1, q[i].l) - tmp;
for (int i = 1; i <= m; i++) {
q[++cnt].l = read(), q[cnt].r = read(), q[cnt].k = read(), q[cnt].id = i, q[cnt].rk = 1;
}
solve(1, n, 1, cnt);
for (int i = 1; i <= cnt; i++)if (q[i].rk)ans[q[i].id] = tmp[q[i].ans];
for (int i = 1; i <= m; i++)printf("%d\n", ans[i]);
return 0;
}
例题
一些例题,可能是CDQ分治也可能是整体二分。不涉及题面,点击标题可跳转。
[POI2011]MET-Meteors
经典的整体二分题目,强烈推荐一做。
首先暴力找肯定是不行的,可以考虑整体二分,它的思路与求区间第k小基本相同。唯一需要注意的就是一个国家可能有多个太空站,并且有的国家可能没有。
具体思路就是二分一个答案,然后根据这个答案将不同的国家分成两部分,然后分治下去,复杂度为$O((k+n)logmlogk)$。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
using namespace std;
inline int read() {
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}
struct Q {
int l, r, a, p, time;
} q[600005], tp1[600005], tp2[600005];
int n, m, k, ans[300005], cnt, obj[600005];
long long tree[300005], now[300005];
inline void add(int x, int y) {
for (; x <= m; x += x & -x)tree[x] += y;
}
inline long long sum(int x) {
long long s = 0;
for (; x >= 1; x -= x & -x)s += tree[x];
return s;
}
void solve(int l, int r, int ql, int qr) {
if (ql > qr)return;
if (l == r) {
for (int i = ql; i <= qr; i++)if (!q[i].p)ans[q[i].r] = l;
return;
}
int mid = (l + r) >> 1, c1 = 0, c2 = 0;
for (int i = ql; i <= qr; i++) {
if (q[i].p && q[i].time <= mid) {
if (q[i].l <= q[i].r)add(q[i].l, q[i].a), add(q[i].r + 1, -q[i].a);
else add(q[i].l, q[i].a), add(1, q[i].a), add(q[i].r + 1, -q[i].a);
tp1[++c1] = q[i];
} else if (q[i].p)tp2[++c2] = q[i];
}
for (int i = ql; i <= qr; i++)if (!q[i].p && now[q[i].r] <= obj[q[i].r])now[q[i].r] += sum(q[i].l);
for (int i = ql; i <= qr; i++) {
if (!q[i].p) {
if (now[q[i].r] >= obj[q[i].r] && obj[q[i].r] > 0)tp1[++c1] = q[i];
else {
tp2[++c2] = q[i];
if (obj[q[i].r] > 0)obj[q[i].r] -= now[q[i].r], obj[q[i].r] = -obj[q[i].r];
}
}
}
for (int i = ql; i <= qr; i++)
if (!q[i].p) {
now[q[i].r] = 0;
if (obj[q[i].r] < 0)obj[q[i].r] = -obj[q[i].r];
}
for (int i = ql; i <= qr; i++) {
if (q[i].p && q[i].time <= mid) {
if (q[i].l <= q[i].r)add(q[i].l, -q[i].a), add(q[i].r + 1, q[i].a);
else add(q[i].l, -q[i].a), add(1, -q[i].a), add(q[i].r + 1, q[i].a);
}
}
for (int i = 1, v = ql; i <= c1; i++, v++)q[v] = tp1[i];
for (int i = 1, v = ql + c1; i <= c2; i++, v++)q[v] = tp2[i];
solve(l, mid, ql, ql + c1 - 1), solve(mid + 1, r, ql + c1, qr);
}
int main() {
n = read(), m = read();
for (int i = 1; i <= m; i++)q[++cnt].l = i, q[cnt].r = read();
for (int i = 1; i <= n; i++)obj[i] = read();
k = read();
for (int i = 1; i <= k; i++) {
q[++cnt].l = read(), q[cnt].r = read(), q[cnt].a = read(), q[cnt].p = 1, q[cnt].time = i;
}
solve(1, k + 1, 1, cnt);
for (int i = 1; i <= n; i++) {
if (ans[i] == k + 1 || ans[i] == 0)printf("NIE\n");
else printf("%d\n", ans[i]);
}
return 0;
}
[ZJOI2013]k大数查询
动态地在区间中插数,求区间第k大,这题跑整体二分是很好的做法,吊锤树套树,只不过要离线。
和上文中的区间第k大没有什么大的差别,只不过需要区间修改和求区间和。用线段树太复杂了,可以用两个树状数组来达到同样的目的。
1 |
|
[HNOI2015]接水果
最麻烦的显然是树链的包含关系,这个不好描述。
对整棵树进行一遍树链剖分,得到DFS序,考虑一下如何用DFS序来描述树链之间的包含关系。需要分两种情况来讨论,假设链为$(u,v)$且有$DFN(u)<DFN(v)$,包含它的链是$(a,b),DFN(a)<DFN(b)$。子树$x$的结点数目用$size(x)$表示。
- 第一种情况:$lca(u,v)\not=u$且$lca(u,v)\not=v$
只需要$a$和$b$分别在两棵子树中即可,那么有: - 第二种情况:$lca(u,v)=u$
设$s$为$(u,v)$上不含端点的第一个结点(这个可以用树剖优雅地求出来),则有: 或者:
可以发现上面的不等式实质上是把一个点$(DFN(u),DFN(v))$放到了一个矩形中,如果一个矩形包含了这个点,那么这个矩形就可以对该点产生贡献。
给每一个矩形赋上权值,现在就是要求覆盖某一个点的所有矩形中,权值第$k$小的那一个。那这样问题变得简单了,我们只需要将矩形化成扫描线,进行一步整体二分,然后用树状数组配合差分维护即可。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
using namespace std;
inline int read() {
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}
struct Edge {
int next, to;
} e[N << 1];
struct Q {
int x, down, up, v, rk;
bool operator<(Q q) {
return x < q.x;
}
} qu[N << 3], tp1[N << 3], tp2[N << 3], tp3[N << 3], tp4[N << 3];
int n, p, q, head[N], cnt = 1, top[N], dfn[N], DFN, son[N], sz[N], dep[N], fa[N], qt, ans[N];
int tree[N], num[N];
inline void add(int x, int y) {
e[cnt].to = y, e[cnt].next = head[x], head[x] = cnt++;
}
inline void addP(int a, int b, int c, int d, int v) {
qu[++qt].x = a, qu[qt].down = b, qu[qt].up = d, qu[qt].v = v, qu[qt].rk = 1;
qu[++qt].x = c, qu[qt].down = b, qu[qt].up = d, qu[qt].v = v, qu[qt].rk = -1;
}
inline void addQ(int x, int y, int k, int rk) {
qu[++qt].x = x, qu[qt].down = y, qu[qt].up = k, qu[qt].v = -rk;
}
inline void ins(int x, int y) {
for (; x <= n; x += x & -x)tree[x] += y;
}
inline int sum(int x) {
int s = 0;
for (; x >= 1; x -= x & -x)s += tree[x];
return s;
}
void DFS(int x) {
int f = 0;
sz[x] = 1;
for (int i = head[x]; i; i = e[i].next) {
if (!sz[e[i].to]) {
fa[e[i].to] = x, dep[e[i].to] = dep[x] + 1, DFS(e[i].to);
if (sz[e[i].to] > f)f = sz[e[i].to], son[x] = e[i].to;
sz[x] += sz[e[i].to];
}
}
}
void DFS(int x, int tp) {
top[x] = tp, dfn[x] = ++DFN;
if (son[x])DFS(son[x], tp);
else return;
for (int i = head[x]; i; i = e[i].next) {
if (!dfn[e[i].to])DFS(e[i].to, e[i].to);
}
}
inline int LCA(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]])x = fa[top[x]];
else y = fa[top[y]];
}
return dep[x] > dep[y] ? y : x;
}
inline int first(int x, int y) {
while (top[x] != top[y]) {
if (dep[x] > dep[y])swap(x, y);
if (fa[top[y]] == x)return top[y];
y = fa[top[y]];
}
if (dep[x] > dep[y])swap(x, y);
return son[x];
}
void solve(int l, int r, int ql, int qr) {
if (ql > qr)return;
if (l == r) {
for (int i = ql; i <= qr; i++)if (qu[i].v < 0)ans[-qu[i].v] = l;
return;
}
int mid = (l + r) >> 1, t1 = 0, t2 = 0, a = 1, t3 = 0, t4 = 0;
for (int i = ql; i <= qr; i++) {
if (qu[i].v < 0)tp1[++t1] = qu[i];
else tp2[++t2] = qu[i];
}
for (int i = 1; i <= t2;) {
while (a <= t1 && (i > t2 || tp1[a].x < tp2[i].x))num[-tp1[a].v] = sum(tp1[a].down), ++a;
do {
if (tp2[i].v <= mid)ins(tp2[i].down, tp2[i].rk), ins(tp2[i].up + 1, -tp2[i].rk);
++i;
} while (i <= t2 && tp2[i].x == tp2[i - 1].x);
}
for (int i = ql; i <= qr; i++) {
if (qu[i].v < 0 && num[-qu[i].v] >= qu[i].up)tp3[++t3] = qu[i];
else if (qu[i].v < 0 && num[-qu[i].v] < qu[i].up)qu[i].up -= num[-qu[i].v], tp4[++t4] = qu[i];
}
a = 1;
for (int i = 1; i <= t2;) {
while (a <= t1 && (i > t2 || tp1[a].x < tp2[i].x))num[-tp1[a].v] = 0, ++a;
do {
if (tp2[i].v <= mid)ins(tp2[i].down, -tp2[i].rk), ins(tp2[i].up + 1, tp2[i].rk);
++i;
} while (i <= t2 && tp2[i].x == tp2[i - 1].x);
}
for (int i = ql; i <= qr; i++) {
if (qu[i].v >= 0 && qu[i].v <= mid)tp3[++t3] = qu[i];
else if (qu[i].v >= 0 && qu[i].v > mid)tp4[++t4] = qu[i];
}
for (int i = 1, j = ql; i <= t3; i++, j++)qu[j] = tp3[i];
for (int i = 1, j = ql + t3; i <= t4; i++, j++)qu[j] = tp4[i];
solve(l, mid, ql, ql + t3 - 1), solve(mid + 1, r, ql + t3, qr);
}
int main() {
n = read(), p = read(), q = read();
for (int i = 1, x, y; i < n; i++)x = read(), y = read(), add(x, y), add(y, x);
dep[1] = 1, DFS(1), DFS(1, 1);
for (int i = 1, a, b, c, d; i <= p; i++) {
a = read(), b = read(), c = read();
if (dfn[a] > dfn[b])swap(a, b);
if (LCA(a, b) == a)d = first(a, b), addP(1, L(b), L(d), R(b), c), addP(L(b), R(d) + 1, R(b) + 1, n, c);
else addP(L(a), L(b), R(a) + 1, R(b), c);
}
int sv = qt + 1;
for (int i = 1, a, b, k; i <= q; i++) {
a = read(), b = read(), k = read();
if (dfn[a] > dfn[b])swap(a, b);
addQ(L(a), L(b), k, i);
}
sort(qu + 1, qu + sv), sort(qu + sv, qu + qt + 1), solve(0, (1 << 30), 1, qt);
for (int i = 1; i <= q; i++)printf("%d\n", ans[i]);
return 0;
}