【学习笔记/题解】AC自动机瞎吹

AC自动机算是自动机系列比较早期引入OI的算法,全称为「$\rm Aho-Corasick ~Automaton$」。

然而其实这款自动机是以两个科学家的名字命名的。(这「款」?

并且上面这一条信息还是我从有道词典里查出来的,百度根本百度不到好吗 QAQ

不过话说这东西也是好久之前学的,于是这篇 blog 实际上是在炒冷饭。

$1$ 原理

原理还是很简单的,就是把 KMP 的一半放到了 Trie 上。那找 fail 就肯定要 bfs 了对吧。根据 Trie 的性质,一开始肯定是要把根节点所有非空孩子当作起始状态。大概是这样:

1
2
3
4
5
6
7
8
9
10
11
12
void bfs(){
int i ;
for (i = 0 ; i < 26 ; ++ i){
if (tr[0][i]) q.push(tr[0][i]) ;
while (!q.empty()){
int now = q.front() ; q.pop() ;
for (i = 0 ; i < 26 ; ++ i)
if (!tr[now][i]) tr[now][i] = tr[fail[now]][i] ;
else fail[tr[now][i]] = tr[fail[now]][i], q.push(tr[now][i]) ;
}
}
}

考虑求 $fail$ 的时候,如果没有当前的孩子,执行一个路径压缩。否则直接转移就可以了。

关于 $fail$ 的意义,一个点的 $fail$ 指向是其他模式串和他的最长的公共后缀,并且和这个最长公共后缀必然从根开始,不同于 KMP 的 border 。

大概AC自动机就这点知识吧?

$2$ 水题杀手

嗯,这个标题很炫酷。

1 LG3808 【模板】AC自动机(简单版)

给定 $n$ 个模式串和 $1$ 个文本串,求有多少个模式串在文本串里出现过。

建 Trie 时记录一下每个串的 $endpos$ 就完了。

2 LG3796 【模板】AC自动机(加强版)

有 $\rm N$ 个由小写字母组成的模式串以及一个文本串 $\rm T$。

每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串 $\rm T$ 中出现的次数最多。

嗯,这个其实是个弱化版本。考虑记录最暴力的解法大概就是每匹配到一个点,就不断向上跳 $fail$ 去找有哪些点打上了 $endpos$ 标记,每次匹配的时候不断向上跳,遇到 $endpos$ 就 $ ++$,最后再扫一遍。

当然其实这个地方存在一个剪枝。完全可以记录离现在这个点最近的一个有 $endpos$ 的 $fail$。推的方式跟推 $fail$ 大同小异。

大概可以把 $fail$ 理解为一阶失配指针,$last$ 为二阶的,所以理所应当在一阶上面跑。

1
2
3
4
5
6
7
8
9
10
11
12
void BFS(){
for (i = 0 ; i < 26 ; ++ i)
if(Trie[0][i]) q.push(Trie[0][i]) ; //1
while (!q.empty()){
int now = q.front() ; q.pop() ;
for (i = 0 ; i < 26 ; ++ i){
if (!v(now)) Trie[now][i] = Trie[fail[now]][i] ;
else fail[v(now)] = Trie[fail[now]][i], q.push(v(now)),
last[v(now)] = Send[fail[v(now)]] ? fail[v(now)] : last[fail[v(now)]] ;
}
}
}

3 LG5357 【模板】AC自动机(二次加强版)

有 $\rm N$ 个由小写字母组成的模式串以及一个文本串 $\rm T$。

每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串 $\rm T$ 中出现的次数最多。

en,好久之前rqy就有讲过这个问题。其实上面那题无论剪不剪枝,最坏复杂度都是 $O(n^2)$ 的,只要这么构造就可以:

1
2
3
4
5
6
a
aa
aaa
aaaa
........(省略)
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa

于是考虑这东西似乎正着算不好算,于是考虑倒着计算贡献。发现原来是个 $dp$ 的过程。

大概就是,在AC自动机的fail树上,自己的孩子一定和自己有着相同后缀。所以只需要按照fail边建树,然后算一遍子树大小。之前需要记录对应位置,最后输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#define MAXD 60010
#define MAXP 201000
#define MAXN 2000200

using namespace std ;
int N ; char In[MAXP], S[MAXN] ;

namespace Hash{
int f[MAXN] ;
typedef unsigned long long ull ;
map <ull, int> htble ;
const ull base = 137 ;
ull get_h(char *s){
int l = strlen(s) ;
ull ret = (s[0] ^ 32768) + 3 ;
for (int i = 0 ; i < l ; ++ i)
ret = ret * base + (ull)s[i] ;
return ret ;
}
}
using namespace Hash ;

namespace Graph{
struct Edge{
int to, next ;
}E[MAXP << 1] ; int cnt1, head[MAXP] ;
void add(int u, int v){
E[++ cnt1].to = v, E[cnt1].next = head[u], head[u] = cnt1 ;
E[++ cnt1].to = u, E[cnt1].next = head[v], head[v] = cnt1 ;
}
#define to(k) E[k].to
}
using namespace Graph ;

namespace AC{
int cnt, e[MAXP], last[MAXP], ans ;
int tr[MAXP][27], res[MAXP], Id[MAXP], fail[MAXP] ;
void Init(){
ans = 0 ;
memset(e, 0, sizeof(e)) ;
memset(Id, 0, sizeof(Id)) ;
memset(res, 0, sizeof(res)) ;
memset(fail, 0, sizeof(fail)) ;
memset(last, 0, sizeof(last)) ;
memset(tr, 0, sizeof(tr)), cnt = 0 ;
}
void insert(char *p){
int rt = 0, l = strlen(p) ;
for (int i = 0 ; i < l ; ++ i){
int now = p[i] - 'a' ;
if (!tr[rt][now])
tr[rt][now] = ++ cnt ;
rt = tr[rt][now] ;
}
++ e[rt] ;
}
void idins(char *p, int ID){
int rt = 0, l = strlen(p) ;
for (int i = 0 ; i < l ; ++ i){
int now = p[i] - 'a' ;
if (!tr[rt][now])
tr[rt][now] = ++ cnt ;
rt = tr[rt][now] ;
}
e[rt] = ID, Id[ID] = rt ;
}
int sz[MAXN] ;
queue <int> q ;
void bfs(){
int i ;
for (i = 0 ; i < 26 ; ++ i)
if (tr[0][i]) q.push(tr[0][i]) ; // 1
while (!q.empty()){
#define qwq tr[now][i]
int now = q.front() ; q.pop() ;
// cout << q.size() << endl ;
for (i = 0 ; i < 26 ; ++ i){
if (!qwq) qwq = tr[fail[now]][i] ;
else fail[qwq] = tr[fail[now]][i], q.push(qwq) ;
if (e[fail[qwq]]) last[qwq] = fail[qwq] ;
else last[qwq] = last[fail[qwq]] ;
}
}
}
int work1(char *p){
int rt = 0, ret = 0, i, l = strlen(p), q ;
for (i = 0 ; i < l ; ++ i){
rt = tr[rt][p[i] - 'a'], q = rt ;
while (q && (~e[q]))
ret += e[q], e[q] = -1, q = fail[q] ;
}
return ret ;
}
void work2(char *p){
int rt = 0, ret = 0, i, l = strlen(p), q ;
for (i = 0 ; i < l ; ++ i){
rt = tr[rt][p[i] - 'a'], q = rt ;
while (q) { if (e[q]) res[e[q]] ++ ; q = last[q] ; }
}
}
void work3(char *p){
int rt = 0, i, l = strlen(p) ;
for (i = 0 ; i < l ; ++ i)
rt = tr[rt][p[i] - 'a'], sz[rt] ++ ;
for (i = 1 ; i <= cnt ; ++ i) add(i, fail[i]) ;
}
void dfs(int u, int v){
for (int k = head[u] ; k ; k = E[k].next){
if (to(k) == v) continue ;
dfs(to(k), u), sz[u] += sz[to(k)] ;
}
}
}
using namespace AC ;
int main(){
int i, ans ; cin >> N ;
for (i = 1 ; i <= N ; ++ i) f[i] = i ;
for (i = 1 ; i <= N ; ++ i){
scanf("%s", In) ;
ull p = get_h(In) ;
if (!htble.count(p))
idins(In, i), htble[p] = i ;
else f[i] = htble[p] ;
}
bfs() ; scanf("%s", S) ; work3(S) ; dfs(0, -1) ;
for (i = 1 ; i <= N ; ++ i) printf("%d\n", sz[Id[f[i]]]) ;
return 0 ;
}

唉,学了好几遍才真正学明白 $fail$ 指针是个什么东西,wtcl

btw,感觉AC自动机如果真要起名的话,叫后缀自动机也可以啊(

$4$ 例题

[$\rm POI2000$] 病毒

给定一堆 $01$ 串,判断是否有一个无限长的 $01$ 串使得不包含任何一个给出 $01$ 串作为其子串。

考虑怎样才能算是无限长的安全代码。

先给出结论:对每个存在病毒代码的状态打上标记。那么如果完整的trie图里面存在一个不带标记的圈,就会有无限长的安全代码。此处完整指的是,对于给定一棵trie树,补全他的所有叶子。

可知这是显然的,因为本质上只需要考虑到深度=最长的病毒串长度这一步就ok了,剩下都可以循环生成。

注意到本质上 trie 不用显式地建出来,于是最后 $dfs$ 的时候统计一下即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
struct ACAM{
int fail[MAXS], vis[MAXS], sz ;
int _ed[MAXS], trans[MAXS][2], ans[MAXS] ;
void Ins(char *t){
int rt = 0 ; N = strlen(t + 1) ;
for (int i = 1 ; i <= N ; ++ i){
t[i] -= '0' ;
if (!trans[rt][t[i]])
trans[rt][t[i]] = ++ sz ;
rt = trans[rt][t[i]] ;
}
_ed[rt] ++ ;
}
void build(){
int i, n ;
for (i = 0 ; i <= 1 ; ++ i)
if (trans[0][i]) q.push(trans[0][i]) ;
while (!q.empty()){
n = q.front() ; q.pop() ;
for (i = 0 ; i <= 1 ; ++ i){
if (!trans[n][i])
trans[n][i] = trans[fail[n]][i] ;
else
fail[trans[n][i]] = trans[fail[n]][i],
_ed[trans[n][i]] |= _ed[fail[trans[n][i]]], q.push(trans[n][i]) ;
}
}
}
void dfs(int x){
if (ans[x])
puts("TAK"), exit(0) ;
if (_ed[x] || vis[x]) return ;
// cout << x << endl ;
vis[x] = ans[x] = 1 ;
for (int i = 0 ; i <= 1 ; ++ i)
//if (trans[x][i]) dfs(trans[x][i]) ;
dfs(trans[x][i]) ;
ans[x] = 0 ;
}
}A ;

int main(){
cin >> T ;
while (T --)
cin >> (s + 1), A.Ins(s) ;
A.build() ; A.dfs(0) ; return puts("NIE"), 0 ;
}

注意到代码中 $dfs$ 里面注释的一行,显然是错的,可以被:

1
2
3
2
1
11

这组数据直接给卡掉。但是数据太水就给过掉了…

…说实在的我本来想尝试去证明为什么那样错的写是对的,尝试努力去编出一个合理解释,到最后才发现是自己原来的 $code$ 错了QAQ