树分治

树分治

树分治

树分治将分治技术运用到了树上,可以更好地维护两点间的路径问题,如果整棵树是一个序列,则可以看做是 CDQ 分治或者是整体二分。

主要思想就是递归的层数不超过 $\log$ 层,每层我们可以 $O(n)$ 遍历所有点,以执行其他操作。

点分治

点分治用于处理树上的路径问题。

一条路径如果指定了根节点,那可以被拆分为 $u \to root$ 和 $v \to root$ 两段进行讨论,点分治就是在这样的条件下产生的。

我们以 P3806 【模板】点分治 1 来讲解,这道题目要在 $O(n \log n)$ 的时间内求出是否有一条路径的长度等于 $k$。

如果选择了根,那么我们可以用桶在 $O(n)$ 时间内求出跨越根的路径有没有等于 $k$ 的,这个实现不难。

把根删除会得到若干棵子树,这些子树递归下去查找即可,但是这样做的时间复杂度是 $O(n^2)$ 的,因为我们有可能不断地选择一条链的端点作为根。

于是可以提出改进,每次选择的根是当前子树的重心,这样的话递归层数不超过 $\log$ 层,时间复杂度就是 $O(n \log n)$,但是记住每次清空的时候不能直接 memset,而需要记录修改了哪些位置上的值然后弹出修改才能保证时间复杂度,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<bits/stdc++.h>
#define N 20005
#define M 10000005
using namespace std;
int n,m,a[N],i,x,y,z,la[N],ne[N],to[N],val[N],tot,q[N],ans[N],col[N],siz[N],tim,sta[N],top;
bool has[M];
inline void merge(int x,int y,int z){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y,val[tot]=z;}
void found(int x,int fa,int &pos,int &now,int all){
int cut = 0;
siz[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(int x,int fa,int step){
for(int i=1;i<=m;i++) if(q[i]>=step&&has[q[i]-step]) ans[i]=1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,step+val[i]);
}
}
void solve2(int x,int fa,int step){
if(step<=1e7) has[step]=1,sta[++top]=step;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,step+val[i]);
}
}
void solve3(int x,int fa,int tim){
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve3(to[i],x,tim);
}
col[x] = tim;
}
void dfs(int x,int all){
int pos = -1,now = INT_MAX;
found(x,-1,pos,now,all);
has[0] = 1,sta[++top] = 0;
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve(to[i],pos,val[i]);
solve2(to[i],pos,val[i]);
}
while(top) has[sta[top]]=0,top--;
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve3(to[i],pos,++tim);
dfs(to[i],siz[to[i]]);
}
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>m;
for(i=1;i<n;i++) cin>>x>>y>>z,merge(x,y,z),merge(y,x,z);
for(i=1;i<=m;i++) cin>>q[i];
dfs(1,n);
for(i=1;i<=m;i++) cout<<(ans[i]?"AYE":"NAY")<<endl;
return 0;
}

注意:一定要区分自己是哪棵子树的重心,代码中使用了颜色数组来记录。

例题 1

P4178 Tree

这道题相对于上一道题的不同是要求小于等于 $k$ 的路径条数,于是用树状数组维护加减查询即可。

代码如下,只是多了树状数组,时间复杂度 $O(n \log^2 n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<bits/stdc++.h>
#define N 100005
using namespace std;
int n,m,a[N],i,x,y,z,la[N],ne[N],to[N],val[N],tot,col[N],siz[N],tim,sta[N],top,tr[N];
long long ans;
inline void add(int x){
x++;
while(x<=(2e4+1)) tr[x]++,x+=x&(-x);
}
inline int query(int x){
int num = 0;
x++;
while(x) num+=tr[x],x-=x&(-x);
return num;
}
inline void clear(int x){
x++;
while(x<=(2e4+1)) tr[x]=0,x+=x&(-x);
}
inline void merge(int x,int y,int z){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y,val[tot]=z;}
void found(int x,int fa,int &pos,int &now,int all){
int cut = 0;
siz[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(int x,int fa,int step){
if(step<=m) ans+=query(m-step);
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,step+val[i]);
}
}
void solve2(int x,int fa,int step){
if(step<=m) add(step),sta[++top]=step;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,step+val[i]);
}
}
void solve3(int x,int fa,int tim){
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve3(to[i],x,tim);
}
col[x] = tim;
}
void dfs(int x,int all){
int pos = -1,now = INT_MAX;
found(x,-1,pos,now,all);
add(0),sta[++top] = 0;
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve(to[i],pos,val[i]);
solve2(to[i],pos,val[i]);
}
while(top) clear(sta[top]),top--;
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve3(to[i],pos,++tim);
dfs(to[i],siz[to[i]]);
}
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
for(i=1;i<n;i++) cin>>x>>y>>z,merge(x,y,z),merge(y,x,z);
cin>>m;
dfs(1,n);
cout<<ans<<endl;
return 0;
}

例题 2

P2664 树上游戏

这道题需要我们求出以每个节点开始的每条路径上的颜色种类之和,因为涉及到了较为复杂的路径查询,因此我们使用点分治解决它。

一条路径,要么是 $lca \to u$,要么是 $u \to lca \to v$,对于第一种情况,我们只需要开一个桶,存储每个点到我们枚举的点的颜色种类数就可以了。

对于第二种路径,如果当前枚举的点是 $u$,我们考虑对每种颜色分开计算贡献,如果某种颜色在 $u \to lca$ 的路径上出现过,那就会贡献所有 $v$(不和 $u$ 在同一棵 $lca$ 的子树内)的数量的答案。

否则我们另外开一个桶,当遍历到 $v$ 的时候如果 $a_v$ 在 $v \to lca$ 上只经过了一次,那么就让 $cnt_{a_v} \gets cnt_{a_v}+siz_v$,然后对于上面的第二种情况剩下的情况,就加上所有 $cnt_{p}$,其中 $p$ 不在 $u \to lca$ 上出现。

当然 $cnt$ 要减去当前子树的贡献。

最后不要忘了加上 $lca \to lca$​ 的答案,以及清空我们用过的每个桶。

代码如下,不算难写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include<bits/stdc++.h>
#define ll long long
#define N 200005
using namespace std;
ll n,m,a[N],i,x,y,z,la[N],ne[N],to[N],tot,col[N],siz[N],sum[N],tim,sta[N],top,ans[N],qzh[N],vis[N],res;
inline void merge(ll x,ll y){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y;}
void found(ll x,ll fa,ll &pos,ll &now,ll all){
ll cut = 0;
siz[x] = 1;
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(ll x,ll fa,ll root,ll colsum){
ll temp = 0;
if(!vis[a[x]]) temp=1,colsum++;
vis[a[x]]++;
siz[x] = 1;
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,root,colsum);
siz[x] += siz[to[i]];
}
if(temp) sta[++top]=a[x],sum[a[x]]+=siz[x],res+=siz[x];
ans[root]+=colsum;
vis[a[x]]--;
}
void solve2(ll x,ll fa,ll root,ll colsum,ll alls,ll rest){
ll temp = 0;
if(!vis[a[x]]) temp=1,colsum++,rest-=sum[a[x]];
vis[a[x]]++;
siz[x] = 1;
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,root,colsum,alls,rest);
siz[x] += siz[to[i]];
}
ans[x]+=colsum*alls+rest;
vis[a[x]]--;
}
void solve3(ll x,ll fa,ll tim){
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve3(to[i],x,tim);
}
col[x] = tim;
}
void solve4(ll x,ll fa){
ll temp = 0;
if(!vis[a[x]]) temp=1;
vis[a[x]]++;
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve4(to[i],x);
}
if(temp) sum[a[x]]-=siz[x],res-=siz[x];
vis[a[x]]--;
}
void solve5(ll x,ll fa){
ll temp = 0;
if(!vis[a[x]]) temp=1;
vis[a[x]]++;
for(ll i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve5(to[i],x);
}
if(temp) sum[a[x]]+=siz[x],res+=siz[x];
vis[a[x]]--;
}
void dfs(ll x,ll all){
ll pos = -1,now = LLONG_MAX;
found(x,-1,pos,now,all);
ans[pos]++,res=0;
// cout<<"! "<<pos<<endl;
for(ll i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
vis[a[pos]]=1;
solve(to[i],pos,pos,1);
}
for(ll i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
vis[a[pos]]=1;
solve4(to[i],pos);
vis[a[pos]]=1;
solve2(to[i],pos,pos,1,all-siz[to[i]],res);
vis[a[pos]]=1;
solve5(to[i],pos);
}
vis[a[pos]]=0;
while(top) sum[sta[top]]=0,top--;
// for(i=1;i<=n;i++) cout<<ans[i]<<" ";
// cout<<endl;
for(ll i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve3(to[i],pos,++tim);
dfs(to[i],siz[to[i]]);
}
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
for(i=1;i<=n;i++) cin>>a[i],qzh[i]=a[i];
sort(qzh+1,qzh+n+1);
for(i=1;i<=n;i++) a[i]=lower_bound(qzh+1,qzh+n+1,a[i])-qzh;
for(i=1;i<n;i++) cin>>x>>y,merge(x,y),merge(y,x);
dfs(1,n);
for(i=1;i<=n;i++) cout<<ans[i]<<endl;
return 0;
}

点分治序

给定一个 $N$ 个结点的树,结点用正整数 $1 \sim N$ 编号。每条边有一个正整数权值。
用 $d(a,b)$ 表示从结点 $a$ 到结点 $b$ 的简单路径的距离。其中要求 $a<b$。将这 $\frac {N(N-1)}{2}$ 个距离从大到小排序,输出前 $M$ 个距离值。

如果按照点分治的顺序来加点到某个序列里面,那么这个序列的长度是 $O(n \log n)$ 的,并且每个点能够与之配对的点都在一个区间里面,即 $dis_{u,v}=dis_u+dis_v-2 \times dis_{\operatorname{lca}(u,v)}$

于是我们直接可以用超级钢琴题目的思路来解决,直接用 ST 表从堆里面取出 $M$ 个距离值,然后输出即可。

时间复杂度为 $O(n \log n \log k)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include<bits/stdc++.h>
#pragma GCC optimize("Ofast")
#define N 100005
#define M 1000005
using namespace std;
struct node{
int l,r,id,x;
bool operator<(const node& a)const{
return a.x>x;
}
}ress;
int n,k,a[N],i,j,x,y,z,la[N],ne[N],to[N],val[N],tot,col[N],siz[N],tim,c,res;
int ls[M],rs[M],dist[M],m,st[M][21],st2[M][21],LOG[M];
long long ans;
priority_queue<node> q;
inline void merge(int x,int y,int z){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y,val[tot]=z;}
void found(int x,int fa,int &pos,int &now,int all){
int cut = 0;
siz[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(int x,int fa,int step,int l,int r){
m++,dist[m] = step,ls[m] = l,rs[m] = r;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,step+val[i],l,r);
}
}
void solve2(int x,int fa,int tim){
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,tim);
}
col[x] = tim;
}
void dfs(int x,int all){
int pos = -1,now = INT_MAX,beg = m+1,las = m+1;
found(x,-1,pos,now,all);
m++,dist[m]=0,ls[m]=m,rs[m]=m;
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve(to[i],pos,val[i],beg,las);
las=m;
}
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve2(to[i],pos,++tim);
dfs(to[i],siz[to[i]]);
}
}
inline int query1(int l,int r){
c = LOG[r-l+1];
return max(st[l][c],st[r-(1<<c)+1][c]);
}
inline int query2(int l,int r){
c = LOG[r-l+1],res = max(st[l][c],st[r-(1<<c)+1][c]);
if(res==st[l][c]) return st2[l][c];
return st2[r-(1<<c)+1][c];
}
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline int read(){
int res = 0;
char c = nc();
while(c<'0'||c>'9')c=nc();
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(int x){
if(x>9) write(x/10);
pc(x%10+'0');
}
int main(){
n=read(),k=read();
for(i=1;i<n;i++) x=read(),y=read(),z=read(),merge(x,y,z),merge(y,x,z);
dfs(1,n);
for(i=m;i>=1;i--){
LOG[i] = log2(i);
st[i][0] = dist[i],st2[i][0] = i;
for(j=1;i+(1<<j)-1<=m;j++){
st[i][j] = max(st[i][j-1],st[i+(1<<j-1)][j-1]);
if(st[i][j]==st[i][j-1]) st2[i][j]=st2[i][j-1];
else st2[i][j]=st2[i+(1<<j-1)][j-1];
}
}
for(i=1;i<=m;i++) q.push((node){ls[i],rs[i],i,dist[i]+query1(ls[i],rs[i])});
while(k--){
write((ress=q.top()).x),pc('\n');
q.pop();
int pos = query2(ress.l,ress.r);
if(ress.l<pos) q.push((node){ress.l,pos-1,ress.id,dist[ress.id]+query1(ress.l,pos-1)});
if(ress.r>pos) q.push((node){pos+1,ress.r,ress.id,dist[ress.id]+query1(pos+1,ress.r)});
}
return fwrite(obuf,p3-obuf,1,stdout),0;
}

边分治

简单提一下,就是砍掉某条边,然后分治计算,但是这样的话对于菊花图来说时间复杂度仍然是 $O(n^2)$ 的,于是我们可以用多叉树转二叉树的方法转成类似线段树的样子,时间复杂度就正确了。

常数肯定比点分治大,也不推荐这种写法。

点分树

用于动态维护点分治的信息,通常会有强制在线、点权/边权修改等操作。

构建

我们考虑把整棵树进行重构,具体而言,设当前这棵树选择了 $pos$ 作为重心,划分成了 $T_1,T_2,T_3,\dots,T_k$ $k$ 棵子树,那么 $pos$ 在点分树上的儿子就是这 $k$ 棵子树选择的重心。

容易发现,这样构造出来的点分树的树高是 $O(\log)$ 级别的,因此很多并不太正确的暴力在点分树上执行操作都能得到较好的时间复杂度。

特别注意:点分树上如果 $u \to v \to w$ 不代表 $u$ 是 $v$ 在原树的祖先,$u \to w$ 在原树的距离也不能使用 $\operatorname{dis}(u,v)+\operatorname{dis}(v,w)$ 得到。

每次涉及到修改点权的时候就暴力跳祖先进行维护,如果是修改边权,就考虑边权变成点权之后再维护信息,时间复杂度通常是 $O(n \log n \times T)$,$T$ 是维护一次的时间复杂度,如果使用线段树等数据结构就是 $\log^2$ 的。

例题 1

P2056 [ZJOI2007] 捉迷藏

这道题因为有点权的修改和路径的查询,所以考虑点分树处理。

一条路径可能会被划分为 $u \to lca$ 加上 $lca \to v$ 两个部分,因此我们根据路径的这个性质开始处理。

具体的,对于每个节点我们都需要维护一个它在点分树上的所有后代(包括本身)到它父亲的距离之和,这个距离可以使用倍增在原树上计算,并且我们可以维护一个平衡树轻松 $O(\log^2)$ 维护这个操作。

然后因为上面路径的性质,一个节点还需要维护它在点分树上的儿子节点中到它距离最远的两个(类似于 dp 树的直径),这个我们可以用平衡树来存储所有儿子的子树中到它距离最远的节点。(每个儿子的子树只存一个,避免路径重复计算)

最后用一个全局平衡树维护答案即可。

为了常数,程序中使用了可删堆来代替平衡树,时间复杂度总的就是 $O(n \log^2 n)$,并且要记得判断所有灯都开启的情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include<bits/stdc++.h>
#define N 200005
using namespace std;
struct heap{
priority_queue<int> q1,q2;
inline void insert(int x){q1.push(x);}
inline int top(){
while(q2.size()&&q1.top()==q2.top()) q1.pop(),q2.pop();
return q1.top();
}
inline int topp(){
while(q2.size()&&q1.top()==q2.top()) q1.pop(),q2.pop();
int tmp = q1.top();
q1.pop();
while(q2.size()&&q1.top()==q2.top()) q1.pop(),q2.pop();
int res = q1.top();
q1.push(tmp);
return res;
}
inline int size(){
return q1.size()-q2.size();
}
inline void erase(int x){q2.push(x);}
}calc[N],alls[N],ans;
int n,m,q,a[N],i,x,y,z,la[N],ne[N],to[N],tot,col[N],siz[N],st[N][21],dep[N],tim,fath[N],vis[N],sum;
char opt;
inline void merge(int x,int y){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y;}
void dfs(int x,int fa){
for(int i=1;i<=20;i++) st[x][i]=st[st[x][i-1]][i-1];
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa) continue;
dep[to[i]] = dep[x]+1,st[to[i]][0] = x;
dfs(to[i],x);
}
}
inline int query(int x,int y){
if(!x||!y) return 0;
int ans = dep[x]+dep[y];
if(dep[x]>dep[y]) swap(x,y);
for(int i=20;i>=0;i--) if(dep[st[y][i]]>=dep[x]) y=st[y][i];
if(x==y) return ans-2*dep[x];
for(int i=20;i>=0;i--) if(st[x][i]!=st[y][i]) x=st[x][i],y=st[y][i];
return ans-2*dep[st[x][0]];
}
void found(int x,int fa,int &pos,int &now,int all){
int cut = 0;
siz[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(int x,int fa,int root){
if(fath[root]) alls[root].insert(query(x,fath[root]));
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,root);
}
}
void solve2(int x,int fa,int tim){
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,tim);
}
col[x] = tim;
}
void init(int x,int all,int fa){
int pos = -1,now = INT_MAX;
found(x,-1,pos,now,all);
fath[pos] = fa;
if(fa) alls[pos].insert(query(pos,fa));
calc[pos].insert(0);
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve(to[i],pos,pos);
}
if(fa) calc[fa].insert(alls[pos].top());
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve2(to[i],pos,++tim);
init(to[i],siz[to[i]],pos);
}
if(calc[pos].size()>=2) ans.insert(calc[pos].top()+calc[pos].topp());
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
for(i=1;i<n;i++) cin>>x>>y,merge(x,y),merge(y,x);
dep[1] = 1,dfs(1,-1);
init(1,n,0);
cin>>q;
while(q--){
cin>>opt;
if(opt=='G'){
if(sum==n) cout<<"-1\n";
else if(ans.size()) cout<<ans.top()<<endl;
else cout<<"0\n";
}
else{
cin>>x;
vis[x]^=1;
if(vis[x]){
sum++;
if(calc[x].size()>=2) ans.erase(calc[x].top()+calc[x].topp());
calc[x].erase(0);
if(calc[x].size()>=2) ans.insert(calc[x].top()+calc[x].topp());
for(i=x;fath[i];i=fath[i]){
if(calc[fath[i]].size()>=2) ans.erase(calc[fath[i]].top()+calc[fath[i]].topp());
calc[fath[i]].erase(alls[i].top());
alls[i].erase(query(x,fath[i]));
if(alls[i].size()) calc[fath[i]].insert(alls[i].top());
if(calc[fath[i]].size()>=2) ans.insert(calc[fath[i]].top()+calc[fath[i]].topp());
}
}
else{
sum--;
if(calc[x].size()>=2) ans.erase(calc[x].top()+calc[x].topp());
calc[x].insert(0);
if(calc[x].size()>=2) ans.insert(calc[x].top()+calc[x].topp());
for(i=x;fath[i];i=fath[i]){
if(calc[fath[i]].size()>=2) ans.erase(calc[fath[i]].top()+calc[fath[i]].topp());
if(alls[i].size()) calc[fath[i]].erase(alls[i].top());
alls[i].insert(query(x,fath[i]));
calc[fath[i]].insert(alls[i].top());
if(calc[fath[i]].size()>=2) ans.insert(calc[fath[i]].top()+calc[fath[i]].topp());
}
}
}
}
return 0;
}

例题 2

P6329 【模板】点分树 | 震波

定义分治块 $x$ 为点分树上以 $x$​ 为根的子树的节点集合。

与上一道题不同的是,这道题是维护一个距离信息,而不是某条链的信息,我们依然可以考虑点分治。

因为不是维护链的信息,所以实现起来要简单许多,首先对于 $pos$ 节点来说,建立两棵线段树,一棵维护 $pos$ 的点分树子树内距离 $pos$ 节点为 $k$ 的点权和,一棵维护点分树子树内距离 $fa_{pos}$(点分树上的父亲)为 $k$ 的点权和。

修改是容易的,只需要在点分树上暴力跳父亲就可以了,考虑查询。

首先要加上当前点分树子树内距离 $pos$ 不超过 $v$ 的点权和,此时还有一些节点没有考虑,这些节点到 $pos$ 的路径一定经过了 $pos$ 在点分树上的祖先,考虑枚举祖先,设枚举到了 $y$,$y$ 往 $pos$ 走一步(下级祖先)是 $x$。

那么我们要加上所有经过 $y$ 到 $pos$ 的节点的点权和,答案就要加上 $y$ 子树对 $pos$ 的贡献,也就是 $0 \sim v-\operatorname{dis}(y,pos)$ 下标的和,这样的话我们算重复了一部分,那一部分并不需要经过 $y$ 到达 $pos$,但它们一定在 $x$ 及其子树内,于是减去 $x$ 的那棵维护点分树子树内距离 $fa_{pos}$(点分树上的父亲)为 $k$ 的点权和的线段树上 $0 \sim v-\operatorname{dis}(y,pos)$ 下标的和就可以了。

时间复杂度是 $O(n \log^2 n)$,空间复杂度卡满同阶。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include<bits/stdc++.h>
#define N 200005
using namespace std;
struct node{int l,r,s;}tr[N<<5];
int n,m,q,a[N],i,x,y,z,la[N],ne[N],to[N],tot,col[N],siz[N],dep[N],tim,fath[N],vis[N],val[N],depp[N],sum,son[N],summ[N],top[N],faa[N],tr_tot;
int alls[N],calc[N],opt,ans;
void add(int x,int c,int s,int t,int &p){
if(!p) p=++tr_tot;
if(s==t){
tr[p].s+=c;
return ;
}
if(x<=(s+t)/2) add(x,c,s,(s+t)/2,tr[p].l);
else add(x,c,(s+t)/2+1,t,tr[p].r);
tr[p].s = tr[tr[p].l].s + tr[tr[p].r].s;
}
int query(int l,int r,int s,int t,int p){
if(!p) return 0;
if(l<=s&&t<=r) return tr[p].s;
int ans = 0;
if(l<=(s+t)/2) ans+=query(l,r,s,(s+t)/2,tr[p].l);
if(r>(s+t)/2) ans+=query(l,r,(s+t)/2+1,t,tr[p].r);
return ans;
}
inline void merge(int x,int y,int z){tot++,ne[tot]=la[x],la[x]=tot,to[tot]=y,val[tot]=z;}
void dfs(int x,int fa){
summ[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa) continue;
dep[to[i]] = dep[x]+1,depp[to[i]]=depp[x]+val[i];
dfs(to[i],x);
if(summ[to[i]]>summ[son[x]]) son[x]=to[i];
summ[x]+=summ[to[i]];
}
}
void dfs2(int x,int fa,int topp){
faa[x] = fa,top[x] = topp;
if(son[x]) dfs2(son[x],x,topp);
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||to[i]==son[x]) continue;
dfs2(to[i],x,to[i]);
}
}
inline int query(int x,int y){
if(!x||!y) return 0;
int ans = depp[x]+depp[y];
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
y = faa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
return ans-2*depp[x];
}
void found(int x,int fa,int &pos,int &now,int all){
int cut = 0;
siz[x] = 1;
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[to[i]]!=col[x]) continue;
found(to[i],x,pos,now,all);
siz[x] += siz[to[i]];
cut = max(cut,siz[to[i]]);
}
cut = max(cut,all-siz[x]);
if(cut<now) now=cut,pos=x;
}
void solve(int x,int fa,int root){
if(fath[root]) add(query(x,fath[root]),a[x],0,n,alls[root]);
add(query(x,root),a[x],0,n,calc[root]);
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve(to[i],x,root);
}
}
void solve2(int x,int fa,int tim){
for(int i=la[x];i;i=ne[i]){
if(to[i]==fa||col[x]!=col[to[i]]) continue;
solve2(to[i],x,tim);
}
col[x] = tim;
}
void init(int x,int all,int fa){
int pos = -1,now = INT_MAX;
found(x,-1,pos,now,all);
fath[pos] = fa;
add(0,a[pos],0,n,calc[pos]);
if(fath[pos]) add(query(pos,fath[pos]),a[pos],0,n,alls[pos]);
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve(to[i],pos,pos);
}
for(int i=la[pos];i;i=ne[i]){
if(col[to[i]]!=col[pos]) continue;
solve2(to[i],pos,++tim);
init(to[i],siz[to[i]],pos);
}
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>q;
for(i=1;i<=n;i++) cin>>a[i];
for(i=1;i<n;i++) cin>>x>>y,merge(x,y,1),merge(y,x,1);
dep[1] = 1,dfs(1,0),dfs2(1,0,1);
init(1,n,0);
while(q--){
cin>>opt;
if(opt==0){
cin>>x>>y;
x^=ans,y^=ans;
ans = query(0,y,0,n,calc[x]);
for(i=x;fath[i];i=fath[i]){
int dis = query(x,fath[i]);
ans += query(0,y-dis,0,n,calc[fath[i]]);
ans -= query(0,y-dis,0,n,alls[i]);
}
cout<<ans<<endl;
}
else{
cin>>x>>y;
x^=ans,y^=ans;
int tmp = y-a[x];
a[x] = y;
for(i=x;i;i=fath[i]){
add(query(x,i),tmp,0,n,calc[i]);
if(fath[i]) add(query(x,fath[i]),tmp,0,n,alls[i]);
}
}
}
return 0;
}

总结

点分树的运用一般都是每个节点维护当前节点的点分树的子树内所有节点对于它的一些信息,和对于它在点分树上的父亲的一些信息。

涉及到链相关的询问,一般都是剖分成 $u \to lca,lca \to v$ 的信息进行合并处理。

涉及到层数/深度相关的询问,一般都是利用一些区间的数据结构处理,并且还需要小小的容斥一下,即上文”经过 $y$ 到达 $pos$ 的节点“,因为 $x$ 内的节点已经处理过了,删除掉 $x$ 子树对于答案的贡献即可。

点分树转化成普通的链的情况就是线段树,我们可以思考成树上的分治(本质也是如此)。