树套树
/ / 阅读耗时 12 分钟 本文介绍很基本的树套树,主要是怕自己忘了这个东西。
树套树(Tree in Tree),其实就是对于一棵树,其每一个结点也是树,由此形成的数据结构。树套树可以来维护区间、多维的信息。
本文只涉及区间查询第k小,查询前驱、后继以及单点修改所需要的树套树方法:线段树套平衡树。这是一类很基本的树套树方法。
线段树维护区间信息,而平衡树维护对应区间上的有序序列。两者结合可以比较容易地解决区间上的查询问题。模板题:戳这里。
下面探讨每一个操作怎么去做:
- 操作一:找到每一个分区间中严格小于k的数的数目,然后输出它们的数量+1。复杂度$O(log^2n)$。
- 操作二:找排名对应的数。由于分区间在这个问题上没有可加性,可以考虑二分的操作。先二分某一个值,用操作一的方法找到其对应的排名,与目标排名进行比较即可。复杂度$O(log^3n)$。
- 操作三:找到包含这个数的所有平衡树,删掉原数,加入新数。
- 操作四:找到所有分区间中的前驱,取最大的一个。
- 操作五:找到所有分区间中的后继,取最小的一个。
操作三到五的时间复杂度全为$O(log^2n)$。
平衡树有很多选择,有splay、treap、替罪羊树等等,这里选择splay树。另外树套树代码是真的长
下面给出模板题代码,常数大一点,吸个氧能过。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
using namespace std;
inline int read() {//读优
char e = getchar();
int s = 0;
while (e < '-')e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return s;
}
int n, m, op[N];
struct Node {
int f, ch[2], v, num, s;
Node() {
s = f = ch[0] = ch[1] = v = num = 0;
}
} nodes[N * 80];
struct Splay {//平衡树定义
static int newNode(int v) {//相当于内存池,分配新结点
static int cnt = 1;
nodes[cnt].v = v, nodes[cnt].s = nodes[cnt].num = 1;
return cnt++;
}
int root = 0;
static inline int identify(int x) {
return nodes[nodes[x].f].ch[1] == x;
}
static inline void update(int x) {
if (x)nodes[x].num = nodes[nodes[x].ch[0]].num + nodes[nodes[x].ch[1]].num + nodes[x].s;
}
static inline void change(int x, int y, int w) {
nodes[x].ch[w] = y, nodes[y].f = x;
}
inline void rotate(int x) {
if (x == root)return;
int f = nodes[x].f, g = nodes[f].f, i = identify(x), j = identify(f);
change(f, nodes[x].ch[i ^ 1], i), change(x, f, i ^ 1), change(g, x, j);
update(f), update(x);
}
inline void splay(int at, int to = 0) {
while (nodes[at].f != to) {
int f = nodes[at].f, g = nodes[f].f;
if (g != to) {
if (identify(at) == identify(f))rotate(f);
else rotate(at);
}
rotate(at);
}
if (to == 0)root = at;
}
inline void insert(int x) {
if (root == 0)root = newNode(x);
else {
int cur = root, nxt = 0, np;
while (true) {
if (nodes[cur].v < x)nxt = nodes[cur].ch[1];
else if (nodes[cur].v > x)nxt = nodes[cur].ch[0];
else break;
if (nxt != 0)cur = nxt;
else break;
}
if (nodes[cur].v == x)nodes[cur].s++, splay(cur);
else if (nodes[cur].v < x)change(cur, np = newNode(x), 1), splay(np);
else change(cur, np = newNode(x), 0), splay(np);
}
}
inline int findRink(int x) {
int cur = root, nxt = 0, ans = 0;
while (true) {
if (nodes[cur].v < x)nxt = nodes[cur].ch[1], ans += nodes[cur].num - nodes[nxt].num;
else if (nodes[cur].v > x)nxt = nodes[cur].ch[0];
else {
ans += nodes[nodes[cur].ch[0]].num;
break;
}
if (nxt != 0)cur = nxt;
else break;
}
splay(cur);
return ans + 1;
}
inline int nextNum(int x) {
int cur = root, nxt = 0, minn = inf, ans = -1;
while (true) {
if (nodes[cur].v < x)nxt = nodes[cur].ch[1];
else if (nodes[cur].v > x) {
nxt = nodes[cur].ch[0];
if (nodes[cur].v <= minn)minn = nodes[cur].v, ans = cur;
} else nxt = nodes[cur].ch[1];
if (nxt != 0)cur = nxt;
else break;
}
splay(ans);
return ans;
}
inline int preNum(int x) {
int cur = root, nxt = 0, maxn = -inf, ans = -1;
while (true) {
if (nodes[cur].v > x)nxt = nodes[cur].ch[0];
else if (nodes[cur].v < x) {
nxt = nodes[cur].ch[1];
if (nodes[cur].v >= maxn)maxn = nodes[cur].v, ans = cur;
} else nxt = nodes[cur].ch[0];
if (nxt != 0)cur = nxt;
else break;
}
splay(ans);
return ans;
}
int delNum(int x) {
int pre = preNum(x), nxt = nextNum(x);
splay(pre), splay(nxt, pre);
if (nodes[nodes[nxt].ch[0]].s > 1)nodes[nodes[nxt].ch[0]].s--, splay(nodes[nxt].ch[0]);
else nodes[nxt].ch[0] = 0, splay(nxt);
return 1;
}
} splay[N << 2];
void build(int l, int r, int k) {//构造树
if (l == r) {
splay[k].insert(op[l]), splay[k].insert(-inf), splay[k].insert(inf);
return;
}
for (int i = l; i <= r; i++)splay[k].insert(op[i]);
splay[k].insert(-inf), splay[k].insert(inf);//手动加入无穷大和无穷小点
build(l, ((l + r) >> 1), k << 1), build((((l + r) >> 1) + 1), r, k << 1 | 1);
}
int queryRank(int a, int b, int s, int l = 1, int r = n, int k = 1) {
if (l >= a && r <= b)return splay[k].findRink(s) - 2;
int mid = (l + r) >> 1;
if (b <= mid)return queryRank(a, b, s, l, mid, k << 1);
else if (mid + 1 <= a)return queryRank(a, b, s, mid + 1, r, k << 1 | 1);
return queryRank(a, mid, s, l, mid, k << 1) + queryRank(mid + 1, b, s, mid + 1, r, k << 1 | 1);
}
void modify(int x, int y, int l = 1, int r = n, int k = 1) {
if (l > r)return;
if (x >= l && x <= r)splay[k].delNum(op[x]), splay[k].insert(y);
int mid = (l + r) >> 1;
if (x > mid)modify(x, y, mid + 1, r, k << 1 | 1);
else if (mid != r)modify(x, y, l, mid, k << 1);
}
int queryPre(int a, int b, int s, int l = 1, int r = n, int k = 1) {
if (l >= a && r <= b)return nodes[splay[k].preNum(s)].v;
int mid = (l + r) >> 1;
if (b <= mid)return queryPre(a, b, s, l, mid, k << 1);
else if (a > mid)return queryPre(a, b, s, mid + 1, r, k << 1 | 1);
return max(queryPre(a, mid, s, l, mid, k << 1), queryPre(mid + 1, b, s, mid + 1, r, k << 1 | 1));
}
int queryNext(int a, int b, int s, int l = 1, int r = n, int k = 1) {
if (l >= a && r <= b)return nodes[splay[k].nextNum(s)].v;
int mid = (l + r) >> 1;
if (b <= mid)return queryNext(a, b, s, l, mid, k << 1);
else if (a > mid)return queryNext(a, b, s, mid + 1, r, k << 1 | 1);
return min(queryNext(a, mid, s, l, mid, k << 1), queryNext(mid + 1, b, s, mid + 1, r, k << 1 | 1));
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++)op[i] = read();
build(1, n, 1);
for (int i = 0; i < m; i++) {
int opt = read(), a, b, c;
if (opt == 1)a = read(), b = read(), c = read(), printf("%d\n", queryRank(a, b, c) + 1);
else if (opt == 2) {
a = read(), b = read(), c = read();
int l = 0, r = inf, mid, s;//[,),注意这里的二分细节
while (l < r) {
if (l == r - 1) {
printf("%d\n", l);
break;
}
mid = (l + r) >> 1, s = queryRank(a, b, mid) + 1;
if (s > c)r = mid;
else l = mid;
}
} else if (opt == 3)a = read(), b = read(), modify(a, b), op[a] = b;
else if (opt == 4)a = read(), b = read(), c = read(), printf("%d\n", queryPre(a, b, c));
else a = read(), b = read(), c = read(), printf("%d\n", queryNext(a, b, c));
}
return 0;
}