2020秦皇岛 J Jewel Splitting

(11 mins to read)

题意

一个长度为n的字符串,让你排成长为d,宽为$\lfloor \frac{n}{d}\rfloor$的矩形,即将连续的长度d的子串作为一行,多余的n%d个丢弃,问有多少种不同的矩形,mod 998244353$n \leq 3 \times 10^5$

做法

显然不同的d互不影响,可以分开做当$d \mid n$,就是一个可重集排列,用map来维护每种hash值出现的个数当$d \nmid n$,那么我们可能会丢弃$[1,n%d],[d+1,d+n%d]…$,容易发现相邻两种只有两个字符串是不同的,因此动态维护即可。有个问题是,丢弃不同的区间,剩余的串的类型是一模一样的,即这个状态的map是一样的,这时候不能重复算,因此还要对每个时刻的map状态进行hash。思路不难,但是卡常。不双hash会wa,双hash又tle🤡

字符串的rollinghash的本质其实就是将这个字符串用base进制的数来表示,然后可以O(1)的求每个子区间的hash值(我好像到现在才明白字符串hash就是用base进制数来表示。。。)考虑map的状态怎么求,就相对于有很多pairs,(hash,cnt)。我的做法是把hash值丢到map里然后映射成小的编号,然后cnt就可以用数组来存了,然后map的状态就相当于第id[hash]位的值为cnt,用cnt*pw[id[hash]]来表示即可。如果只要整个串的hash值,各个位的值用异或也是可以,这样可以快点。然后就是双hash的值,用pair的话肯定会慢,这里用int存,然后高低位合并成一个ull即可。

通过这题测了一下各种map和setunordered_set很没用,set和map比较稳定,unordered_map有时快有时慢,加了手写hash也慢,很不稳定。所以一般情况下还是用set和map就行了,稳定log,unordered容易卡成n如果要实现无序的set,不推荐unordered_set。如果可以离线,那就vector此外pbds的gp_hash_table很快,被卡map可以用这个代替unordered_map

1
2
3
4
5
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
gp_hash_table<ull, int> id;
gp_hash_table<ull, bool> vis;
gp_hash_table<ull, null_type> vis;

无序set也可以用下面两个,测出来bool快一点,null_type内存小点gp_hash_table封装的函数不多,只有下标、insert要实现count,要靠find==end所以最好开局部,不要开全局

1
2
map:2700ms
gp_hash_table:1700ms
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
using ull = uint64_t;
const int N = 3e5 + 5;
const int mod = 998244353, mod2 = 1e9 + 9, base = 19260817;
int kase;
int n, fac[N<<1], ifac[N<<1], cnt[N<<1], ans, tot;
int HSH, HSH2;
//map<ull, int> id;
gp_hash_table<ull, int> id;
gp_hash_table<ull, bool> vis;
//map<int, bool> vis;
//set<int> st;
char s[N];
int pw[N<<1], pw2[N<<1];
int hsh[N], hsh2[N];
void Add(int &x, int y) { x += y; if(x>=mod) x -= mod; }
int Pow(int a, int b)
{
int ans = 1;
while(b)
{
if(b&1) ans = 1ll*ans*a%mod;
a = 1ll*a*a%mod;
b >>= 1;
}
return ans;
}
void pre(int n)
{
pw[0] = pw2[0] = fac[0] = 1;
for(int i=1; i<=2*n; i++)
{
pw[i] = 1ll*pw[i-1]*base%mod;
pw2[i] = 1ll*pw2[i-1]*base%mod2;
fac[i] = 1ll*fac[i-1]*i%mod;
}
ifac[n] = Pow(fac[n], mod-2);
for(int i=n-1; i>=0; i--) ifac[i] = 1ll*ifac[i+1]*(i+1)%mod;
}
void init()
{
for(int i=1; i<=n; i++)
{
hsh[i] = (1ll*hsh[i-1]*base%mod + s[i])%mod;
hsh2[i] = (1ll*hsh2[i-1]*base%mod2 + s[i])%mod2;
}
}
inline int Sub(int x, int y) { x -= y; if(x<0) x += mod; return x; }
inline int Sub2(int x, int y) { x -= y; if(x<0) x += mod2; return x; }
int get(int l, int r) { return Sub(hsh[r], 1ll*hsh[l-1]*pw[r-l+1]%mod); }
int get2(int l, int r) { return Sub2(hsh2[r], 1ll*hsh2[l-1]*pw2[r-l+1]%mod2); }
void upd(int l, int r, int v)
{
ull hv = (ull)get(l, r)<<32|get2(l, r), p = 0;
auto it = id.find(hv);
if(it==id.end()) p = id[hv] = ++tot;
else p = it->second;
HSH ^= 1ll*pw[p]*cnt[p]%mod;
HSH2 ^= 1ll*pw2[p]*cnt[p]%mod2;
ans = 1ll*ans*fac[cnt[p]]%mod*ifac[cnt[p] + v]%mod;
cnt[p] += v;
HSH ^= 1ll*pw[p]*cnt[p]%mod;
HSH2 ^= 1ll*pw2[p]*cnt[p]%mod2;
}
void clear(gp_hash_table<ull, int> &table)
{
gp_hash_table<ull, int> tmp;
table.swap(tmp);
}
void clear(gp_hash_table<ull, bool> &table)
{
gp_hash_table<ull, bool> tmp;
table.swap(tmp);
}
int work(int d)
{
for(int i=1; i<=tot; i++) cnt[i] = 0;
//st.clear();
clear(id); clear(vis);
tot = HSH = HSH2 = 0;
int sum = 0;
ans = fac[n/d];
for(int i=n%d+1; i<=n; i+=d) upd(i, i+d-1, 1);
//vector<pair<ull, int>> lsh;
//lsh.push_back({(ull)HSH<<32|HSH2, ans});
Add(sum, ans); vis.insert({(ull)HSH<<32|HSH2, 1});
if(n%d)
{
int m = n/d*d;
for(int i=1; i<m; i+=d)
{
upd(i, i+d-1, 1);
upd(i+n%d, i+n%d+d-1, -1);
//lsh.push_back({(ull)HSH<<32|HSH2, ans});
if(vis.find((ull)HSH<<32|HSH2)==vis.end()) Add(sum, ans), vis.insert({(ull)HSH<<32|HSH2, 1});
}
}
/*
sort(begin(lsh), end(lsh));
for(int i=0; i<(int)lsh.size(); )
{
int j = i + 1;
while(j<(int)lsh.size() && lsh[j]==lsh[i])
{
//assert(lsh[j].second==lsh[i].second);
++j;
}
Add(sum, lsh[i].second);
i = j;
}
*/
return sum;
}
void solve()
{
scanf("%s", s+1);
//for(int i=1; i<=300000; i++) s[i] = char('a'+i%26);
//s[300001] = 0;
n = strlen(s+1);
init();
int res = 0;
for(int d=1; d<=n; d++) Add(res, work(d));
printf("Case #%d: %d\n", ++kase, res);
}
int main()
{
pre(N-5);
int _; scanf("%d", &_);
while(_--) solve();
return 0;
}