全局平衡二叉树

全局平衡二叉树

六月 30, 2024

引入

我们都知道,树链剖分的时间复杂度如下表:

链修改 链查询 子树修改 子树查询
时间复杂度 $O(\log^2 n)$ $O(\log^2 n)$ $O(\log n)$ $O(\log n)$
空间复杂度 $O(n)$ $O(n)$ $O(n)$ $O(n)$

容易发现,链修改和链查询很慢,多出来的一个 $\log$ 被用到了线段树的区间修改上,不过,我们可以为一条链单独开一棵线段树来减少常熟消耗,但不幸的是,这样子并不妨碍卡树链剖分的数据。

特别是在维护 DDP 的时候,树链剖分多出来的一个 $\log$ 尽管在随机数据上跑得跟 $O(1)$ 一样快,不过终究会 TLE,因此,我们这里引入一个科技:“全局平衡二叉树”,来优化掉这个 $\log$。

声明:在随机数据下,树链剖分比全局平衡二叉树快,但是全局平衡二叉树的时间复杂度是标准的单 $\log$,无法被卡。

构建

首先,我们还是像树链剖分那样分出来一棵树的轻重边,假设分出来的边如下图所示:

那么我们对每个重链重构二叉树,将二叉树的根节点与重链链头的父亲用边连接,这样我们就得到了一个跟 LCT 相似的辅助树:

现在,我们规定对每个重链重构二叉树的时候,需要满足这棵二叉树的中序遍历就是这条链从浅到深的节点序列,并且每次我们选择这个序列的带权中位数为根,左右两边递归构建子二叉树。

带权中位数的权值是当前节点去掉重儿子后的子树大小,即 $sum_x-sum_{son_x}$,一定要加上它自己所代表的 $1$。

这样建树,我们就有一个非常重要的性质:辅助树的树高不超过 $\log n$。

证明:首先从一个节点到根的轻边不超过 $\log$ 条,由树链剖分的性质决定;然后到根的二叉树边也不会超过 $\log$ 条,因为我们取带权中位数,这样的话每跳一跳二叉树边,子树大小至少乘 $2$,得证。

下面给出构建代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
inline int build(int s,int t,int fa){
if(s>t) return 0;
if(s==t){
tr[sta[s]].fa = fa,tr[sta[s]].val = get(sta[s]),pushup(sta[s]);
return sta[s];
}
int ans = INT_MAX,pos = 1;
qzh[s-1]=0;
for(int i=s;i<=t;i++) qzh[i]=qzh[i-1]+light[sta[i]]; //light[u] 表示 u 轻儿子子树大小
for(int i=s;i<=t;i++) if(2*qzh[i]>=qzh[t]){pos=i;break;} //请注意这里的写法
tr[sta[pos]].ls = build(s,pos-1,sta[pos]),tr[sta[pos]].rs = build(pos+1,t,sta[pos]),tr[sta[pos]].val=get(sta[pos]),pushup(sta[pos]),tr[sta[pos]].fa=fa;
return sta[pos];
}

操作杂谈

注意:以下讨论基于 $tag$ 具有线段树的可合并性以及结合律。

链修改

首先,一个链修改可以拆分成 $1 \to x$ 的形式,因此下面只讨论从 $1$ 开始的链修改/查询。

因为在辅助树上 $x$ 可以暴力往根跳父亲,于是 $x$ 在二叉树的时候就需要把中序遍历小于等于 $x$ 的节点打上标记,这并不困难,只需要从 $x$ 跳到当前二叉树的根节点,然后给路径上的各个节点打上标记($x$ 的右子树不能打标记)即可。然后再处理下一棵二叉树,时间复杂度 $O(\log n)$。

链查询

如果有上文链修改的 $tag$,那么需要先跳一遍把 $tag$ 下放再 pushup 回去,然后再暴力跳到辅助树的根节点统计答案即可;也可以选择一边统计答案一边下放 $tag$,不影响时间复杂度 $O(\log n)$。

子树修改

设子树修改根节点为 $u$,那么找到 $u$ 的重链所代表的二叉树 $S$,那么我们直接对 $u$ 的重链上属于 $u$ 子树的那部分打上标记,这个不难,只需要跳到 $S$ 根节点处理即可。问题是 $u$ 子树有通过轻边连接到 $u$ 重链的节点,那部分节点需要通过轻边下传标记,因为不管轻边还是二叉树边,条数都是 $O(\log n)$ 的,所以直接下放即可。时间复杂度 $O(\log n)$。

子树查询

同上,跳到整棵辅助树的根节点下放标记(轻边也要下放),然后再执行一遍修改的操作统计答案即可。时间复杂度 $O(\log n)$​。

动态 DP 优化

本质上是链修改/链查询,并且没有 $tag$ 的上传/下放,于是就可以优化到单 $\log$。

复杂度

注意到,全局平衡二叉树的复杂度消耗如下表:

链修改 链查询 子树修改 子树查询
时间复杂度 $O(\log n)$ $O(\log n)$ $O(\log n)$ $O(\log n)$
空间复杂度 $O(n)$ $O(n)$ $O(n)$ $O(n)$

十分优秀,不过常数很大,$\log$ 基本上要跑满,还是那句话,在随机数据下不如树链剖分,如果碰到卡树链剖分的题,全局平衡二叉树是不二之选。

下面提供卡树链剖分题目的全局平衡二叉树写法 SDOI2017 切树游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#include<bits/stdc++.h>
#define mod 10007
#define inv2 5004
#define N 60005
#define M 128
namespace IO{
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline int read(){
int res = 0;
char c = nc();
while(c<'0'||c>'9')c=nc();
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(int x){
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
}
using namespace std;
using namespace IO;
inline int qmi(int a,int b,int p){
int res = 1%p,t = a;
while(b){
if(b&1) res=res*t%p;
t=t*t%p;
b>>=1;
}
return res;
}
vector<int> op[N];
int n,m,i,x,y,a[N],q,sum[N],son[N],top[N],dfn[N],tot,nid[N],leaf[N],fath[N],inv[N],light[N],sta[N],tott,qzh[N],istop[N];
inline void gxor1(int *a,int n){
for(int i=1;i<n;i*=2){
for(int j=0;j<n;j+=2*i){
for(int k=j;k<j+i;k++){
a[k] += a[k+i],a[k+i] = a[k]-a[k+i]-a[k+i]+2*mod;
a[k] %= mod,a[k+i] %= mod;
}
}
}
}
inline void gxor2(int *a,int n){
for(int i=1;i<n;i*=2){
for(int j=0;j<n;j+=2*i){
for(int k=j;k<j+i;k++){
a[k] += a[k+i],a[k+i] = a[k]-a[k+i]-a[k+i]+2*mod;
a[k] = a[k]*inv2%mod,a[k+i] = a[k+i]*inv2%mod;
}
}
}
}
struct node{int a[M];}emp,ep,f[N],h[N],lh[N],zero,one,getto[N];
struct nodes{pair<int,int> a[M];}lf[N];
inline node to_no1(nodes temp){
node ans;
for(int i=0;i<m;i++) ans.a[i]=(temp.a[i].second?0:temp.a[i].first);
return ans;
}
inline nodes to_no2(node temp){
nodes ans;
for(int i=0;i<m;i++) ans.a[i] = make_pair(temp.a[i],0);
return ans;
}
nodes operator*(nodes a,node b){
for(int i=0;i<m;i++){
if(b.a[i]==0) a.a[i].second++;
else a.a[i].first=a.a[i].first*b.a[i]%mod;
}
return a;
}
nodes operator/(nodes a,node b){
for(int i=0;i<m;i++){
if(b.a[i]==0) a.a[i].second--;
else a.a[i].first=a.a[i].first*inv[b.a[i]]%mod;
}
return a;
}
node operator*(node a,node b){
for(int i=0;i<m;i++) emp.a[i]=(a.a[i]*b.a[i])%mod;
return emp;
}
node operator+(node a,node b){
for(int i=0;i<m;i++) (emp.a[i]=(a.a[i]+b.a[i]))>=mod&&(emp.a[i]-=mod);
return emp;
}
node operator-(node a,node b){
for(int i=0;i<m;i++) (emp.a[i]=(a.a[i]-b.a[i]))<0&&(emp.a[i]+=mod);
return emp;
}
node get_to(int pos){
for(int i=0;i<m;i++) emp.a[i]=(i==pos);
gxor1(emp.a,m);
return emp;
}
struct matrix{node a[4];}temp,emps;
matrix operator*(matrix a,matrix b){
temp.a[0] = a.a[0]*b.a[0];
temp.a[1] = a.a[0]*b.a[1]+a.a[1];
temp.a[2] = b.a[0]*a.a[2]+b.a[2];
temp.a[3] = a.a[2]*b.a[1]+a.a[3]+b.a[3];
return temp;
}
matrix operator/(matrix a,matrix b){
temp.a[0] = (a.a[0]*b.a[0]+a.a[2]*b.a[2]);
temp.a[1] = (a.a[0]*b.a[1]+a.a[1]+a.a[2]*b.a[3]);
temp.a[2] = a.a[2];
return temp;
}
struct NODE{int fa,ls,rs;matrix val,p;}tr[N<<2];
inline bool isroot(int x){return tr[tr[x].fa].ls!=x&&tr[tr[x].fa].rs!=x;}
inline matrix get(int u){
node temp = to_no1(lf[u])*getto[a[u]];
matrix res;
res.a[0] = temp,res.a[1] = temp,res.a[2] = temp,res.a[3] = lh[u]+temp;
return res;
}
inline matrix query(int pos){
int las = leaf[pos];
matrix ans = emps;
ans.a[0]=getto[a[las]],ans.a[1]=ans.a[0],ans.a[2]=one;
if(pos==las) return ans;
while(!isroot(pos)) pos=tr[pos].fa;
ans = ans/tr[pos].p;
return ans;
}
inline void pushup(int p){
if(tr[p].ls&&!istop[tr[p].rs]&&tr[p].rs) tr[p].p=tr[tr[p].rs].p*tr[p].val*tr[tr[p].ls].p;
else if(tr[p].ls) tr[p].p=tr[p].val*tr[tr[p].ls].p;
else if(tr[p].rs&&!istop[tr[p].rs]) tr[p].p=tr[tr[p].rs].p*tr[p].val;
else tr[p].p=tr[p].val;
}
inline int build(int s,int t,int fa){
// cerr<<s<<" "<<t<<" "<<fa<<endl;
if(s>t) return 0;
if(s==t){
tr[sta[s]].fa = fa,tr[sta[s]].val = get(sta[s]),pushup(sta[s]);
return sta[s];
}
int ans = INT_MAX,pos = 1;
qzh[s-1]=0;
for(int i=s;i<=t;i++) qzh[i]=qzh[i-1]+light[sta[i]];
for(int i=s;i<=t;i++) if(2*qzh[i]>=qzh[t]){pos=i;break;}
tr[sta[pos]].ls = build(s,pos-1,sta[pos]),tr[sta[pos]].rs = build(pos+1,t,sta[pos]),tr[sta[pos]].val=get(sta[pos]),pushup(sta[pos]),tr[sta[pos]].fa=fa;
return sta[pos];
}
inline void init(int x,int fa){
sum[x] = 1;
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa) continue;
fath[op[x][i]] = x;
init(op[x][i],x);
sum[x] += sum[op[x][i]];
if(sum[op[x][i]]>sum[son[x]]) son[x]=op[x][i];
}
light[x]++;
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa||op[x][i]==son[x]) continue;
light[x] += sum[op[x][i]];
}
}
inline void init2(int x,int fa,int topp){
top[x] = topp,dfn[x] = ++tot,nid[tot] = x;
if(son[x]) init2(son[x],x,topp),leaf[x]=leaf[son[x]];
else leaf[x]=x,istop[x] = 1;
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa||op[x][i]==son[x]) continue;
init2(op[x][i],x,op[x][i]);
}
}
inline void init3(int x,int fa){
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa) continue;
init3(op[x][i],x);
}
f[x]=getto[a[x]],lf[x]=to_no2(one);
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa) continue;
if(op[x][i]==son[x]) h[x]=h[op[x][i]],f[x]=f[x]*(f[op[x][i]]+one);
else lh[x]=lh[x]+h[op[x][i]],lf[x]=lf[x]*(f[op[x][i]]+one);
}
f[x]=f[x]*to_no1(lf[x]);
h[x]=h[x]+lh[x]+f[x];
}
inline void init4(int x,int fa){
if(x==top[x]){
int temp = leaf[x];
tott = 0;
while(1){
sta[++tott] = temp;
if(temp==x) break;
temp = fath[temp];
}
reverse(sta+1,sta+tott+1);
build(1,tott,fath[x]);
}
for(int i=0;i<op[x].size();i++){
if(op[x][i]==fa) continue;
init4(op[x][i],x);
}
}
inline void upd(int x,int y){
int pos = x,temp = a[x],func = 0;
while(1){
// cout<<"!! "<<x<<endl;
if(isroot(x)){
// cout<<"??? "<<x<<endl;
a[pos] = temp;
matrix old = query(x);
a[pos] = y;
tr[x].val = get(x);
pushup(x);
matrix now = query(x);
if(tr[x].fa){
lf[tr[x].fa] = lf[tr[x].fa]/(old.a[0]+one)*(now.a[0]+one);
lh[tr[x].fa] = lh[tr[x].fa]-old.a[1]+now.a[1];
// cout<<lf[1].a[0]<<" "<<lf[1].a[1]<<" "<<lf[1].a[2]<<" "<<lf[1].a[3]<<endl;
// cout<<old.a[1].a[0]<<" "<<old.a[1].a[1]<<" "<<old.a[1].a[2]<<" "<<old.a[1].a[3]<<endl;
}
else break;
}
else{
a[pos] = y;
tr[x].val = get(x);
pushup(x);
}
x=tr[x].fa,func++;
}
}
int main(){
// freopen("1.in","r",stdin);
// freopen("1.out","w",stdout);
n=read(),m=read();
for(i=0;i<mod;i++) inv[i]=qmi(i,mod-2,mod);
while((m&(-m))!=m) m++;
for(i=0;i<m;i++) getto[i]=get_to(i);
for(i=0;i<m;i++) one.a[i]=1;
for(i=1;i<=n;i++) a[i]=read();
for(i=1;i<n;i++) x=read(),y=read(),op[x].push_back(y),op[y].push_back(x);
init(1,0),init2(1,0,1),init3(1,0),init4(1,0);
q=read();
while(q--){
char opt = nc();
while(opt!='C'&&opt!='Q') opt=nc();
if(opt=='Q'){
x=read();
matrix ans = query(1);
gxor2(ans.a[1].a,m);
write(ans.a[1].a[x]),pc('\n');
}
else{
x=read(),y=read();
upd(x,y);
}
}
return fwrite(obuf,p3-obuf,1,stdout),0;
}