本节介绍Splay树的用法与实现。

二叉排序树

        Splay本质上是平衡二叉排序树(Binary Sort Tree,BST),是平衡树的一种。因此在了解Splay前必须先了解二叉排序树。
        二叉排序树是指满足以下性质的二叉树:

  • 若左子树不空,则左子树上所有结点的值均小于根结点。
  • 若右子树不空,则右子树上所有结点的值均大于根结点。
  • 左右子树也是BST。

        下面给出了一个BST的例子:

        在本文的splay中不允许有相同的数,因此每一个结点都有自己的计数变量(在本文中用s表示)来记录这个数的个数。
        另外,BST的左右子树性质不同,因此不能向往常一样把树当图来存,必须基于一定储存结构。这里用结构体来定义树的每一个结点:

1
2
3
4
5
6
7
struct Node {
int v, ch[2], s, num, f;

Node() : s(0), num(0) {
ch[0] = ch[1] = 0;
}
} node[500005];

        v表示该结点表示的数,ch[0]和ch[1]是左右子树根结点编号,s是数的重数,num是该树所有数的总数量(计入重数),f是其父结点编号。然后开一个数组储存这些结点。
        问题来了,BST有什么用呢?如果我们需要不断加入删除数据,不断地询问某个数的排序序号,查找某个数的前驱或者后继等操作,普通的顺序表效率过低,而链表查找效率硬伤,此时我们就需要BST来完成这些工作。平衡的BST可以在$O(logn)$复杂度下完成这些操作,这是容易理解的。
        但如果不平衡呢?当二叉排序树退化成一条链时,各种操作都会退化成$O(n)$复杂度,这是我们不希望看到的,于是就有了各种使BST平衡的方法,Splay就是其中之一。

旋转操作

        旋转操作是Splay中很重要的操作。旋转的意思就是:将某个结点移动到其父结点的位置,保持BST性质不变。旋转只是树的结构变形,没有实质性改变。
        当目标结点是其父结点的左儿子时,需要右旋(称为Zig操作),若为右儿子,需要左旋(称为Zag操作)。下面以右旋为例探讨旋转规律。

        比如我们想把值为3的结点移动到其父结点位置。
        发现4比5小,5应该接管以值为4的结点为根的子树,于是:

        5比3大,将值为5的结点移动下来作值为3的结点的左儿子(父结点变儿子结点)。

        然后以3为值的结点替代原父结点位置:

        旋转完成,这就是一次zig操作。
        无论左旋还是右旋,都可以发现旋转分为三步。

  1. 父结点接管目标结点反边上的子树。(接管反边是指:若目标结点为左儿子,则接管右子树,右儿子反之)。
  2. 父结点作目标结点反边上的子树。
  3. 目标结点代替父结点位置(父结点是爷爷结点的什么儿子,目标结点就作什么儿子)。

        现在就来编写实现旋转操作的函数,在这之前先写一个重建父子关系的函数:

1
2
3
void change(int father, int son, int w) {
node[father].ch[w] = son, node[son].f = father;
}

        change函数的作用就是让son编号去作father编号结点的w儿子(0为左,1为右)。然后还需要一个判别父子关系的函数:
1
2
3
4
int identify(int x) {
int f = node[x].f;
return node[f].ch[0] == x ? 0 : 1;
}

        这个函数用来判断编号为x的结点是其父结点的什么儿子。最后需要一个update函数:
1
2
3
void update(int x) {
node[x].num = node[node[x].ch[0]].num + node[node[x].ch[1]].num + node[x].s;
}

        这是更新一下编号为x的结点的num值。由于旋转后结点间关系发生改变,因此需要更新num。
        可以写出旋转函数:
1
2
3
4
5
6
void rotate(int x) {
if (x == root)return;
int f = node[x].f, g = node[f].f, i = identify(x), j = identify(f);
change(f, node[x].ch[i ^ 1], i), change(x, f, i ^ 1), change(g, x, j);
update(f), update(x);
}

        根结点不能旋转。根据三个步骤重建父子关系,然后更新目标结点和父结点的num。这里用i^1来取反边,这是一个技巧,在网络流中也用到了这个技巧。

伸展操作

        伸展树之所以叫伸展树就是因为伸展。
        所谓伸展,就是将某个结点(本文中称为at)通过一系列旋转移动到另一个结点(称为to)下方(也就是儿子)。
        如果to本身就在at的下方怎么办?在常规的splay操作中不会碰到这种情况。
        首先需要认识到一个问题:如果需要移动到根结点位置怎么办?根结点不是任何结点的儿子,其实从上面的叙述中的确无法做到,因此有个东西叫虚拟根结点,它是根结点的父结点。在本文中,虚拟根结点编号为0,根结点作其右儿子。这样就可以通过移动到虚拟根结点下面来移动到根结点位置了。在后文中可以看到引入虚拟根结点的很多好处。
        当at的父结点正好就是to时无需伸展。
        若at的爷爷结点正好是to时,一步旋转(zig或zag)就可以做到伸展效果。问题就是其余情况怎么处理。
        当at的子结点性质和其父结点不同时(不同是指at是左儿子而父结点是右儿子或者反之)可以通过两次旋转(zig-zag或者zag-zig)at结点来将其旋转到爷爷结点的位置。这称为之字形旋转。
        当at的子结点性质与其父结点相同时,同样两次旋转(zig-zig或者zag-zag)at结点来将其旋转到爷爷结点的位置。这称为一字形旋转。
        但实际情况中可以发现,对于一字形旋转,树的结构并没有得到很好的改善。通常我们这样处理一子形旋转的情况:先旋转父结点再旋转at结点。这种用先旋转父结点再旋转子结点来代替旋转两次子结点的操作称为双旋,普通的一子形方法就称为单旋。
        这样就可以写出伸展操作的函数了:

1
2
3
4
5
6
7
8
9
10
11
void splay(int at, int to = 0) {//to默认为0表示默认移动到根结点
while (node[at].f != to) {
int f = node[at].f, g = node[f].f;
if (g != to) {
if (identify(at) == identify(f))rotate(f);//双旋操作
else rotate(at);
}
rotate(at);
}
if (to == 0)root = at;//不要忘记修改根结点
}

        伸展有什么用呢?可以理解成保持树结构的随机性,也就是在一种平均意义上保持树的平衡。
        伸展的另一个作用是更新经过结点的num值,这是旋转时所做的。

插入元素

        这个不用多说,直接看代码就可以:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void insert(int x) {
if (root == 0)change(0, cnt, 1), node[cnt].v = x, node[cnt].s = 1, node[cnt].num = 1, root = cnt++;//无根要建根
else {
int cur = root, nxt = 0;
while (true) {
if (x > node[cur].v)nxt = node[cur].ch[1];
else if (x < node[cur].v)nxt = node[cur].ch[0];
else break;
if (nxt != 0)cur = nxt;
else break;
}
if (node[cur].v == x)node[cur].s++, splay(cur);//已经有了,s直接加一
else if (node[cur].v > x)change(cur, cnt, 0), node[cnt].v = x, node[cnt].s = 1, node[cnt].num = 1, splay(cnt++);
else change(cur, cnt, 1), node[cnt].v = x, node[cnt].s = 1, node[cnt].num = 1, splay(cnt++);
}
}

        cnt是当出现新结点时的编号计数器,这样每有一个新的点加入都会分配一个新编号。这样导致的一个问题是不能重用空间,但是对于大多数情况空间都是够用的。如果希望重用空间可以选择每次重建根(为什么会多次重建根?因为可能会删元素啊)时重新从1编号,但是这样做一定要记得将已经被删除的结点完全清空(比如子结点编号设成0等等)。
        每一次插入元素后都要把已经插入的元素对应结点伸展到根结点以保持结构随机性。
        为什么没有看到更新num?伸展的时候顺便更新了,上文已经提到。

查找某数的编号

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int find(int x) {
int cur = root, nxt = 0;
while (true) {
if (node[cur].v > x)nxt = node[cur].ch[0];
else if (node[cur].v < x)nxt = node[cur].ch[1];
else break;
if (nxt != 0)cur = nxt;
else break;
}
if (node[cur].v == x) {
splay(cur);
return cur;
}
return -1;
}

        很容易理解,没有返回-1,最后将该结点伸展到根结点。

查找排名为k的数(第k小)

        第k小就是指计入重数后的排序为k的数,直接看代码吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
int findNum(int x) {
int cur = root, nxt = 0;
while (true) {
if (node[node[cur].ch[0]].num >= x)nxt = node[cur].ch[0];//找左树
else if (x > node[node[cur].ch[0]].num && x <= node[node[cur].ch[0]].num + node[cur].s) {
splay(cur);
return cur;
} else nxt = node[cur].ch[1], x -= node[node[cur].ch[0]].num + node[cur].s;//找右树
if (nxt != 0)cur = nxt;
else break;
}
return -1;
}

        这个函数返回的是结点编号,如果没有返回-1。要把最后的结点伸展到根结点。

查询某元素排名

        对于相同的元素,这里的排名是指最小的排名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int findRank(int x) {
int cur = root, nxt = 0, ans = 0;
while (true) {
if (node[cur].v < x)nxt = node[cur].ch[1], ans += node[cur].num - node[nxt].num;
else if (node[cur].v > x)nxt = node[cur].ch[0];
else {
ans += node[node[cur].ch[0]].num;
break;
}
if (nxt != 0)cur = nxt;
else break;
}
splay(cur);
return ans + 1;
}

        代码是容易理解的,最后仍然需要伸展。

查找前驱

        前驱是指小于某个数中的最大的那一个。值得注意的是该数不一定是已经插入的元素,即使之前没有插入也是可以查找前驱的。

1
2
3
4
5
6
7
8
9
10
11
12
13
int preNum(int x) {
int cur = root, nxt = 0, maxn = -0x7fffffff, ans = -1;
while (true) {
if (node[cur].v > x)nxt = node[cur].ch[0];
else if (node[cur].v < x) {
nxt = node[cur].ch[1];
if (node[cur].v > maxn)maxn = node[cur].v, ans = cur;
} else nxt = node[cur].ch[0];//相等向小了找
if (nxt != 0)cur = nxt;
else break;
}
return ans;
}

        同样只返回编号,没有返回-1。

查找后继

1
2
3
4
5
6
7
8
9
10
11
12
13
int nextNum(int x) {
int cur = root, nxt = 0, minn = 0x7fffffff, ans = -1;
while (true) {
if (node[cur].v < x)nxt = node[cur].ch[1];
else if (node[cur].v > x) {
nxt = node[cur].ch[0];
if (node[cur].v < minn)minn = node[cur].v, ans = cur;
} else nxt = node[cur].ch[1];
if (nxt != 0)cur = nxt;
else break;
}
return ans;
}

        和前驱基本相同。

删除操作

        删除是比较麻烦的一个,这里介绍两种实现方式。
        第一种:通过前驱和后继删除。
        将前驱伸展到根结点然后将后继伸展到根结点(就是前驱)下方(作右儿子),那么待删除结点一定是后继的左儿子并且它没有任何子树。从这个结点上删除即可。如果其s>1那么直接s减去一,否则删除该结点,最后进行一步伸展,还可以顺便更新一下num。
        没有前驱或后继怎么办?需要手动加入无穷小和无穷大两个数,这样所有数就都有前驱和后继了。

1
2
3
4
5
6
7
8
int delNum(int x) {
if (find(x) == -1)return 0;//没有该数就不能删除,如果保证存在该数,则该句可以忽略以提高效率
int pre = preNum(x), nxt = nextNum(x);
splay(pre), splay(nxt, pre);//作伸展
if (node[node[nxt].ch[0]].s > 1)node[node[nxt].ch[0]].s--, splay(node[nxt].ch[0]);
else node[nxt].ch[0] = 0, splay(nxt);
return 1;
}

        另一种方法就是分类讨论。
        将待删除结点伸展到根结点,然后分类讨论(只探讨完全删除结点的情况,s自减一的情况不探讨):

  • 根结点没有子树,直接删除,重置root=0。
  • 根结点有右子树无左子树,将右子树根作为根。仅有左子树相似操作。
  • 既有左子树又有右子树,按照方法一进行(因为这时候该数一定有前驱和后继)。

        实际上没有前驱和后继的情况仅有删除最小值最大值的时候才会出现,所以第二种方法常常会比第一种方法多一次splay,效率低一些。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int delNum(int x) {
splay(find(x));//默认有该数
if (!node[root].ch[0] && !node[root].ch[1]) {
if (node[root].s > 1)node[root].s--;
else root = 0, node[0].ch[1] = 0;
} else if (!node[root].ch[0]) {
if (node[root].s > 1)node[root].s--;
else change(0, node[root].ch[1], 1), root = node[root].ch[1];
} else if (!node[root].ch[1]) {
if (node[root].s > 1)node[root].s--;
else change(0, node[root].ch[0], 1), root = node[root].ch[0];
} else {
int pre = preNum(x), nxt = nextNum(x);
splay(pre), splay(nxt, pre);
if (node[node[nxt].ch[0]].s > 1)node[node[nxt].ch[0]].s--, splay(node[nxt].ch[0]);
else node[nxt].ch[0] = 0, splay(nxt);
}
return 1;
}

        推荐洛谷P3369模板题。AC代码就是上面的组合。
        splay的操作函数尽量不要用递归写,因为splay操作的存在,递归时树的结构可能发生变化,极易出错。下面给出洛谷AC代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include<iostream>

using namespace std;

struct Node {
int v, ch[2], s, num, f;

Node() : s(0), num(0) {
ch[0] = ch[1] = 0;
}
} node[100005];

int root = 0, n, cnt;

inline int identify(int x) {
int f = node[x].f;
return node[f].ch[0] == x ? 0 : 1;
}

inline void change(int t, int s, int w) {
node[t].ch[w] = s, node[s].f = t;
}

inline void update(int x) {
node[x].num = node[node[x].ch[0]].num + node[node[x].ch[1]].num + node[x].s;
}

inline void rotate(int x) {
if (x == root)return;
int f = node[x].f, g = node[f].f, i = identify(x), j = identify(f);
change(f, node[x].ch[i ^ 1], i), change(x, f, i ^ 1), change(g, x, j);
update(f), update(x);
}

inline void splay(int at, int to = 0) {
while (node[at].f != to) {
int f = node[at].f, g = node[f].f;
if (g != to) {
if (identify(at) == identify(f))rotate(f);
else rotate(at);
}
rotate(at);
}
if (to == 0)root = at;
}

inline int find(int x) {
int cur = root, nxt = 0;
while (true) {
if (node[cur].v > x)nxt = node[cur].ch[0];
else if (node[cur].v < x)nxt = node[cur].ch[1];
else break;
if (nxt != 0)cur = nxt;
else break;
}
if (node[cur].v == x) {
splay(cur);
return cur;
}
return -1;
}

inline int findRank(int x) {
int cur = root, nxt = 0, ans = 0;
while (true) {
if (node[cur].v < x)nxt = node[cur].ch[1], ans += node[cur].num - node[nxt].num;
else if (node[cur].v > x)nxt = node[cur].ch[0];
else {
ans += node[node[cur].ch[0]].num;
break;
}
if (nxt != 0)cur = nxt;
else break;
}
splay(cur);
return ans + 1;
}

inline int findNum(int x) {
int cur = root, nxt = 0;
while (true) {
if (node[node[cur].ch[0]].num >= x)nxt = node[cur].ch[0];
else if (x > node[node[cur].ch[0]].num && x <= node[node[cur].ch[0]].num + node[cur].s) {
splay(cur);
return cur;
} else nxt = node[cur].ch[1], x -= node[node[cur].ch[0]].num + node[cur].s;
if (nxt != 0)cur = nxt;
else break;
}
return -1;
}

inline int preNum(int x) {
int cur = root, nxt = 0, maxn = -0x7fffffff, ans = -1;
while (true) {
if (node[cur].v > x)nxt = node[cur].ch[0];
else if (node[cur].v < x) {
nxt = node[cur].ch[1];
if (node[cur].v > maxn)maxn = node[cur].v, ans = cur;
} else nxt = node[cur].ch[0];
if (nxt != 0)cur = nxt;
else break;
}
return ans;
}

inline int nextNum(int x) {
int cur = root, nxt = 0, minn = 0x7fffffff, ans = -1;
while (true) {
if (node[cur].v < x)nxt = node[cur].ch[1];
else if (node[cur].v > x) {
nxt = node[cur].ch[0];
if (node[cur].v < minn)minn = node[cur].v, ans = cur;
} else nxt = node[cur].ch[1];
if (nxt != 0)cur = nxt;
else break;
}
return ans;
}

inline void insert(int x) {
if (root == 0)change(0, 1, 1), node[1].v = x, node[1].s = 1, node[1].num = 1, root = 1, cnt = 2;
else {
int cur = root, nxt = 0;
while (true) {
if (x > node[cur].v)nxt = node[cur].ch[1];
else if (x < node[cur].v)nxt = node[cur].ch[0];
else break;
if (nxt != 0)cur = nxt;
else break;
}
if (node[cur].v == x)node[cur].s++, splay(cur);
else if (node[cur].v > x)change(cur, cnt, 0), node[cnt].v = x, node[cnt].s = 1, node[cnt].num = 1, splay(cnt++);
else change(cur, cnt, 1), node[cnt].v = x, node[cnt].s = 1, node[cnt].num = 1, splay(cnt++);
}
}

inline int delNum(int x) {
int pre = preNum(x), nxt = nextNum(x);
splay(pre), splay(nxt, pre);
if (node[node[nxt].ch[0]].s > 1)node[node[nxt].ch[0]].s--, splay(node[nxt].ch[0]);
else node[nxt].ch[0] = 0, splay(nxt);
return 1;
}


int main() {
ios::sync_with_stdio(false);
cin >> n;
insert(-(1 << 30)), insert(1 << 30);
for (int i = 1; i <= n; i++) {
int opt, x;
cin >> opt >> x;
if (opt == 1)insert(x);
else if (opt == 2)delNum(x);
else if (opt == 3)cout << findRank(x) - 1 << endl;
else if (opt == 4)cout << node[findNum(x + 1)].v << endl;
else if (opt == 5)cout << node[preNum(x)].v << endl;
else cout << node[nextNum(x)].v << endl;
}
return 0;
}


        更新文艺平衡树的splay版本,建议先阅读这篇文章的后半部分,这里其实是它的扩展。
        给定一个区间,需要不断翻转其中某一段子区间,如何用Splay完成这项操作?以洛谷模板题为例。
        我们可以先将1~n加入splay,这样splay的中序遍历序列就是原序列。
        加入无穷小点和无穷大点,保证所有值都有前驱和后继。
        当需要翻转[l,r]时,找到排名为l和r+2的两个结点(为什么是l和r+2?不要忘记之前加入了无穷小点!),这是区间的前驱和后继。将前驱splay到根结点,后继splay到根结点下方,然后后继的左子树就是目标区间,在其根结点上打上翻转标记。
        翻转标记表明这个树需要翻转,它自然有自己的下压函数:

1
2
3
4
5
6
void down(int x) {
swap(node[x].ch[0], node[x].ch[1]);
node[node[x].ch[0]].lazy ^= 1;
node[node[x].ch[1]].lazy ^= 1;
node[x].lazy = 0;
}

        每当rotate或者通过排名找数时都需要下压标记。
        有一个疑问:翻转子树难道不会破坏BST性质?从权值的角度讲当然会破坏,但是可以从更高的角度去理解。BST维护的是权重的大小关系,左树权重比右树小。在这里权重就意味着在区间中的位置,从这个角度看,BST的性质没有被破坏。寻找排名为k的数是一个基于权重的算法而非权值,因此这里需要通过排名来寻找结点。
        如果洛谷模板题的原始序列是随机的,问题就会更复杂一些。这是因为1~n序列很特殊,它的权值就是权重。我们可以通过一些基于权值的算法(比如insert)来构树,但是对于随机的序列,再用这种方法就不能保证原有的顺序。
        下附模板题AC代码,注意这里没有s变量,这是因为1~n一定不会重。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<bits/stdc++.h>

using namespace std;
struct Node {
int v, ch[2], f, lazy, num;
} node[100005];
int cnt, n, m, root, p = 0;

void update(int x) {
node[x].num = node[node[x].ch[0]].num + node[node[x].ch[1]].num + 1;
}

void change(int x, int y, int w) {
node[x].ch[w] = y, node[y].f = x;
}

inline int identify(int x) {
return node[node[x].f].ch[0] != x;
}

inline void down(int x) {
swap(node[x].ch[0], node[x].ch[1]);
node[node[x].ch[0]].lazy ^= 1;
node[node[x].ch[1]].lazy ^= 1;
node[x].lazy = 0;
}

void rotate(int x) {
if (x == root)return;
if (node[x].lazy)down(x);
int f = node[x].f, g = node[f].f, w = identify(x), w1 = identify(f);
change(f, node[x].ch[w ^ 1], w);
change(x, f, w ^ 1);
change(g, x, w1);
update(x), update(f);
}

void splay(int at, int to = 0) {
while (node[at].f != to) {
int f = node[at].f, g = node[f].f;
if (g != to) {
if (identify(f) == identify(at))rotate(f);
else rotate(at);
}
rotate(at);
}
if (to == 0)root = at;
}

void insert(int x) {
cnt++, node[cnt].v = x, node[cnt].ch[0] = node[cnt].ch[1] = 0, node[cnt].lazy = 0, node[cnt].num = 1;
if (root == 0) {
change(0, cnt, 1), root = cnt;
return;
}
int cur = root, nxt;
while (true) {
if (node[cur].v <= x)nxt = node[cur].ch[1];
else nxt = node[cur].ch[0];
if (nxt == 0)break;
else cur = nxt;
}
if (node[cur].v <= x)change(cur, cnt, 1);
else change(cur, cnt, 0);
splay(cur);
}

int findNum(int x) {
int cur = root, nxt;
while (true) {
if (node[cur].lazy)down(cur);
if (node[node[cur].ch[0]].num + 1 == x) {
splay(cur);
return cur;
} else if (node[node[cur].ch[0]].num < x)nxt = node[cur].ch[1], x -= node[node[cur].ch[0]].num + 1;
else nxt = node[cur].ch[0];
if (nxt == 0)break;
else cur = nxt;
}
return -1;
}

void print(int x) {
if (x == 0)return;
if (node[x].lazy)down(x);
print(node[x].ch[0]);
if (p >= 1 && p <= n)cout << node[x].v << " ";
p++;
print(node[x].ch[1]);
}

int main() {
ios::sync_with_stdio(false);
cin >> n >> m;
insert(-0x7ffffff), insert(0x7ffffff);
for (int i = 1; i <= n; i++)insert(i);
for (int i = 1; i <= m; i++) {
int l, r;
cin >> l >> r;
l = findNum(l), r = findNum(r + 2);
splay(l), splay(r, l);
node[node[node[root].ch[1]].ch[0]].lazy ^= 1;
}
print(root);
return 0;
}