P5217 贫穷 fhq treap

(8 mins to read)

给定一个字符串序列要求支持以下操作

  • 在x位置后插入一个字母
  • 删除x位置的字母
  • 翻转区间[x,y]
  • 输出初始字符串序列第x个字母当前位置
  • 输出当前字符串序列第x个字母
  • 输出区间[x,y]字母种类数

前三个很简单。第四个要查节点号为x的字母的位置,由于是维护区间,按size分裂,所以自上而下找不了,只能额外维护父亲,然后自下而上找。由于第三个区间翻转操作,导致左右儿子可能会改变,所以要先到根节点把所有懒标记pushdown下来,再自下而上找才能正确第五个就是查第k大(区间意义下),直接split和merge比较好,如果是从根递归着走,不要忘记pushdown第六个状压维护一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
char s[N];
int n, m;
int rt, tot;
int rnd() { return rand()<<15|rand(); }
struct node
{
int ls, rs;
int mask, sz, fa, key, c;
bool rev;
}t[N];
int newnode(int c)
{
int p = ++tot;
t[p].ls = t[p].rs = t[p].fa = 0;
t[p].mask = 1<<c, t[p].c = c, t[p].key = rnd();
t[p].sz = 1;
t[p].rev = 0;
return p;
}
void up(int p)
{
t[p].sz = t[t[p].ls].sz + t[t[p].rs].sz + 1;
t[p].mask = t[t[p].ls].mask | t[t[p].rs].mask | (1<<t[p].c);
}
void down(int p)
{
if(t[p].rev)
{
swap(t[p].ls, t[p].rs);
t[t[p].ls].rev ^= 1;
t[t[p].rs].rev ^= 1;
t[p].rev = 0;
}
}
void split(int p, int &x, int &y, int k, int fx=0, int fy=0)
{
if(!p) { x = y = 0; return; }
down(p);
if(k<=t[t[p].ls].sz)
{
y = p;
split(t[p].ls, x, t[p].ls, k, fx, p);
}
else
{
x = p;
split(t[p].rs, t[p].rs, y, k-t[t[p].ls].sz-1, p, fy);
}
if(x) t[x].fa = fx;
if(y) t[y].fa = fy;
up(p);
}
void merge(int &p, int x, int y)
{
if(!x||!y) { p = x|y; return; }
if(t[x].key<t[y].key)
{
down(x);
p = x;
merge(t[p].rs, t[p].rs, y);
t[t[p].rs].fa = p;
}
else
{
down(y);
p = y;
merge(t[p].ls, x, t[p].ls);
t[t[p].ls].fa = p;
}
up(p);
}
void downall(int p)
{
if(!p) return;
downall(t[p].fa);
down(p);
}
int get(int p)
{
downall(p);
int rk = t[t[p].ls].sz + 1;
while(t[p].fa)
{
if(t[t[p].fa].rs==p) rk += t[t[p].fa].sz - t[p].sz;
p = t[p].fa;
}
return rk*(p==rt);
}
int kth(int p, int k)
{
down(p);
if(k<=t[t[p].ls].sz) return kth(t[p].ls, k);
else if(k==t[t[p].ls].sz+1) return t[p].c;
else return kth(t[p].rs, k-t[t[p].ls].sz-1);
}
int main()
{
scanf("%d%d", &n, &m);
scanf("%s", s+1);
for(int i=1; i<=n; i++)
{
int x = newnode(s[i]-'a');
merge(rt, rt, x);
}
while(m--)
{
static char opt[2], c[2];
scanf("%s", opt);
if(opt[0]=='I')
{
int k;
scanf("%d%s", &k, c);
int x, y, z;
split(rt, x, y, k);
z = newnode(c[0]-'a');
merge(x, x, z);
merge(rt, x, y);
}
else if(opt[0]=='D')
{
int k;
scanf("%d", &k);
int x, y, z;
split(rt, x, y, k);
split(x, x, z, k-1);
merge(rt, x, y);
}
else if(opt[0]=='R')
{
int l, r; scanf("%d%d", &l, &r);
int x, y, z;
split(rt, x, y, r);
split(x, z, x, l-1);
t[x].rev ^= 1;
merge(x, z, x);
merge(rt, x, y);
}
else if(opt[0]=='P')
{
int k;
scanf("%d", &k);
printf("%d\n", get(k));
}
else if(opt[0]=='T')
{
int k;
scanf("%d", &k);
printf("%c\n", char(kth(rt, k)+'a'));
}
else
{
int l, r; scanf("%d%d", &l, &r);
int x, y, z;
split(rt, x, y, r);
split(x, z, x, l-1);
printf("%d\n", __builtin_popcount(t[x].mask));
merge(x, z, x);
merge(rt, x, y);
}
}
return 0;
}