P5494 【模板】线段树分裂 gxy001

模板题,要求维护权值线段树合并,分裂,单项插入,区间查询,找第 $k$ 小。

权值线段树的节点 $[l,r]$ 维护的信息有,他的左孩子和右孩子,区间内元素个数。

单项插入


权值线段树基本操作,此题不用离散化。

void update(int const &p,int const &v,int &x,int const &l=1,int const &r=n){
	if(!x) x=newnode();//动态开点
	tr[x].v+=v;
	if(l==r) return;
	int mid=(l+r)>>1;
	if(p<=mid) update(p,v,tr[x].ls,l,mid);
	else update(p,v,tr[x].rs,mid+1,r);
}

区间查询


也是基本操作,注意的是动态开点线段树在找到空节点后就可以返回了。

long long query(int const &pl,int const &pr,int const &x,int const &l=1,int const &r=n){
	if(!x) return 0;
	if(pl==l&&pr==r) return tr[x].v;
	int mid=(l+r)>>1;
	if(pr<=mid) return query(pl,pr,tr[x].ls,l,mid);
	else if(pl>mid) return query(pl,pr,tr[x].rs,mid+1,r);
	else return query(pl,mid,tr[x].ls,l,mid)+query(mid+1,pr,tr[x].rs,mid+1,r);
}

找第 $k$ 小


在线段树上二分,如果左孩子的元素个数大于等于 $k$,说明第 $k$ 小在左子树内;否则,在右子树内。

int kth(long long const &p,int const &x,int const &l=1,int const &r=n){
	if(tr[x].v<p) return -1;
	if(l==r) return l;
	int mid=(l+r)>>1;
	if(tr[tr[x].ls].v>=p) return kth(p,tr[x].ls,l,mid);
	else return kth(p-tr[tr[x].ls].v,tr[x].rs,mid+1,r);
}

线段树合并


线段树合并有两种写法,这里写的是将一棵树合并到另一颗上的写法。对于 $t$ 中没有的节点不用遍历;对于 $p$ 中没有但 $t$ 中有的,可以直接把该节点挂在 $p$ 上。($p,t$ 的含义见题目)

void merge(int &x1,int &x2,int const &l=1,int const &r=n){
	if(!x2) return;
	if(!x1){x1=x2;x2=0;return;}
	tr[x1].v+=tr[x2].v;
	int mid=(l+r)>>1;
	merge(tr[x1].ls,tr[x2].ls,l,mid);
	merge(tr[x1].rs,tr[x2].rs,l,mid);
	delnode(x2);//垃圾回收
}

线段树分裂

终于到正题了,线段树分裂是线段树合并的逆操作。对于待分裂区域与其他区域的公有节点,复制一份;对于独有节点,直接拿过来挂上去;最后记得 pushup

void split(int const &pl,int const &pr,int &x1,int &x2,int const &l=1,int const &r=n){
	if(pl==l&&pr==r){
		x2=x1;
		x1=0;
		return;
	}
	if(!x2) x2=newnode();
	int mid=(l+r)>>1;
	if(pr<=mid) split(pl,pr,tr[x1].ls,tr[x2].ls,l,mid);
	else if(pl>mid) split(pl,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
	else split(pl,mid,tr[x1].ls,tr[x2].ls,l,mid),split(mid+1,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
	tr[x1].v=tr[tr[x1].ls].v+tr[tr[x1].rs].v;//pushup
	tr[x2].v=tr[tr[x2].ls].v+tr[tr[x2].rs].v;
}

可回收垃圾(


新建节点和删除节点,开个垃圾桶,回收废弃节点。

int newnode(){
	if(tp!=lj)return *--tp;
	else return ++cnt;
}
void delnode(int &x){
	*tp++=x;
	tr[x].v=tr[x].ls=tr[x].rs=0;
	x=0;
}

时间复杂度证明

  • 单点修改:由于线段树的二分结构,深度为 ${\log_2n}$,每层只访问 $1$ 个节点,所以时间复杂度为 $\mathrm {O(\log_2n)}$。

  • 区间查询:称被询问区间完整包含的节点为完整节点,被部分包含的为边缘节点,则每一层最多只访问 $2$ 个边缘节点,位于区间边缘;最多访问 $2$ 个完整节点,因为完整节点的兄弟一定不是完整节点。原因是,若一个节点的两个孩子均为完整节点,则该节点为完整节点,不会访问其孩子。所以一层最多访问 $4$ 节点,时间复杂度 $\mathrm{O(\log_2n)}$

  • 找第 $k$ 小:每层只会访问 $1$ 个节点,时间复杂度为 $\mathrm {O(\log_2n)}$。

  • 线段树合并:显然,我们每次只会访问重合节点,那么单次合并的时间复杂度就可以认为是较小树的节点个数,那么显然,多次合并的总复杂度为总点数级别,$\mathrm {O(n\log_2n)}$。

    • 注:不能简单的认为其复杂度为单次合并的最坏复杂度乘询问次数。感性理解,在本题中,显然我们可以花费一次询问使树的节点加一条链,让合并的复杂度增大,但也会失去一次进行线段树合并的机会,这样就保证了其总复杂度不会过高。(很玄学,我也不会更严谨的证明,希望有大佬教我)
  • 线段树分裂:我们可以看到,这个操作访问的节点和区间查询是一致的,复制边缘节点,直接拿走完整节点。时间复杂度 $\mathrm{O(\log_2n)}$

完整代码

#include<cstdio>
struct node{
	long long v;
	int ls,rs;
}tr[8000010];
int lj[8000010],*tp(lj),cnt;
int n,m,rt[200010],rtcnt=1;
int newnode(){
	if(tp!=lj)return *--tp;
	else return ++cnt;
}
void delnode(int &x){
	*tp++=x;
	tr[x].v=tr[x].ls=tr[x].rs=0;
	x=0;
}
void split(int const &pl,int const &pr,int &x1,int &x2,int const &l=1,int const &r=n){
	if(pl==l&&pr==r){
		x2=x1;
		x1=0;
		return;
	}
	if(!x2) x2=newnode();
	int mid=(l+r)>>1;
	if(pr<=mid) split(pl,pr,tr[x1].ls,tr[x2].ls,l,mid);
	else if(pl>mid) split(pl,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
	else split(pl,mid,tr[x1].ls,tr[x2].ls,l,mid),split(mid+1,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
	tr[x1].v=tr[tr[x1].ls].v+tr[tr[x1].rs].v;
	tr[x2].v=tr[tr[x2].ls].v+tr[tr[x2].rs].v;
}
void merge(int &x1,int &x2,int const &l=1,int const &r=n){
	if(!x2) return;
	if(!x1){x1=x2;x2=0;return;}
	tr[x1].v+=tr[x2].v;
	int mid=(l+r)>>1;
	merge(tr[x1].ls,tr[x2].ls,l,mid);
	merge(tr[x1].rs,tr[x2].rs,l,mid);
	delnode(x2);
}
void update(int const &p,int const &v,int &x,int const &l=1,int const &r=n){
	if(!x) x=newnode();
	tr[x].v+=v;
	if(l==r) return;
	int mid=(l+r)>>1;
	if(p<=mid) update(p,v,tr[x].ls,l,mid);
	else update(p,v,tr[x].rs,mid+1,r);
}
long long query(int const &pl,int const &pr,int const &x,int const &l=1,int const &r=n){
	if(!x) return 0;
	if(pl==l&&pr==r) return tr[x].v;
	int mid=(l+r)>>1;
	if(pr<=mid) return query(pl,pr,tr[x].ls,l,mid);
	else if(pl>mid) return query(pl,pr,tr[x].rs,mid+1,r);
	else return query(pl,mid,tr[x].ls,l,mid)+query(mid+1,pr,tr[x].rs,mid+1,r);
}
int kth(long long const &p,int const &x,int const &l=1,int const &r=n){
	if(tr[x].v<p) return -1;
	if(l==r) return l;
	int mid=(l+r)>>1;
	if(tr[tr[x].ls].v>=p) return kth(p,tr[x].ls,l,mid);
	else return kth(p-tr[tr[x].ls].v,tr[x].rs,mid+1,r);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1,x;i<=n;i++){
		scanf("%d",&x);
		if(x) update(i,x,rt[1]);
	}
	for(int i=1,op,p,x,y;i<=m;i++){
		scanf("%d%d",&op,&p);
		switch(op){
			case 0:
				scanf("%d%d",&x,&y);
				split(x,y,rt[p],rt[++rtcnt]);
				break;
			case 1:
				scanf("%d",&x);
				merge(rt[p],rt[x]);
				break;
			case 2:
				scanf("%d%d",&x,&y);
				if(x)update(y,x,rt[p]);
				break;
			case 3:
				scanf("%d%d",&x,&y);
				printf("%lld\n",query(x,y,rt[p]));
				break;
			case 4:
				long long k;
				scanf("%lld",&k);
				printf("%d\n",kth(k,rt[p]));
				break;
		}
	}
	return 0;
}