线段树
/ / 阅读耗时 22 分钟 这一次介绍线段树,线段树是一种可以维护区间和,区间最值的高效数据结构。线段树本质上是一棵平衡二叉树。
先探讨线段树的构造。由于线段树基于平衡二叉树,且仅有叶子节点储存实际的序列信息,故线段树的空间消耗是比较大的,实际空间消耗大约为给定序列的四倍。
以维护区间和的线段树为例。
构造线段树时,假定f(l,r)为将[l,r]的序列构造为线段树,f(l,r)可以拆分成f(l,mid)和f(mid+1,r),其中mid=(l+r)/2。按照二叉树的节点规律向下构造,并给相应的节点赋值即可。
示例代码:1
2
3
4
5
6
7
8
9void make(int x, int y, int k) {
if (x == y) {
tree[k] = op[x];
return;
}
int mid = (x + y) / 2;
make(x, mid, 2 * k), make(mid + 1, y, 2 * k + 1);
tree[k] = tree[2 * k] + tree[2 * k + 1];
}
求特定区间[l,r]的区间和则通过拆分区间,从线段树中读出分区和,再相加即可。通过这种操作可以将线段树的时间复杂度降至O(logn)。
难点在于如何进行区间拆分。实际上区间拆分可分为四种情况:
1. 给定区间和线段树划分区间恰好相同。直接返回该区间的值即可。
2. 给定区间在划分区间的右半部分。将右半部分作为新的划分区间再进行拆分
3. 给定区间在划分区间的左半部分。将左半部分作为新的划分区间再进行拆分
4. 给定区间部分在左区间,部分在右区间。分别将左半部分和右半部分作为新的划分区间,再分别在两个划分区间中拆分各自的给定区间部分,返回两者结果的和。
示例代码:1
2
3
4
5
6
7int find(int x, int y, int l, int r, int k) {
if (l <= x && y <= r)return tree[k];
int mid = (x + y) / 2;
if (l >= mid + 1)return find(mid + 1, y, l, r, 2 * k + 1);
if (r <= mid)return find(x, mid, l, r, 2 * k);
return find(x, mid, l, mid, 2 * k) + find(mid + 1, y, mid + 1, r, 2 * k + 1);
}
至此,通过这两个函数已经可以处理一个不再修改的线段树。但是在实际操作中,有时不仅需要频繁求出区间和,还要不断地修改数据。下面介绍线段树的值修改方法。
一种显而易见的思路是直接修改数据再重新构造线段树,或者直接在线段树中查找到相关的节点,修改这些节点的值。显然后者要优于前者,但是倘若修改的区间之后不再参与查找区间和的操作,那么时间就会白白浪费。所以,需要一种方法,在求区间和时才对相关节点进行实际的值修改操作,否则只作个标记即可。这种方法可以大大提升效率。
为了实现值修改的标记,需要引入懒标记(Lazy)的概念。
给每一个节点添加变量Lazy并初始化为0,它的意义是标明这个节点的所有子节点都要在原有的基础上加上Lazy。值得注意的是,该节点本身不在标记的范围内,并且Lazy标记仅是一个标记,子节点的值实际上并没有改变。
当需要将区间[l,r]全体加上x时,需要进行如下操作:
1. 将每一个拆分区间(假定为[a,b])的值加上(b-a+1)*x。((b-a+1)是该拆分区间的数据量)
2. 用与查找区间和相同的方法拆分区间[l,r],给每一个拆分的区间的Lazy加上x。
3. 更新拆分区间节点的祖先节点和父节点的值。
示例代码:1
2
3
4
5
6
7
8
9
10
11
12void add(int x, int y, int l, int r, int c, int k) {
if (l <= x && y <= r) {
tree[k] += (r - l + 1) * c;
lazy[k] += c;
return;
}
int mid = (x + y) / 2;
if (l >= mid + 1)add(mid + 1, y, l, r, c, 2 * k + 1);
else if (r <= mid)add(x, mid, l, r, c, 2 * k);
else add(x, mid, l, mid, c, 2 * k), add(mid + 1, y, mid + 1, r, c, 2 * k + 1);
tree[k] = tree[2 * k] + tree[2 * k + 1];
}
那么如何使用这个标记?这里需要用到down()函数来完成Lazy下压实现操作。
除了构造线段树,当任何时候需要使用一个节点时,需要先检查该节点的Lazy是否为0。若为0,不必调用down();否则应调用down()来下压标记。
down()的具体实现如下:
1. 两个子节点的值自加上它们数据量与该节点Lazy的乘积,更新自身数据。
2. 两个子节点的Lazy自加该节点的Lazy,来继承父节点的标记。
3. 该节点Lazy清空为0,表示下压完成。
值得注意的是,叶子节点由于没有子节点,它的Lazy是没有意义的。
示例代码:1
2
3
4
5
6
7void down(int l, int r, int k) {
if (l == r)return;
int mid = (l + r) / 2;
tree[2 * k] += (mid - l + 1) * lazy[k], tree[2 * k + 1] += (r - mid) * lazy[k];
lazy[2 * k] += lazy[k], lazy[2 * k + 1] += lazy[k];
lazy[k] = 0;
}
维护区间最值类比即可。
线段树进阶
新加一种操作:给区间[l,r]上的所有数乘上某个数,如何用线段树维护区间和?
新引入一个懒标记表示乘法是显然的,但是问题并没有想像中那样简单。
如果下压标记时发现该结点同时有加法标记和乘法标记,究竟是先加再乘还是先乘再加?顺序的不同显然会影响结果。如果原有的数为a,考虑先加上b再乘上c的结果$(a+b)c=ac+bc$和顺序反过来的结果$ac+b$。
可以发现,无论是先乘再加还是先加再乘,原数一定会被乘,区别只是加数要不要乘的问题。显然先加再乘可以转化为先乘然后再加上一个“处理后”的加数,这就是我们之后操作的原理。我们所有的步骤都基于先乘再加的思想。
给每一个结点开两个标记:add、muti,(下面代码中用的lazy1和lazy2)分别表示加法标记和乘法标记。初始化add=0,muti=1(注意是1)。
懒标记的意义是:这个结点的所有子结点需要先乘上muti,再加上add。(先乘再加思想)
有关加法的任何操作都与普通的线段数相同。
乘法的步骤便有所不同:
- 将该结点值乘上p(假设需要乘的是p)
- 乘法标记乘上p(表示下面的子树也要乘p)
- 加法标记乘上p(这是很重要的一步!!)
为什么还要处理加法标记呢?前文已经说明我们要先乘再加,先加则用处理加数的方法实现。如果发现当前结点已经有了加法标记,说明“先加”,那么必须给这个加数也乘上p,这就相当于处理了加数,从而把先加再乘转化为等价的先乘再加问题。
懒标记下压也有所不同,根据先乘再加的思想,有以下步骤:
- 该两个子结点值乘上muti(先乘)
- 两个子结点乘法标记乘上muti(乘法标记继承)
- 两个子结点加法标记乘上muti(处理加数!!)
- 恢复muti=1(清空乘法标记)
- 两个子结点加上区间长度乘以add(再加)
- 两个子结点加法标记加上add(加法标记继承)
- 恢复add=0(清空加法标记)
下面是线段树模板2的示例代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
using namespace std;
long long tree[MAX * 4], op[MAX], lazy1[MAX * 4], lazy2[MAX * 4], n, m, mod;
inline long long read() {
char e = getchar();
while ((e < '0' || e > '9') && (e != '-'))e = getchar();
bool k = false;
long long s = 0;
if (e == '-')k = true, e = getchar();
while (e >= '0' && e <= '9')s = s * 10 + e - '0', e = getchar();
return k ? -s : s;
}
void build(int l, int r, int k) {
lazy1[k] = 0, lazy2[k] = 1;
if (l == r) {
tree[k] = op[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, 2 * k), build(mid + 1, r, 2 * k + 1), tree[k] = (tree[2 * k] + tree[2 * k + 1]) % mod;
}
void down(int x, int l, int r) {
int ls = 2 * x, rs = 2 * x + 1, mid = (l + r) / 2;
lazy2[ls] = (lazy2[ls] * lazy2[x]) % mod, lazy2[rs] = (lazy2[rs] * lazy2[x]) % mod;
lazy1[ls] = (lazy1[ls] * lazy2[x]) % mod, lazy1[rs] = (lazy1[rs] * lazy2[x]) % mod;
tree[ls] = (tree[ls] * lazy2[x]) % mod, tree[rs] = (tree[rs] * lazy2[x]) % mod;
lazy2[x] = 1;
lazy1[ls] = (lazy1[ls] + lazy1[x]) % mod, lazy1[rs] = (lazy1[rs] + lazy1[x]) % mod;
tree[ls] = (tree[ls] + (mid - l + 1) * lazy1[x]) % mod;
tree[rs] = (tree[rs] + (r - mid) * lazy1[x]) % mod;
lazy1[x] = 0;
}
long long find(int a, int b, int l, int r, int k) {
if (l >= a && r <= b)return tree[k];
if (l < r && (lazy1[k] != 0 || lazy2[k] != 1))down(k, l, r);
int mid = (l + r) >> 1;
if (a >= mid + 1)return find(a, b, mid + 1, r, 2 * k + 1);
if (b <= mid)return find(a, b, l, mid, 2 * k);
return (find(a, mid, l, mid, 2 * k) + find(mid + 1, b, mid + 1, r, 2 * k + 1)) % mod;
}
void add(int a, int b, int l, int r, long long s, int k) {
if (l >= a && r <= b) {
tree[k] = ((r - l + 1) * s + tree[k]) % mod;
lazy1[k] += s, lazy1[k] %= mod;
return;
}
if (l < r && (lazy1[k] != 0 || lazy2[k] != 1))down(k, l, r);
int mid = (l + r) >> 1;
if (mid + 1 <= a)add(a, b, mid + 1, r, s, 2 * k + 1);
else if (b <= mid)add(a, b, l, mid, s, 2 * k);
else add(a, mid, l, mid, s, 2 * k), add(mid + 1, b, mid + 1, r, s, 2 * k + 1);
tree[k] = (tree[2 * k] + tree[2 * k + 1]) % mod;
}
void times(int a, int b, int l, int r, long long s, int k) {
if (l >= a && r <= b) {
tree[k] = tree[k] * s % mod, lazy2[k] *= s, lazy2[k] %= mod, lazy1[k] *= s, lazy1[k] %= mod;
return;
}
if (l < r && (lazy1[k] != 0 || lazy2[k] != 1))down(k, l, r);
int mid = (l + r) >> 1;
if (mid + 1 <= a)times(a, b, mid + 1, r, s, 2 * k + 1);
else if (b <= mid)times(a, b, l, mid, s, 2 * k);
else times(a, mid, l, mid, s, 2 * k), times(mid + 1, b, mid + 1, r, s, 2 * k + 1);
tree[k] = (tree[2 * k] + tree[2 * k + 1]) % mod;
}
int main() {
n = read(), m = read(), mod = read();
for (int i = 1; i <= n; i++)op[i] = read();
build(1, n, 1);
for (int i = 0; i < m; i++) {
long long x, a, b, c;
x = read();
if (x == 1)a = read(), b = read(), c = read(), times(a, b, 1, n, c, 1);
else if (x == 2)a = read(), b = read(), c = read(), add(a, b, 1, n, c, 1);
else a = read(), b = read(), cout << find(a, b, 1, n, 1) << endl;
}
return 0;
}
另一个进阶(其实就是线段树另一个功能):维护区间连续和最大(小)值,模板题:戳这里。下面以最大值为例。
方法是维护用线段树维护四个值:anssum、lsum、rsum、totsum。分别表示这个区间最大连续和,从左端点开始的最大连续和、从右端点开始的最大连续和以及区间和。这四个值可以互相递推。
示例代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
using namespace std;
struct NP {//线段树数据结构体定义
long long lsum, rsum, totsum, anssum;
} tree[N << 2];
int n, m, op[N << 2];
inline int readInt() {
char e = getchar();
int s = 0;
bool g = false;
while (e < '-')e = getchar();
if (e == '-')g = true, e = getchar();
while (e > '-')s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return g ? -s : s;
}
inline void update(int k) {//递推出来
tree[k].lsum = max(tree[k << 1].lsum, tree[k << 1].totsum + tree[k << 1 | 1].lsum);
tree[k].rsum = max(tree[k << 1 | 1].rsum, tree[k << 1 | 1].totsum + tree[k << 1].rsum);
tree[k].totsum = tree[k << 1].totsum + tree[k << 1 | 1].totsum;
tree[k].anssum = max(tree[k << 1 | 1].lsum + tree[k << 1].rsum,
max(tree[k << 1].anssum, tree[k << 1 | 1].anssum));
}
void build(int l, int r, int k = 1) {//构建线段树
if (l == r) {
tree[k].anssum = tree[k].totsum = tree[k].lsum = tree[k].rsum = op[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, k << 1), build(mid + 1, r, k << 1 | 1);
update(k);
}
void modify(int pos, int what, int l, int r, int k = 1) {//单点修改
if (l == r) {
tree[k].anssum = tree[k].totsum = tree[k].lsum = tree[k].rsum = what;
return;
}
int mid = (l + r) >> 1;
if (pos > mid)modify(pos, what, mid + 1, r, k << 1 | 1);
else modify(pos, what, l, mid, k << 1);
update(k);
}
NP query(int l, int r, int a, int b, int k = 1) {//查询
if (a >= l && b <= r)return tree[k];
int mid = (a + b) >> 1;
if (l > mid)return query(l, r, mid + 1, b, k << 1 | 1);
else if (r <= mid)return query(l, r, a, mid, k << 1);
else {
NP t1 = query(l, mid, a, mid, k << 1), t2 = query(mid + 1, r, mid + 1, b, k << 1 | 1), tmp;
tmp.totsum = t1.totsum + t2.totsum;//下面是合并答案的过程
tmp.lsum = max(t1.lsum, t1.totsum + t2.lsum);
tmp.rsum = max(t2.rsum, t2.totsum + t1.rsum);
tmp.anssum = max(t1.rsum + t2.lsum, max(t1.anssum, t2.anssum));
return tmp;
}
}
int main() {
n = readInt();
for (int i = 1; i <= n; i++)op[i] = readInt();
build(1, n);
m = readInt();
for (int i = 1; i <= m; i++) {
int a = readInt(), b = readInt(), c = readInt();
if (a == 0)modify(b, c, 1, n);
else printf("%lld\n", query(b, c, 1, n).anssum);
}
return 0;
}