类欧几里得算法学习笔记

发现自己完全不会类欧几里得算法,于是来学一学.

问题最基本的形式是求
$$
\sum_{x=1}^n \lfloor\frac{ax+b}{c}\rfloor
$$
其中 $a,b,c$ 为给定的常数, $0\le a,b,c,n\le 10^9, c\neq 0$ .

当 $a \ge c$ 时,可以把 $a$ 写成 $a=kc+a\bmod c$ 的形式,答案就变成了
$$
\sum_{x=1}^n \lfloor\frac{(a\bmod c)\cdot x+b}{c}\rfloor + \sum_{x=1}^n kx
$$
后面那个 $\sum$ 是个等差数列求和,可以 $O(1)$ 求出.

对于前面那个 $\sum$ ,它的形式和原问题是一致的,我们可以递归调用该算法求解.

当 $a<c$ 时,若 $a=0$ ,则答案为 $x\cdot \lfloor\frac b c\rfloor$ ,否则,我们通过画图来观察性质:

答案应该是 $x=0,x=n,y=0,y=\frac{a}{c}x+\frac{b}{c}$ 这四条直线围成的直角梯形内 $x>0,y>0$ 的整点的数量.

我们过交点 $(0,\frac{b}{c})$ 向 $x=n$ 引一条垂线,将图形分为两部分.

黄色部分是一个矩形,贡献可以 $O(1)$ 求出.

对于蓝色部分的直角三角形,我们可以通过翻转与平移将它变成一个子问题,递归下去求解即可.

这样递归一次会交换 $a,c$ .

考虑每递归一次,若 $a\ge c$ ,则 $a$ 变为 $a\bmod c$ , $c$ 不变.否则,就交换 $a,c$ ,递归边界是 $a=0$ .

可以发现这就类似于欧几里得算法求 $\gcd$ 的递归过程,可以得出时间复杂度 $O(\log \max(a,c))$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int read()
{
int out = 0, sgn = 1;
char jp = getchar();
while (jp != '-' && (jp < '0' || jp > '9'))
jp = getchar();
if (jp == '-')
sgn = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * sgn;
}
const int P = 998244353;
int add(int a, int b)
{
return a + b >= P ? a + b - P : a + b;
}
void inc(int &a, int b)
{
a = add(a, b);
}
int mul(int a, int b)
{
return 1LL * a * b % P;
}
// ans = \sum_{x=1}^{n} \lfloor \frac{ax+b}{c} \rfloor
int s1(int n)
{
return 1LL * n * (n + 1) / 2 % P;
}
int solve(int a, int b, int c, ll n)
{
if (a == 0)
return mul(n, b / c);
int ans = 0;
if (a >= c || b >= c)
{
ans = solve(a % c, b % c, c, n);
inc(ans, mul(a / c, s1(n % P)));
inc(ans, mul(b / c, (n + 1) % P));
}
else
{
ll m = (n * a + b) / c;
ans = mul(n % P, m % P);
inc(ans, P - solve(c, c - b - 1, a, m - 1));
}
return ans;
}
int bf(int a, int b, int c, int n)
{
int ans = 0;
for (int x = 1; x <= n; ++x)
inc(ans, (1LL * a * x + b) / c % P);
return ans;
}
int main()
{
int T = read();
while (T--)
{
int n = read(), a = read(), b = read(), c = read();
int ans = solve(a, b, c, n);
inc(ans, P - b / c);
printf("%d\n", ans);
// printf("%d\n", bf(a, b, c, n));
}
return 0;
}