Hdu 6593 Coefficient

多项式多点求值.

给定 $a,b,c,d,n$ ,有一个函数
$$
f(x)=\frac{b}{c+\exp(ax+d)}
$$
记 $x_0$ 为 $ax+d=0$ 的唯一实根,即 $-\frac{d}{a}$ ,将 $f(x)$ 在 $x=x_0$ 处做泰勒展开,求 $(x-x_0)^n$ 这一项的系数.

每组数据的 $n$ 是一样的,有 $q$ 次询问,每次询问给出 $a,b,c,d$ ,需要求出对应的答案,系数对 $998244353$ 取模.
$0\le n,q\le 3\times 10^5,a\neq 0$ .

一个靠谱做法

显然 $a,b,d$ 都没什么大用, $a,b$ 可以先忽略掉,最后乘上 $a^n\cdot b$ ,而 $d$ 只有平移的作用,可以直接忽略.

于是转化为求
$$
ans=[x^n]\frac{1}{c+\exp(x)} \\
$$
尝试把分母构造成 $1-q\cdot (\exp(x)-1)$ 的形式,对比系数不难发现 $q=-\frac 1{c+1}$ .
$$
\begin{aligned}
ans&=\frac{-1}{q}[x^n]\frac{1}{1-q\cdot (\exp(x)-1)} \\
&=\frac{-1}{q}[x^n] \sum_{k=0}^n q^k(\exp(x)-1)^k \\
\end{aligned}
$$
这里就能做了, $(\exp(x)-1)^k$ 再除个 $k!$ 就是第二类斯特林数一列的 EGF, 问题转化为求一行的第二类斯特林数.

如果继续往下推也不是很复杂,
$$
\begin{aligned}
ans&=\frac{-1}{q}[x^n] \sum_{k\ge 0}q^k\sum_{i=0}^k \binom k i\cdot \exp(ix)\cdot (-1)^{k-i}\\
&=\frac{-1}{q}[x^n] \sum_{i+j\le n}q^{i+j}\cdot\binom{i+j}{i}\cdot\exp(ix)\cdot(-1)^j\\
&=\frac{-1}{q}[x^n] \sum_{i+j\le n}q^{i+j}\cdot\binom{i+j}{i}\cdot \frac{i^n}{n!}\cdot(-1)^j
\end{aligned}
$$

EI : 处理倒数上有些奇怪东西的时候,就得强行加一再减一.

这里可以看出,强行凑出 $\exp(x)-1$ 的形式后,它的常数项是 $0$ ,枚举它的指数时就只用枚举到 $n$ 了.

不难发现这是个关于 $q$ 的 $n$ 次多项式,做一次卷积求出它,然后做多点求值即可.

特判掉 $c=-1$ 和 $n=0$ 的情况.

一个玄学做法

由于某种原因,答案为
$$
\frac{a^n\cdot b}{n!(c+1)^{n+1}}\cdot \sum_{i=0}^{n-1}(-1)^{n+i}\cdot E(n,i)\cdot c^i
$$
$E(n,k)$ 表示长度为 $n$ 的排列 $p$ ,恰有 $k$ 个位置满足 $p_i<p_{i+1}$ ,即恰有 $k+1$ 个极长单调下降的连续段方案数.

考虑容斥,钦定这 $k+1$ 段有 $i$ 个拼接的位置不合法,注意段可以为空,开头结尾也要算上,贡献是 $\binom{n+1}{i}$ .

然后把 $n$ 个数分到剩下的 $k+1-i$ 段中,段内会自动排成单调下降,且根据每一段的长度,分割的位置也会唯一确定.
$$
E(n,k)=\sum_{i=0}^k (-1)^i\binom{n+1}{i}(k+1-i)^n
$$
做一次卷积即可求出一行所有的 $E(n,k)$ ,得到 $\sum$ 这一坨关于 $c$ 的 $n-1$ 多项式,然后做多点求值即可.

特判掉 $c=-1$ 和 $n=0​$ 的情况.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define y1 ysgh
const int P = 998244353;
inline int read()
{
int out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh < 0 ? out * fh % P + P : out * fh % P;
}
void print(int x)
{
if (x >= 10)
print(x / 10);
putchar('0' + x % 10);
}
void write(int x, char c)
{
if (x < 0)
putchar('-'), x = -x;
print(x);
putchar(c);
}
int add(int a, int b)
{
return a + b >= P ? a + b - P : a + b;
}
void inc(int &a, int b)
{
a = add(a, b);
}
int mul(int a, int b)
{
return 1LL * a * b % P;
}
int fpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1)
res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
const int N = 1 << 20 | 10;
const int M = 3e5 + 10;
namespace Polynomial
{
int curn = 0, rev[N], omega[N], inv[N], invn;
void init(int n)
{
if (n == curn)
return;
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (n >> 1));
for (int l = 2; l <= n; l <<= 1)
{
omega[l] = fpow(3, (P - 1) / l);
inv[l] = fpow(omega[l], P - 2);
}
invn = fpow(n, P - 2);
curn = n;
}
void DFT(int *a, int n, bool invflag)
{
init(n);
for (int i = 0; i < n; ++i)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int l = 2; l <= n; l <<= 1)
{
int gi = omega[l], m = l >> 1;
if (invflag)
gi = inv[l];
for (int *p = a; p != a + n; p += l)
{
int g = 1;
for (int i = 0; i < m; ++i)
{
int t = mul(g, p[i + m]);
p[i + m] = add(p[i], P - t);
p[i] = add(p[i], t);
g = mul(g, gi);
}
}
}
if (invflag)
{
for (int i = 0; i < n; ++i)
a[i] = mul(a[i], invn);
}
}
void NTT(int *A, int *B, int *C, int lenA, int lenB)
{
int lenC = lenA + lenB - 1, n = 1;
while (n < lenC)
n <<= 1;
static int a[N], b[N];
copy(A, A + lenA, a), fill(a + lenA, a + n, 0);
copy(B, B + lenB, b), fill(b + lenB, b + n, 0);
DFT(a, n, false);
DFT(b, n, false);
for (int i = 0; i < n; ++i)
C[i] = mul(a[i], b[i]);
DFT(C, n, true);
}
typedef vector<int> poly;
void debug(poly A)
{
for (int i = 0; i < A.size(); ++i)
write(A[i], ' ');
puts("");
}
poly operator * (const poly &A, const poly &B)
{
int len = A.size() + B.size() - 1, n = 1;
while (n < len)
n <<= 1;
poly C(len);
static int a[N], b[N];
memcpy(a, &A[0], A.size() * 4), memset(a + A.size(), 0, (n - A.size()) * 4);
memcpy(b, &B[0], B.size() * 4), memset(b + B.size(), 0, (n - B.size()) * 4);
DFT(a, n, false);
DFT(b, n, false);
for (int i = 0; i < n; ++i)
a[i] = mul(a[i], b[i]);
DFT(a, n, true);
memcpy(&C[0], a, len * 4);
return C;
}
poly MULT(poly A, const poly &B)
{
reverse(A.begin(), A.end());
int lenA = A.size(), n = 1, lenB = min(lenA, (int)B.size());
while (n < lenA + lenB - 1)
n <<= 1;
poly C(lenA);
static int a[N], b[N];
memcpy(a, &A[0], lenA * 4), memset(a + lenA, 0, (n - lenA) * 4);
memcpy(b, &B[0], lenB * 4), memset(b + lenB, 0, (n - lenB) * 4);
DFT(a, n, false);
DFT(b, n, false);
for (int i = 0; i < n; ++i)
a[i] = mul(a[i], b[i]);
DFT(a, n, true);
memcpy(&C[0], a, lenA * 4);
reverse(C.begin(), C.end());
return C;
}
poly inverse(const poly &A)
{
int len = A.size(), n = 1;
while (n < len)
n <<= 1;
static int a[N], res[N], tmp[N];
memcpy(a, &A[0], len * 4), memset(a + len, 0, (n - len) * 4);
res[0] = fpow(a[0], P - 2);
for (int i = 2; i <= n; i <<= 1)
{
NTT(a, res, tmp, i, i);
NTT(tmp, res, tmp, i, i);
for (int j = 0; j < i; ++j)
res[j] = add(mul(2, res[j]), P - tmp[j]);
}
poly B(len);
memcpy(&B[0], res, len * 4);
return B;
}
int pos[M], idx = 1, ls[M * 2], rs[M * 2], ans[M], id[M];
poly prod[M << 1], G[M << 1];
void BuildTree(int x, int l, int r)
{
if (l == r)
{
prod[x].resize(2);
prod[x][0] = 1, prod[x][1] = add(0, P - pos[l]);
return;
}
int mid = (l + r) >> 1;
BuildTree(ls[x] = ++idx, l, mid);
BuildTree(rs[x] = ++idx, mid + 1, r);
prod[x] = prod[ls[x]] * prod[rs[x]];
}
void dfs(int x, int l, int r)
{
G[x].resize(r - l + 1);
if (l == r)
{
ans[id[l]] = G[x][0];
return;
}
int mid = (l + r) >> 1;
G[ls[x]] = MULT(G[x], prod[rs[x]]);
dfs(ls[x], l, mid);
G[rs[x]] = MULT(G[x], prod[ls[x]]);
dfs(rs[x], mid + 1, r);
}
}
using namespace Polynomial;
int n, m, a[M], b[M], c[M], d[M], fac[M], invfac[M];
int binom(int x, int y)
{
if (x < 0 || y < 0 || x < y)
return 0;
return mul(fac[x], mul(invfac[y], invfac[x - y]));
}
void solve()
{
if (!m)
return;
if (n == 0)
{
while (m--)
{
int A = read(), B = read(), C = read(), D = read();
if (C == P - 1)
puts("-1");
else
write(mul(B, fpow(C + 1, P - 2)), '\n');
}
return;
}
poly F(n + 1), g(n + 1);
for (int i = 0; i <= n; ++i)
{
F[i] = binom(n + 1, i);
if (i & 1)
F[i] = add(0, P - F[i]);
g[i] = fpow(i, n);
}
F = F * g;
for (int i = 0; i < n; ++i)
{
F[i] = F[i + 1];
if ((i + n) & 1)
F[i] = add(0, P - F[i]);
}
F.resize(n);
int tot = 0;
for (int i = 1; i <= m; ++i)
{
int A = read(), B = read(), C = read(), D = read();
if (C == P - 1)
ans[i] = -1;
else
{
ans[i] = 0;
++tot;
a[tot] = A, b[tot] = B, c[tot] = pos[tot] = C, d[tot] = D;
id[tot] = i;
}
}
if (tot)
{
idx = 1;
BuildTree(1, 1, tot);
poly Inv = prod[1];
Inv.resize(n), Inv = inverse(Inv);
G[1] = MULT(F, Inv);
dfs(1, 1, tot);
}
for (int i = 1; i <= m; ++i)
{
if (ans[i] != -1)
{
ans[i] = mul(ans[i], mul(fpow(a[i], n), b[i]));
ans[i] = mul(ans[i], invfac[n]);
ans[i] = mul(ans[i], fpow(fpow(c[i] + 1, n + 1), P - 2));
}
write(ans[i], '\n');
}
}
int main()
{
fac[0] = 1;
for (int i = 1; i <= M - 1; ++i)
fac[i] = mul(fac[i - 1], i);
invfac[M - 1] = fpow(fac[M - 1], P - 2);
for (int i = M - 2; i >= 0; --i)
invfac[i] = mul(invfac[i + 1], i + 1);
while (~scanf("%d%d", &n, &m))
solve();
return 0;
}