平衡树之替罪羊树

(6 mins to read)

很简单的思路,在每次插入和删除的时候,检查一下左右子树是不是失衡,如果失衡就将整棵子树进行拍扁(中序遍历存到数组),重构(每次以中心点为当前的根,然后递归左右区间)。失衡的判定:当左子树或者右子树的大小超过了整个子树大小的某个比例alpha(一般定义在0.7~0.8)

1
bool bad(int p) { return t[p].cnt && t[p].sz*alpha<max(t[t[p].ls].sz, t[t[p].rs].sz); }

几个细节:

  • 删除的实现采用最简单的惰性删除,找到对应的节点后让cnt–即可,即使等于0也没有关系,因为在拍扁重构的时候可以把等于0的节点直接丢弃掉
  • 由于插入删除的时候可能需要重构,因此一定要记得传引用
  • 这种平衡树最大的特点就是好写加常数小(均摊mlogn)

记住拍扁重构的思想应该就很容易手写了吧

前驱、后继、第k大、排名其实没必要每个都实现一个函数。只需要实现求第k大,然后求小于x的数的个数以及小于等于x的数的个数即可,只与后两者只有细微的差异,完全可以合并到一个函数中洛谷模板p3369开O2 170ms:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
const double alpha = 0.7;
struct node
{
int ls, rs;
int val, sz, cnt;
}t[N];
int rt, tot;
void up(int p) { t[p].sz = t[t[p].ls].sz + t[t[p].rs].sz + t[p].cnt; }
bool bad(int p) { return t[p].cnt && t[p].sz*alpha<max(t[t[p].ls].sz, t[t[p].rs].sz); }
int ord[N], top;
void dfs(int p)
{
if(!p) return;
dfs(t[p].ls);
if(t[p].cnt) ord[++top] = p;
dfs(t[p].rs);
}
int build(int l, int r)
{
if(l>r) return 0;
int mid = (l+r)>>1, p = ord[mid];
t[p].ls = build(l, mid-1);
t[p].rs = build(mid+1, r);
up(p);
return p;
}
void rebuild(int &p)
{
top = 0;
dfs(p);
p = build(1, top);
}
void newnode(int &p, int v)
{
p = ++tot;
t[p].val = v, t[p].ls = t[p].rs = 0;
t[p].cnt = t[p].sz = 1;
}
void ins(int &p, int v)
{
if(!p) newnode(p, v);
else
{
if(v==t[p].val) t[p].cnt++;
else if(v<t[p].val) ins(t[p].ls, v);
else ins(t[p].rs, v);
up(p);
if(bad(p)) rebuild(p);
}
}
void del(int &p, int v)
{
if(v==t[p].val) t[p].cnt--;
else if(v<t[p].val) del(t[p].ls, v);
else del(t[p].rs, v);
up(p);
if(bad(p)) rebuild(p);
}
int Less(int p, int v, bool eq)
{
if(!p) return 0;
if(t[p].cnt&&v==t[p].val) return t[t[p].ls].sz + t[p].cnt*eq;
else if(v<t[p].val) return Less(t[p].ls, v, eq);
else return t[t[p].ls].sz + t[p].cnt + Less(t[p].rs, v, eq);
}
int kth(int p, int k)
{
if(k<=t[t[p].ls].sz) return kth(t[p].ls, k);
else if(t[t[p].ls].sz+t[p].cnt>=k) return t[p].val;
else return kth(t[p].rs, k-t[p].cnt-t[t[p].ls].sz);
}
int main()
{
int n; scanf("%d", &n);
while(n--)
{
int t, x;
scanf("%d%d", &t, &x);
if(t==1) ins(rt, x);
else if(t==2) del(rt, x);
else if(t==3) printf("%d\n", Less(rt, x, 0)+1);
else if(t==4) printf("%d\n", kth(rt, x));
else if(t==5) printf("%d\n", kth(rt, Less(rt, x, 0)));
else printf("%d\n", kth(rt, Less(rt, x, 1)+1));
}
return 0;
}