P4103题解

本文最后更新于 2026年2月21日 下午

0 前言

这个题我能调半天充分证明了我完全是神人。

1 题意

题面非常简明易懂。

2 思路

由于我并不会DSU on Tree和DDP,所以这里的做法是复杂难写的虚树做法。

考虑在原树上对这 kk 的信息进行处理肯定很难做,所以先建出虚树来,每一条虚树边都是对应原树上的路径长度。

对于三个信息,考虑树上DP。我们把题目中给定的 kk 个点称为关键点,虚树上其余的点则是非关键点。

最简单的显然是维护两个关键点路径长最大值。用 fuf_u 表示点 uu 子树内关键点到点 uu 的最长路径。向上转移是平凡的,答案的话再枚举 uu 的子节点 vv 的同时更新一下就可以了,因为非叶子处的关键点一定不如叶子更优,而叶子一定都是关键点,所以没有什么需要特殊处理的地方。同时因为这题不可能有负边权,所以 ff 一开始全初始化为 00 即可。

和上面差不多的是维护两个关键点路径的最小值。这里需要考虑非叶子处关键点对答案的贡献。用 gug_u 表示点 uu 子树内关键点到点 uu 的最短路径。需要处理的一点是,若 uu 本身为关键点则 gug_u 显然为 00,因为自己到自己距离为 00。其余的基本同上。

难点在于维护任意两个关键点路径之和。这里难处理的一个原因在于和并不是可重复贡献的,需要精确计数。

考虑一个常见的想法。我们用 huh_u 表示 uu 字数内所有关键点到 uu 的长度之和。我们用 sizusiz_u 表示 uu 子树内的关键点数量,假设一个点 uu 仅有 v1,v2v_1,v_2 两个儿子,其中 v1v_1 子树内关键点到 v1v_1 的长度分别为 a1,a2,,asizv1a_1,a_2,\ldots,a_{siz_{v_1}}v2v_2 子树关键点到 v1v_1 的长度分别为 b1,b2,,bsizv2b_1,b_2,\ldots,b_{siz_{v_2}}。那么如何把它们互相之间的贡献计算合并进答案呢?有如下推导:

isizv1jsizv2ai+bj=isizv1(sizv2ai+jsizv2bj)=sizv2isizv1ai+sizv1jsizv2bj=sizv2(hv1+sizv1dis(u,v1))+sizv1(hv2+sizv2dis(u,v2))\sum_i^{siz_{v_1}}\sum_j^{siz_{v_2}}a_i+b_j\\ = \sum_i^{siz_{v_1}} (siz_{v_2}\cdot a_i + \sum_j^{siz_{v_2}}b_j)\\ = siz_{v_2}\sum_i^{siz_{v_1}} a_i + siz_{v_1} \sum_j^{siz_{v_2}}b_j\\ = siz_{v_2}\cdot (h_{v_1}+siz_{v_1}\cdot dis(u,v_1)) + siz_{v_1}\cdot (h_{v_2}+siz_{v_2}\cdot dis(u,v_2))

发现这样的合并是简单的。那么考虑我们已经合并完了子树 v1vi1v_1\sim v_{i-1}huh_u 中已经合并了子树 v1vi1v_1\sim v_{i-1} 的贡献,则合并子树 viv_i 产生的代价显然是:

sizvihu+sizu(hvi+sizvidis(u,vi))siz_{v_i}\cdot h_{u} + siz_{u}\cdot (h_{v_i}+siz_{v_i}\cdot dis(u,v_i))

注意,这里的 sizusiz_{u}huh_u 并非表示整个 uu 子树的信息,而是仅仅考虑已经合并完的部分的信息,也就是说要随着算贡献动态更新。

至于 huh_u 如何向上转移也是很平凡的,只需要枚举 viv_i 时考虑一下边 (u,v)(u,v) 被算多少次就可以了。显然是:

fufv1+sizvidis(u,v)f_u\leftarrow f_{v_1}+siz_{v_i}\cdot dis(u,v)

然后有些小细节处理一下,这题就做完了。

3 代码

又臭又长的奇怪代码,注意使用 #define int long long 带来的多一倍的空间问题。

多测还需要注意清空干净,注意不要漏清。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
struct edge{
int to,val;
};
std::vector<edge> vt[N];
int n,q;
int fa[N][30],sum[N][30];
int st[N],top,siz[N];
int dfn[N],tot,dep[N];
std::vector<int> g[N],node;
ll anss[N],anssum,ansi,ansa;
ll fmx[N],fmn[N];
bool isg[N];

void dfs(int u,int father){
dfn[u]=++tot;
fa[u][0]=father;
dep[u]=dep[father]+1;
for(auto &i:g[u]){
if(i==father)continue;
sum[i][0]=1;
dfs(i,u);
}
}
void init(){
for(int j=1;j<=25;++j){
for(int i=1;i<=n;++i){
fa[i][j]=fa[fa[i][j-1]][j-1];
sum[i][j]=sum[i][j-1]+sum[fa[i][j-1]][j-1];
}
}
}
int lca(int u,int v){
if(dep[u]<dep[v])std::swap(u,v);
for(int j=25;j>=0;--j){
if(dep[fa[u][j]]>=dep[v])u=fa[u][j];
}
if(u==v)return u;
for(int j=25;j>=0;--j){
if(fa[u][j]!=fa[v][j]){
u=fa[u][j];
v=fa[v][j];
}
}
return fa[u][0];
}
int dist(int u,int v){
int res=0;
if(dep[u]<dep[v])std::swap(u,v);
for(int j=25;j>=0;--j){
if(dep[fa[u][j]]>=dep[v]){res+=sum[u][j];u=fa[u][j];}
}
if(u==v)return res;
for(int j=25;j>=0;--j){
if(fa[u][j]!=fa[v][j]){
res+=sum[u][j];u=fa[u][j];
res+=sum[v][j];v=fa[v][j];
}
}
return sum[u][0]+sum[v][0]+res;
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void addedge(int u,int v,int w){
vt[u].push_back({v,w});
vt[v].push_back({u,w});
}
void buildvt(){
std::sort(node.begin(),node.end(),cmp);
st[top=1]=node[0];
for(auto &i:node){
isg[i]=1;
if(i==node[0])continue;
int lc=lca(i,st[top]);
if(lc==st[top])goto end;
while(top>=2&&dfn[lc]<dfn[st[top-1]]){
addedge(st[top-1],st[top],dist(st[top-1],st[top]));
top--;
}
addedge(lc,st[top],dist(lc,st[top]));
if(dfn[lc]>dfn[st[top-1]])st[top]=lc;
else top--;
end:st[++top]=i;
}
for(int i=1;i<top;++i){
addedge(st[i],st[i+1],dist(st[i],st[i+1]));
}
}
void solve(int u,int father){
siz[u]=isg[u];
fmx[u]=0,fmn[u]=inf;
if(isg[u])fmn[u]=0;
for(auto &i:vt[u]){
if(i.to==father)continue;
solve(i.to,u);
ansi=std::min((long long)fmn[u]+i.val+fmn[i.to],ansi);
if(fmn[u]>i.val+fmn[i.to])fmn[u]=std::min(fmn[u],i.val+fmn[i.to]);
ansa=std::max((long long)fmx[u]+i.val+fmx[i.to],ansa);
if(fmx[u]<i.val+fmx[i.to])fmx[u]=std::max(fmx[u],i.val+fmx[i.to]);
anssum+=(anss[u]*siz[i.to]+siz[u]*(anss[i.to]+i.val*siz[i.to]));
siz[u]+=siz[i.to];
anss[u]+=(anss[i.to]+siz[i.to]*i.val);
}
}
void clr(int u,int father){
siz[u]=anss[u]=0;
fmx[u]=0;
fmn[u]=inf;
for(auto &i:vt[u]){
if(i.to==father)continue;
clr(i.to,u);
}
vt[u].clear();
isg[u]=0;
}
signed main(){
n=read();
for(int i=1;i<n;++i){
int u=read(),v=read();
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
init();
q=read();
while(q--){
ansi=inf;ansa=-inf;anssum=0;
node.clear();
int k=read();
for(int i=1;i<=k;++i){
int u=read();
node.push_back(u);
}
buildvt();
solve(node[0],0);
printf("%lld %lld %lld\n",anssum,ansi,ansa);
clr(node[0],0);
}
return 0;
}

4 复杂度分析

空间复杂度:瓶颈在于倍增,空间复杂度为 O(nlogn)O(n\log n)。虚树的空间复杂度显然为 O(k)O(k),看起来是很小的。

时间复杂度:倍增预处理 O(nlogn)O(n\log n),建虚树 O(klogk)O(k\log k),虚树上DP O(k)O(k)。总的复杂度为 O(nlogn+klogk)O(n\log n+\sum k\log k)

5 后记

做的费劲巴拉。

新年快乐!


P4103题解
http://ljhljh1102,github.io/2026/02/14/P4103题解/
作者
1102
发布于
2026年2月14日
许可协议