难度:省选/NOI-

题目描述

        请写一个程序,要求维护一个数列,支持以下 6 种操作:(请注意,格式栏 中的下划线‘ _ ’表示实际输入文件中的空格)

输入格式

        输入文件的第 1 行包含两个数 N 和 M,N 表示初始时数列中数的个数,M 表示要进行的操作数目。 第 2 行包含 N 个数字,描述初始时的数列。 以下 M 行,每行一条命令,格式参见问题描述中的表格。

输出格式

        对于输入数据中的 GET-SUM 和 MAX-SUM 操作,向输出文件依次打印结 果,每个答案(数字)占一行。

输入输出样例

Sample input

9 8
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM

Sample output

-1
10
1
10

说明

        你可以认为在任何时刻,数列中至少有 1 个数。
        输入数据一定是正确的,即指定位置的数在数列中一定存在。
        50%的数据中,任何时刻数列中最多含有 30000 个数;
        100%的数据中,任何时刻数列中最多含有 500000 个数。
        100%的数据中,任何时刻数列中任何一个数字均在[-1000, 1000]内。
        100%的数据中,M ≤20000,插入的数字总数不超过 4000000 。

题解

        很让人自闭的平衡树题,做这个题你需要知道平衡树相关的知识以及良好的心态。强烈推荐将本题代码多打几遍,能很好地提高码力。
        用平衡树维护数列的相对位置信息。对于插入和删除操作,直接在平衡树上进行就可以了。对于第三个操作,需要用一个懒标记,表明这棵子树上的结点点权需要统一修改。
        结点太多会炸空间,但注意到任意时刻最多有500000个数,故那些被删除的结点我们用一个队列回收,在需要开新结点时再拿出来复用。这样可以牺牲一些时间来换空间。
        对于翻转操作,只需要打上一个翻转懒标记就可以了(还记得文艺平衡树吗?)。求和自然也很简单,直接像线段树那样维护即可。比较麻烦的是最后一个,需要维护额外的两个信息,关于这部分见线段树一文的后半部分。
        本题有很多细节需要注意,一不留神就会疯狂WA,这里提一个比较隐晦的错误(我找了N个小时的bug)。
        这里面有两个懒标记,一个是区间整体修改标记,另一个是翻转标记。这两个标记的含义是不同的,第一个与线段树相同,表示它的子树需要修改,而这个结点本身已经经过更新,后者表示这棵树需要翻转(即尚未翻转),这一点点区别会带来很大的麻烦。事实证明,对于前者,如果在split和merge过程中不断下放标记,结点是可以得到更新的,但是后者不可以!!即使下压了标记,也会出现某些结点未得到更新而出现惨案。
        为了避免这种情况的发生,最好采用统一的标记下放方式,对于文艺平衡树也是如此。即标记代表其子树需要更新,而该结点本身已更新
        下面代码(需要卡常)给出一些注释,标明一些可能出现细节问题的BUG。这里平衡树用的非旋Treap。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <bits/stdc++.h>

#define N 510000
using namespace std;
struct Node {
int l, r, k, v, sum, num, la1, la2;
int max, lmax, rmax;
} nd[N];
queue<int> que;//回收队列
int root, n, m;
char opt[50];

inline void down(int x) {//下压标记
if (nd[x].la1 != 1000000000) {//标记表示已被处理,1000000000就是没有标记,不要用0、1、-1这样的数做无标记的指标
nd[nd[x].l].v = nd[nd[x].r].v = nd[x].la1;
nd[nd[x].l].la1 = nd[nd[x].r].la1 = nd[x].la1;
nd[nd[x].l].lmax = nd[nd[x].l].rmax = nd[nd[x].l].max =
nd[x].la1 >= 0 ? nd[x].la1 * nd[nd[x].l].num : nd[x].la1;
nd[nd[x].r].lmax = nd[nd[x].r].rmax = nd[nd[x].r].max =
nd[x].la1 >= 0 ? nd[x].la1 * nd[nd[x].r].num : nd[x].la1;
nd[nd[x].l].sum = nd[nd[x].l].num * nd[x].la1, nd[nd[x].r].sum = nd[nd[x].r].num * nd[x].la1;
nd[x].la1 = 1000000000;
}
if (nd[x].la2) {//第二个懒标记下压
swap(nd[nd[x].l].l, nd[nd[x].l].r), swap(nd[nd[x].l].lmax, nd[nd[x].l].rmax);//别忘了交换lmax和rmax
swap(nd[nd[x].r].l, nd[nd[x].r].r), swap(nd[nd[x].r].lmax, nd[nd[x].r].rmax);
nd[nd[x].l].la2 ^= 1, nd[nd[x].r].la2 ^= 1, nd[x].la2 = 0;
}
}

inline void update(int x) {//更新结点信息
nd[0].v = 0;
nd[x].num = nd[nd[x].l].num + nd[nd[x].r].num + 1;
nd[x].sum = nd[nd[x].l].sum + nd[nd[x].r].sum + nd[x].v;
down(nd[x].l), down(nd[x].r);
if (nd[x].l == 0 && nd[x].r == 0)nd[x].max = nd[x].lmax = nd[x].rmax = nd[x].v;//这一步和线段树有区别,最好分类讨论
else if (nd[x].l == 0) {
nd[x].lmax = max(nd[x].v, nd[x].v + nd[nd[x].r].lmax);
nd[x].rmax = max(nd[nd[x].r].rmax, nd[nd[x].r].sum + nd[x].v);
nd[x].max = max(nd[nd[x].r].max, nd[x].v + nd[nd[x].r].lmax), nd[x].max = max(nd[x].max, nd[x].v);
return;
} else if (nd[x].r == 0) {
nd[x].rmax = max(nd[x].v, nd[x].v + nd[nd[x].l].rmax);
nd[x].lmax = max(nd[nd[x].l].lmax, nd[nd[x].l].sum + nd[x].v);
nd[x].max = max(nd[nd[x].l].max, nd[x].v + nd[nd[x].l].rmax), nd[x].max = max(nd[x].max, nd[x].v);
} else {
nd[x].max = max(nd[nd[x].l].rmax + nd[nd[x].r].lmax + nd[x].v, max(nd[nd[x].l].max, nd[nd[x].r].max));
nd[x].max = max(nd[x].max, max(nd[nd[x].l].rmax, nd[nd[x].r].lmax) + nd[x].v);
nd[x].max = max(nd[x].max, nd[x].v);
nd[x].lmax = max(nd[nd[x].l].lmax,
max(nd[nd[x].l].sum + nd[nd[x].r].lmax + nd[x].v, nd[nd[x].l].sum + nd[x].v));
nd[x].rmax = max(nd[nd[x].r].rmax,
max(nd[nd[x].r].sum + nd[nd[x].l].rmax + nd[x].v, nd[nd[x].r].sum + nd[x].v));
}
}


inline int newNode(int v) {//分配新结点
static int cnt = 1;
int sv;
if (!que.empty())sv = que.front(), que.pop();//优先从回收队列中取
else sv = cnt++;
nd[sv].k = rand(), nd[sv].v = v, nd[sv].num = 1, nd[sv].max = nd[sv].rmax = nd[sv].lmax = v;
nd[sv].r = nd[sv].l = 0, nd[sv].sum = v, nd[sv].la1 = 1000000000, nd[sv].la2 = 0;
return sv;
}

void split(int rt, int &a, int &b, int s) {
if (rt == 0) {
a = b = 0;
return;
}
down(rt);//下压标记
if (nd[nd[rt].l].num + 1 <= s)a = rt, split(nd[rt].r, nd[a].r, b, s - nd[nd[rt].l].num - 1);
else b = rt, split(nd[rt].l, a, nd[b].l, s);
update(rt);
}

void merge(int &rt, int a, int b) {
down(a), down(b);
if (a == 0 || b == 0) {
rt = a + b;
return;
}
if (nd[a].k < nd[b].k)rt = b, merge(nd[rt].l, a, nd[b].l);
else rt = a, merge(nd[rt].r, nd[a].r, b);
update(rt);
}

inline void insert() {//插入
int pos, tot, pr = 0, a, b;
scanf("%d%d", &pos, &tot);
for (int i = 0, x; i < tot; i++)scanf("%d", &x), merge(pr, pr, newNode(x));
split(root, a, b, pos), merge(a, a, pr), merge(root, a, b);
}

void trans(int x) {//回收结点
if (x == 0)return;
que.push(x), trans(nd[x].l), trans(nd[x].r);
}

inline void del() {//删除
int pos, a, b, c, tot;
scanf("%d%d", &pos, &tot);
split(root, a, c, pos - 1), split(c, b, c, tot), merge(root, a, c), trans(b);//回收b树的空间
}

inline void ms() {//区间整体修改
int pos, tot, C, a, b, c;
scanf("%d%d%d", &pos, &tot, &C);
split(root, a, c, pos - 1), split(c, b, c, tot);
nd[b].v = C, nd[b].sum = nd[b].num * C, nd[b].max = nd[b].lmax = nd[b].rmax = C >= 0 ? C * nd[b].num : C;
nd[b].la1 = C, merge(a, a, b), merge(root, a, c);
}

inline void rs() {//翻转
int pos, tot, a, b, c;
scanf("%d%d", &pos, &tot);
split(root, a, c, pos - 1), split(c, b, c, tot);
swap(nd[b].l, nd[b].r), swap(nd[b].lmax, nd[b].rmax);
nd[b].la2 ^= 1, merge(a, a, b), merge(root, a, c);
//打完标记再合并
}

inline void gs() {//获得区间和
int pos, tot, a, b, c;
scanf("%d%d", &pos, &tot);
split(root, a, c, pos - 1), split(c, b, c, tot), printf("%d\n", nd[b].sum);
merge(a, a, b), merge(root, a, c);
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1, x; i <= n; i++)scanf("%d", &x), merge(root, root, newNode(x));
while (m--) {
scanf("%s", opt);
if (opt[0] == 'I')insert();
else if (opt[0] == 'D')del();
else if (opt[0] == 'M' && opt[2] == 'K')ms();
else if (opt[0] == 'R')rs();
else if (opt[0] == 'G')gs();
else printf("%d\n", nd[root].max);
}
return 0;
}

        常数小一些的:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <time.h>
#pragma GCC optimize(3)
using namespace std;
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define INF 0X3f3f3f3f
struct Node
{
int l, r, k, num, lz1, lz2, sum, lmax, rmax, max, v;
} nd[5000005];
int cnt = 1, root, pos, tmp[5000005];
char opt[20];
inline int newNode(int v)
{
nd[cnt].num = 1, nd[cnt].k = rand(), nd[cnt].l = nd[cnt].r = 0;
nd[cnt].v = nd[cnt].lmax = nd[cnt].rmax = nd[cnt].max = nd[cnt].sum = v;
nd[cnt].lz1 = 0, nd[cnt].lz2 = INF;
return cnt++;
}
inline void upd(int x)
{
if (x) {
nd[x].num = 1 + nd[nd[x].l].num + nd[nd[x].r].num;
nd[x].sum = nd[x].v + nd[nd[x].l].sum + nd[nd[x].r].sum;
nd[x].max = MAX(MAX(nd[nd[x].l].max, nd[nd[x].r].max), MAX(nd[x].v + nd[nd[x].l].rmax + nd[nd[x].r].lmax, MAX(nd[x].v, MAX(nd[x].v + nd[nd[x].l].rmax, nd[x].v + nd[nd[x].r].lmax))));
nd[x].lmax = MAX(nd[nd[x].l].lmax, MAX(nd[nd[x].l].sum + nd[x].v + nd[nd[x].r].lmax, nd[nd[x].l].sum + nd[x].v));
nd[x].rmax = MAX(nd[nd[x].r].rmax, MAX(nd[nd[x].r].sum + nd[x].v + nd[nd[x].l].rmax, nd[nd[x].r].sum + nd[x].v));
}
}
inline void solve(int x, int to)
{
nd[x].v = to, nd[x].sum = to * nd[x].num;
if (to >= 0) {
nd[x].lmax = nd[x].rmax = nd[x].max = nd[x].sum;
}
else {
nd[x].lmax = nd[x].rmax = nd[x].max = to;
}
nd[x].lz2 = to;
}
inline void pud(int x)
{
if (x && nd[x].lz1) {
if (nd[x].l) {
swap(nd[nd[x].l].l, nd[nd[x].l].r), swap(nd[nd[x].l].lmax, nd[nd[x].l].rmax);
nd[nd[x].l].lz1 ^= 1;
}
if (nd[x].r) {
swap(nd[nd[x].r].l, nd[nd[x].r].r), swap(nd[nd[x].r].lmax, nd[nd[x].r].rmax);
nd[nd[x].r].lz1 ^= 1;
}
nd[x].lz1 = 0;
}
if (x && nd[x].lz2 != INF) {
solve(nd[x].l, nd[x].lz2), solve(nd[x].r, nd[x].lz2);
nd[x].lz2 = INF;
}
}
void split(int rt, int& a, int& b, int num)
{
pud(rt);
if (rt == 0) {
a = b = 0;
return;
}
if (nd[nd[rt].l].num + 1 <= num)
a = rt, split(nd[rt].r, nd[a].r, b, num - nd[nd[rt].l].num - 1);
else
b = rt, split(nd[rt].l, a, nd[rt].l, num);
upd(rt);
}
void merge(int& rt, int a, int b)
{
pud(a), pud(b);
if (a == 0 || b == 0) {
rt = a + b;
return;
}
if (nd[a].k > nd[b].k)
rt = a, merge(nd[rt].r, nd[a].r, b);
else
rt = b, merge(nd[rt].l, a, nd[b].l);
upd(rt);
}
inline int read()
{
int s = 0, f = 0;
char e = getchar();
while (e < '-') e = getchar();
if (e == '-') f = 1, e = getchar();
while (e > '-') s = (s << 1) + (s << 3) + (e & 15), e = getchar();
return f ? -s : s;
}
void build(int l, int r, int& rt)
{
if (l > r) {
rt = 0;
return;
}
int mid = (l + r) >> 1;
rt = newNode(tmp[mid]), build(l, mid - 1, nd[rt].l), build(mid + 1, r, nd[rt].r);
upd(rt);
}
inline void add(int pos, int tot)
{
int a, b, c = 0;
split(root, a, b, pos);
for (int i = 1; i <= tot; i++) tmp[i] = read();
build(1, tot, c), merge(a, a, c), merge(root, a, b);
}
inline void del(int pos, int tot)
{
int a, b, c;
split(root, a, c, pos + tot), split(a, a, b, pos), merge(root, a, c);
}
inline void ch(int pos, int tot, int to)
{
int a, b, c;
split(root, a, c, pos + tot), split(a, a, b, pos);
solve(b, to);
merge(a, a, b), merge(root, a, c);
}
inline void res(int pos, int tot)
{
int a, b, c;
split(root, a, c, pos + tot), split(a, a, b, pos);
if (b) swap(nd[b].l, nd[b].r), swap(nd[b].lmax, nd[b].rmax), nd[b].lz1 ^= 1;
merge(a, a, b), merge(root, a, c);
}
inline int sum(int pos, int tot)
{
int a, b, c, ans;
split(root, a, c, pos + tot), split(a, a, b, pos);
ans = nd[b].sum;
merge(a, a, b), merge(root, a, c);
return ans;
}
int main()
{
srand(time(0));
int n, m, a, b, c;
nd[0].max = nd[0].lmax = nd[0].rmax = -INF;
n = read(), m = read();
for (int i = 1; i <= n; i++) tmp[i] = read();
build(1, n, root);
while (m--) {
scanf("%s", opt);
if (opt[0] == 'I') {
a = read(), b = read(), add(a, b);
}
else if (opt[0] == 'D') {
a = read(), b = read(), del(a - 1, b);
}
else if (opt[2] == 'K') {
a = read(), b = read(), c = read(), ch(a - 1, b, c);
}
else if (opt[0] == 'R') {
a = read(), b = read(), res(a - 1, b);
}
else if (opt[0] == 'G') {
a = read(), b = read(), printf("%d\n", sum(a - 1, b));
}
else if (opt[0] == 'M') {
printf("%d\n", nd[root].max);
}
}
return 0;
}