kmp学习笔记

(13 mins to read)

应用

给定主串t和模式串s,求出s在t中的所有匹配位置求出一个串的所有前缀的最大border(next数组)

定义

前缀pre(s, l):s长度为l的前缀(s[1…l])后缀suf(s, l):s长度为l的后缀(s[|s|-l+1…|s|])周期period:s[i]=s[i+p],$\forall i \in1…|s|-ps[i]=s[i+p]$,则p是s的周期border: pre(s, l) = suf(s, l),则pre(s, l)是s的border

性质

若s有长为x的border,则|s|-x为s的一个周期。证明可以分成border相交和不相交两种情况最小周期=串长-最大borderborder具有传递性,串s的所有border可以用nxt[|s|],nxt[nxt[|s|]]...表示Weak Periodicity Lemmap和q是s的周期,p+q<=|s|,则gcd(p, q)也是s的周期Periodicity Lemmap和q是s的周期,p+q-gcd(p, q)<=|s|,则gcd(p, q)也是s的周期(牛客多校1的签到是给出两个无限串各自的周期,要求比较这两个串的字典序大小,只需要比较前p+q-gcd(p,q)或者前p+q个字符,如果都相同,说明p和q都是各自的周期,两个串相等)

求法

nxt数组的求法可以当作是s串自己和自己匹配的过程。规定nxt[0]=−1,s串从0开始,nxt[i]表示s[0…i-1]的最大border考虑当前已经求出了nxt[0…i],且nxt[i]=j,说明s[i-j...i-1]=s[0...j-1]

  • s[i]==s[j]nxt[++i] = ++j
  • ss[i]!=s[j],此时我们需要不断往前跳nxt,直到找到一个border的后一个字符等于s[j],然后转为第一种情况
1
2
3
4
5
6
7
8
9
void kmp_pre(char s[], int n, int nxt[])
{
int i = 0, j = nxt[0] = -1;
while(i<n)
{
while(j!=-1&&s[i]!=s[j]) j = nxt[j];
nxt[++i] = ++j;
}
}

字符串匹配的过程和求next数组的过程类似(模式串都是s,只不过主串变了)。规定s串为模式串,t为主串,下标均从0开始。先求出s串的next数组考虑当前t串匹配到i-1,s串匹配到j-1,说明t[i-j...i-1]=s[0...j-1]

  • t[i]=s[j],++i,++j。如果j>|s|,说明找到一个匹配,记录下位置,让j=nxt[j]
  • t[i]!=s[j],j不断跳nxt直到找到一个border的后一个字符等于t[i]为止,然后转化为第一种情况
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int kmp_work(char s[], int n, char t[], int m)
{
int i = 0, j = 0, ans = 0;
kmp_pre(s, n, nxt);
while(i<m)
{
while(j!=-1&&t[i]!=s[j]) j = nxt[j];
i++; j++;
if(j>=n)
{
ans++;
match[i-n] = 1;
j = nxt[j];
}
}
return ans;
}

习题

  • p3375求出所有匹配位置+nxt数组

  • p5829引入一种失配树(和ac自动机的fail树类似):按照next数组进行建树,next[i]作为i的父亲。显然每个点有且仅有一个父节点。容易发现每个点的所有祖先就是它的所有border,跳nxt的过程相当于不断往父亲走本题多次询问一个串s的某两个前缀的最大公共border。建出失配树后很直观的可以发现答案就是两个节点的LCA代表的前缀

  • p3435求每个前缀的最大周期。求最大周期相当于求最小border。放在失配树上就相当于找到每个点祖先中离根最近的那个。树上dp一下即可。也可以路径压缩。

  • poj2185一个n*m字符矩阵,求一个最小的子矩阵,使其重复多次后所得矩阵包含原矩阵。行列分开做一遍kmp,求出各自的最大border。按行做的时候,每一行的字符当作一个字符(hash一下,直接暴力也行)

  • p2375对s的所有前缀,求出长度不超过一半的border数量放到失配树上,倍增一下(border长度单减)(无脑+暴力)$O(n\log n)$$O(n)$做法:考虑在求解nxt数组的时候,再增加一个k指针,该指针和j指针一样,但是需要额外满足长度的限制,如果不满足继续跳nxt。i点的答案就是k点的答案的nxt前缀和(同时dp递推即可)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void solve()
{
cin >> s;
int n = strlen(s);
int i = 0, j = nxt[0] = -1, k = -1;
ll ans = 1;
while(i<n)
{
while(~j && s[i]!=s[j]) j = nxt[j];
while(~k && s[i]!=s[k]) k = nxt[k];
nxt[++i] = ++j;
++k;
dp[i] = dp[nxt[i]] + 1;
while(2*k>i) k = nxt[k];
ans = ans*(dp[k]+1)%mod; //求的是每个点答案+1的连乘积
}
cout << ans << '\n';
}
  • p3426对于字符串s,求一个最短的字符串t,其可重叠地重复数次后等于s显然t是s的一个border如果当前等价于s[1…i],相当于t是pre(s, i)的一个border。所以这个t是s的某些前缀的border,而且这些前缀的位置的差值要<=这个串t的长度。在失配树上考虑,答案是0到n这条链上的某个点x(s的某个border),并且该点子树内所有点的位置的最大差值<=x(x是其子树内的所有点的某个border),如果最大差值>x,那么这两个前缀间就不能通过x拼接而成。失配树上从0走到n,不断删去非子树内的点,用set或双向链表维护maxgap即可考虑dp[i]:s[1…i]的答案dp[i]要么等于i,要么等于dp[nxt[i]]等于dp[nxt[i]]的条件是nxt[i]…i-1间$\exists j \ dp[j] = dp[nxt[i]] \land i-j<=dp[nxt[i]]$
1
2
3
4
5
6
7
for(int i=1; i<=n; i++)
{
dp[i] = i;
if(p[dp[nxt[i]]] && i-p[dp[nxt[i]]]<=dp[nxt[i]])
dp[i] = dp[nxt[i]];
p[dp[i]] = i;
}
  • p3193给定一个长为m的字符串t,问能构造出多少个长为n的字符串s,使得t不是s的子串。$m<=20, n<=10^9$f[i][j]:s串构造到i,t串匹配到j的方案数先预处理出t串当前匹配到i,加一个字符后匹配到j的方案数g[i][j],这个用kmp即可转移:$f[i][j] = \sum_{k=0}^{m-1} f[i-1][k] * g[k][j]$显然可以矩阵快速幂优化

$\begin{bmatrix}f[i][0] & \cdots & f[i][m-1] \end{bmatrix} = \begin{bmatrix} f[i-1][0] &\cdots & f[i-1][m-1] \end{bmatrix} \begin{bmatrix} g[0][0] & \cdots &g[0][m-1] \\ \vdots & \ddots & \vdots \\ g[m-1][0] & \cdots & g[m-1][m-1]\\ \end{bmatrix}$

其中f[0][0]=1最后答案即为$g^n$的第一行的和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define sz(x) (int)x.size()
#define all(x) begin(x), end(x)
#define fi first
#define se second
#define debug(x) cerr << #x << " " << x << '\n'
using namespace std;
using ll = long long;
using pii = pair<int,int>;
using pli = pair<ll,int>;
const int INF = 0x3f3f3f3f, N = 25;
const ll LINF = 1e18 + 5;
int n, m, mod;
int g[N][N], nxt[N];
char s[N];
struct matrix
{
int n, m;
int mat[N][N];
matrix() {}
matrix(int a, int b) : n(a), m(b) { memset(mat, 0, sizeof(mat)); }
void one()
{
for(int i=1; i<=n; i++)
for(int j=1; j<=m; j++)
mat[i][j] = (i==j);
}
};
matrix operator * (matrix a, matrix b)
{
matrix res(a.n, b.m);
for(int i=1; i<=a.n; i++)
for(int j=1; j<=b.m; j++)
for(int k=1; k<=a.m; k++)
res.mat[i][j] = (res.mat[i][j]+a.mat[i][k]*b.mat[k][j])%mod;
return res;
}
matrix operator ^ (matrix a, ll x)
{
matrix res(a.n, a.m); res.one();
while(x)
{
if(x&1) res = res * a;
a = a * a;
x >>= 1;
}
return res;
}
void get_nxt()
{
int i = 0, j = nxt[0] = -1;
while(i<m)
{
while(~j && s[i]!=s[j]) j = nxt[j];
nxt[++i] = ++j;
}
}
void init_g()
{
for(int i=-1; i<m-1; i++)
for(int k=0; k<10; k++)
{
int j = i + 1;
while(~j && s[j]-'0'!=k) j = nxt[j];
g[i+1][j+1]++;
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m >> mod;
cin >> s;
get_nxt();
init_g();
matrix f(m, m);
for(int i=1; i<=m; i++)
for(int j=1; j<=m; j++)
f.mat[i][j] = g[i-1][j-1];
f = f ^ n;
int ans = 0;
for(int i=1; i<=m; i++) ans = (ans + f.mat[1][i])%mod;
cout << ans << '\n';
return 0;
}