bzoj 4860 树的难题

点分治 + 线段树合并.

  • 英语月考的时候一直在想这个题…
  • 有路径长度的限制,可以考虑点分治.然后发现在合并两条路径时有两种情况.
  • 靠近当前分治中心的那两条边如果颜色不同,就直接将两条路径权值加起来.否则还要减去那条边的颜色权值.
  • 分治时把子树按照与当前分支中心连接的边的颜色排序,扫一遍,维护两颗线段树,分别表示连到分治中心的边与当前颜色不同的最大权值与相同的最大权值.
  • 处理完一种颜色的时候把两颗线段树合并起来就好了.时间复杂度 $O(nlog^2n)$ .

bzoj 不支持 C++11​ 是真的毒瘤…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mp make_pair
typedef pair<int,int> pii;
const ll inf=1e18;
inline int read()
{
int out=0,fh=1;
char jp=getchar();
while ((jp>'9'||jp<'0')&&jp!='-')
jp=getchar();
if (jp=='-')
fh=-1,jp=getchar();
while (jp>='0'&&jp<='9')
out=out*10+jp-'0',jp=getchar();
return out*fh;
}
const int MAXN=2e5+10;
int n,m,limL,limR,roota,rootb,cnt;
struct node
{
int ls,rs;
ll mxv;
}Tree[MAXN*30];
#define root Tree[o]
int newnode()
{
int o=++cnt;
root.ls=root.rs=0;
root.mxv=-inf;
return o;
}
void pushup(int o)
{
root.mxv=max(Tree[root.ls].mxv,Tree[root.rs].mxv);
}
int merge(int a,int b)
{
if(!a || !b)
return a+b;
Tree[a].mxv=max(Tree[a].mxv,Tree[b].mxv);
Tree[a].ls=merge(Tree[a].ls,Tree[b].ls);
Tree[a].rs=merge(Tree[a].rs,Tree[b].rs);
return a;
}
void insert(int &o,int l,int r,int pos,ll c)
{
if(!o)
o=newnode();
if(l==r)
{
root.mxv=max(root.mxv,c);
return;
}
int mid=(l+r)>>1;
if(pos<=mid)
insert(root.ls,l,mid,pos,c);
else
insert(root.rs,mid+1,r,pos,c);
pushup(o);
}
ll query(int o,int l,int r,int L,int R)
{
if(!o)
return -inf;
if(L<=l && r<=R)
return root.mxv;
ll res=-inf;
int mid=(l+r)>>1;
if(L<=mid)
res=max(res,query(root.ls,l,mid,L,R));
if(R>mid)
res=max(res,query(root.rs,mid+1,r,L,R));
return res;
}
vector<pii> edge[MAXN];
ll ans=-inf,mi;
int rt=0,totsize,vis[MAXN],siz[MAXN];
int val[MAXN];
void Findrt(int u)
{
siz[u]=1;
vis[u]=1;
int sonsize=0;
int SIZ=edge[u].size();
for(int id=0;id<SIZ;++id)
{
pii i=edge[u][id];
int v=i.second;
if(vis[v])
continue;
Findrt(v);
siz[u]+=siz[v];
sonsize=max(sonsize,siz[v]);
}
sonsize=max(sonsize,totsize-siz[u]);
if(sonsize<mi)
rt=u,mi=sonsize;
vis[u]=0;
}
void getsize(int u)
{
++totsize;
vis[u]=1;
int SIZ=edge[u].size();
for(int id=0;id<SIZ;++id)
{
pii i=edge[u][id];
int v=i.second;
if(!vis[v])
getsize(v);
}
vis[u]=0;
}
ll mx[MAXN];
void dfs(int u,ll c,int len,int curcol)
{
if(len>limR)
return;
vis[u]=1;
mx[len]=max(mx[len],c);
int SIZ=edge[u].size();
for(int id=0;id<SIZ;++id)
{
pii i=edge[u][id];
int v=i.second,r=i.first;
ll newc=c;
if(r!=curcol)
newc+=val[r];
if(!vis[v])
dfs(v,newc,len+1,r);
}
vis[u]=0;
}
void solve(int u)
{
totsize=cnt=roota=rootb=0;
mi=inf;
getsize(u);
Findrt(u);
int Rt=rt;
getsize(Rt);
vis[Rt]=1;
int precol=0;
int SIZ=edge[Rt].size();
for(int id=0;id<SIZ;++id)
{
pii e=edge[Rt][id];
int v=e.second,r=e.first;
if(vis[v])
continue;
for(int k=0;k<=siz[v];++k)
mx[k]=-inf;
if(rootb && r!=precol)
{
roota=merge(roota,rootb);
rootb=0;
}
dfs(v,val[r],1,r);
for(int i=1;i<=siz[v] && i<limR && mx[i]!=mx[0];++i)
{
ll tmp=max(ans,i>=limL && i<=limR? mx[i]:-inf);
tmp=max(tmp,query(roota,1,limR,max(1,limL-i),limR-i)+mx[i]);
tmp=max(tmp,query(rootb,1,limR,max(1,limL-i),limR-i)-val[r]+mx[i]);
ans=max(ans,tmp);
}
for(int i=1;i<=siz[v] && i<limR && mx[i]!=mx[0];++i)
{
insert(rootb,1,limR,i,mx[i]);
}
precol=r;
}
for(int id=0;id<SIZ;++id)
{
pii e=edge[Rt][id];
int v=e.second;
if(!vis[v])
solve(v);
}
}
int main()
{
Tree[0].mxv=-inf;
n=read(),m=read(),limL=read(),limR=read();
for(int i=1;i<=m;++i)
val[i]=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read(),c=read();
edge[u].push_back(mp(c,v));
edge[v].push_back(mp(c,u));
}
for(int i=1;i<=n;++i)
sort(edge[i].begin(),edge[i].end());
solve(1);
cout<<ans<<endl;
return 0;
}