Loj 3267 Help Yourself

第二类斯特林数 + 线段树优化 dp 转移.

USACO 怎么也有如此套路的题…

可以先将所有线段按照左端点从小到大排序,方便后续处理.

记一个非空集合 $T$ 的线段形成的连通块数目为 $c(T)$ ,则答案为 $\sum_T c(T)^k$ .

朴素的 dp

设 $dp(i,x,p)$ 表示考虑了前 $i$ 条线段,选了的线段右端点最大值为 $x$ ,有 $p$ 个连通块的方案数.

这样做状态数为 $O(n^3)$ ,不太可行.

幂次展开为组合数

注意到 $k​$ 比较小,尝试将其展开为组合数的形式,
$$
\begin{aligned}
ans&=\sum_T c(T)^k \\
&=\sum_T\sum_{i=0}^k {k\brace i}\cdot i!\cdot \binom{c(T)}{i}\\
&=\sum_{i=0}^k{k\brace i}\cdot i!\cdot \sum_{T}\binom{c(T)}{i}
\end{aligned}
$$

当题目中要算 $c^k$ 的贡献, $c$ 大而 $k$ 小,尤其是 $c$ 是某种东西的数目时,可以尝试展开成组合数.

原来需要记录 $c$ 的大小,最后算贡献,现在就将贡献摊在每个 $c$ 增大的时候计算,只用记录 $k$ 的大小.

于是我们在状态中就不需要记录连通块数目,而是在新产生一个连通块时,考虑它是否被选.

设 $dp(i,x,p)$ 表示考虑了前 $i$ 条线段,选了的线段右端点最大值为 $x$ ,产生的连通块被选定了 $p$ 个的方案数.

状态数从 $O(n^3)$ 降到了 $O(n^2k)$ ,还需进一步优化.

线段树优化转移

考虑转移的形式,假定当前在考虑第 $i$ 条线段,其覆盖的区间为 $[l,r]$ .

若不选这条线段,则每个 $dp(i-1,x,p)$ 转移到 $dp(i,x,p)$ .

若选了这条线段,则根据 $x​$ 的大小分情况讨论.

若 $x<l​$ ,此时会产生新的一个连通块, $dp(i-1,x,p)​$ 可以转移到 $dp(i,r,p)​$ 和 $dp(i,r,p+1)​$ .

若 $l\le x\le r​$ , $dp(i-1,x,p)​$ 可以转移到 $dp(i,r,p)​$ .

若 $x>r$ , $dp(i-1,x,p)$ 可以转移到 $dp(i,x,p)$ .

不难发现,我们可以开 $k+1$ 棵线段树来维护这个 dp 数组,第 $p$ 棵线段树维护了当前所有的 $dp(i,x,p)$ .

时间复杂度 $O(nk\log n)$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
void print(int x)
{
if (x >= 10)
print(x / 10);
putchar('0' + x % 10);
}
void write(int x, char c)
{
if (x < 0)
putchar('-'), x = -x;
print(x);
putchar(c);
}
const int P = 1e9 + 7, inv2 = (P + 1) >> 1;
int add(int a, int b)
{
return a + b >= P ? a + b - P : a + b;
}
void inc(int &a, int b)
{
a = add(a, b);
}
int mul(int a, int b)
{
return 1LL * a * b % P;
}
const int N = 2e5 + 10, K = 10 + 1;
struct node
{
int tag, sum;
node(){tag = 1, sum = 0;}
};
struct Segtree
{
node Tree[N << 2];
#define root Tree[x]
void pushup(int x)
{
root.sum = add(Tree[x << 1].sum, Tree[x << 1 | 1].sum);
}
void modify(int x, int c)
{
root.sum = mul(root.sum, c);
root.tag = mul(root.tag, c);
}
void pushdown(int x)
{
if (root.tag != 1)
{
modify(x << 1, root.tag);
modify(x << 1 | 1, root.tag);
root.tag = 1;
}
}
void Mul(int x, int l, int r, int L, int R, int c)
{
if (L > R || L > r || R < l)
return;
if (L <= l && r <= R)
return modify(x, c);
int mid = (l + r) >> 1;
pushdown(x);
if (L <= mid)
Mul(x << 1, l, mid, L, R, c);
if (R > mid)
Mul(x << 1 | 1, mid + 1, r, L, R, c);
pushup(x);
}
void Add(int x, int l, int r, int pos, int c)
{
if (l == r)
{
inc(root.sum, c);
return;
}
int mid = (l + r) >> 1;
pushdown(x);
if (pos <= mid)
Add(x << 1, l, mid, pos, c);
else
Add(x << 1 | 1, mid + 1, r, pos, c);
pushup(x);
}
int query(int x, int l, int r, int L, int R)
{
if (L > R || L > r || R < l)
return 0;
if (L <= l && r <= R)
return root.sum;
int mid = (l + r) >> 1, res = 0;
pushdown(x);
if (L <= mid)
inc(res, query(x << 1, l, mid, L, R));
if (R > mid)
inc(res, query(x << 1 | 1, mid + 1, r, L, R));
return res;
}
#undef root
} T[K];
int n, m, k, S[K][K], d[K];
pair<int, int> seg[N];
int main()
{
n = read(), k = read(), m = n << 1;
for (int i = 1; i <= n; ++i)
seg[i].first = read(), seg[i].second = read();
sort(seg + 1, seg + 1 + n);
S[0][0] = 1;
for (int i = 1; i <= k; ++i)
for (int j = 1; j <= i; ++j)
S[i][j] = add(S[i - 1][j - 1], mul(S[i - 1][j], j));
T[0].Add(1, 0, m, 0, 1);
for (int i = 1; i <= n; ++i)
{
int l = seg[i].first, r = seg[i].second;
for (int p = k; p >= 0; --p)
{
if (p < k)
T[p + 1].Add(1, 0, m, r, T[p].query(1, 0, m, 0, l - 1));
T[p].Add(1, 0, m, r, T[p].query(1, 0, m, 0, r));
T[p].Mul(1, 0, m, r + 1, m, 2);
}
}
int ans = 0, fac = 1;
for (int p = 0; p <= k; ++p)
{
int s = T[p].query(1, 0, m, 0, m);
inc(ans, mul(s, mul(S[k][p], fac)));
fac = mul(fac, p + 1);
}
write(ans, '\n');
return 0;
}