【学习笔记】树上莫队

嗯…好…序列数据结构就是用来上树的嘛/kk

发现首先树上分块很好处理,比较麻烦的一点在于,如何从 $(ou,ov)$ 这条路径移动到 $(nu,nv)$。

发现莫队的精髓在于排完序之后左右端点的移动是独立的,所以考虑如何转化成 $(ou,nu)$ 和 $(ov,nv)$。

$1$ 瞎构造

蜜汁more.jpg

有一个很常见的定理,就是「树上两条路径如果有交点,那么一定是在其中一条路径两端点的 $lca$ 」。

证明暂时不知道(但是我有在思考)什么比较优美的证明。但遇到这种情况总是可以分类讨论的…略了略了

于是为了不处理本来就有交点的情况,莫队上树选择维护 $p(u,v)=(u,v)\setminus\{lca(u,v)\}$ 。发现这东西有很好的性质,即

其中 $~\mathrm{xor}~$ 定义在集合上。

证明大概是考虑令 $r(x)$ 表示 $x$ 到根节点的路径的点集。那么就可以这么化式子:

然后接下来懒得写了,发现就是个交换律+结合律的套路。

于是考虑每次把 $p(ou,nu)$ 和 $p(ov, nv)$ 内的点全部取反,单独处理一下 $lca$ 就好了。

$2$ 如何分块

发现就是 loj#2151王室联邦 这题的分块方式。似乎这东西证明的话可以直接归纳出来?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void dfs(int u, int f) {
int base = tp;
for (int k = head[u]; k; k = next(k)) {
if (to(k) == f)
continue;
dfs(to(k), u);
if (tp - base >= S) {
cap[++Id] = u;
while (tp > base)
blg[s[tp--]] = Id;
}
}
s[++tp] = u;
}
int main() {
cin >> N >> S;
int i, u, v;
for (i = 1; i < N; ++i) scanf("%d%d", &u, &v), add(u, v);
dfs(1, 0);
while (tp) blg[s[tp--]] = Id;
cout << Id << endl;
for (i = 1; i <= N; ++i) cout << blg[i] << " ";
puts("");
for (i = 1; i <= Id; ++i) cout << cap[i] << " ";
return 0;
}

$3$ 例题

$1$ SP10707 Count On a Tree 2

给定一个 $n$ 个节点的树,每个节点表示一个整数。

$q$ 组询问,询问 $u$ 到 $v$ 的路径上有多少个不同的整数。

套路题?不禁思考我都做了些什么浪费时间的题啊…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
const int N = 100010 ;
const int M = 200010 ;

struct Edge{
int to ;
int next ;
}E[M] ;
struct qss{
int id ;
int u, v ;
}qs[M] ;
int res ;
int ans[N] ;
int buc[N] ;
int son[N] ;
int blg[N] ;
bool vis[N] ;
int b, bnum ;
int n, q, len ;
int stk[N], tp ;
int head[N], cnt ;
int top[N], sz[N] ;
int dep[N], fa[N] ;
int base[N], t[N] ;

void add(int u, int v){
to(++ cnt) = v, next(cnt) = head[u], head[u] = cnt ;
to(++ cnt) = u, next(cnt) = head[v], head[v] = cnt ;
}
void dfs(int u, int f){
//cout << u << endl ;
int old = tp ; sz[u] = 1 ;
fa[u] = f, dep[u] = dep[f] + 1 ;
for (int k = head[u] ; k ; k = next(k)){
if (to(k) != f){
dfs(to(k), u), sz[u] += sz[to(k)] ;
if (!son[u] || sz[son[u]] < sz[to(k)]) son[u] = to(k) ;
if (tp - old >= b){
++ bnum ;
while (tp > old)
blg[stk[tp --]] = bnum ;
}
}

}
stk[++ tp] = u ;
}
void dfs(int u, int f, int tp){
top[u] = tp ;
if (son[u]) dfs(son[u], u, tp) ;
for (int k = head[u] ; k ; k = next(k))
if (to(k) != f && to(k) != son[u]) dfs(to(k), u, to(k)) ;
}
bool comp(qss a, qss b){
return blg[a.u] == blg[b.u] ? blg[a.v] < blg[b.v] : blg[a.u] < blg[b.u] ;
}
int lca(int u, int v){
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]])
swap(u, v) ; u = fa[top[u]] ;
}
return dep[u] < dep[v] ? u : v ;
}
void rev(int x){
if (!vis[x])
res += !buc[base[x]] ++ ;
else res -= !-- buc[base[x]] ;
vis[x] ^= 1 ;
}
void movemove(int u, int v){
if (dep[u] < dep[v]) swap(u, v) ;
while (dep[u] > dep[v]) rev(u), u = fa[u] ;
while (u != v) rev(u), rev(v), u = fa[u], v = fa[v] ;
}
int main(){
cin >> n >> q ;
int u, v, f ; b = sqrt(n) ;
for (int i = 1 ; i <= n ; ++ i)
scanf("%d", &base[i]), t[i] = base[i] ;
sort(t + 1, t + n + 1) ;
len = unique(t + 1, t + n + 1) - t - 1 ;
for (int i = 1 ; i <= n ; ++ i)
base[i] = lower_bound(t + 1, t + n + 1, base[i]) - t ;
for (int i = 1 ; i < n ; ++ i)
scanf("%d%d", &u, &v), add(u, v) ;
dfs(1, 0) ; dfs(1, 0, 0) ; bnum ++ ;
while (tp) blg[stk[tp --]] = bnum ;
//for (int i = 1 ; i <= n ; ++ i) cout << top[i] << endl ;
for (int i = 1 ; i <= q ; ++ i)
scanf("%d%d", &qs[i].u, &qs[i].v), qs[i].id = i ;
//cout << lca(5, 8) << " " << lca(2, 7) << " " << lca(3, 7) << endl ;
sort(qs + 1, qs + q + 1, comp) ; u = v = 1 ;
for (int i = 1 ; i <= q ; ++ i){
movemove(u, qs[i].u) ; u = qs[i].u ;
movemove(v, qs[i].v) ; v = qs[i].v ;
//cout << u << " " << v << " " << res << endl ;
f = lca(u, v) ; rev(f) ;
ans[qs[i].id] = res ; rev(f) ;
}
for (int i = 1 ; i <= q ; ++ i)
printf ("%d\n", ans[i]) ; return 0 ;
}

btw,感觉自己码风好多了?

$2$ uoj#53 [WC2013]糖果公园

给出一棵 $n$ 个点的树,每个节点有一个颜色。

每次或者询问你一条路径求 $\sum_{c}val_c\sum_{i=1}^{cnt_c}w_i$,或者更改一个点的颜色。

其中 $val$ 表示该颜色的价值, $cnt$ 表示其出现的次数, $w_i$ 表示第 $i$ 次出现的价值。

可以离线。

发现就是树上莫队套了一个带修?

写起来没什么感觉…就是很长…很长…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>

using namespace std ;

typedef long long LL ;

#define to(k) E[k].to
#define next(k) E[k].next

const int N = 200010 ;
const int M = 200010 ;
const int Q = 200010 ;

struct Edge{
int to, next ;
}E[N << 1] ;
struct uv{
int p, c ;
}us[Q] ;
struct qs{
int u, v ;
int t, id ;
}qy[Q] ;
LL res ;
LL ans[Q] ;
int n, m, q ;
int cntu, cntq ;
int stk[N], tp ;
int head[N], cnt ;
int sz[N], son[N] ;
int buc[N], vis[N] ;
int b, bnum, blg[N] ;
int dep[N], fa[N], top[N] ;
int val[M], wth[M], clr[N] ;

void dfs(int u, int f){
//cout << u << endl ;
dep[u] = dep[f] + 1 ;
sz[u] = 1, fa[u] = f ; int t = tp ;
for (int k = head[u] ; k ; k = next(k)){
if (to(k) == f) continue ;
dfs(to(k), u) ; sz[u] += sz[to(k)] ;
if (!son[u] || sz[son[u]] < sz[to(k)]) son[u] = to(k) ;
if (tp - t >= b){
++ bnum ;
while (tp > t)
blg[stk[tp --]] = bnum ;
}
}
stk[++ tp] = u ;
}
void add(int u, int v){
to(++ cnt) = v, next(cnt) = head[u], head[u] = cnt ;
to(++ cnt) = u, next(cnt) = head[v], head[v] = cnt ;
}
void dfs(int u, int f, int tp){
top[u] = tp ;
if (son[u]) dfs(son[u], u, tp) ;
for (int k = head[u] ; k ; k = next(k))
if (to(k) != son[u] && to(k) != f) dfs(to(k), u, to(k)) ;
}
int lca(int u, int v){
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]])
u ^= v ^= u ^= v ; u = fa[top[u]] ;
}
return dep[u] < dep[v] ? u : v ;
}
bool comp(qs a, qs c){
return (blg[a.u] == blg[c.u]) ?
((blg[a.v] == blg[c.v]) ? a.t < c.t : blg[a.v] < blg[c.v]) : (blg[a.u] < blg[c.u]) ;
}
void rev(int x){
if (vis[x])
res -= 1ll * wth[buc[clr[x]] --] * val[clr[x]] ;
else
res += 1ll * wth[++ buc[clr[x]]] * val[clr[x]] ;
vis[x] ^= 1 ;
}
void movemove(int x, int y){
if (dep[x] < dep[y]) swap(x, y) ;
while (dep[x] > dep[y]) rev(x), x = fa[x] ;
while (x != y) rev(x), rev(y), x = fa[x], y = fa[y] ;
}
int main(){
// freopen("in.in", "r", stdin) ;
cin >> n >> m >> q ;
int u, v, f, l, r, mk, t ;
for (int i = 1 ; i <= m ; ++ i) scanf("%d", &val[i]) ;
for (int i = 1 ; i <= n ; ++ i) scanf("%d", &wth[i]) ;
for (int i = 1 ; i < n ; ++ i) scanf("%d%d", &u, &v), add(u, v) ;
for (int i = 1 ; i <= n ; ++ i) scanf("%d", &clr[i]) ;
//for (int i = 1 ; i <= cnt ; ++ i) cout << E[i].to << endl ;
b = pow(1.0 * n, 0.6667) ;
dfs(1, 0) ; ++ bnum ;
while (tp) blg[stk[tp --]] = bnum ;dfs(1, 0, 1) ;
u = 1, v = 1, t = 0 ;
// for (int i = 1 ; i <= n ; ++ i) cout << top[i] << endl ;
// cout << lca(1, 3) << " " << lca(2,4) << " " << lca(1, 4) << endl ;
// for (int i = 1 ; i <= n ; ++ i) cout << fa[i] << " " ;
for (int i = 1 ; i <= q ; ++ i){
scanf("%d%d%d", &mk, &l, &r) ;
if (mk)
qy[++ cntq].u = l, qy[cntq].v = r,
qy[cntq].t = cntu, qy[cntq].id = cntq ;
else
us[++ cntu].p = l, us[cntu].c = r ;
}
sort(qy + 1, qy + cntq + 1, comp) ;
for (int i = 1 ; i <= cntq ; ++ i){
movemove(u, qy[i].u) ;
movemove(v, qy[i].v) ;
u = qy[i].u, v = qy[i].v ;
while (t < qy[i].t){
++ t ;
if (vis[us[t].p])
rev(us[t].p),
swap(clr[us[t].p], us[t].c),
rev(us[t].p) ;
else swap(clr[us[t].p], us[t].c) ;
}
while (t > qy[i].t){
if (vis[us[t].p])
rev(us[t].p),
swap(clr[us[t].p], us[t].c),
rev(us[t].p) ;
else swap(clr[us[t].p], us[t].c) ;
t -- ;
}
f = lca(u, v) ; rev(f) ;
ans[qy[i].id] = res ; rev(f) ;
}
for (int i = 1 ; i <= cntq ; ++ i)
printf("%lld\n", ans[i]) ; return 0 ;
}