【学习笔记】整体二分

哇,一直想学整体二分诶,终于学会了233

十分感谢李煜东的蓝书!感觉从网上找讲解,一直看不明白呢…

整体二分,其实本质上跟线段树上二分、splay上二分很相似。是对值域的整体分治。

首先考虑一个简单的问题,如何求解整个序列的第 $k$ 大。有这么一种二分的方法。最开始取一个 $L=1$,一个 $R=\rm MAX$。如果当前 $L\sim mid$ 的数字个数不足 $k$ 个,那么令 $k$ 减去这些数量,并将 $mid+1$ 作为新的 $L$ 继续二分。可以证明这样做也是 $\log$ 的。

为什么要说这种方法?本质上,求出整个序列的第 $k$ 大二分方法有很多,比如你可以直接二分答案。但实际上,只有上述二分方式具有较强的可分治性,每次可以直接舍弃掉 $[L,mid]$ 的全部内容,每次问题规模都小一半;但显然直接去二分第 $k$ 大不具备这个性质。

所以从这个例子中,或许可以得出一些启发性的分治方式。

$\rm Part1$ 简述整体二分

考虑现在给定一个序列有 $n$ 个数,$m$ 组询问。每次询问一个区间的第 $k$ 小的数。 可以离线

考虑首先离散化,并且把序列中一开始给出的数当作插入操作。然后:

  • 1、每次二分,在把答案分治掉的同时,需要对整个序列的元素也分治。所以需要对整个序列进行重排,很简单地 $double-l$ 扫一遍再合并就可以了。
  • 2、考虑采用上方说的方法去处理插入和询问。对于一个插入操作,如果插入的数值 $<$ 当前二分的值域的 $mid$,那么就直接插入BIT并且放到前一半,否则放到后一半不管;对于一个询问,可以用树状数组求出现在区间 $[ql, qr]$ 内有多少被插入过的数——根据上文讨论得到的解法,考虑如果 $qk$ 比当前值大,说明之前还有比当前值域的 $mid$ 大的数没有被插入,所以把当前的询问减去当前的 $res$ 后放到后一半;否则什么都不处理,放到前一半。
  • 3、考虑这么做的正确性。发现插入操作一定在询问操作之前,所以不需要担心询问扑空;同时发现对于一个单独的询问,实际上就是进行了上文中类似全局二分的操作。
  • 4、每次用完BIT要清零。
  • 5、s发现每次值域规模、元素规模均减半。于是最后复杂度就是 $O(m\log^2 n)$。而如果一开始不离散化,复杂度就会变成 $O(m\log \mathrm{SIZE}\log n)$ 。值得注意的是,不离散化也根本不需要担心空间会挂,因为在分治整个值域的时候,把值域作为分治轴可以使得只用下标就可以计算出贡献。

然后以下是 LG3834【模板】可持久化线段树1 (主席树),裸的区间 $k$ 小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <cstdio>
#include <iostream>
#include <algorithm>

#define MAXN 400010
#define low(x) (x & (-x))

using namespace std ;

struct query{
int op, x, y, z ;
}q[MAXN], lq[MAXN], rq[MAXN] ;

int N, M, base[MAXN] ;
int ans[MAXN], _bit[MAXN], t[MAXN], cnt ;

void upd(int p, int x){
for ( ; p <= N ; p += low(p)) _bit[p] += x ;
}
int ask(int p){
int ret = 0 ;
for ( ; p ; p -= low(p)) ret += _bit[p] ;
return ret ;
}
void solve(int vl, int vr, int l, int r){
if (l > r) return ;
if (vl == vr){
for (int i = l ; i <= r ; ++ i)
if (q[i].op > 0) ans[q[i].op] = vl ;
return void() ;
}
int vmid = (vl + vr) >> 1, bl = 0, br = 0 ;
for (int i = l ; i <= r ; ++ i){
if (!q[i].op){
if (q[i].y <= vmid)
upd(q[i].x, 1), lq[++ bl] = q[i] ;
else rq[++ br] = q[i] ;
}
else {
int res = ask(q[i].y) - ask(q[i].x - 1) ;
if (res >= q[i].z) lq[++ bl] = q[i] ;
else q[i].z -= res, rq[++ br] = q[i] ;
}
}
// cout << l << " " << r << " " << bl << " " << br << endl ;
for (int i = l ; i <= r ; ++ i)
if (!q[i].op && q[i].y <= vmid) upd(q[i].x, -1) ;
for (int i = 1 ; i <= bl ; ++ i) q[l + i - 1] = lq[i] ;
for (int i = 1 ; i <= br ; ++ i) q[l + bl + i - 1] = rq[i] ;
solve(vl, vmid, l, l + bl - 1), solve(vmid + 1, vr, l + bl, r) ;
}

int main(){
cin >> N >> M ;
int n, val, l, r, k ;
for (int i = 1 ; i <= N ; ++ i)
scanf("%d", &base[i]), t[i] = base[i] ;
sort(t + 1, t + N + 1) ;
n = unique(t + 1, t + N + 1) - t - 1 ;
// for (int i = 1 ; i <= n ; ++ i) cout << t[i] << endl ;
for (int i = 1 ; i <= N ; ++ i){
val = lower_bound(t + 1, t + n + 1, base[i]) - t ;
++ cnt, q[cnt].op = 0, q[cnt].x = i, q[cnt].y = val ;
// cout << q[cnt].y << endl ;
}
for (int i = 1 ; i <= M ; ++ i){
scanf("%d%d%d", &l, &r, &k), ++ cnt ;
q[cnt].op = i, q[cnt].x = l, q[cnt].y = r, q[cnt].z = k ;
}
solve(0, N * 2, 1, cnt) ;
// for (int i = 1 ; i <= N ; ++ i) cout << ans[i] << endl ;
for (int i = 1 ; i <= M ; ++ i) printf("%d\n", t[ ans[i] ]) ;
return 0 ;
}

$\rm Part2$ 如何处理简单带修

发现其实每个带修操作都是单点覆盖的话,可以拆成一个插入一个删除,而删除操作是BIT可维护的。

以下是 bzoj#1901 Dynamic Rankings ,裸的区间 $k$ 小+单点替换:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#define MAXN 500010
#define Inf 1000000009
#define low(x) (x & (-x))

using namespace std ;

char c[5] ;
struct query{
int op, x, y, z ;
}q[MAXN], lq[MAXN], rq[MAXN] ;
int N, M, base[MAXN], _bit[MAXN], ans[MAXN], cnt ;

void upd(int p, int x){
for ( ; p <= N ; p += low(p)) _bit[p] += x ;
}
int ask(int p){
int ret = 0 ;
for ( ; p ; p -= low(p)) ret += _bit[p] ;
return ret ;
}
void solve(int vl, int vr, int l, int r){
if (l > r) return ;
if (vl == vr){
for (int i = l ; i <= r ; ++ i)
if (q[i].op > 0) ans[q[i].op] = vr ;
return void() ;
}
int vmid = (vl + vr) >> 1, bl = 0, br = 0 ;
for (int i = l ; i <= r ; ++ i){
if (q[i].op == -1){
if (q[i].y <= vmid)
upd(q[i].x, 1), lq[++ bl] = q[i] ;
else rq[++ br] = q[i] ;
}
else if (q[i].op == -2){
if (q[i].y <= vmid)
upd(q[i].x, -1), lq[++ bl] = q[i] ;
else rq[++ br] = q[i] ;
}
else {
int res ;
res = ask(q[i].y) - ask(q[i].x - 1) ;
if (res >= q[i].z) lq[++ bl] = q[i] ;
else q[i].z -= res, rq[++ br] = q[i] ;
}
}
for (int i = l ; i <= r ; ++ i)
if (q[i].op == -1 && q[i].y <= vmid) upd(q[i].x, -1) ;
else if (q[i].op == -2 && q[i].y <= vmid) upd(q[i].x, 1) ;
for (int i = 1 ; i <= bl ; ++ i) q[l + i - 1] = lq[i] ;
for (int i = 1 ; i <= br ; ++ i) q[l + bl + i - 1] = rq[i] ;
solve(vl, vmid, l, l + bl - 1), solve(vmid + 1, vr, l + bl, r) ;
}
int main(){
int l, r, k ;
ios::sync_with_stdio(0) ;
cin.tie(0), cout.tie(0), cin >> N >> M ;
for (int i = 1 ; i <= N ; ++ i) cin >> base[i] ;
for (int i = 1 ; i <= N ; ++ i)
q[++ cnt].y = base[i],
q[cnt].x = i, q[cnt].op = -1 ;
memset(ans, -1, sizeof(ans)) ;
for (int i = 1 ; i <= M ; ++ i){
cin >> (c + 1) ;
if (c[1] == 'Q')
cin >> l >> r >> k, q[++ cnt].op = i,
q[cnt].x = l, q[cnt].y = r, q[cnt].z = k ;
else {
cin >> l >> k ;
q[++ cnt].y = base[l],
q[cnt].x = l, q[cnt].op = -2 ;
q[++ cnt].y = (base[l] = k),
q[cnt].x = l, q[cnt].op = -1 ;
}
}
solve(-Inf, Inf, 1, cnt) ;
for (int i = 1 ; i <= M ; ++ i)
if (ans[i] >= 0) cout << ans[i] << endl ;
return 0 ;
}

$\rm Part3$ 一道拓展

你需要维护 $n$ 个可重整数集,集合的编号从 $1$ 到 $n$。
这些集合初始都是空集,有 $m$ 个操作:

1 l r c:表示将 $c$ 加入到编号在 $[l,r]$ 内的集合中
2 l r c: 表示查询编号在 $[l,r]$ 内的集合的并集中,第 $c$ 大的数是多少。

发现这东西本质上就是把单点插入变成了区间插入,所以拿线段树维护一下区间和即可。并且由于查询 $K$ 大,原来的二分方式需要对称一下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <iostream>

#define il inline
#define MAXN 100010
#define LL long long
#define Inf 1000000009

using namespace std ;

struct query{
int op, x, y ; LL z ;
}q[MAXN], lq[MAXN], rq[MAXN] ;
int N, M, base[MAXN], ans[MAXN], cnt ;

LL s[MAXN << 2], cov[MAXN << 2], tag[MAXN << 2] ;

il void _up(int rt){ s[rt] = s[rt << 1] + s[rt << 1 | 1] ; }
il void _down(int rt, int l, int r){
if (cov[rt]){
tag[rt << 1] = s[rt << 1] = 0 ;
tag[rt << 1 | 1] = s[rt << 1 | 1] = 0 ;
cov[rt << 1] = cov[rt << 1 | 1] = 1, cov[rt] = 0 ;
}
if (tag[rt]){
int mid = (l + r) >> 1 ;
tag[rt << 1] += tag[rt] ;
tag[rt << 1 | 1] += tag[rt] ;
s[rt << 1 | 1] += 1ll * (r - mid) * tag[rt] ;
s[rt << 1] += 1ll * (mid - l + 1) * tag[rt] ;
tag[rt] = 0 ;
}
}
void update(int rt, int l, int r, int ul, int ur, int v){
if (l >= ul && r <= ur)
return tag[rt] += v,
s[rt] += v * (r - l + 1), void() ;
int mid = (l + r) >> 1 ; _down(rt, l, r) ;
if (ul <= mid) update(rt << 1, l, mid, ul, ur, v) ;
if (ur > mid) update(rt << 1 | 1, mid + 1, r, ul, ur, v) ;
_up(rt) ;
}
LL query(int rt, int l, int r, int ql, int qr){
if (l >= ql && r <= qr) return s[rt] ;
int mid = (l + r) >> 1 ; LL res = 0 ; _down(rt, l, r) ;
if (ql <= mid) res += query(rt << 1, l, mid, ql, qr) ;
if (qr > mid) res += query(rt << 1 | 1, mid + 1, r, ql, qr) ;
return res ;
}
void solve(int vl, int vr, int l, int r){
if (l > r) return ;
if (vl == vr){
for (int i = l ; i <= r ; ++ i)
if (q[i].op > 0) ans[q[i].op] = vr ;
return void() ;
}
cov[1] = 1, tag[1] = s[1] = 0 ;
int vmid = (vl + vr) >> 1, bl = 0, br = 0 ;
for (int i = l ; i <= r ; ++ i){
if (!q[i].op){
if (q[i].z > vmid)
update(1, 1, N, q[i].x, q[i].y, 1), rq[++ br] = q[i] ;
else lq[++ bl] = q[i] ;
}
else {
LL res ;
res = query(1, 1, N, q[i].x, q[i].y) ;
if (res >= q[i].z) rq[++ br] = q[i] ;
else q[i].z -= res, lq[++ bl] = q[i] ;
}
}
for (int i = 1 ; i <= bl ; ++ i) q[l + i - 1] = lq[i] ;
for (int i = 1 ; i <= br ; ++ i) q[l + bl + i - 1] = rq[i] ;
solve(vl, vmid, l, l + bl - 1), solve(vmid + 1, vr, l + bl, r) ;
}
int main(){
LL op, l, r, k ;
ios::sync_with_stdio(0) ;
cin.tie(0), cout.tie(0), cin >> N >> M ;
for (int i = 1 ; i <= M ; ++ i){
cin >> q[i].op >> q[i].x >> q[i].y >> q[i].z ;
if (q[i].op == 1) q[i].op = 0 ; else q[i].op = ++ cnt ;
}
solve(-N, N, 1, M) ;
for (int i = 1 ; i <= cnt ; ++ i) cout << ans[i] << '\n' ;
return 0 ;
}