【学习笔记&题解】线段树优化建图

大概就是把边数从 $n^2$ 优化到 $n\log n$ 的一个操作?还是挺简单的吧。

$0$ 瞎扯

考虑三个场景:从 $[l_1,r_1]$ 连向 $x$, 从 $x$ 连向 $[l_2,r_2]$,从 $[l_1,r_1]$ 连向 $[l_2,r_2]$ 。

发现可以用两棵线段树来模拟这个东西,一棵线段树从 $rt$ 向 $lc,rc$ 连边,一棵从 $lc,rc$ 向 $rt$ 连边。只要把叶子节点设置为 $1\sim n$ ,这样最终点数就是 $O(n\log n)$ 的。稳得很。

然后就没有然后了。

$1$ CF786B Legacy

给定一张 $n$ 个点的图,有 $m$ 个下列操作:

  • 1.进行单点与单点连有向边

  • 2.进行单点与区间连有向边

  • 3.进行区间与单点连有向边。

求最短路。

发现就是个板子?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
typedef long long LL ;

const int N = 200010 ;
const int M = 400010 ;
const int NN = 2000010 ;
const LL Inf = (1ll << 60) ;

#define to(k) E[k].to
#define val(k) E[k].val
#define next(k) E[k].next

struct Edge{
int val ;
int to, next ;
}E[NN << 1] ;
LL dis[NN] ;
int n, m, s ;
bool vis[NN] ;
queue <int> q ;
int Id, head[NN], cnt ;
int rt1, rt2, lc[NN], rc[NN] ;

void add(int u, int v, int val){
to(++ cnt) = v, val(cnt) = val,
next(cnt) = head[u], head[u] = cnt ;
}
void build(int &rt, int l, int r, int p){
if (l == r)
return rt = l, void() ;
rt = ++ Id ;
int mid = (l + r) >> 1 ;
build(lc[rt], l, mid, p) ;
build(rc[rt], mid + 1, r, p) ;
if (!p) add(rt, lc[rt], 0), add(rt, rc[rt], 0) ;
else add(lc[rt], rt, 0), add(rc[rt], rt, 0) ;
}
void update(int rt, int l, int r, int u, int ul, int ur, int w, int p){
if (l >= ul && r <= ur){
p ? add(rt, u, w) : add(u, rt, w) ;
return ;
}
int mid = (l + r) >> 1 ;
if (ul <= mid) update(lc[rt], l, mid, u, ul, ur, w, p) ;
if (ur > mid) update(rc[rt], mid + 1, r, u, ul, ur, w, p) ;
}
int main(){
cin >> n ;
Id = n, cin >> m >> s ;
build(rt1, 1, n, 0) ;
build(rt2, 1, n, 1) ;
int mk, u, v, w, l, r ;
for (int i = 1 ; i <= m ; ++ i){
scanf("%d", &mk) ;
if (mk == 1)
scanf("%d%d%d", &u, &v, &w), add(u, v, w) ;
else if (mk == 2)
scanf("%d%d%d%d", &u, &l, &r, &w),
update(rt1, 1, n, u, l, r, w, mk - 2) ;
else
scanf("%d%d%d%d", &u, &l, &r, &w),
update(rt2, 1, n, u, l, r, w, mk - 2) ;
}
// for (int i = 1 ; i <= cnt ; ++ i) cout << E[i].to << " " << E[i].val << endl ;
// cout << cnt << endl ;
fill(dis, dis + Id + 1, Inf) ;
q.push(s), dis[s] = 0, vis[s] = 1 ;
while (!q.empty()){
int n = q.front() ;
vis[n] = 0, q.pop() ;
// cout << n << endl ;
for (int k = head[n] ; k ; k = next(k)){
if (dis[to(k)] > dis[n] + val(k)){
dis[to(k)] = dis[n] + val(k) ;
if (!vis[to(k)])
vis[to(k)] = 1, q.push(to(k)) ;
}
}
}
for (int i = 1 ; i <= n ; ++ i)
printf("%lld ", dis[i] == Inf ? -1 : dis[i]) ;
return 0 ;
}

$2$ LG3588 [POI2015]PUS

给定一个长度为 $n$ 的正整数序列 $a$ ,每个数都在 $1$ 到 $10^9$ 范围内,告诉你其中 $s$ 个数.

给出 $m$ 条信息,每条信息包含三个数 $l,r,k$ 以及接下来 $k$ 个正整数,表示 $a_l..a_{l+1}…a_{r-1}..a_r$ 里这 $k$ 个数中的任意一个都比任意一个剩下的 $r-l+1-k$ 个数大 (严格大于,即没有等号)

构造一组合法解或者输出 $-1$

似乎就是一个差分约束?发现每次需要从每个 $k_{i,j}$ 连向剩下的子区间。如果用线段树优化建图,这样就是 $\sum k_i^2\log k_i$ 的边数,炸的很惨。

这个地方有个很妙的 $\rm Idea$ ,即对每个操作建立一个新点 $p_i$,让每个 $k_{i,j}$ 连向 $p_i$,再让 $p_i$ 连向每个分割开的子区间。这样就成功地只需要 $\sum O(k_i\log k_i)$ 的空间,很稳。

话说这还是自己第一次真正去写差分约束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
void add(int u, int v, int val){
to(++ cnt) = v ;
val(cnt) = val, deg[v] ++ ;
next(cnt) = head[u], head[u] = cnt ;
}
void build(int &rt, int l, int r){
if (l == r)
return rt = l, void() ;
rt = ++ Id ;
int mid = (l + r) >> 1 ;
build(lc[rt], l, mid) ;
build(rc[rt], mid + 1, r) ;
add(lc[rt], rt, 0), add(rc[rt], rt, 0) ;
}
void update(int rt, int l, int r, int u, int ul, int ur, int w){
int mid = (l + r) >> 1 ;
if (l >= ul && r <= ur) return add(rt, u, w) ;
if (ul <= mid) update(lc[rt], l, mid, u, ul, ur, w) ;
if (ur > mid) update(rc[rt], mid + 1, r, u, ul, ur, w) ;
}
int main(){
int u, v, l, r, k ;
cin >> n >> s >> m ;
Id = n ; build(root, 1, n) ;
for (int i = 1 ; i <= s ; ++ i)
scanf("%d%d", &u, &v), base[u] = dis[u] = v ;
for (int i = 1 ; i <= m ; ++ i){
scanf("%d%d%d", &l, &r, &k) ;
Id ++, last = l - 1 ;
for (int j = 1 ; j <= k ; ++ j){
scanf("%d", &v) ; add(Id, v, 1) ;
if (last + 1 < v)
update(root, 1, n, Id, last + 1, v - 1, 0) ;
last = v ;
}
if (last < r)
update(root, 1, n, Id, last + 1, r, 0) ;
}
for (int i = 1 ; i <= Id ; ++ i){
if (!dis[i]) dis[i] = 1 ;
if (!deg[i]) q.push(i) ;
}
while (!q.empty()){
u = q.front() ; q.pop() ; vis[u] = 1 ;
for (int k = head[u] ; k ; k = next(k)){
dis[to(k)] = max(dis[to(k)], dis[u] + val(k)) ;
if (base[to(k)] && dis[to(k)] > base[to(k)])
return puts("NIE"), 0 ;
if (! -- deg[to(k)]) q.push(to(k)) ;
}
}
for (int i = 1 ; i <= Id ; ++ i)
if (!vis[i] || dis[i] > Inf) return puts("NIE"), 0 ;
puts("TAK") ;
for (int i = 1 ; i <= n ; ++ i) printf("%d ", dis[i]) ;
}

$3$ [SNOI2017]炸弹

在一条直线上有 $n$ 个炸弹,每个炸弹的坐标是 $x_i$,爆炸半径是 $r_i$。

当一个炸弹爆炸时,如果另一个炸弹所在位置 $x_j$ 满足: $|x_j-x_i| \le r_i$ ,那么,该炸弹也会被引爆。计算先把第 $i$ 个炸弹引爆,将引爆多少个炸弹。

答案对 $10^9+7$ 取模。保证 $x_i$ 随着 $i$ 单增。

发现可以先二分出每个点要连的左右端点。然后考虑缩点。由于是 $\rm DAG$ 就可以直接求出每个点能到达的点的编号的 $\min,\max$ ,这东西可以一遍 $\rm topsort$ 求出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <queue>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std ;

typedef long long LL ;

const int N = 4000010 ;
const int M = 8000010 ;
const int P = 1000000007 ;

#define fr(k) E[k].fr
#define to(k) E[k].to
#define val(k) E[k].val
#define next(k) E[k].next

struct Edge{
int next ;
int fr, to ;
}E[M] ;
LL ans ;
int tot ;
int deg[N] ;
bool vis[N] ;
int s[N], tp ;
queue <int> q ;
LL x[N], r[N] ;
int L[N], R[N] ;
int SCC, blg[N] ;
int n, Id, root ;
int head[N], cnt ;
int val[N * 3][2] ;
int dfn[N], low[N] ;
int lc[N * 3], rc[N * 3] ;

void add(int u, int v){
to(++ cnt) = v, fr(cnt) = u,
next(cnt) = head[u], head[u] = cnt ;
// cout << u << " " << v << endl ;
}
void build(int &rt, int l, int r){
if (l == r){
rt = l ;
val[l][0] = l ;
val[l][1] = l ;
return void() ;
}
rt = ++ Id ;
int mid = (l + r) >> 1 ;
build(lc[rt], l, mid) ;
build(rc[rt], mid + 1, r) ;
add(rt, lc[rt]) ; add(rt, rc[rt]) ;
val[rt][0] = val[lc[rt]][0] ;
val[rt][1] = val[rc[rt]][1] ;
}
void update(int rt, int l, int r, int ul ,int ur, int f){
int mid = (l + r) >> 1 ;
if (l >= ul && r <= ur)
return f == rt ? void() : add(f, rt) ;
if (ul <= mid) update(lc[rt], l, mid, ul, ur, f) ;
if (ur > mid) update(rc[rt], mid + 1, r, ul, ur, f) ;
}
void tarjan(int x){
s[++ tp] = x, vis[x] = 1 ;
dfn[x] = low[x] = ++ tot ;
for (int k = head[x] ; k ; k = next(k)){
if (!dfn[to(k)])
tarjan(to(k)),
low[x] = min(low[x], low[to(k)]) ;
else if (vis[to(k)])
low[x] = min(low[x], dfn[to(k)]) ;
}
if (low[x] == dfn[x]){
int now ;
L[++ SCC] = n + 1 ;
// cout << x << endl ;
while (tp){
now = s[tp --] ;
L[SCC] = min(L[SCC], val[now][0]) ;
R[SCC] = max(R[SCC], val[now][1]) ;
vis[now] = 0, blg[now] = SCC ;
if (now == x) break ;
}
}
}
int main(){
cin >> n ;
int lq, rq ; LL xx, oo ;
Id = n, build(root, 1, n) ;
//cout << cnt << endl ;
// for (int i = 1 ; i <= cnt ; ++ i)
// cout << fr(i) << " " << to(i) << endl ;
for (int i = 1 ; i <= n ; ++ i)
scanf("%lld%lld", &x[i], &r[i]) ;
for (int i = 1 ; i <= n ; ++ i){
xx = x[i], oo = r[i] ;
lq = lower_bound(x + 1, x + n + 1, xx - oo) - x ;
rq = upper_bound(x + 1, x + n + 1, xx + oo) - x ;
rq -- ; if (lq <= rq) update(root, 1, n, lq, rq, i) ;
// cout << lq << " " << rq << endl ;
}
// puts("") ;
// cout << root << endl ;
tarjan(root) ;
// cout << cnt << " " << SCC << endl ;
tot = cnt, cnt = 0 ;
// for (int i = 1 ; i <= n ; ++ i) cout << blg[i] << " " ;
// puts("") ;
memset(head, 0, sizeof(head)) ;
for (int i = 1 ; i <= tot ; ++ i)
if (blg[fr(i)] != blg[to(i)])
add(blg[to(i)], blg[fr(i)]), deg[blg[fr(i)]] ++ ;
for (int i = 1 ; i <= SCC ; ++ i) if (!deg[i]) q.push(i) ;
while (!q.empty()){
oo = q.front() ; q.pop() ;
for (int k = head[oo] ; k ; k = next(k)){
L[to(k)] = min(L[to(k)], L[oo]) ;
R[to(k)] = max(R[to(k)], R[oo]) ;
if (!-- deg[to(k)]) q.push(to(k)) ;
}
}
for (int i = 1 ; i <= n ; ++ i){
(ans += 1ll * i * (R[blg[i]] - L[blg[i]] + 1) % P) %= P ;
//cout << L[blg[i]] << " " << R[blg[i]] << endl ;
}
cout << ans << endl ; return 0 ;
}