bzoj 4598 模式字符串

点分治 + $hash$ .

  • 考虑点分治.由于分治时路径起点不确定,无法直接匹配,所以需要 $hash$ 暂时存储状态.
  • 若当前分治中心为 $rt$ ,维护 $pre(i),suf(i)$ 分别表示节点 $i$ 到 $rt$ 的路径, $rt$ 到 $i$ 的路径的 $hash$ 值.
  • 到一个点,先算得它的 $pre$ ,在后面接上 $rt$ 的字符,判断这个串在循环意义下是否与模式串的某个前缀匹配.
  • 可以反着推,假设它能与某个前缀循环匹配,那么根据这个串的长度,可以算出它应该是模式串重复了 $\lfloor len/m \rfloor$ 次,再接上一个长度为 $len\mod m$ 的前缀形成的.判一下 $pre$ 是否与理论上求得的 $hash$ 值相等即可.
  • 若在循环意义下匹配上了长度为 $i$ 的前缀,它的贡献就是当前能与长度 $m-i$ 的后缀循环匹配的 $suf$ 数目.
  • 再算这个点 $suf$ 的贡献,与上面的方法类似.维护一个反串的 $hash$ 会十分方便.

常数大的一批,写法是对的,但时限卡不进去.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int out=0,fh=1;
char jp=getchar();
while ((jp>'9'||jp<'0')&&jp!='-')
jp=getchar();
if (jp=='-')
fh=-1,jp=getchar();
while (jp>='0'&&jp<='9')
out=out*10+jp-'0',jp=getchar();
return out*fh;
}
const int inf=1e9;
const int MAXN=1e5+10;
int n,m;
char buf[MAXN];
int ecnt=0,head[MAXN],to[MAXN<<1],nx[MAXN<<1];
void addedge(int u,int v)
{
++ecnt;
to[ecnt]=v;
nx[ecnt]=head[u];
head[u]=ecnt;
}
typedef unsigned long long ull;
const ull Base=37;
ull pw[MAXN],Hash[MAXN],revHash[MAXN];
ull val[MAXN],pre[MAXN],suf[MAXN];
ull Pattern_Power[MAXN],revPattern_Power[MAXN];
ull Add_Char(ull hash,int c)
{
return hash*Base+c;
}
void Init_Hash()
{
Hash[0]=0;
for(int i=1;i<=m;++i)
Hash[i]=Hash[i-1]*Base+(buf[i]-'A');
revHash[0]=0;
for(int i=1;i<=m;++i)
revHash[i]=revHash[i-1]*Base+(buf[m-i+1]-'A');
Pattern_Power[0]=0;
for(int i=1;i*m<=n;++i)
Pattern_Power[i]=Pattern_Power[i-1]*pw[m]+Hash[m];
revPattern_Power[0]=0;
for(int i=1;i*m<=n;++i)
revPattern_Power[i]=revPattern_Power[i-1]*pw[m]+revHash[m];
}
int rt,totsize,mi,siz[MAXN],vis[MAXN];
ll ans;
void Findrt(int u,int fa)
{
siz[u]=1;
int mxsiz=0;
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(v==fa || vis[v])
continue;
Findrt(v,u);
siz[u]+=siz[v];
if(siz[v]>mxsiz)
mxsiz=v;
}
mxsiz=max(mxsiz,totsize-siz[u]);
if(mxsiz<mi)
mi=mxsiz,rt=u;
}
int sumpre[MAXN],sumsuf[MAXN];
int stk1[MAXN],stk2[MAXN],tp=0;
int lstpre[MAXN],lstsuf[MAXN],cnt=0;
int Match_pre(ull hash,int len)
{
int x=len/m,y=len%m;
ull exphash=Pattern_Power[x]*pw[y]+Hash[y];
if(exphash==hash)
return y;
else
return m;
}
int Match_suf(ull hash,int len)
{
int x=len/m,y=len%m;
ull exphash=revPattern_Power[x]*pw[y]+revHash[y];
if(exphash==hash)
return y;
else
return m;
}
int prelen,suflen;
ull preval,sufval;
void dfs(int u,int fa,int len,int Rt)
{
pre[u]=val[u]*pw[len-1]+pre[fa];
suf[u]=val[u]*pw[len-1]+suf[fa];
prelen=Match_pre(pre[u],len);
suflen=Match_suf(suf[u],len);
if(prelen!=m || suflen!=m)
{
++cnt;
lstpre[cnt]=prelen;
lstsuf[cnt]=suflen;
}
preval=Add_Char(pre[u],val[Rt]);
prelen=Match_pre(preval,len+1);
sufval=Add_Char(suf[u],val[Rt]);
suflen=Match_suf(sufval,len+1);
if(prelen!=m)
ans+=sumsuf[(m-prelen)%m];
if(suflen!=m)
ans+=sumpre[(m-suflen)%m];
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(vis[v] || v==fa)
continue;
dfs(v,u,len+1,Rt);
}
}
void solve(int u)
{
vis[u]=1;
if(siz[u]<m)
return;
for(int i=1;i<=tp;++i)
{
sumpre[stk1[i]]=0;
sumsuf[stk2[i]]=0;
}
pre[u]=suf[u]=0;
++sumpre[0],++sumsuf[0];
tp=1;
stk1[tp]=0;
stk2[tp]=0;
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(vis[v])
continue;
dfs(v,u,1,u);
while(cnt)
{
++tp;
stk1[tp]=lstpre[cnt];
stk2[tp]=lstsuf[cnt];
++sumpre[lstpre[cnt]];
++sumsuf[lstsuf[cnt]];
--cnt;
}
}
}
void Divide(int u)
{
solve(u);
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(vis[v])
continue;
mi=inf,totsize=siz[v];
Findrt(v,0);
Divide(rt);
}
}
void Reset()
{
memset(head,0,sizeof head);
ecnt=0;
memset(vis,0,sizeof vis);
ans=0;
}
int main()
{
pw[0]=1;
for(int i=1;i<=100000;++i)
pw[i]=pw[i-1]*Base;
int T=read();
while(T--)
{
Reset();
n=read(),m=read();
scanf("%s",buf+1);
for(int i=1;i<=n;++i)
val[i]=buf[i]-'A';
for(int i=1;i<n;++i)
{
int u=read(),v=read();
addedge(u,v);
addedge(v,u);
}
scanf("%s",buf+1);
Init_Hash();
mi=inf,totsize=n;
Findrt(1,0);
Divide(rt);
printf("%lld\n",ans);
}
return 0;
}