Buying a New String AC自动机

(9 mins to read)

给定两个字符串A,B,再给出n个字符串$s_i$​,每个$s_i$​有价值$b_i$​,让你从A中取一个子串(可空),再从B中取一个子串(可空),拼接成串C,定义C的价值为$\sum cnt_i©⋅b_i$.$cnt_i$​表示$s_i$​在C中作为子串出现的次数.要求输出最大的价值.Constraints$1≤T≤10$$1≤|A|,|B|≤10^3$$1≤N≤10^4$$1≤|S_i|≤26$ for each valid i$1≤b_i≤10^5$ for each valid i$A,B,S_1,S_2,…,S_N$​ contain only lowercase English letters$S_1,S_2,…,S_N$​ are pairwise distinctthe sum of |A| over all test cases does not exceed $2⋅10^3$the sum of |B| over all test cases does not exceed $2⋅10^3$the sum of N over all test cases does not exceed $2⋅10^4$

做法

简单的贪心可以发现,肯定是取A的某个前缀和B的某个后缀拼接.此题的关键$|S_i|$很小,如果大了我不知道还能不能做了.利用这点我们可以先预处理出A的每个前缀的价值和B的每个后缀的价值,这个用ac自动机处理,然后枚举前缀i和后缀j拼接,此时C中没有算的影响就是拼接处前后26个字母的价值,可以从后缀的起点开始在ac自动机上走25步,每次暴力跳fail统计下答案,最后取个max.这样还是会t,复杂度是$O(|A|^2 * 26^2)$大概6e9,每次跳fail是O(26)的,那就再建个fail树,预处理下把这个O(26)去掉就过了.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define sz(x) (int)x.size()
#define all(x) begin(x), end(x)
#define fi first
#define se second
#define debug(x) cerr << #x << " " << x << '\n'
using namespace std;
using ll = long long;
using pii = pair<int,int>;
using pli = pair<ll,int>;
const int INF = 0x3f3f3f3f, N = 3e5 + 5;
const ll LINF = 1e18 + 5;
constexpr int mod = 1e9 + 7;
char a[N], b[N], rb[N];
char s[N][30], rs[N][30];
int n, val[N];
ll pre[N], suf[N];
struct ACA
{
int nxt[N][26], fail[N], len[N], cnt, end[N];
ll tmp[N][30];
vector <int> to[N];
void clear()
{
for(int i=0; i<=cnt; i++)
{
fail[i] = end[i] = len[i] = 0;
to[i].clear();
for(int j=0; j<26; j++)
nxt[i][j] = tmp[i][j+1] = 0;
}
cnt = 0;
}
void insert(char *s, int v, int n)
{
int p = 0;
for(int i=0; i<n; i++)
{
int k = s[i] - 'a';
if(!nxt[p][k]) nxt[p][k] = ++cnt;
p = nxt[p][k];
}
end[p] = v; len[p] = n;
}
void build()
{
queue <int> q;
for(int i=0;i<26;i++) if(nxt[0][i]) q.push(nxt[0][i]);
while(!q.empty())
{
int k = q.front(); q.pop();
for(int i=0;i<26;i++)
{
if(nxt[k][i])
{
fail[nxt[k][i]] = nxt[fail[k]][i];
q.push(nxt[k][i]);
}
else nxt[k][i] = nxt[fail[k]][i];
}
}
}
void go(char *s, ll *v, int n)
{
int p = 0;
for(int i=0; i<n; i++)
{
p = nxt[p][s[i]-'a'];
for(int j=p; j; j=fail[j])
v[i] += end[j];
}
}
void dfs(int u)
{
for(int i=len[u]; i>0; i--) tmp[u][i] += end[u];
for(int v : to[u])
{
for(int i=26; i>0; i--) tmp[v][i] += tmp[u][i];
dfs(v);
}
}
}ac, ac2;
void solve()
{
scanf("%s%s", a, b);
scanf("%d", &n);
int x = strlen(a), y = strlen(b);
for(int i=0; i<y; i++) rb[i] = b[y-i-1];
ac.clear(); ac2.clear();
for(int i=1; i<=n; i++)
{
scanf("%s%d", s[i], val+i);
int t = strlen(s[i]);
ac.insert(s[i], val[i], t);
for(int j=0; j<t; j++) rs[i][j] = s[i][t-j-1];
ac2.insert(rs[i], val[i], t);
}
ac.build(); ac2.build();
for(int i=0; i<x; i++) pre[i] = 0;
for(int i=0; i<y; i++) suf[i] = 0;
ac.go(a, pre, x); ac2.go(rb, suf, y);
ll ans = 0;
int p = 0;
for(int i=1; i<x; i++) pre[i] += pre[i-1];
for(int i=1; i<y; i++) suf[i] += suf[i-1];
for(int i=1; i<=ac.cnt; i++) ac.to[ac.fail[i]].pb(i);
ac.dfs(0);
for(int i=0; i<x; i++)
{
p = ac.nxt[p][a[i]-'a'];
for(int j=0; j<y; j++)
{
ll cur = pre[i] + suf[y-j-1];
int pp = p;
for(int k=0; k<25&&j+k<y; k++)
{
pp = ac.nxt[pp][b[j+k]-'a'];
cur += ac.tmp[pp][k+2];
}
if(cur>ans) ans = cur;
}
}
printf("%lld\n", ans);
}
int main()
{
int T; scanf("%d", &T);
while(T--) solve();
return 0;
}