Luogu U124868 第〇类循环数·行

组合计数 + 循环卷积.

多测不清空,爆零两行泪.

从递推式入手,简单整理合并同类项可得
$$
f(i,j)=f(i-1,j-1)\cdot \frac{i+1}{i}\cdot \frac{j+1}{j}+f(i-1,j)\cdot \frac{i+1}{i}
$$
把超出边界的值都看成 $0​$ ,可以发现第 $0​$ 列仍然不满足递推式,算一下差的值,不难发现在最左边再补一列 $1,2,3,4,\dots​$ 就行了,其中第 $j​$ 列的值为 $j+1​$ .可以看成从左上角出发,每次向下或向右下走一步.

其中第 $i-1$ 行到第 $i$ 行的贡献是 $\frac{i+1}i$ ,从第 $j-1$ 列到第 $j$ 列的贡献是 $\frac{j+1}{j}$ .

把所有贡献和行走的方案数乘在一起,可以算出 $f(i,j)=\binom{i}{j+1}\cdot (i+1)\cdot (j+1)$ .

要求的是
$$
ans=\sum_{1\le a\le A}\sum_{1\le b\le B}\sum_{1\le c\le C}f(k,(ax^2+bx+c)\bmod k)
$$
尝试对每个 $i​$ 算出有几组 $a,b,c​$ 满足 $(ax^2+bx+c)\bmod k=i​$ .

预处理
$$
f_a(i)=\sum_{1\le a\le A}[ax^2\bmod k=i],\\
f_b(i)=\sum_{1\le b\le B}[bx\bmod k=i],\\
f_c(i)=\sum_{1\le c\le C}[c\bmod k=i].
$$

考虑如何求 $f_a(i)$ ,记 $v=x^2\bmod k$ ,由于 $k$ 是 $2$ 的幂,特判 $v=0$ ,否则将 $v$ 分解成 $v=2^p\cdot t$ ,其中 $t$ 为奇数.

显然每 $\frac k {2^p}$ 个 $a$ 就会将 $0,2^p,2\times 2^p,\dots,(\frac{k}{2^p}-1)\times 2^p$ 遍历,算一下完整循环节个数,多余部分 $O(k)$ 暴力.

$f_b,f_c$ 的预处理同理,都可以在 $O(k)$ 内完成.

然后将 $f_a,f_b,f_c$ 做长度为 $k$ 的循环卷积就能求出每个 $i$ 有几组 $a,b,c$ 满足 $(ax^2+bx+c)\bmod k=i$ 了.

保证了 $k$ 为 $2$ 的幂,且对于模数有 $2\times 10^6\le 2^{21}$ ,只用做 $3$ 次长度为 $k$ 的 DFT ,一次长度为 $k$ 的 IDFT 即可.

时间复杂度 $O(T\cdot k\log k)$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define y1 ysgh
inline ll read()
{
ll out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-') jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9') out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
void print(int x)
{
if (x >= 10)
print(x / 10);
putchar('0' + x % 10);
}
void write(int x, char c)
{
if (x < 0)
putchar('-'), x = -x;
print(x);
putchar(c);
}
namespace Module
{
const int P = 1004535809;
int add(int a, int b)
{
return a + b >= P ? a + b - P : a + b;
}
void inc(int &a, int b)
{
a = add(a, b);
}
void dec(int &a, int b)
{
a = add(a, P - b);
}
int mul(int a, int b)
{
return 1LL * a * b % P;
}
int fpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
}
using namespace Module;
// f(x, y) = binom(x, y + 1) * (y + 1) * (x + 1) = x * (x + 1) * binom(x - 1, y)
const int N = 1 << 21 | 10;
int fac[N], invfac[N];
int binom(int x, int y)
{
return mul(fac[x], mul(invfac[y], invfac[x - y]));
}
int omega[N], inv_omega[N], rev[N];
void init(int n)
{
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (n >> 1));
for (int l = 2; l <= n; l <<= 1)
{
omega[l] = fpow(3, (P - 1) / l);
inv_omega[l] = fpow(omega[l], P - 2);
}
}
void DFT(int *a, int n, bool invflag)
{
for (int i = 0; i < n; ++i)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int l = 2; l <= n; l <<= 1)
{
int m = l >> 1, gi = invflag ? inv_omega[l] : omega[l];
for (int *p = a; p != a + n; p += l)
{
int g = 1;
for (int i = 0; i < m; ++i)
{
int t = mul(g, p[i + m]);
p[i + m] = add(p[i], P - t);
p[i] = add(p[i], t);
g = mul(g, gi);
}
}
}
if (invflag)
{
int invn = fpow(n, P - 2);
for (int i = 0; i < n; ++i)
a[i] = mul(a[i], invn);
}
}

int a[N], b[N], c[N], _a[N], _b[N];
void solve()
{
int A = read(), B = read(), C = read(), k = read(), x = read() % k, v;
for (int i = 0; i < k; ++i)
a[i] = b[i] = c[i] = 0;
v = 1LL * x * x % k;
if (v)
{
int mod = k, t = v;
while (t % 2 == 0)
mod >>= 1, t >>= 1;
for (int i = 0; i < mod; ++i)
a[1LL * i * (k / mod) % k] = A / mod;
for (int i = 1; i <= A % mod; ++i)
a[1LL * i * v % k]++;
}
else
a[0] = A;
v = x;
if (v)
{
int mod = k, t = v;
while (t % 2 == 0)
mod >>= 1, t >>= 1;
for (int i = 0; i < mod; ++i)
b[1LL * i * (k / mod) % k] = B / mod;
for (int i = 1; i <= B % mod; ++i)
b[1LL * i * v % k]++;
}
else
b[0] = B;
for (int i = 0; i < k; ++i)
c[i] = C / k;
for (int i = 1; i <= C % k; ++i)
inc(c[i], 1);
init(k);
DFT(a, k, false), DFT(b, k, false), DFT(c, k, false);
for (int i = 0; i < k; ++i)
a[i] = mul(a[i], mul(b[i], c[i]));
DFT(a, k, true);
int ans = 0;
for (int i = 0; i < k; ++i)
inc(ans, mul(a[i], binom(k - 1, i)));
ans = mul(ans, mul(k, k + 1));
ans = mul(ans, fpow(mul(A, mul(B, C)), P - 2));
cout << ans << '\n';
}
int main()
{
int mx = 2000000;
fac[0] = 1;
for (int i = 1; i <= mx; ++i)
fac[i] = mul(fac[i - 1], i);
invfac[mx] = fpow(fac[mx], P - 2);
for (int i = mx - 1; i >= 0; --i)
invfac[i] = mul(invfac[i + 1], i + 1);
int T = read();
while (T--)
solve();
return 0;
}