Loj 2564 原题识别

比较麻烦的二维数点.

随机生成数据带来的性质

首先需要注意数据的生成方式是给出的,表面上是减少输入量,实际上蕴含了这棵树的一些有用的性质.

前 $p$ 个点形成一条主链,后面的点随机向前面的点连边,那么每个点到主链的距离期望是 $O(\log n)$ 的.

于是对于树上任意两个点 $x,y$ 和它们的最近公共祖先 $l$ , $\min(dis(x,l),dis(y,l))$ 的期望也是 $O(\log n)$ 的.

此外,由于每个点颜色是在 $[1,n]$ 中随机的,所以在期望意义下,每种颜色只会有 $O(1)$ 个点.

第一问

先考虑第一问怎么做,求区间内颜色数目通常可以找出 $p_i$ 表示上一个和 $i$ 颜色相同的点,然后二维数点.

现在是在树上,就令 $p_i$ 表示 $i$ 的所有真祖先中,颜色与 $i​$ 相同且深度最大的点.

不妨令 $dis(x,l) <dis(y,l)$ ,我们可以先求出 $[l, y]$ 这条链上的颜色数目,这可以用主席树二维数点 $O(\log n)$ 求出.

再枚举 $(l,x]​$ 链上的每个点 $i​$ ,若 $i​$ 的颜色是 $c​$ ,就找出所有颜色是 $c​$ 的点,如果都没有在 $[l,y]​$ 上出现, $i​$ 就有贡献 $1​$.

检查某个点是否在 $[l,y]$ 这条链上可以直接用欧拉序判断.

$(l,x]​$ 这条链上期望只会有 $O(\log n)​$ 个点,而每个点期望只会有 $O(1)​$ 个点和它颜色相同,每次判断也是 $O(1)​$ 的.

于是单次第一种询问的时间复杂度为 $O(\log n)​$ .

第二问

仍然令 $dis(x,l) <dis(y,l)$ ,可以讨论 $a,b$ 所在的位置,拆成三个子问题.

Case 1

$a\in [1,l), b\in [1,y]$ .这等价于在一条链上询问,点 $i$ 能产生贡献,当且仅当 $L\in[p_i+1,i], R\ge i$ .

于是可以对 $i​$ 分段,将贡献大力拆开,用主席树维护一下区间的 $\sum 1, \sum i,\sum p_i,\sum i\cdot p_i​$ 就可以支持询问.

Case 2

$a\in [l,x],b\in[1,l)$ .我们把 $a\in [1,x],b\in[1,l) $ 的贡献算出来,再减去 $a\in [1,l),b\in[1,l) $ 的贡献.

这两种贡献和 Case 1 的形式是一样的.

Case 3

$a\in [l,x],b\in[l,y]$ .先计算 $[l,y]$ 上的点的贡献,此时不考虑 $[l,x]$ 内的点的颜色,贡献拆开后就是一个二维数点.

再考虑 $[l,x]$ 上的每个点 $i$ ,它有贡献当且仅当 $p_i<l$ 且 $[l,b]$ 内没有和它颜色相同的点.

找出 $q_i$ 表示 $[l,y]$ 中第一个和 $i$ 颜色相同的点(没有就记作 $y+1$ ).

当 $p_i<l$ 时,就有贡献 $(x-i+1)(q_i-l)$ ,仍然可以用二维数点计算.

每种情况都可以在 $O(\log n)$ 的时间内求出贡献,总复杂度 $O(n \log n)$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline unsigned int read()
{
unsigned int out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
unsigned int SA, SB, SC;
unsigned int rng61()
{
SA ^= SA << 16;
SA ^= SA >> 5;
SA ^= SA << 1;
unsigned int t = SA;
SA = SB;
SB = SC;
SC ^= t ^ SA;
return SC;
}
const int N = 4e5 + 10, K = 19;
int n, m, len, ecnt, head[N], to[N], nx[N], col[N], fa[N];
void addedge(int u, int v)
{
++ecnt;
to[ecnt] = v;
nx[ecnt] = head[u];
head[u] = ecnt;
}
void gen()
{
for (int i = 2; i <= n; ++i)
{
if (i <= len)
fa[i] = i - 1;
else
fa[i] = rng61() % (i - 1) + 1;
addedge(fa[i], i);
}
for (int i = 1; i <= n; ++i)
col[i] = rng61() % n + 1;
}
vector<int> vec[N];
int dfn[N], ed[N], Log[N], st[N][K], dep[N], app[N], ptot, tot;
ll pre2[N], pre3[N];
struct info
{
ll sum1, sum2, sum3, sum4;
info(ll sum1 = 0, ll sum2 = 0, ll sum3 = 0, ll sum4 = 0) :
sum1(sum1), sum2(sum2), sum3(sum3), sum4(sum4) {}
info& operator += (const info &rhs)
{
sum1 += rhs.sum1, sum2 += rhs.sum2;
sum3 += rhs.sum3, sum4 += rhs.sum4;
return *this;
}
info& operator -= (const info &rhs)
{
sum1 -= rhs.sum1, sum2 -= rhs.sum2;
sum3 -= rhs.sum3, sum4 -= rhs.sum4;
return *this;
}
} tree[N * 20];
int ls[N * 20], rs[N * 20];
void upd(int &o, int lst, int l, int r, int x, int y)
{
o = ++ptot;
tree[o] = tree[lst], ls[o] = ls[lst], rs[o] = rs[lst];
++tree[o].sum1, tree[o].sum2 += y;
tree[o].sum3 += x, tree[o].sum4 += 1LL * x * y;
if (l == r)
return;
int mid = (l + r) >> 1;
if (y <= mid)
upd(ls[o], ls[lst], l, mid, x, y);
else
upd(rs[o], rs[lst], mid + 1, r, x, y);
}
void query(info &res, int tl, int tr, int l, int r, int L, int R)
{
if (tl == tr || l > R || r < L)
return;
if (L <= l && r <= R)
{
res += tree[tr];
res -= tree[tl];
return;
}
int mid = (l + r) >> 1;
if (L <= mid)
query(res, ls[tl], ls[tr], l, mid, L, R);
if (R > mid)
query(res, rs[tl], rs[tr], mid + 1, r, L, R);
}
int lst[N], rt[N];
void dfs(int u, int Fa)
{
st[dfn[u] = ++tot][0] = u, dep[u] = dep[Fa] + 1;
vec[col[u]].push_back(u), lst[u] = app[col[u]];
int t = app[col[u]];
app[col[u]] = u;
pre2[u] = pre2[Fa] + dep[u] - dep[lst[u]];
pre3[u] = pre3[Fa] + 1LL * (dep[u] - dep[lst[u]]) * dep[u];
upd(rt[u], rt[Fa], 0, n, dep[u], dep[lst[u]]);
for (int i = head[u]; i; i = nx[i])
{
int v = to[i];
dfs(v, u);
st[++tot][0] = u;
}
app[col[u]] = t, ed[u] = tot;
}
int LCA(int x, int y)
{
x = dfn[x], y = dfn[y];
if (x > y)
swap(x, y);
int k = Log[y - x + 1];
if (dep[st[x][k]] < dep[st[y - (1 << k) + 1][k]])
return st[x][k];
else
return st[y - (1 << k) + 1][k];
}
int vis[N];
bool onlink(int x, int y, int p)
{
return dfn[x] <= dfn[p] && dfn[p] <= ed[x] && dfn[p] <= dfn[y] && dfn[y] <= ed[p];
}
ll calc1(int x, int y)
{
++tot;
int la = LCA(x, len), lb = LCA(y, len);
if (la > lb)
swap(x, y);
int l = LCA(x, y);
info tmp = info(0, 0, 0, 0);
query(tmp, rt[fa[l]], rt[y], 0, n, 0, dep[l] - 1);
int res = tmp.sum1;
for (int i = x; i != l; i = fa[i])
if (vis[col[i]] != tot)
{
vis[col[i]] = tot;
int fl = 1, sz = vec[col[i]].size();
for (int j = 0; j < sz && fl; ++j)
if (onlink(l, y, vec[col[i]][j]))
fl = 0;
res += fl;
}
return res;
}
ll calc(int x, int y, info res)
{
int a = dep[x], b = dep[y];
ll s = (res.sum1 * a - res.sum2) * (b + 1) - res.sum3 * a + res.sum4;
return s + 1LL * (a + b + 2) * pre2[x] - 2 * pre3[x] - a;
}
ll calc2(int x, int y)
{
++tot;
int la = LCA(x, len), lb = LCA(y, len);
if (la > lb)
swap(x, y);
int l = LCA(x, y), pl = fa[l], dl = dep[l];
info t1 = info(0, 0, 0, 0), t2 = t1;
query(t1, rt[pl], rt[y], 0, n, 0, dl - 1);
query(t2, rt[pl], rt[x], 0, n, 0, dl - 1);
ll res = calc(pl, y, t1) + calc(pl, x, t2);
res -= 2 * (1LL * dl * pre2[pl] - pre3[pl]) + dl - 1;
res += (1LL * (dep[y] + 1) * t1.sum1 - t1.sum3) * (dep[x] - dl + 1);
int tp = 0;
for (int i = x; i != l; i = fa[i])
app[++tp] = i;
app[++tp] = l;
while (tp > 0)
{
int i = app[tp--];
if (vis[col[i]] != tot)
{
vis[col[i]] = tot;
int mn = dep[y] + 1, sz = vec[col[i]].size();
for (int j = 0; j < sz; ++j)
if (onlink(l, y, vec[col[i]][j]) && mn > dep[vec[col[i]][j]])
mn = dep[vec[col[i]][j]];
res += 1LL * (mn - dl) * (dep[x] - dep[i] + 1);
}
}
return res;
}
void solve()
{
dfs(1, 0);
Log[1] = 0;
for (int i = 2; i <= tot; ++i)
Log[i] = Log[i >> 1] + 1;
for (int j = 1; j < K; ++j)
for (int i = 1; i + (1 << j) - 1 <= tot; ++i)
if (dep[st[i][j - 1]] < dep[st[i + (1 << j >> 1)][j - 1]])
st[i][j] = st[i][j - 1];
else
st[i][j] = st[i + (1 << j >> 1)][j - 1];
tot = 0;
m = read();
for (int i = 1; i <= m; ++i)
{
int op = read(), x = read(), y = read();
if (op == 1)
printf("%lld\n", calc1(x, y));
else
printf("%lld\n", calc2(x, y));
}
}
void reset()
{
ecnt = ptot = tot = 0;
for (int i = 1; i <= n; ++i)
{
rt[i] = head[i] = vis[i] = app[i] = 0;
vec[i].clear();
}
}
int main()
{
int T = read();
while (T--)
{
n = read(), len = read();
SA = read(), SB = read(), SC = read();
gen();
solve();
reset();
}
return 0;
}