bzoj 5416 冒泡排序

观察结论 + $dp$ 计数.

通过提示,可以发现一个排列合法,当且仅当它的最长下降子序列长度 $<3$ .

否则,若存在 $\ge 3$ 的下降子序列,中间的元素需要先与左边的交换,再与右边的交换,就存在了冗余的步数,不合法.

而根据 $Dirworth$ 定理,该条件又等价于这个排列能被 $2$ 个上升子序列覆盖.

从前往后依次填入所有数,设 $f(i,j)$ 表示还有 $i$ 个数没填,且这 $i$ 个数中有 $j$ 个数大于已经填入的最大值的方案数.

转移时,枚举填的数是这 $j$ 个数中的第 $k$ 小,若 $k=0$ ,则表示填入的数不在这 $j$ 个数中.
$$
f(i,j)=\sum_{k=0}^j f(i-1,j-k)
$$
容易发现它是个前缀和,
$$
f(i,j)=f(i,j-1)+f(i-1,j)
$$
而 $f(i,i)=0$ ,即不合法,那么可以看出 $f(i,j)$ 的组合意义.

它表示从 $(0,0)$ 出发,每次可以向右方或上方走一步,在中途不触碰直线 $y=x$ ,到达 $(i,j)$ 的方案数.

于是预处理阶乘及其逆元后,可以 $O(1)$ 求 $f(i,j)$ .

还有一个限制是字典序必须大于给出的排列 $p$ ,就像数位 $dp$ 那样做就可以了,时间复杂度 $O(n\log n)$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int out=0,fh=1;
char jp=getchar();
while ((jp>'9'||jp<'0')&&jp!='-')
jp=getchar();
if (jp=='-')
fh=-1,jp=getchar();
while (jp>='0'&&jp<='9')
out=out*10+jp-'0',jp=getchar();
return out*fh;
}
const int P=998244353;
int add(int a,int b)
{
return (a+b>=P)?(a+b-P):(a+b);
}
int mul(int a,int b)
{
return 1LL * a * b % P;
}
int fpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1)
res=mul(res,a);
a=mul(a,a);
b>>=1;
}
return res;
}
const int MAXN=6e5+10;
int fac[MAXN<<1],invfac[MAXN<<1];
void init(int N)
{
fac[0]=1;
for(int i=1;i<=N;++i)
fac[i]=mul(fac[i-1],i);
invfac[N]=fpow(fac[N],P-2);
for(int i=N-1;i>=0;--i)
invfac[i]=mul(invfac[i+1],i+1);
}
int C(int M,int N)
{
if(M<0 || N<0 || M<N)
return 0;
return mul(fac[M],mul(invfac[N],invfac[M-N]));
}
int f(int i,int j)
{
return add(C(i+j-1,j),P-C(i+j-1,j-2));
}
int n,a[MAXN],val[MAXN],pre[MAXN],suf[MAXN];
struct FenwickTree
{
int bit[MAXN];
#define lowbit(x) x&(-x)
void init()
{
memset(bit,0,sizeof bit);
}
void add(int x,int c)
{
for(;x<=n;x+=lowbit(x))
bit[x]+=c;
}
int query(int x)
{
int s=0;
for(;x;x-=lowbit(x))
s+=bit[x];
return s;
}
}T;
void solve()
{
n=read();
for(int i=1;i<=n;++i)
a[i]=read();
T.init();
for(int i=n;i>=1;--i)
{
suf[i]=n-i-T.query(a[i]);
T.add(a[i],1);
pre[i]=i-1-(n-a[i]-suf[i]);
}
int cnt=n,ans=0;
for(int i=1;i<=n;++i)
{
if(!suf[i])
break;
bool flag=(suf[i]<cnt);
cnt=min(cnt,suf[i]);
ans=add(ans,f(n-i+1,cnt-1));
if(!flag && pre[i]!=a[i]-1)
break;
}
cout<<ans<<endl;
}
int main()
{
init(MAXN-10<<1);
int T=read();
while(T--)
solve();
return 0;
}