【学习笔记&题解】从笛卡尔树到[SPOJ 3734] Periodni

笛卡尔树的性质比较单一,但是应用起来也是很有趣的。

笛卡尔树学习笔记

大概就是类似 treap 的一个东西。首先笛卡尔树是二叉搜索树,即 $lc,root,rc$ 的键值满足单调的偏序关系。其次笛卡尔树每个节点也有一个权值,这个权值满足堆性质。

然后在保证了键值单调时,可以 $O(n)$ 建树,方式是维护最右链。假设键值单调递增,那么考虑每次加入一个新节点 $x$,要把栈里面所有 $val_y>val_x$ 的点都弹出,然后把当前栈顶元素 $z$ 的右儿子设置为 $x$ ,$x$ 的左儿子设置为原来 $z$ 的右儿子。

以下是洛谷 P5854 笛卡尔树。不知道为什么,我写笛卡尔树很容易忘了最后入栈233

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int main(){
scanf("%d", &n) ;
stk[tp = 1] = 0 ; base[0] = -I ;
for (int i = 1 ; i <= n ; ++ i) base[i] = qr() ;
for (int i = 1 ; i <= n ; ++ i){
while (tp && base[stk[tp]] > base[i]) -- tp ;
fa[i] = stk[tp] ; int x = rc[fa[i]] ;
fa[x] = i, lc[i] = x, rc[fa[i]] = i ;
// for (int j = 1 ; j <= n ; ++ j)
// printf("%d %d\n", lc[j], rc[j]) ;
// printf("%d\n", x) ;
// puts("- - - - - - - - - - - - - -") ;
stk[++ tp] = i ;//!
}
for (int i = 1 ; i <= n ; ++ i)
ans1 ^= (1ll * (lc[i] + 1) * i) ;
for (int i = 1 ; i <= n ; ++ i)
ans2 ^= (1ll * (rc[i] + 1) * i) ;
printf("%lld %lld\n", ans1, ans2) ;
}

[SPOJ 3734] Periodni

给定一个 $n$ 列的表格,每列的高度各不相同,但底部对齐,然后向表格中填入 $k$ 个相同的数,填写时要求不能有两个数在同一列,或同一行。

注意,如果两个同一行的点之间有空白,那么我们不认为他们在同一行。

$1\le n,k\le 500$ 。

感觉这题是不是如果看不出来要建笛卡尔树,人就直接没了啊…

大概就是说按照高度为权值,序号为键值建笛卡尔树。发现这样的树有着良好的性质,就是对于一个点 $x$ 不包含自己的子树,他们在 $h_x$ 以上的那部分高度互相不影响,且如果保证不自选的话一定不会出现在同一列的情况。

然后就转化成了从子树内选择 $k$ 个无序点的方案数了。考虑树上背包。 $f_{i,v}$ 表示以 $i$ 为根,合法地选了 $v$ 个点的方案数。那么先考虑子树内的转移,就比较朴素。这部分可以随便做,因为并不关心是 $O(nk)$ 还是 $O(nk^2)$ 。然后考虑选定 $i$ 时的方案数,也就是算上 $\leq h_x$ 那一块小矩形的方案数。这个地方大概就是要考察一点我没有的组合能力了。考虑从一个 $n$ 行 $m$ 列的矩形里选出 $k$ 个不同行、不同列的点的方案数,本质上就是选 $k$ 个行号、$k$ 个列号,然后对行号和列号做个匹配,可以知道方案数是

那么同理,如果要算考虑 $x$ 的矩形的方案数,就是在算一个有 $h_x$ 行,$sz_x-(i-j)$ 列的矩形的方案数,就是 $c(h_x,sz_x-(i-j),k)$ 。

实现上有一点下传高度的小细节。

其实这题本质上是分类讨论了一下放的情况,即在 $h_x$ 的高度上面放和在 $h_x$ 的高度下面(含)放。是一道不错的计数题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <cstdio>
#include <vector>
#include <iostream>

using namespace std ;

typedef long long ll ;

const int N = 1050 ;

const int M = 1000050 ;

const int P = 1000000007 ;

int rt ;
int tp ;
int n, k ;
int fa[N] ;
int lc[N] ;
int rc[N] ;
int sz[N] ;
int fac[M] ;
int inv[M] ;
int stk[N] ;
int tmp[N] ;
int base[N] ;
int f[N][N] ;

vector <int> E[N] ;

inline int qr(){
int r = 0, f = 1 ;
char c = getchar() ;
while (c > '9' || c < '0'){
if (c == '-') f = -1 ; c = getchar() ;
}
while (c <= '9' && c >= '0'){
r = (r << 1) + (r << 3) + c - '0', c = getchar() ;
}
return r * f ;
}

template <typename T>
inline void add(T &x, ll y, int mod = P){
x += y ; x = x >= mod ? x - mod : x ;
}
template <typename T>
inline T addn(T x, ll y, int mod = P){
x += y ; return (x = x > mod ? x - mod : x) ;
}
ll expow(ll x, ll y){
ll ret = 1 ;
while (y){
if (y & 1)
ret = ret * x % P ;
x = x * x % P ; y >>= 1 ;
}
return ret ;
}
int comb(int a, int b){
return 1ll * fac[a] * inv[a - b] % P * inv[b] % P ;
}
inline int calc(int x, int y, int z){
if (x < 0 || y < 0) return 0 ;
return 1ll * comb(x, z) * comb(y, z) % P * fac[z] % P ;
}
void do_do(int x, int h){
sz[x] = 1 ;
f[x][0] = 1 ;
// printf("%d %d %d\n", k, x, h) ;
for (auto y : E[x]){
do_do(y, base[x]), sz[x] += sz[y] ;
for (int i = 0 ; i <= k ; ++ i)
for (int j = 0 ; j <= i ; ++ j)
add(tmp[i], 1ll * f[x][i - j] * f[y][j] % P) ;
for (int i = 0 ; i <= k ; ++ i)
f[x][i] = tmp[i], tmp[i] = 0 ;
}
int nh = base[x] - h ;
// printf("%d %d %d %d\n", k, x, h, nh) ;
for (int i = 0 ; i <= min(sz[x], k) ; ++ i)
for (int j = 0 ; j <= min(nh, i) ; ++ j)
add(tmp[i], 1ll * f[x][i - j] * calc(sz[x] - i + j, nh, j) % P) ;
// printf("%d %d %d\n", k, x, h) ;
for (int i = 0 ; i <= k ; ++ i)
f[x][i] = tmp[i], tmp[i] = 0 ;
}
int main(){
// freopen("1.in", "r", stdin) ;
cin >> n >> k ; fac[0] = 1 ;
stk[tp = 1] = 0 ; base[0] = -1 ;
for (int i = 1 ; i < M ; ++ i)
fac[i] = 1ll * fac[i - 1] * i % P ;
inv[M - 1] = expow(fac[M - 1], P - 2) ;
for (int i = M - 2 ; i >= 0 ; -- i)
inv[i] = 1ll * inv[i + 1] * (i + 1) % P ;
for (int i = 1 ; i <= n ; ++ i) base[i] = qr() ;
for (int i = 1 ; i <= n ; ++ i){
while (tp && base[stk[tp]] > base[i]) -- tp ;
fa[i] = stk[tp] ; int j = rc[stk[tp]] ; stk[++ tp] = i ;
fa[j] = i ; lc[i] = j ; rc[fa[i]] = i ;
}
for (int i = 1 ; i <= n ; ++ i){
E[fa[i]].push_back(i) ;
if (!fa[i]) rt = i ;
}
do_do(rt, 0) ; printf("%d\n", f[rt][k]) ; return 0 ;
}