Loj 2304 泳池

$dp$ + 常系数线性递推.

  • 面积恰好 $=K$ 的概率不太好求,考虑求出面积 $\le K$ 与面积 $\le K-1$ 的概率,两者相减即为答案.
  • 设 $f(i,j)$ 表示矩形长为 $i$ ,最下面 $j$ 行都安全,而第 $j+1$ 行至少一个位置危险,最大面积不超过 $K$ 的概率.
  • 记 $g(i,j)$ 表示矩形长为 $i$ ,最下面 $j$ 行都安全,最大面积不超过 $K$ 的概率.则 $g(i,j)=\sum_{p\ge j} f(i,p)$ .
  • 边界为 $g(0,j)=f(0,j)=1,g(i,j)=f(i,j)=0\ (i\cdot j>K)$ .我们需要求得 $g(n,0)$ .
  • 枚举第 $j+1$ 行第一个危险的格子在 $r+1$ 列.那么要求前 $r$ 列 $j+1$ 行都安全, $r+2\sim i$ 列前 $j$ 行安全,第 $r+1$ 列前 $j$ 行安全, 第 $r+1$ 列第 $j+1$ 行危险,则转移有,

$$
f(i,j)=\sum_{r=0}^{i-1} g(r,j+1)\cdot g(i-r-1,j) \cdot q^j\cdot(1-q)
$$

$4$ 个限制依次对应了转移方程中的 $4$ 项.

  • 大力 $dp$ ,时间复杂度为 $O(n^2)$ .
  • 考虑如何优化.注意到当 $n>K$ 时,仅有 $f(i,0)$ 与 $g(i,0)$ 这些项不为 $0$ ,而我们要求的是 $g(n,0)$ .
  • 所以只用考虑它们的转移.将 $j=0$ 代入原来的转移方程,可以发现,

$$
f(i,0)=g(i,0)=\sum_{r=0}^K g(r,1)\cdot g(i-r-1,0)\cdot (1-q)
$$

  • $g(r,1)$ 最多只有前 $K+1$ 项非 $0$ ,这部分可以通过大力 $dp$ 求出.那么 $g(r,1)\cdot (1-q)$ 就可看做常系数.
  • 求 $g(i,0)$ 就是一个常系数线性递推,递推式的长度为 $K$ .使用 $O(K^2\cdot logn)$ 的大力取模做法即可.
  • 总时间复杂度 $O(K^2\cdot logn)$ .
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int out=0,fh=1;
char jp=getchar();
while ((jp>'9'||jp<'0')&&jp!='-')
jp=getchar();
if (jp=='-')
fh=-1,jp=getchar();
while (jp>='0'&&jp<='9')
out=out*10+jp-'0',jp=getchar();
return out*fh;
}
const int P=998244353;
inline int add(int a,int b)
{
return (a + b) % P;
}
inline int mul(int a,int b)
{
return 1LL * a * b % P;
}
int fpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1)
res=mul(res,a);
a=mul(a,a);
b>>=1;
}
return res;
}
int inv(int x)
{
return fpow(x,P-2);
}
const int MAXN=1e3+10;
int base[MAXN],ans[MAXN],tmp[MAXN<<1];
void Mul(int *a,int *b,int *f,int k)
{
memset(tmp,0,k<<3);
for(int i=0;i<k;++i)
for(int j=0;j<k;++j)
tmp[i+j]=add(tmp[i+j],mul(a[i],b[j]));
for(int i=2*k-2;i>=k;--i)
for(int j=0;j<k;++j)
tmp[i-j-1]=add(tmp[i-j-1],mul(tmp[i],f[j]));
memcpy(a,tmp,k<<2);
}
int solve(int *a,int *f,int n,int k)
{
memset(base,0,sizeof base);
memset(ans,0,sizeof ans);
base[1]=ans[0]=1;
while(n)
{
if(n&1)
Mul(ans,base,f,k);
Mul(base,base,f,k);
n>>=1;
}
int res=0;
for(int i=0;i<k;++i)
res=add(res,mul(a[i],ans[i]));
return res;
}
int n,q,pw[MAXN];
int f[MAXN][MAXN],g[MAXN][MAXN];
int a[MAXN],F[MAXN];
int Solve(int k)
{
memset(f,0,sizeof f);
memset(g,0,sizeof g);
memset(a,0,sizeof a);
memset(F,0,sizeof F);
for(int j=0;j<=k+1;++j)
g[0][j]=f[0][j]=1;
for(int i=1;i<=1000;++i)
{
for(int j=k/i;j>=0;--j)
{
int L=j?(i-1-k/j):0;
L=max(L,0);
int R=min(i-1,k/(j+1));
for(int r=L;r<=R;++r)
{
int t=mul(pw[j],add(1,P-q));
t=mul(t,mul(g[r][j+1],g[i-1-r][j]));
f[i][j]=add(f[i][j],t);
}
g[i][j]=add(g[i][j+1],f[i][j]);
}
}
if(n<=1000)
return g[n][0];
for(int i=0;i<=k;++i)
a[i]=g[i][0];
for(int r=0;r<=k;++r)
F[r]=mul(g[r][1],add(1,P-q));
return solve(a,F,n,k+1);
}
int main()
{
n=read();
int k=read();
int x=read(),y=read();
q=mul(x,inv(y));
pw[0]=1;
for(int i=1;i<=k;++i)
pw[i]=mul(pw[i-1],q);
cout<<add(Solve(k),P-Solve(k-1))<<endl;
return 0;
}