bzoj 4180 字符串计数

SAM + 矩阵快速幂 floyd.

考虑如果给出了串 $S$ ,如何算最小操作次数,将 $S$ 放到 $T$ 的 SAM 上去匹配,若没有出边则返回 $1$ ,且操作次数 $+1$ .

可以二分答案 $mid$ ,尝试检查用 $mid$ 次操作能构造出的串最小长度,若 $\le n$ ,说明答案 $\ge mid$ .

考虑如何求出用 $mid$ 次操作能构造的串的最小长度,其实转移只会出现在根节点的出边指向的点中.

bfs 预处理出它们之间的距离,即要加入多少个字符,对转移矩阵求 $mid​$ 次幂即可得到用 $mid​$ 次操作的最短距离.

时间复杂度 $O(|\sum|\cdot |T|+|\sum|^3\log^2 n)$ ,其中 $|\sum|$ 代表字符集大小.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read()
{
ll out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
void print(ll x)
{
if (x >= 10)
print(x / 10);
putchar('0' + x % 10);
}
void write(ll x, char c)
{
if (x < 0)
putchar('-'), x = -x;
print(x);
putchar(c);
}
const int N = 2e5 + 10, S = 4, inf = 1e9;
const ll INF = 4e18;
int n, idx = 1, lst = 1, len[N], fa[N], ch[N][S];
char buf[N];
void Extend(int c)
{
int p = lst, np = ++idx;
lst = np;
len[np] = len[p] + 1;
while (p && ch[p][c] == 0)
ch[p][c] = np, p = fa[p];
if (!p)
fa[np] = 1;
else
{
int q = ch[p][c];
if (len[q] == len[p] + 1)
fa[np] = q;
else
{
int nq = ++idx;
len[nq] = len[p] + 1;
fa[nq] = fa[q], fa[q] = fa[np] = nq;
memcpy(ch[nq], ch[q], sizeof ch[nq]);
while (p && ch[p][c] == q)
ch[p][c] = nq, p = fa[p];
}
}
}
struct Martix
{
ll v[S][S];
Martix()
{
for (int i = 0; i < S; ++i)
for (int j = 0; j < S; ++j)
v[i][j] = INF;
}
Martix operator * (const Martix &rhs) const
{
Martix res;
for (int i = 0; i < S; ++i)
for (int k = 0; k < S; ++k) if (v[i][k] < INF)
for (int j = 0; j < S; ++j) if (rhs.v[k][j] < INF)
res.v[i][j] = min(res.v[i][j], v[i][k] + rhs.v[k][j]);
return res;
}
} A, tmp;
Martix fpow(Martix a, ll b)
{
Martix res;
for (int i = 0; i < S; ++i)
res.v[i][i] = 0;
while (b)
{
if (b & 1LL)
res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
int dis[N], vis[N];
queue<int> q;
void bfs(int id, int st)
{
for (int i = 1; i <= idx; ++i)
vis[i] = 0, dis[i] = inf;
vis[S] = 1, dis[st] = 0, q.push(st);
while (!q.empty())
{
int x = q.front();
q.pop();
for (int i = 0; i < S; ++i)
if (!vis[ch[x][i]])
{
if (!ch[x][i])
A.v[id][i] = min(A.v[id][i], 1LL * dis[x] + 1);
else
{
vis[ch[x][i]] = 1;
dis[ch[x][i]] = dis[x] + 1;
q.push(ch[x][i]);
}
}
}
}
ll solve(ll k)
{
tmp = fpow(A, k);
ll mi = INF;
for (int i = 0; i < S; ++i)
for (int j = 0; j < S; ++j)
mi = min(mi, tmp.v[i][j]);
return mi;
}
int main()
{
ll len = read();
scanf("%s", buf + 1);
n = strlen(buf + 1);
for (int i = 1; i <= n; ++i)
Extend(buf[i] - 'A');
for (int i = 0; i < S; ++i)
bfs(i, ch[1][i]);
ll L = 1, R = len, ans;
while (L <= R)
{
ll mid = (L + R) >> 1;
if (solve(mid) <= len)
ans = mid, L = mid + 1;
else
R = mid - 1;
}
if (solve(ans) < len)
++ans;
write(ans, '\n');
return 0;
}