KTT & Segment Tree Beats

KTT & Segment Tree Beats

KTT & 吉司机线段树

首先,吉司机线段树是一种用来维护区间取 $\max$,区间求和的操作,它的时间复杂度是通过势能分析得到的 $O(n \log n)$,我们也可以在这上面进行一定的扩展,从而达到更加高级的区间操作类型。

KTT 简介

KTT 本质上是一种类型的吉司机线段树,这个类型的线段树主要还是在于每个节点维护一个 $key$ 键值,当修改的 $c \le key$ 我们可以直接打标记进行修改;否则就一直递归。

这样的时间复杂度视情况而定,最朴素的吉司机线段树就是 $\log$ 级别的,如果加上区间加等操作就是 $\log^2$ 级别的,这里介绍的 KTT 需要维护区间最大子段和,这个代价是 $\log^3$ 级别的,具体证明可以看发明者的博客或者在网上搜索资料。

KTT 构建

因为每个节点需要维护 $lmax,smax,rmax,sum$ 四个变量,即左边的最大子段和,整个的最大子段和,右边的最大子段和和整个序列的和。

现在我们需要支持:

  • 区间加正数。
  • 区间最大子段和。

接下来我们考虑维护一个键值 $s$ 表示当一次性加的数 $>s$ 的时候上面四个变量就会有至少一个的转移方程被改动。

转移方程被改动:以 $smax$ 的转移为例,$smax=\max{smax_{ls},smax_{rs},rmax_{ls}+lmax_{rs}}$,当一次性加的数 $>s$ 的时候就有可能不会取某个值,而是转到另外一种转移去取。

这个时候我们就需要把上面的四个变量看做四个函数 $y=kx+b$,$x$ 就是一次性加的值,是不确定的,$k$ 就是区间的长度,如果转移方式不变,那么一次性加 $x$ 就相当于让答案加了 $kx$,$b$ 就是当 $x=0$ 的时候的值,就是原本的变量的值。

那么键值 $s$ 怎么维护呢?我们于是把 $lmax,smax,rmax,sum$ 四个函数的转移拿出来,$\max$ 中的任意两个函数都有一个交点,取这些交点最小的 $x$​ 是多少就好。

特别的,$s$ 还要和子树中所有 $s$ 取 $\min$,因为有可能会影响到转移时候的值,当交点 $x<0$ 的时候就不用管它,如果所有交点都小于 $0$,$s \gets \inf$​。

函数的加减法既可以通过意义得出,也可以通过数学推导得出:$k’ \gets k_1+k_2,b’ \gets b_1+b_2$。

还有 pushtag 的时候记得让 $b \gets b+kx$。

题目:P5693 EI 的第六分块,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<bits/stdc++.h>
#define ll int
#define N 400005
using namespace std;
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline ll read(){
ll res = 0,w = 1;
char c = nc();
while(c<'0'||c>'9'){
if(c=='-') w=-1;
c=nc();
}
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res*w;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(long long x){
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
struct poly{
long long k,b;
inline void add(long long x){b+=k*x;}
};
inline poly operator+(const poly a,const poly b){return (poly){a.k+b.k,a.b+b.b};}
inline bool operator<(poly a,poly b){return a.b<b.b;}
inline long long get(poly a,poly b){
if(a.k==b.k) return 0x3f3f3f3f3f3f3f3f;
if(a.k<b.k) swap(a,b);
if(b.b<=a.b) return 0x3f3f3f3f3f3f3f3f;
return (b.b-a.b)/(a.k-b.k);
}
struct node{
poly smax,lmax,rmax,sum;
long long s,tag;
}tr[N<<2],c;
inline node merge(node a,node b){
c.sum = a.sum + b.sum;
c.smax = max({a.smax,b.smax,a.rmax+b.lmax});
c.lmax = max(a.lmax,b.lmax+a.sum);
c.rmax = max(b.rmax,a.rmax+b.sum);
c.s = min({a.s,b.s,get(a.lmax,b.lmax+a.sum),get(b.rmax,a.rmax+b.sum),get(a.smax,b.smax),get(b.smax,a.rmax+b.lmax),get(a.smax,a.rmax+b.lmax)});
c.tag = 0;
return c;
}
ll n,m,opt,x,y,z,i,a[N];
inline void build(ll s,ll t,ll p){
if(s==t){
tr[p].sum = tr[p].smax = tr[p].lmax = tr[p].rmax = (poly){1,a[s]},tr[p].s = 0x3f3f3f3f3f3f3f3f;
return ;
}
build(s,(s+t)/2,2*p),build((s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline void pushtag(ll p,long long c){tr[p].tag+=c,tr[p].s-=c,tr[p].lmax.add(c),tr[p].rmax.add(c),tr[p].smax.add(c),tr[p].sum.add(c);}
inline void pushdown(ll p){
if(tr[p].tag){
pushtag(2*p,tr[p].tag),pushtag(2*p+1,tr[p].tag);
tr[p].tag = 0;
}
}
inline void defeat(long long c,ll s,ll t,ll p){
if(c<=tr[p].s) return pushtag(p,c);
c += tr[p].tag,tr[p].tag = 0;
defeat(c,s,(s+t)/2,2*p),defeat(c,(s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline void upd(ll l,ll r,ll c,ll s,ll t,ll p){
if(l<=s&&t<=r) return defeat(c,s,t,p);
pushdown(p);
if(l<=(s+t)/2) upd(l,r,c,s,(s+t)/2,2*p);
if(r>(s+t)/2) upd(l,r,c,(s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline node query(ll l,ll r,ll s,ll t,ll p){
if(l<=s&&t<=r) return tr[p];
pushdown(p);
if(l<=(s+t)/2&&r>(s+t)/2) return merge(query(l,r,s,(s+t)/2,2*p),query(l,r,(s+t)/2+1,t,2*p+1));
else if(l<=(s+t)/2) return query(l,r,s,(s+t)/2,2*p);
else return query(l,r,(s+t)/2+1,t,2*p+1);
}
int main(){
n=read(),m=read();
for(i=1;i<=n;i++) a[i]=read();
build(1,n,1);
while(m--){
opt=read();
if(opt==1){
x=read(),y=read(),z=read();
upd(x,y,z,1,n,1);
}
else{
x=read(),y=read();
write(max(query(x,y,1,n,1).smax.b,0ll)),pc('\n');
}
}
fwrite(obuf,p3-obuf,1,stdout);
return 0;
}

KTT 扩展

KTT 还可以解决区间取 $\max$,区间求最大子段和的操作,如何实现呢?

方法和上面一样,只不过如果值变动了,一定变动的只有最小值,次小值不会变动,因为我们可以在外面套个吉司机线段树模板。

然后 $y=kx+b$ 中 $k$ 的含义变成最小值的数量就可以了,特别注意记录 tag 的顺序以及求交点的时候。

因为这道题是求最小值,而最小值有一些特殊的性质,于是当交点等于 $0$ 的时候也要记录进去答案,严格来说,这么做才是对的。

题目:P6792 [SNOI2020] 区间和,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f3f3f3f3f
#define ll long long
#define N 400005
using namespace std;
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline ll read(){
ll res = 0,w = 1;
char c = nc();
while(c<'0'||c>'9'){
if(c=='-') w=-1;
c=nc();
}
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res*w;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(long long x){
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
struct poly{
long long k,b;
inline void add(long long x){b+=k*x;}
};
inline poly operator+(const poly a,const poly b){return (poly){a.k+b.k,a.b+b.b};}
inline poly max(poly a,poly b){
if(a.k<b.k||(a.k==b.k&&a.b<b.b)) swap(a,b);
if(a.b>=b.b) return a;
return b;
}
inline long long get(poly a,poly b){
if(a.k==b.k) return inf;
if(a.k<b.k) swap(a,b);
if(b.b<a.b) return inf;
return (b.b-a.b)/(a.k-b.k);
}
struct node{
poly smax,lmax,rmax,sum;
long long s,tag,minn,sminn;
}tr[N<<2],c;
inline node merge(node a,node b){
c.minn = min(a.minn,b.minn);
if(c.minn==a.minn&&c.minn==b.minn) c.sminn=min(a.sminn,b.sminn);
else if(c.minn==a.minn) c.sminn=min(a.sminn,b.minn);
else c.sminn=min(a.minn,b.sminn);
if(c.minn!=a.minn) a.sum.k=a.lmax.k=a.rmax.k=a.smax.k=0;
if(c.minn!=b.minn) b.sum.k=b.lmax.k=b.rmax.k=b.smax.k=0;
c.sum = a.sum + b.sum;
c.smax = max(a.smax,max(b.smax,a.rmax+b.lmax));
c.lmax = max(a.lmax,b.lmax+a.sum);
c.rmax = max(b.rmax,a.rmax+b.sum);
c.s = min({a.s,b.s,get(a.lmax,b.lmax+a.sum),get(b.rmax,a.rmax+b.sum),get(a.smax,b.smax),get(b.smax,a.rmax+b.lmax),get(a.smax,a.rmax+b.lmax)});
c.tag = -inf;
return c;
}
ll n,m,opt,x,y,z,i,a[N];
inline void build(ll s,ll t,ll p){
tr[p].tag = -inf;
if(s==t){
tr[p].sum = tr[p].smax = tr[p].lmax = tr[p].rmax = (poly){1,a[s]},tr[p].s = inf;
tr[p].minn = a[s],tr[p].sminn = inf;
return ;
}
build(s,(s+t)/2,2*p),build((s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline void pushtag(ll p,long long c){
if(c<=tr[p].minn) return ;
long long delta = c-tr[p].minn;
tr[p].minn=c,tr[p].tag=max(tr[p].tag,c),tr[p].s-=delta,tr[p].lmax.add(delta),tr[p].rmax.add(delta),tr[p].smax.add(delta),tr[p].sum.add(delta);
}
inline void pushdown(ll p){
if(tr[p].tag!=-inf){
pushtag(2*p,tr[p].tag),pushtag(2*p+1,tr[p].tag);
tr[p].tag = -inf;
}
}
inline void defeat(long long c,ll s,ll t,ll p){
if(c-tr[p].minn<=tr[p].s) return pushtag(p,c);
pushdown(p);
defeat(c,s,(s+t)/2,2*p),defeat(c,(s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline void upd(ll l,ll r,ll c,ll s,ll t,ll p){
if(tr[p].minn>=c) return ;
if(l<=s&&t<=r&&c<tr[p].sminn) return defeat(c,s,t,p);
pushdown(p);
if(l<=(s+t)/2) upd(l,r,c,s,(s+t)/2,2*p);
if(r>(s+t)/2) upd(l,r,c,(s+t)/2+1,t,2*p+1);
tr[p] = merge(tr[2*p],tr[2*p+1]);
}
inline node query(ll l,ll r,ll s,ll t,ll p){
if(l<=s&&t<=r) return tr[p];
pushdown(p);
if(l<=(s+t)/2&&r>(s+t)/2) return merge(query(l,r,s,(s+t)/2,2*p),query(l,r,(s+t)/2+1,t,2*p+1));
else if(l<=(s+t)/2) return query(l,r,s,(s+t)/2,2*p);
else return query(l,r,(s+t)/2+1,t,2*p+1);
}
int main(){
// freopen("hack1.in","r",stdin);
// freopen("1.out","w",stdout);
n=read(),m=read();
for(i=1;i<=n;i++) a[i]=read();
build(1,n,1);
while(m--){
// cout<<m<<endl;
opt=read();
if(opt==0){
x=read(),y=read(),z=read();
upd(x,y,z,1,n,1);
}
else{
x=read(),y=read();
write(max(query(x,y,1,n,1).smax.b,0ll)),pc('\n');
}
}
fwrite(obuf,p3-obuf,1,stdout);
return 0;
}

KTT 思想

重复一次,KTT 的思想是在于维护一个键值,当修改的范围超过键值就暴力递归,否则我们可以 $O(1)$ 处理键值修改所带来的影响。

即:在 pushtag 操作的时候就可以 $O(1)$ 更新当前节点维护的所有区间信息。

很多题目我们都可以这么来做,无疑是一个新的线段树技巧!