bzoj 4446 小凸玩密室

树形 $dp$ .

设 $f(i,j)$ 表示将子树 $i$ 全部点亮,下一次点亮点 $j$ 的最小花费.这样的话,状态数是 $O(n^2)$ 的.

注意到点亮子树 $i$ 后,下一次要么点亮 $i$ 的某个祖先,要么点亮 $i$ 的某个祖先的另外一侧的儿子.树是完全二叉树,所以可以直接用深度表示,树深是 $O(\log n)$ 的,再通过位运算得到节点标号.

设 $f(i,j)$ 表示点亮子树 $i$ 后,下一次点亮 $i$ 的第 $j$ 级祖先的最小花费, $g(i,j)$ 表示点亮子树 $i$ 后,下一次点亮 $i$ 的第 $j$ 级祖先的另一个儿子的最小花费.这样状态数是 $O(n\log n)$ 的.

默认以 $1$ 为根, $dp$ 求出 $f,g$ 的值.然后枚举第一个点亮的点 $x$ ,先点亮子树 $x$ ,跳到 $fa_x$ ,再点亮 $fa_x$ 的另一侧子树,再跳到 $fa_{fa_x}$ ,点亮 $fa_{fa_x}$ 的另一颗子树…需要跳 $O(\log n)$ 次.

时间复杂度 $O(n\log n)$ .

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read()
{
ll out=0,fh=1;
char jp=getchar();
while ((jp>'9'||jp<'0')&&jp!='-')
jp=getchar();
if (jp=='-')
fh=-1,jp=getchar();
while (jp>='0'&&jp<='9')
out=out*10+jp-'0',jp=getchar();
return out*fh;
}
int fa(int x)
{
return x>>1;
}
int lson(int x)
{
return x<<1;
}
int rson(int x)
{
return x<<1|1;
}
int Ancestor(int x,int i)
{
return x>>i;
}
int Brother(int x,int i)
{
return (x>>(i-1))^1;
}
const int MAXN=2e5+10;
const int Log=20;
const ll inf=1e18;
int n,dep[MAXN];
ll a[MAXN],dis[MAXN][Log],f[MAXN][Log],g[MAXN][Log];
int main()
{
n=read();
for(int i=1; i<=n; ++i)
a[i]=read();
dep[1]=1;
for(int i=2; i<=n; ++i)
{
dep[i]=dep[fa(i)]+1;
dis[i][1]=read();
for(int j=2; j<=dep[i]; ++j)
dis[i][j]=dis[fa(i)][j-1]+dis[i][1];
}
for(int i=n; i>=1; --i)
{
int l=lson(i),r=rson(i);
for(int j=1; j<=dep[i]; ++j)
{
if(r<=n) //lson and rson
{
f[i][j]=min(a[l]*dis[l][1]+g[l][1]+f[r][j+1],a[r]*dis[r][1]+g[r][1]+f[l][j+1]);
g[i][j]=min(a[l]*dis[l][1]+g[l][1]+g[r][j+1],a[r]*dis[r][1]+g[r][1]+g[l][j+1]);
}
else if(l<=n)// only lson
{
f[i][j]=a[l]*dis[l][1]+f[l][j+1];
g[i][j]=a[l]*dis[l][1]+g[l][j+1];
}
else //leaf
{
f[i][j]=dis[i][j]*a[Ancestor(i,j)];
g[i][j]=(dis[i][j]+dis[Brother(i,j)][1])*a[Brother(i,j)];
}
}
}
ll ans=inf;
for(int i=1; i<=n; ++i)
{
ll tmp=f[i][1];
for(int x=fa(i),y=i; x; y=x,x=fa(x))
{
int z=y^1;
if(z<=n)
tmp+=dis[z][1]*a[z]+f[z][2];
else
tmp+=dis[x][1]*a[fa(x)];
}
ans=min(ans,tmp);
}
cout<<ans<<endl;
return 0;
}