如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
模板题,要求维护权值线段树合并,分裂,单项插入,区间查询,找第 $k$ 小。
权值线段树的节点 $[l,r]$ 维护的信息有,他的左孩子和右孩子,区间内元素个数。
单项插入
权值线段树基本操作,此题不用离散化。
void update(int const &p,int const &v,int &x,int const &l=1,int const &r=n){
if(!x) x=newnode();//动态开点
tr[x].v+=v;
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) update(p,v,tr[x].ls,l,mid);
else update(p,v,tr[x].rs,mid+1,r);
}
区间查询
也是基本操作,注意的是动态开点线段树在找到空节点后就可以返回了。
long long query(int const &pl,int const &pr,int const &x,int const &l=1,int const &r=n){
if(!x) return 0;
if(pl==l&&pr==r) return tr[x].v;
int mid=(l+r)>>1;
if(pr<=mid) return query(pl,pr,tr[x].ls,l,mid);
else if(pl>mid) return query(pl,pr,tr[x].rs,mid+1,r);
else return query(pl,mid,tr[x].ls,l,mid)+query(mid+1,pr,tr[x].rs,mid+1,r);
}
找第 $k$ 小
在线段树上二分,如果左孩子的元素个数大于等于 $k$,说明第 $k$ 小在左子树内;否则,在右子树内。
int kth(long long const &p,int const &x,int const &l=1,int const &r=n){
if(tr[x].v<p) return -1;
if(l==r) return l;
int mid=(l+r)>>1;
if(tr[tr[x].ls].v>=p) return kth(p,tr[x].ls,l,mid);
else return kth(p-tr[tr[x].ls].v,tr[x].rs,mid+1,r);
}
线段树合并
线段树合并有两种写法,这里写的是将一棵树合并到另一颗上的写法。对于 $t$ 中没有的节点不用遍历;对于 $p$ 中没有但 $t$ 中有的,可以直接把该节点挂在 $p$ 上。($p,t$ 的含义见题目)
void merge(int &x1,int &x2,int const &l=1,int const &r=n){
if(!x2) return;
if(!x1){x1=x2;x2=0;return;}
tr[x1].v+=tr[x2].v;
int mid=(l+r)>>1;
merge(tr[x1].ls,tr[x2].ls,l,mid);
merge(tr[x1].rs,tr[x2].rs,l,mid);
delnode(x2);//垃圾回收
}
线段树分裂
终于到正题了,线段树分裂是线段树合并的逆操作。对于待分裂区域与其他区域的公有节点,复制一份;对于独有节点,直接拿过来挂上去;最后记得 pushup
。
void split(int const &pl,int const &pr,int &x1,int &x2,int const &l=1,int const &r=n){
if(pl==l&&pr==r){
x2=x1;
x1=0;
return;
}
if(!x2) x2=newnode();
int mid=(l+r)>>1;
if(pr<=mid) split(pl,pr,tr[x1].ls,tr[x2].ls,l,mid);
else if(pl>mid) split(pl,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
else split(pl,mid,tr[x1].ls,tr[x2].ls,l,mid),split(mid+1,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
tr[x1].v=tr[tr[x1].ls].v+tr[tr[x1].rs].v;//pushup
tr[x2].v=tr[tr[x2].ls].v+tr[tr[x2].rs].v;
}
可回收垃圾(
新建节点和删除节点,开个垃圾桶,回收废弃节点。
int newnode(){
if(tp!=lj)return *--tp;
else return ++cnt;
}
void delnode(int &x){
*tp++=x;
tr[x].v=tr[x].ls=tr[x].rs=0;
x=0;
}
时间复杂度证明
单点修改:由于线段树的二分结构,深度为 ${\log_2n}$,每层只访问 $1$ 个节点,所以时间复杂度为 $\mathrm {O(\log_2n)}$。
区间查询:称被询问区间完整包含的节点为完整节点,被部分包含的为边缘节点,则每一层最多只访问 $2$ 个边缘节点,位于区间边缘;最多访问 $2$ 个完整节点,因为完整节点的兄弟一定不是完整节点。原因是,若一个节点的两个孩子均为完整节点,则该节点为完整节点,不会访问其孩子。所以一层最多访问 $4$ 节点,时间复杂度 $\mathrm{O(\log_2n)}$
找第 $k$ 小:每层只会访问 $1$ 个节点,时间复杂度为 $\mathrm {O(\log_2n)}$。
线段树合并:显然,我们每次只会访问重合节点,那么单次合并的时间复杂度就可以认为是较小树的节点个数,那么显然,多次合并的总复杂度为总点数级别,$\mathrm {O(n\log_2n)}$。
- 注:不能简单的认为其复杂度为单次合并的最坏复杂度乘询问次数。感性理解,在本题中,显然我们可以花费一次询问使树的节点加一条链,让合并的复杂度增大,但也会失去一次进行线段树合并的机会,这样就保证了其总复杂度不会过高。(很玄学,我也不会更严谨的证明,希望有大佬教我)
线段树分裂:我们可以看到,这个操作访问的节点和区间查询是一致的,复制边缘节点,直接拿走完整节点。时间复杂度 $\mathrm{O(\log_2n)}$
完整代码
#include<cstdio>
struct node{
long long v;
int ls,rs;
}tr[8000010];
int lj[8000010],*tp(lj),cnt;
int n,m,rt[200010],rtcnt=1;
int newnode(){
if(tp!=lj)return *--tp;
else return ++cnt;
}
void delnode(int &x){
*tp++=x;
tr[x].v=tr[x].ls=tr[x].rs=0;
x=0;
}
void split(int const &pl,int const &pr,int &x1,int &x2,int const &l=1,int const &r=n){
if(pl==l&&pr==r){
x2=x1;
x1=0;
return;
}
if(!x2) x2=newnode();
int mid=(l+r)>>1;
if(pr<=mid) split(pl,pr,tr[x1].ls,tr[x2].ls,l,mid);
else if(pl>mid) split(pl,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
else split(pl,mid,tr[x1].ls,tr[x2].ls,l,mid),split(mid+1,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
tr[x1].v=tr[tr[x1].ls].v+tr[tr[x1].rs].v;
tr[x2].v=tr[tr[x2].ls].v+tr[tr[x2].rs].v;
}
void merge(int &x1,int &x2,int const &l=1,int const &r=n){
if(!x2) return;
if(!x1){x1=x2;x2=0;return;}
tr[x1].v+=tr[x2].v;
int mid=(l+r)>>1;
merge(tr[x1].ls,tr[x2].ls,l,mid);
merge(tr[x1].rs,tr[x2].rs,l,mid);
delnode(x2);
}
void update(int const &p,int const &v,int &x,int const &l=1,int const &r=n){
if(!x) x=newnode();
tr[x].v+=v;
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) update(p,v,tr[x].ls,l,mid);
else update(p,v,tr[x].rs,mid+1,r);
}
long long query(int const &pl,int const &pr,int const &x,int const &l=1,int const &r=n){
if(!x) return 0;
if(pl==l&&pr==r) return tr[x].v;
int mid=(l+r)>>1;
if(pr<=mid) return query(pl,pr,tr[x].ls,l,mid);
else if(pl>mid) return query(pl,pr,tr[x].rs,mid+1,r);
else return query(pl,mid,tr[x].ls,l,mid)+query(mid+1,pr,tr[x].rs,mid+1,r);
}
int kth(long long const &p,int const &x,int const &l=1,int const &r=n){
if(tr[x].v<p) return -1;
if(l==r) return l;
int mid=(l+r)>>1;
if(tr[tr[x].ls].v>=p) return kth(p,tr[x].ls,l,mid);
else return kth(p-tr[tr[x].ls].v,tr[x].rs,mid+1,r);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1,x;i<=n;i++){
scanf("%d",&x);
if(x) update(i,x,rt[1]);
}
for(int i=1,op,p,x,y;i<=m;i++){
scanf("%d%d",&op,&p);
switch(op){
case 0:
scanf("%d%d",&x,&y);
split(x,y,rt[p],rt[++rtcnt]);
break;
case 1:
scanf("%d",&x);
merge(rt[p],rt[x]);
break;
case 2:
scanf("%d%d",&x,&y);
if(x)update(y,x,rt[p]);
break;
case 3:
scanf("%d%d",&x,&y);
printf("%lld\n",query(x,y,rt[p]));
break;
case 4:
long long k;
scanf("%lld",&k);
printf("%d\n",kth(k,rt[p]));
break;
}
}
return 0;
}