替罪羊树是博客里提到的第三种平衡树,前两种分别是Splay和Treap(含fhq Treap)。它更易写,也容易理解。

        替罪羊树的目的同样是使树尽可能平衡,它的思路很简单粗暴:当某棵树不平衡时,拍掉重建。
        国际惯例,先看结点结构体定义:

1
2
3
4
struct Node {
int size, ch[2], v, s, num;

} node[500005];

        这里有很多已经熟悉的变量,只是多了一个size。size记录子树中的有效结点数目,下文会提到。
        为什么会有有效结点这种说法呢?替罪羊树中删除某个结点,并不是直接删除,而是打上标记(即允许s=0的结点存在)。此时这个结点就是无效结点,不能计入size。
        何为拍掉重建?首先需要知道如何判定一棵树已经不平衡,这就需要引入平衡因子的概念。
        当左子树的size或着右子树的size大于整棵树的size乘以平衡因子时,我们认为这棵树不平衡。很明显左子树size和右子树size中一定有一个不小于整棵树size的一半,即平衡因子至少为0.5,而最大显然为1。平均来看,我们取平衡因子为0.75,记alpha=0.75。
        update定义:
1
2
3
4
void update(int x) {
node[x].num = node[node[x].ch[0]].num + node[node[x].ch[1]].num + node[x].s;
node[x].size = node[node[x].ch[0]].size + node[node[x].ch[1]].size + (node[x].s > 0);
}

        检验函数:
1
2
3
bool check(int x) {
return node[node[x].ch[0]].size > node[x].size * alpha || node[node[x].ch[1]].size > node[x].size * alpha;
}

        如何重建?首先对树进行一次中序遍历,记得要剔除所有无效结点,将遍历编号序列存在数组中。然后新树的树根就是这个序列的中点,两棵子树递归进行,易知这样建树符合BST定义。
        中序遍历代码:
1
2
3
4
5
6
void serach(int x) {
if (x == 0)return;
serach(node[x].ch[0]);
if (node[x].s > 0)temp[++ct] = x;
serach(node[x].ch[1]);
}

        重建函数:
1
2
3
4
void rebuild(int &x) {//引用!!因为要修改值
ct = 0, serach(x);
solve(x, 1, ct);
}

        solve就是负责建树的递归函数,这里记得传进来的x是引用。
1
2
3
4
5
6
7
8
9
10
11
void solve(int &x, int l, int r) {//引用!!
if (r < l) {
x = 0;
return;
}
int mid = (l + r) / 2;
x = temp[mid];
solve(node[x].ch[0], l, mid - 1);
solve(node[x].ch[1], mid + 1, r);
update(x);
}

        插入元素与BST基本相同:
1
2
3
4
5
6
7
8
9
10
11
12
void insert(int &x, int y) {//仍然是引用,与treap的思想类似
if (x == 0) {
x = ++cnt;//新建一个点
node[x].v = y, node[x].s = node[x].size = node[x].num = 1, node[x].ch[0] = node[x].ch[1] = 0;
return;
}
if (node[x].v == y)node[x].s++;//直接更新
else if (node[x].v < y)insert(node[x].ch[1], y);
else insert(node[x].ch[0], y);
update(x);//更新信息
if (check(x))rebuild(x);//重建
}

        删除函数:
1
2
3
4
5
6
7
8
void del(int &x, int y) {//也是引用
if (x == 0)return;
if (node[x].v < y) del(node[x].ch[1], y);
else if (node[x].v > y)del(node[x].ch[0], y);
else node[x].s--;//不需要检验s>1,s=0就意味着这个点无效
update(x);//更新
if (check(x))rebuild(x);//重建
}

        找排名和根据排名找元素与Splay和Treap都相同,但是求前驱后继的方法有所不同。
        由于无效结点的存在,无法判定前驱后继究竟在左子树上还是右子树上(以往是确定的,但替罪羊树上可能某一侧全部是无效结点),这就造成了麻烦。在替罪羊树中,可以用下面的思路求前驱和后继:

  • 求x的前驱:插入x,找到x的排名p,删掉x,找到排名p-1的值即为前驱。
  • 求x的后继:插入x+1,找到x+1的排名p,删掉x+1,找到排名p的值即为后继。

        这个算法很容易理解,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
int preNum(int x) {
insert(root, x);
int p = findRank(x);
del(root, x);
return findNum(p - 1);
}

int nextNum(int x) {
insert(root, x + 1);
int p = findRank(x + 1);
del(root, x + 1);
return findNum(p);
}

        下附模板题AC代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<bits/stdc++.h>

using namespace std;
const double alpha = 0.75;

struct Node {
int size, ch[2], v, s, num;

} node[500005];

int root = 0, cnt = 0, ct = 0, temp[100005];

void update(int x) {
node[x].num = node[node[x].ch[0]].num + node[node[x].ch[1]].num + node[x].s;
node[x].size = node[node[x].ch[0]].size + node[node[x].ch[1]].size + (node[x].s > 0);
}

bool check(int x) {
return node[node[x].ch[0]].size > node[x].size * alpha || node[node[x].ch[1]].size > node[x].size * alpha;
}

void serach(int x) {
if (x == 0)return;
serach(node[x].ch[0]);
if (node[x].s > 0)temp[++ct] = x;
serach(node[x].ch[1]);
}

void solve(int &x, int l, int r) {
if (r < l) {
x = 0;
return;
}
int mid = (l + r) / 2;
x = temp[mid];
solve(node[x].ch[0], l, mid - 1);
solve(node[x].ch[1], mid + 1, r);
update(x);
}

void rebuild(int &x) {
ct = 0, serach(x);
solve(x, 1, ct);
}

void insert(int &x, int y) {
if (x == 0) {
x = ++cnt;
node[x].v = y, node[x].s = node[x].size = node[x].num = 1, node[x].ch[0] = node[x].ch[1] = 0;
return;
}
if (node[x].v == y)node[x].s++;
else if (node[x].v < y)insert(node[x].ch[1], y);
else insert(node[x].ch[0], y);
update(x);
if (check(x))rebuild(x);
}

void del(int &x, int y) {
if (x == 0)return;
if (node[x].v < y) del(node[x].ch[1], y);
else if (node[x].v > y)del(node[x].ch[0], y);
else node[x].s--;
update(x);
if (check(x))rebuild(x);
}

inline int findNum(int x) {
int cur = root, nxt = 0;
while (true) {
if (node[node[cur].ch[0]].num >= x)nxt = node[cur].ch[0];
else if (x > node[node[cur].ch[0]].num && x <= node[node[cur].ch[0]].num + node[cur].s) {
return cur;
} else nxt = node[cur].ch[1], x -= node[node[cur].ch[0]].num + node[cur].s;
if (nxt != 0)cur = nxt;
else break;
}
return -1;
}

inline int findRank(int x) {
int cur = root, nxt = 0, ans = 0;
while (true) {
if (node[cur].v < x)nxt = node[cur].ch[1], ans += node[cur].num - node[nxt].num;
else if (node[cur].v > x)nxt = node[cur].ch[0];
else {
ans += node[node[cur].ch[0]].num;
break;
}
if (nxt != 0)cur = nxt;
else break;
}
return ans + 1;
}

inline int preNum(int x) {
insert(root, x);
int p = findRank(x);
del(root, x);
return findNum(p - 1);
}

inline int nextNum(int x) {
insert(root, x + 1);
int p = findRank(x + 1);
del(root, x + 1);
return findNum(p);
}

int main() {
ios::sync_with_stdio(false);
int n;
cin >> n;
for (int i = 1; i <= n; i++) {
int opt, x;
cin >> opt >> x;
if (opt == 1)insert(root, x);
else if (opt == 2)del(root, x);
else if (opt == 3)cout << findRank(x) << endl;
else if (opt == 4)cout << node[findNum(x)].v << endl;
else if (opt == 5)cout << node[preNum(x)].v << endl;
else cout << node[nextNum(x)].v << endl;
}
return 0;
}