Loj 2462 完美的集合

树上连通块的 dp 技巧 + 奇怪的组合数取模.

“点数减边数”

对于树上满足某种性质 $P$ 的连通块计数,有一个巧妙的转化 “点数减边数” .

枚举树上的每个点 $x$ ,计算出包含 $x$ ,且满足性质 $P$ 的连通块数目,对所有 $x$ 求和得到 $a$ .

再枚举树上每条边 $(x,y)$ ,计算出同时包含 $x,y$ ,且满足性质 $P$ 的连通块数目,对所有 $(x,y)$ 求和得到 $b$ .

每个连通块对 $a$ 贡献是点数,对 $b$ 贡献是边数,而树上的连通块也是一棵树,点数 = 边数 + 1.

于是可以得出满足性质 $P$ 的连通块数目就是 $a-b$ .

这种转化常用于求若干连通块的交,因为这样转化就变成了每个连通块都包含枚举的点/边,连通块之间就独立了.

计算完美集合数目

这道题就可以使用上面这个转化,计算选出能被 $x​$ 测试的完美集合方案,减去能被 $x,y​$ 同时测试的完美集合方案.

集合内肯定不能包含 $dist*v>Max$ 的点,可以将它们去掉,对剩下的点做一个背包,维护最优解和最优解的数目.

树上背包可以按照 dfs 序转移,若选了 $i$ ,则考虑其子树,转移到 $f_{dfn(i)+1}$ ,否则跳过其子树,转移到 $f_{dfn(i)+siz(i)}$ .

于是我们可以在 $O(nm)$ 的时间复杂度内求出能被某一个点或某两个点同时测试的完美集合数目.

计算模意义下组合数

如果没有再从集合中选出 $k$ 个的要求和 $5^{23}$ 这个奇怪的模数,这道题就已经优美地结束了.

出题人为了加大力度,还需要我们计算出 $\binom{s}{k}\bmod 5^{23}$ ,其中 $s$ 是计算出的集合数目,可以达到 $2^{60}$ , $k$ 可以达到 $10^9$ .

用拓展 Lucas 的思路,其实就是要设法算出 $p!$ 中含有的因子 $5$ 的个数以及 $\prod_{i\le p}[i\bmod 5\neq 0] i$ 的值.

前者可以直接 $O(\log p)​$ 计算,关键在于如何计算后者.

在拓展 Lucas 中,我们是 $O(mod)$ 暴力算出一个循环节的贡献,但这里 $mod=5^{23}$ ,显然行不通.

考虑构造多项式 $F_p(x)=\prod_{1\le i\le p} [i\bmod 5\neq 0] (x+i)$ ,要求的答案就是这个多项式展开后的常数项.

若 $p$ 不是 $10$ 的整数倍,就先将最后不超过 $9$ 个一次式的乘积暴力乘出来,就可以将 $p$ 凑成 $10k$ 的形式.

考虑 $F_{10k}(x)=F_{5k}(x)\cdot F_{5k}(x+5k)​$ ,其中 $F_{5k}​$ 的常数项可以递归下去算.

而 $F_{5k}(x+5k)​$ 在模 $5^{23}​$ 下可以表示成一个关于 $(x+5k)​$ 的多项式,形如 $\sum c_i\cdot (x+5k)^i​$ .

而我们只关心展开后不含 $x​$ 的常数项,可以发现当 $i\ge 23​$ 时, 将 $(x+5k)^i​$ 展开后常数项都为 $0​$ ,没有贡献.

那么递归求出 $F_{5k}$ 展开后的前 $23$ 项,将 $(x+5k)^i$ 代入,展开后即可求得 $F_{5k}(x+5k)$ 的前 $23$ 项.

再暴力卷积合并得到 $F_{10k}$ 的前 $23$ 项,这样只会需要求 $O(\log s)$ 个 $F_{5k}$ 的前 $23$ 项,复杂度可以接受.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read()
{
ll out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
const ll P = 11920928955078125;
ll add(ll a, ll b)
{
return a + b >= P ? a + b - P : a + b;
}
void inc(ll &a, ll b)
{
a = add(a, b);
}
ll mul(ll a, ll b)
{
ll res = a * b - (ll)((long double)a / P * b + 1e-8) * P;
return res < 0 ? res + P : res % P;
}
ll fpow(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1LL)
res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
ll Inv(ll x)
{
return fpow(x, P / 5 * 4 - 1);
}
typedef pair<ll, ll> pll;
typedef vector<ll> poly;
poly operator * (poly a, poly b)
{
poly c(23);
for (int j = 0; j < 23; ++j) if (b[j])
for (int i = 0; i + j < 23; ++i) if (a[i])
inc(c[i + j], mul(a[i], b[j]));
return c;
}
poly p[10000];
ll pw[23], binom[23][23];
void init()
{
binom[0][0] = 1;
for (int i = 1; i < 23; ++i)
{
binom[i][0] = 1;
for (int j = 1; j <= i; ++j)
binom[i][j] = add(binom[i - 1][j], binom[i - 1][j - 1]);
}
p[0].resize(23);
p[0][0] = 1;
for (int i = 1; i < 10000; ++i)
{
if (i % 5)
{
poly tmp(23);
tmp[0] = i, tmp[1] = 1;
p[i] = p[i - 1] * tmp;
}
else
p[i] = p[i - 1];
}
}
poly trans(poly a, ll k)
{
pw[0] = 1;
for (int i = 1; i < 23; ++i)
pw[i] = mul(pw[i - 1], k);
poly b(23);
for (int i = 0; i < 23; ++i)
for (int j = 0; j <= i; ++j)
inc(b[j], mul(a[i], mul(pw[i - j], binom[i][j])));
return b;
}
poly facpoly(ll n)
{
if (n < 10000)
return p[n];
ll k = n / 10 * 10;
poly t1 = facpoly(k >> 1);
poly t2 = trans(t1, k >> 1);
t1 = t1 * t2;
for (ll i = k + 1; i <= n; ++i)
if (i % 5)
{
poly tmp(23);
tmp[0] = i % P, tmp[1] = 1;
t1 = t1 * tmp;
}
return t1;
}
pll Fac(ll n)
{
pll res = make_pair(facpoly(n)[0], n / 5);
if (n >= 5)
{
pll tmp = Fac(n / 5);
res.first = mul(res.first, tmp.first);
res.second += tmp.second;
}
return res;
}
ll Binom(ll n, ll k)
{
if (n < k)
return 0;
pll f1 = Fac(n), f2 = Fac(k), f3 = Fac(n - k);
f1.second -= f2.second + f3.second;
f1.first = mul(f1.first, mul(Inv(f2.first), Inv(f3.first)));
if (f1.second >= 23)
return 0;
return mul(f1.first, fpow(5, f1.second));
}
const int N = 60 + 10, M = 1e4 + 10, inf = 1e9;
int n, m, k, ecnt = 0, nx[N << 1], to[N << 1], head[N];
ll Max;
void addedge(int u, int v)
{
++ecnt;
to[ecnt] = v;
nx[ecnt] = head[u];
head[u] = ecnt;
}
int w[N], val[N], dist[N][N], valid[N], rnk[N], siz[N], _fa[N], fa[N], idx = 0;
void dfs(int u, int F)
{
_fa[u] = F, siz[u] = 1, rnk[++idx] = u;
for (int i = head[u]; i; i = nx[i])
{
int v = to[i];
if (v == F || !valid[v])
continue;
dfs(v, u);
siz[u] += siz[v];
}
}
ll f[N][M], g[N][M], mx = 0;
pll calc(int x, int y)
{
idx = 0;
dfs(x, 0);
for (int i = 0; i <= m; ++i)
{
f[idx + 1][i] = 0;
g[idx + 1][i] = 1;
}
for (int i = idx; i >= 1; --i)
{
int u = rnk[i];
for (int j = 0; j <= m; ++j)
{
ll v1 = j >= w[u] ? f[i + 1][j - w[u]] + val[u] : 0;
ll v2 = f[i + siz[u]][j];
if (u == y && j < w[u])
f[i][j] = g[i][j] = 0;
else if (u == y || (j >= w[u] && v1 > v2))
{
f[i][j] = f[i + 1][j - w[u]] + val[u];
g[i][j] = g[i + 1][j - w[u]];
}
else if (j < w[u] || v1 < v2)
{
f[i][j] = f[i + siz[u]][j];
g[i][j] = g[i + siz[u]][j];
}
else
{
f[i][j] = f[i + siz[u]][j];
g[i][j] = g[i + siz[u]][j] + g[i + 1][j - w[u]];
}
}
}
return make_pair(f[1][m], g[1][m]);
}
ll solve(int x, int y)
{
for (int i = 1; i <= n; ++i)
if (1LL * dist[x][i] * val[i] > Max || (y && 1LL * dist[y][i] * val[i] > Max))
valid[i] = 0;
else
valid[i] = 1;
if (!valid[x] || (y && !valid[y]))
return 0;
pll tmp = calc(x, y);
if (tmp.first < mx)
return 0;
return Binom(tmp.second, k);
}
int main()
{
n = read(), m = read(), k = read(), Max = read();
init();
for (int i = 1; i <= n; ++i)
w[i] = read();
for (int i = 1; i <= n; ++i)
val[i] = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
dist[i][j] = (i == j) ? 0 : inf;
for (int i = 1; i < n; ++i)
{
int u = read(), v = read(), l = read();
addedge(u, v), addedge(v, u);
dist[u][v] = dist[v][u] = l;
}
for (int k = 1; k <= n; ++k)
for (int i = 1; i <= n; ++i) if (dist[i][k] < inf)
for (int j = 1; j <= n; ++j) if (dist[k][j] < inf)
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]);
for (int i = 1; i <= n; ++i)
valid[i] = 1;
for (int i = 1; i <= n; ++i)
mx = max(mx, calc(i, 0).first);
for (int i = 1; i <= n; ++i)
fa[i] = _fa[i];
ll ans = 0;
for (int i = 1; i <= n; ++i)
inc(ans, add(solve(i, 0), P - (fa[i] ? solve(i, fa[i]) : 0)));
printf("%lld\n", ans);
return 0;
}