本文探讨AC自动机的相关内容,注意AC自动机并不能自动让你AC。
        UPD:更新两道例题。

        AC自动机在1975年诞生于贝尔实验室的多模匹配算法。之前的文章曾介绍过KMP匹配算法,它用于单模式串匹配。对于多模式串匹配,我们当然可以多次用KMP匹配算法去解决,但那样时间复杂度为$O(m+kn))$,k为模式串数量,比较低效。而在多模匹配中应用AC自动机可以达到$O(m+n)$,m为所有模式串长度之和。

有限状态自动机

        这里只粗略地谈谈自动机这个东西。
        在有限状态自动机(Deterministic finite automation, DFA)中,状态有有限个,它们彼此用有向边连接,表示状态的转移。当一个动作发生时,自动机会从一个状态转移到另一个状态,并做出相应的反应,这里的AC自动机便是一种有限状态自动机。

Trie树

        AC自动机的结构是一棵trie树,如下图所示:

        trie树又称字典树,常用于快速求字典序。在trie树中,根结点不记录字符串信息,所有的字符串信息只记录在子结点中。上图所示的trie树表示记录了he、is、she、shr四个字符串。
        如果需要同时记录she和sh两个字符串怎么办?trie树上每一个结点都拥有一个指标变量,它记录这个结点对应的字符是否表示有个字符串的结尾,这样在相应结点处打上标记就可以同时表示she和sh。

AC自动机

        下面介绍AC自动机进行的过程。
        首先,我们需要将所有模式串加入trie树,并做每一个串到对应结点的标记。如果字符串中只有小写字母,那么我们可以这样建立trie树:

1
2
3
4
5
6
7
8
9
int tr[N][26], ct = 1;//N是结点预估数量
void insert(const char *p, int j) {//将编号为j的模式串p加入到trie树中
int i = 0, now = 1;
while (p[i]) {
if (tr[now][p[i] - 'a'] == 0)tr[now][p[i] - 'a'] = ++ct;//没有就新建
now = tr[now][p[i] - 'a'], i++;//有就继续向下走
}
to[j] = now;//记录末尾编号
}

        建完trie树后,我们可以这样理解trie树:树上的每一个结点都对应一种状态,它表示到根结点的字符均匹配,这样trie树可以看成是一个状态机的雏形。
        假如当前匹配到了某个结点,它表示的字符为’a’,并且下一个字符需要匹配’b’,而该结点正好有对应’b’的子结点,那此时状态直接转移到对应’b’的子结点即可。但是如果不存在对应’b’的子结点呢?如何转移?这就需要引入失配指针的概念。
        当某个结点发生失配时(指它不存在满足匹配的子结点),应该由这个状态转移至另一个状态,描述这个失配转移关系的指针称为失配指针,它是AC自动机中的重要概念。和KMP算法类似,失配指针应该指向与它目前匹配子串拥有最长相同前后缀的子串的末尾结点。
        首先考虑在根结点如何转移。对于一个字符,如果根结点存在对应这个字符的结点,转移到该结点即可;而对于根结点不存在对应子结点的情况(比如上文图中根结点不存在表示’a’的子结点),也应该规定一个转移方向,这时应当转移至根,因为没有可以匹配的模式串。这一步的示例代码如下:
1
2
3
4
for (int i = 0; i < 26; i++) {//规定1号结点为根
if (tree[1][i])fail[tree[1][i]] = 1, que.push(tree[1][i]);//fail记录失配指向编号,que是队列
else tree[1][i] = 1;//转移到根,直接记在trie树上
}

        其实也可以这样做:
1
2
for (int i = 0; i < 26; i++)tree[0][i] = 1;
que.push(1);

        这里将结点加入队列,以进行之后的求失配指针的过程。在这里,我们不断从队列中取点,然后遍历它的所有子结点,对于它存在的子结点,子结点的失配指针指向为其父结点(就是当前取出的结点)失配指针指向结点相对应的子结点;对于不存在的结点,转移方向也是父结点失配指针指向结点相对应的子结点。这个过程可以感性地理解一下,它其实是对父结点失配后子串的延伸。
1
2
3
4
5
6
7
8
while (!que.empty()) {
int p = que.front();
que.pop();
for (int i = 0; i < 26; i++) {
if (tree[p][i])fail[tree[p][i]] = tree[fail[p]][i], que.push(tree[p][i]);
else tree[p][i] = tree[fail[p]][i];
}
}

        之后就是匹配的过程,十分简洁,不断转移状态即可。途中需要记录每一个状态被访问到的次数。
1
for (int i = 0, now = 1; str[i]; ++i)now = tree[now][str[i] - 'a'], ++vis[now];//str是主串,vis记录次数

        那现在如何求每一个模式串匹配的次数呢?这就需要fail树。

fail树

        对于trie树上的每一个非根结点,单独建立一张fail[x]->x的图,容易证明这个图是一棵树,称为fail树。fail树满足一个性质:对于fail树的一个子树,若其子树上某一个结点的状态成立,则子树根结点状态必成立。这样建立fail树,然后进行一遍DFS,求出子树上状态vis的值之和即可。
        这里我们需要在一种更高的角度去理解fail树。在fail树上,以某个结点为根的子树,其上的结点都有什么特点?
        fail树上的任何一个结点都代表着一个前缀,对于任何一个结点,其一定是其子树上所有结点的公共后缀。由于子串就是前缀的后缀,故在fail树上,以某个结点为根的子树上的结点就是包含这个结点对应前缀的所有字符串前缀。

1
2
3
4
5
6
7
8
inline void add(int x, int y) {
static int cnt = 1;
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

void DFS(int x) {
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to), vis[x] += vis[edge[i].to];
}

代码模板

        洛谷上三道模板题。

模板一

        很水,直接计数判是否为0即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>

using namespace std;
char str[1000005], op[1000005];
int vis[1000005], to[1000005], ct = 1, fail[1000005], tr[1000005][26];

struct Edge {
int next, to;
} edge[1000005];
int head[1000005];

inline void add(int x, int y) {
static int cnt = 1;
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

queue<int> que;

void insert(const char *p, int j) {
int i = 0, now = 1;
while (p[i]) {
if (tr[now][p[i] - 'a'] == 0)tr[now][p[i] - 'a'] = ++ct;
now = tr[now][p[i] - 'a'], i++;
}
to[j] = now;
}

void DFS(int x) {
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to), vis[x] += vis[edge[i].to];
}

int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%s", op), insert(op, i);
for (int i = 0; i < 26; i++)tr[0][i] = 1;
que.push(1);
while (!que.empty()) {
int p = que.front();
que.pop();
for (int i = 0; i < 26; i++) {
if (tr[p][i])fail[tr[p][i]] = tr[fail[p]][i], que.push(tr[p][i]);
else tr[p][i] = tr[fail[p]][i];
}
}
scanf("%s", str);
for (int i = 0, now = 1; str[i]; ++i)now = tr[now][str[i] - 'a'], ++vis[now];
for (int i = 2; i <= ct; i++)add(fail[i], i);
DFS(1);
int ans = 0;
for (int i = 1; i <= n; i++)ans += vis[to[i]] > 0;
cout << ans;
return 0;
}

模板二

        根据匹配次数输出最大的字符串。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include<bits/stdc++.h>

using namespace std;
char str[1000005], op[151][80];
int vis[1000005], to[1000005], ct = 1, fail[1000005], tr[1000005][26], cnt = 1;

struct Edge {
int next, to;
} edge[1000005];
int head[1000005];

inline void add(int x, int y) {
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

queue<int> que;

void insert(const char *p, int j) {
int i = 0, now = 1;
while (p[i]) {
if (tr[now][p[i] - 'a'] == 0)tr[now][p[i] - 'a'] = ++ct;
now = tr[now][p[i] - 'a'], i++;
}
to[j] = now;
}

void DFS(int x) {
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to), vis[x] += vis[edge[i].to];
}

int main() {
int n;
while (scanf("%d", &n) == 1) {
if (n == 0)return 0;
cnt = ct = 1, memset(tr, 0, sizeof(tr)), memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis)), memset(fail, 0, sizeof(fail));
for (int i = 1; i <= n; i++)scanf("%s", op[i]), insert(op[i], i);
for (int i = 0; i < 26; i++)tr[0][i] = 1;
que.push(1);
while (!que.empty()) {
int p = que.front();
que.pop();
for (int i = 0; i < 26; i++) {
if (tr[p][i])fail[tr[p][i]] = tr[fail[p]][i], que.push(tr[p][i]);
else tr[p][i] = tr[fail[p]][i];
}
}
scanf("%s", str);
for (int i = 0, now = 1; str[i]; ++i)now = tr[now][str[i] - 'a'], ++vis[now];
for (int i = 1; i <= ct; i++)head[i] = 0;
for (int i = 2; i <= ct; i++)add(fail[i], i);
DFS(1);
int maxn = 0;
for (int i = 1; i <= n; i++)maxn = max(maxn, vis[to[i]]);
cout << maxn << endl;
for (int i = 1; i <= n; i++)if (vis[to[i]] == maxn)cout << op[i] << endl;
}
}

模板三

        输出匹配数目。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>

using namespace std;
char str[2000010], op[200005];
int vis[500000], to[200005], ct = 1, fail[500000], tree[500000][26];

struct Edge {
int next, to;
} edge[500000];
int head[500000];

inline void add(int x, int y) {
static int cnt = 1;
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

queue<int> que;

void insert(const char *p, int j) {
int i = 0, now = 1;
while (p[i]) {
if (tree[now][p[i] - 'a'] == 0)tree[now][p[i] - 'a'] = ++ct;
now = tree[now][p[i] - 'a'], i++;
}
to[j] = now;
}

void DFS(int x) {
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to), vis[x] += vis[edge[i].to];
}

int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%s", op), insert(op, i);
for (int i = 0; i < 26; i++) {
if (tree[1][i])fail[tree[1][i]] = 1, que.push(tree[1][i]);
else tree[1][i] = 1;
}
while (!que.empty()) {
int p = que.front();
que.pop();
for (int i = 0; i < 26; i++) {
if (tree[p][i])fail[tree[p][i]] = tree[fail[p]][i], que.push(tree[p][i]);
else tree[p][i] = tree[fail[p]][i];
}
}
scanf("%s", str);
for (int i = 0, now = 1; str[i]; ++i)now = tree[now][str[i] - 'a'], ++vis[now];
for (int i = 2; i <= ct; i++)add(fail[i], i);
DFS(1);
for (int i = 1; i <= n; i++)cout << vis[to[i]] << endl;
return 0;
}

例题

        为了加深对AC自动机的理解,这里介绍几道例题。以下题目不涉及题面,可点开标题自查。

[NOI2011]阿狸的打字机

        AC自动机的好题,强烈建议一做。
        首先将所有字符串放入AC自动机,然后求出fail树。由上面我们关于fail树性质的讨论,可以发现,要求串x在串y中出现多少次,就是求Trie树中属于串y的结点在以x为根的fail子树上有多少个。现在问题变得十分清晰明了了。
        本题当然可以离线做,这里我们用在线做法(能在线绝不离线)。
        首先一遍DFS将树上问题转化为区间问题。然后将某一个串y的Trie树上的结点打上标记,用线段树求和就可以求出所有关于串y的询问。对于所有的串都开一个线段树就可以在线了,但是这样显然很费空间,考虑可持久化线段树。
        注意到Trie树上的结点,其进行修改时仅修改一个值(只给它自己这一个结点打上标记),那么这里就可以用树上主席树的方法建立可持久化线段树,然后就可以将空间复杂度降到$O(nlogn)$。对于某一个询问,只需要调出所需的线段树,然后区间求和即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>

#define N 200005
using namespace std;
char op[N];
int ct = 1, to[N], fail[N], head[N], DFN[N], ID = 1, size[N], root[N], ps;
struct Trie {
int next[28], fa;
} trie[N];
struct Edge {
int next, to;
} edge[N];
struct Node {
int l, r, v;
} nd[N << 6];
queue<int> que;

inline void add(int x, int y) {
static int cnt = 1;
edge[cnt].next = head[x], edge[cnt].to = y, head[x] = cnt++;
}

void DFS(int x) {//DFS序转化
size[x] = 1, DFN[x] = ID++;
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to), size[x] += size[edge[i].to];
}

int newTree(int s, int pre, int l = 1, int r = ID - 1) {
int v = ++ps;
nd[v] = nd[pre], ++nd[v].v;
int mid = (l + r) >> 1;
if (l < r) {
if (s <= mid)nd[v].l = newTree(s, nd[pre].l, l, mid);
else nd[v].r = newTree(s, nd[pre].r, mid + 1, r);
}
return v;
}

void DFS2(int x) {//建立主席树
for (int i = 0; i < 26; i++) {
if (trie[trie[x].next[i]].fa == x)
root[trie[x].next[i]] = newTree(DFN[trie[x].next[i]], root[x]), DFS2(trie[x].next[i]);
}
}

int build(int l = 1, int r = ID - 1) {
int s = ++ps;
if (l != r)nd[s].l = build(l, (l + r) >> 1), nd[s].r = build(((l + r) >> 1) + 1, r);
return s;
}

int query(int rt, int l, int r, int L = 1, int R = ID - 1) {
if (L >= l && R <= r)return nd[rt].v;
int mid = (L + R) >> 1;
if (l > mid)return query(nd[rt].r, l, r, mid + 1, R);
else if (r <= mid)return query(nd[rt].l, l, r, L, mid);
return query(nd[rt].l, l, mid, L, mid) + query(nd[rt].r, mid + 1, r, mid + 1, R);
}

int main() {
scanf("%s", op);
for (int i = 0, now = 1, s = 0; op[i]; i++) {
if (op[i] == 'P')to[++s] = now;
else if (op[i] == 'B' && now != 1)now = trie[now].fa;
else if (trie[now].next[op[i] - 'a'])now = trie[now].next[op[i] - 'a'];
else trie[now].next[op[i] - 'a'] = ++ct, trie[ct].fa = now, now = ct;
}
for (int i = 0; i < 26; i++) {
if (trie[1].next[i])fail[trie[1].next[i]] = 1, que.push(trie[1].next[i]);
else trie[1].next[i] = 1;
}
while (!que.empty()) {//建立AC自动机
int f = que.front();
que.pop();
for (int i = 0; i < 26; i++) {
if (trie[f].next[i])fail[trie[f].next[i]] = trie[fail[f]].next[i], que.push(trie[f].next[i]);
else trie[f].next[i] = trie[fail[f]].next[i];
}
}
for (int i = 2; i <= ct; i++)add(fail[i], i);
DFS(1), root[1] = build(), DFS2(1);
int m;
scanf("%d", &m);
while (m--) {
int x, y;
scanf("%d%d", &x, &y);
printf("%d\n", query(root[to[y]], DFN[to[x]], DFN[to[x]] + size[to[x]] - 1));
}
return 0;
}

[POI2000]病毒

        把字符串存到AC自动机里,然后建fail树,再将不能走的点打上标记,重建状态转移图,再来一步拓扑排序判环即可。本题比较容易。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <bits/stdc++.h>

#define N 50000
using namespace std;
struct Edge {
int to, next;
} edge[N], edge2[N];
int trie[N][2], fail[N], ct = 1, n, to[N], head[N], mark[N], head2[N], vis[N];
int m, in[N], ans;
char op[N];

inline void insert(const char *s, int p) {
int now = 1;
for (int i = 0; s[i]; i++) {
if (trie[now][s[i] - '0'] == 0)now = trie[now][s[i] - '0'] = ++ct;
else now = trie[now][s[i] - '0'];
}
to[p] = now;
}

inline void add(int x, int y) {
static int cnt = 1;
edge[cnt].to = y, edge[cnt].next = head[x], head[x] = cnt++;
}

inline void add2(int x, int y) {
static int cnt = 1;
++in[y], edge2[cnt].to = y, edge2[cnt].next = head2[x], head2[x] = cnt++;
}

void DFS(int x) {
mark[x] = 1;
for (int i = head[x]; i; i = edge[i].next)DFS(edge[i].to);
}

void DFS2(int x) {
for (int i = 0; i < 2; i++) {
if (!mark[trie[x][i]]) {
add2(x, trie[x][i]);
if (!vis[trie[x][i]])vis[trie[x][i]] = 1, DFS2(trie[x][i]);
}
}
}

queue<int> que;

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%s", op), insert(op, i);
for (int i = 0; i < 2; i++) {
if (trie[1][i])fail[trie[1][i]] = 1, que.push(trie[1][i]);
else trie[1][i] = 1;
}
while (!que.empty()) {
int f = que.front();
que.pop();
for (int i = 0; i < 2; i++) {
if (trie[f][i])fail[trie[f][i]] = trie[fail[f]][i], que.push(trie[f][i]);
else trie[f][i] = trie[fail[f]][i];
}
}
for (int i = 2; i <= ct; i++)add(fail[i], i);
for (int i = 1; i <= n; i++)DFS(to[i]);
vis[1] = 1, DFS2(1);
for (int i = 1; i <= ct; i++) {
if (!mark[i]) {
++m;
if (!in[i])que.push(i);
}
}
while (!que.empty()) {
int f = que.front();
que.pop(), ++ans;
for (int i = head2[f]; i; i = edge2[i].next) {
--in[edge2[i].to];
if (!in[edge2[i].to])que.push(edge2[i].to);
}
}
if (ans == m)cout << "NIE";
else cout << "TAK";
return 0;
}