bzoj 4381 Odwiedziny

根号分治.

先预处理倍增数组用于支持查询 LCA, 预处理和查询 LCA 的时间复杂度为 $O(n\log n + m\log n)$ .

$k\le \sqrt n$

预处理出 $nxt(x,i)$ 表示从 $x$ 向上跳 $i$ 步走到的点, $sum(x,i)$ 表示从 $x$ 向上每次跳 $i$ 步经过的点权和.

于是询问时就可以做到 $O(1)$ 回答.

这部分的时间复杂度为 $O(m\sqrt n)$ .

$k> \sqrt n$

对于 $k>\sqrt n$ 的询问,显然跳的步数不会超过 $\sqrt n$ ,每次用倍增找出要跳到的下一个点,更新答案即可.

这部分的时间复杂度为 $O(m\sqrt n\log n)$ .

适当调整

可以发现,若我们对 $k\le b,k>b$ 分别执行以上两种算法,两部分总时间复杂度为 $O(mb+m\frac n b\cdot \log n)$ .

取 $b=\sqrt {n\log n}​$ 即可做到 $O(m\sqrt{n\log n})​$ ,再算上求 LCA 部分,复杂度为 $O(m\sqrt {n\log n}+(m+n)\log n)​$ .

然而这只是理论复杂度,实现时由于常数等原因需要自己调节参数以取得更好的效果…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//%std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int out = 0, fh = 1;
char jp = getchar();
while ((jp > '9' || jp < '0') && jp != '-')
jp = getchar();
if (jp == '-')
fh = -1, jp = getchar();
while (jp >= '0' && jp <= '9')
out = out * 10 + jp - '0', jp = getchar();
return out * fh;
}
void print(int x)
{
if (x >= 10)
print(x / 10);
putchar('0' + x % 10);
}
void write(int x, char c)
{
print(x);
putchar(c);
}
const int N = 5e4 + 10, K = 16, S = 200 + 10;
int n, m, B, a[N], b[N], c[N], fa[N][K], dep[N], ans[N];
vector<int> G[N];
int nxt[N][S], sum[N][S];
void dfs(int u, int F)
{
fa[u][0] = F;
for (int i = 1; (1 << i) <= dep[u]; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
nxt[u][1] = F, sum[u][1] = sum[F][1] + a[u];
for (int i = 2; i <= B; ++i)
{
nxt[u][i] = nxt[F][i - 1];
sum[u][i] = sum[nxt[u][i]][i] + a[u];
}
for (int v : G[u])
if (v != F)
{
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
int jump(int x, int k)
{
for (int i = K - 1; i >= 0; --i)
if ((1 << i) <= k)
x = fa[x][i], k -= 1 << i;
return x;
}
struct info
{
int x, y, k, id;
bool operator < (const info &rhs) const
{
return k < rhs.k;
}
} q[N];
int LCA(int x, int y)
{
if (dep[x] < dep[y])
swap(x, y);
for (int i = K - 1; i >= 0; --i)
if ((1 << i) <= dep[x] - dep[y])
x = fa[x][i];
if (x == y)
return x;
for (int i = K - 1; i >= 0; --i)
if ((1 << i) <= dep[x] && fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int bf(int x, int y, int k)
{
if (x == y)
return a[x];
int s = a[x], lca = LCA(x, y);
while (dep[x] - dep[lca] > k)
{
x = jump(x, k);
s += a[x];
}
int p = dep[x] + dep[y] - 2 * dep[lca];
if (p <= k)
return s + a[y];
p -= k;
while (1)
{
if (p <= 0)
return s + a[y];
s += a[jump(y, p)];
p -= k;
}
return s;
}
int query(int x, int y, int k)
{
int s = 0, lca = LCA(x, y);
int d = dep[x] - dep[lca];
if (d >= k)
{
d -= d % k;
s += sum[x][k];
x = jump(x, d);
s -= sum[nxt[x][k]][k];
}
else
s += a[x];
if (lca == y)
return s;
int p = dep[x] + dep[y] - 2 * dep[lca];
if (p <= k)
return s + a[y];
int z = jump(y, (p - 1) % k + 1);
int w = jump(y, p - k);
s += a[y];
s += sum[z][k], s -= sum[nxt[w][k]][k];
return s;
}
int main()
{
n = read();
B = sqrt(n / 4);
for (int i = 1; i <= n; ++i)
a[i] = read();
for (int i = 1; i < n; ++i)
{
int u = read(), v = read();
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i)
b[i] = read();
for (int i = 1; i <= n - 1; ++i)
{
c[i] = read();
if (c[i] <= B)
write(query(b[i], b[i + 1], c[i]), '\n');
else
write(bf(b[i], b[i + 1], c[i]), '\n');
}
return 0;
}