P5399 [Ynoi2018] 駄作 gxy001

首先你要会树分块,这种树分块满足每个块都是一个连通块,且每个块只有最多两个节点与其他块相连,更多内容不赘述,可以看我的博客

一个邻域可以被拆分成不同块内的 $O(B)$ 个邻域,其中只有 $O(1)$ 个邻域的中心不是界点。

先对每个块单独处理,只考虑同一块内的两个邻域的贡献,分两种情况。

  • 两个邻域的中心有至少一个不是界点,这种情况总共只有 $O(m)$ 次。我们知道 $d(a,b)=d(rt,a)+d(rt,b)-2d(rt,\operatorname{lca}(a,b))$,我们只要知道 $\sum d(rt,\operatorname{lca}(a,b))$ 就可以了。对一个邻域内的点到根的路径全部加一,求另一个邻域内的每个点到根的路径的权值和就是我们要求得值,总时间复杂度 $O(mB)$。
  • 两个邻域的中心都是界点,中心只有三种情况,半径只有 $O(B^2)$ 种情况,所以总情况数只有 $O(B^2)$ 种,预处理出来就行了,预处理的方法是枚举第一个邻域的半径,然后用上面的方法就可以求出第二个邻域所有半径下的答案,总复杂度 $O(\frac {n}{B}B^2)=O(nB)$。

在考虑不同块之间的贡献,对于两个不相交的邻域,有一个求出答案的方法,找到一个点使得从两个邻域中各选一个点路径必定经过这个点。只要求出每个邻域所有点到这个点的距离和点的个数,称为 $d_i,c_i$,则答案为 $d_0c_1+d_1c_0$。对每个询问每个块求出邻域在这个块内点的个数和到两个界点的距离和,在收缩树上按照这个式子 $\text{dp}$ 一下子就完事了。

时间复杂度 $O((n+m)\sqrt n)$,空间复杂度 $O(n+m)$。

#include<cstdio>
#include<vector>
#include<utility>
int const B=430;
int const maxb=B+10,maxk=600000/B+10,maxn=100010; 
int n,m,top[maxn],sz[maxn],book[maxn],bot[maxn],st[maxn],tp,cct[maxk],ctu1[maxn],ctu0[maxn],dfk,dfn[maxn],dep[maxn],star[maxn],ed[maxn];
std::vector<int> v[maxn];
int p0[maxn],d0[maxn],p1[maxn],d1[maxn],fa[maxn],w[maxk],dis0[maxn][2],dis1[maxn][2],cta[maxk][maxb],ctb[maxk][maxb];
long long ans[maxn],dab[maxk][maxb],daa[maxk][maxb],dba[maxk][maxb],dbb[maxk][maxb],dsu1[maxn],dsu0[maxn],psu0[maxn],psu1[maxn]; 
namespace bloc{
	int trans[maxn],pcnt;
	int pool[maxk],*pl=pool,f[maxk],son[maxk],*ve[maxk],ecnt;
	std::pair<int,int> pr[maxk];
	void alloc(int x){trans[x]=++pcnt;}
	void add(int x,int y){
		x=trans[x],y=trans[y];
		pr[++ecnt]=std::make_pair(x,y);
		++son[x];
		f[y]=x;
	}
	int ct1[maxk],cnt1[maxk];
	long long dis1[maxk],ds1[maxk],ps1[maxk];
	int ct0[maxk],cnt0[maxk];
	long long dis0[maxk],ds0[maxk],ps0[maxk];
	void init(int p,int d,int dis0,int dis1,int ctu,long long dsu,long long psu,int *ct,long long *ds,long long *ps){
		static int bk[maxk],que[maxk],dis[maxk];
		for(int i=1;i<=pcnt;i++) dis[i]=0x7fffffff,bk[i]=ct[i]=ds[i]=ps[i]=0;
		int *hd=que,*tl=que;
		if(dis0<=d) dis[trans[top[p]]]=dis0,*tl++=trans[top[p]];
		if(dis1<=d) dis[trans[bot[p]]]=dis1,*tl++=trans[bot[p]];
		bk[trans[bot[p]]]=1;
		if(p==1) ct[trans[1]]=1,ds[trans[1]]=ps[trans[1]]=0;
		else{
			int u=trans[bot[p]];
			ct[u]=ctu,ds[u]=dsu,ps[u]=psu;
		}
		while(hd!=tl){
			int x=*hd++;
			for(int i=0;i<son[x];i++)if(!bk[ve[x][i]]){
				int u=ve[x][i],dk=std::min(d-dis[x],cct[u]);
				ct[u]=cta[u][dk],ds[u]=dab[u][dk],ps[u]=daa[u][dk],bk[u]=1;
			}
			if(!bk[x]){
				int dk=std::min(d-dis[x],cct[x]);
				ct[x]=ctb[x][dk],ds[x]=dbb[x][dk],ps[x]=dba[x][dk],bk[x]=1;
			}
			for(int i=0;i<son[x];i++)if(dis[ve[x][i]]==0x7fffffff&&dis[x]+w[ve[x][i]]<=d) dis[ve[x][i]]=dis[x]+w[ve[x][i]],*tl++=ve[x][i];
			if(dis[f[x]]==0x7fffffff&&dis[x]+w[x]<=d) dis[f[x]]=dis[x]+w[x],*tl++=f[x]; 
		}
	}
	void solve(){
		for(int i=1;i<=pcnt;i++)ve[i]=pl,pl+=son[i],son[i]=0;
		for(int i=1;i<=ecnt;i++)ve[pr[i].first][son[pr[i].first]++]=pr[i].second; 
		cct[trans[1]]=1;
		cta[trans[1]][0]=ctb[trans[1]][0]=cta[trans[1]][1]=ctb[trans[1]][1]=1;
		for(int i=1;i<=m;i++){
			init(p0[i],d0[i],::dis0[i][0],::dis0[i][1],ctu0[i],dsu0[i],psu0[i],ct0,ds0,ps0);
			init(p1[i],d1[i],::dis1[i][0],::dis1[i][1],ctu1[i],dsu1[i],psu1[i],ct1,ds1,ps1);
			for(int i=1;i<=pcnt;i++) dis0[i]=dis1[i]=cnt0[i]=cnt1[i]=0;
			for(int x=1;x<=pcnt;x++){
				ans[i]+=ct0[x]*dis1[x]+ct1[x]*dis0[x]+cnt0[x]*ds1[x]+cnt1[x]*ds0[x];
				dis0[x]+=1ll*cnt0[x]*w[x]+ps0[x],dis1[x]+=1ll*cnt1[x]*w[x]+ps1[x];
				cnt0[x]+=ct0[x],cnt1[x]+=ct1[x];
				if(f[x]){
					ans[i]+=cnt1[f[x]]*dis0[x]+cnt0[f[x]]*dis1[x]+cnt1[x]*dis0[f[x]]+cnt0[x]*dis1[f[x]];
					cnt0[f[x]]+=cnt0[x],cnt1[f[x]]+=cnt1[x];
					dis0[f[x]]+=dis0[x],dis1[f[x]]+=dis1[x]; 
				}
			}
		}
	}
}
void dfs(int x){
	dfn[++dfk]=x;
	star[x]=dfk;
	int cnt=0;
	for(auto u:v[x]){
		dep[u]=dep[x]+1;
		dfs(u);
		sz[x]+=sz[u];
		if(book[u])++cnt,book[x]=book[u];
	}
	if(sz[x]>=B||x==1||cnt>1){
		static int son[maxn];
		int vv=0;
		cnt=0;
		int p=0;
		for(int i=v[x].size()-1;~i;i--){
			p+=sz[v[x][i]];
			if(p>B||(cnt&&book[v[x][i]])){
				p-=sz[v[x][i]];
				if(!cnt) bloc::alloc(cnt=v[x][i+1]);
				son[vv++]=cnt;
				while(p--){
					int u=st[tp--];
					top[u]=x,bot[u]=cnt;
				}
				p=sz[v[x][i]];
				cnt=0;
			}
			cnt|=book[v[x][i]];
		}
		if(!cnt) bloc::alloc(cnt=v[x][0]);
		son[vv++]=cnt;
		while(p--){
			int u=st[tp--];
			top[u]=x,bot[u]=cnt;
		}
		bloc::alloc(x);
		for(int i=0;i<vv;i++) bloc::add(x,son[i]);
		book[x]=x;
		sz[x]=0;
	}
	st[++tp]=x;
	++sz[x];
	ed[x]=dfk+1;
}
namespace get{
	int ff[maxb],dp[maxb],bg[maxb],trans[maxn],ecnt,cnt,son[maxb],pool[1010],*pl,*ve[maxb],a,b,dis1[maxn],dis2[maxn];
	std::pair<int,int> pr[maxb];
	void add(int x,int y){
		pr[++ecnt]=std::make_pair(x,y);
		++son[x],++son[y];
		ff[y]=x;
	}
	void clear(int x){
		a=x;
		b=ecnt=0;
		trans[x]=cnt=1;
		pl=pool;
		for(int i=1;i<=B+1;i++) son[i]=dp[i]=0;
		for(int i=1;i<=n;i++) dis1[i]=dis2[i]=0x7fffffff;
		dp[trans[x]]=1;
	}
	void insert(int x){
		b=bot[x];
		add(trans[a],trans[x]=++cnt);
		dp[trans[x]]=2,bg[trans[x]]=trans[a];
		static int que[maxb];
		int *hd=que,*tl=que;
		*tl++=x;
		while(hd!=tl){
			int x=*hd++;
			for(auto u:v[x])
				if(top[u]==a) bg[trans[u]=++cnt]=trans[x],dp[trans[u]]=dp[trans[x]]+1,*tl++=u,add(trans[x],trans[u]);
				else break;
		}
	}
	void bfs(int x,int *ans){
		for(int i=star[x];i<ed[x];i++) ans[dfn[i]]=dep[dfn[i]]-dep[x];
		int l=star[x],r=ed[x];
		for(int i=fa[x];i;i=fa[i]){
			for(int j=star[i];j<l;j++) ans[dfn[j]]=dep[dfn[j]]+dep[x]-2*dep[i]; 
			for(int j=r;j<ed[i];j++) ans[dfn[j]]=dep[dfn[j]]+dep[x]-2*dep[i];
			l=star[i],r=ed[i]; 
		}
	}
	int bfsn[maxb][maxb],cn[maxb][maxb];
	void getbfs(int x,int *bfsx,int *cx){
		static int que[maxb];
		int *hd=que,*tl=que;
		*tl++=x;
		for(int i=1;i<=cnt;i++) cx[i]=-1;
		cx[x]=0;
		int pp=0;
		while(hd!=tl){
			int x=*hd++;
			bfsx[++pp]=x;
			for(int i=0;i<son[x];i++)if(!~cx[ve[x][i]]) cx[ve[x][i]]=cx[x]+1,*tl++=ve[x][i];
		}
	}
	long long mark[maxb]; 
	void getans(int x,int y,long long s[][maxb]){
		for(int i=0;i<cnt;i++)
			for(int j=0;j<cnt;j++)s[i][j]=0;
		x=trans[x],y=trans[y];
		int *bfsx=bfsn[x],*bfsy=bfsn[y],*cx=cn[x],*cy=cn[y];
		for(int i=0;i<cnt;i++){
			for(int j=1;j<=cnt;j++) mark[j]=0;
			for(int j=1;j<=cnt;j++) if(cx[bfsx[j]]<=i){if(bfsx[j]!=trans[a]) mark[bfsx[j]]=1;}else break;
			long long d=0;
			for(int j=1;j<=cnt;j++) if(mark[j]) d+=dp[j];
			for(int a=cnt;a;a--) mark[bg[a]]+=mark[a];
			int p=mark[1];
			for(int a=2;a<=cnt;a++) mark[a]+=mark[bg[a]];
			for(int j=1;j<=cnt;j++) if(bfsy[j]!=trans[a]) s[i][cy[bfsy[j]]]+=1ll*p*dp[bfsy[j]]+d-2*mark[bfsy[j]];
			for(int j=1;j<cnt;j++) s[i][j]+=s[i][j-1];
		}
	}
	void init(int x,int y,long long *aa,long long *ab,int *ct){
		x=trans[x],y=trans[y];
		int *bfsx=bfsn[x],*cx=cn[x];
		for(int i=0;i<cnt;i++){
			for(int j=1;j<=cnt;j++) mark[j]=0;
			for(int j=1;j<=cnt;j++) if(cx[bfsx[j]]<=i){if(bfsx[j]!=trans[a]) mark[bfsx[j]]=1;}else break;
			long long d=0;
			for(int j=1;j<=cnt;j++) if(mark[j]) d+=dp[j],++ct[i];
			for(int a=cnt;a;a--) mark[bg[a]]+=mark[a];
			int p=mark[1];
			for(int a=2;a<=cnt;a++) mark[a]+=mark[bg[a]];
			aa[i]=1ll*p*dp[x]+d-2*mark[x];
			ab[i]=1ll*p*dp[y]+d-2*mark[y];
		}
	}
	void gets(int pp,int dpp,int &ctu,long long &dsu,long long &psu){
		for(int j=1;j<=cnt;j++) mark[j]=0;
		int *bfsx=bfsn[pp],*cx=cn[pp];
		for(int i=1;i<=cnt;i++) if(cx[bfsx[i]]<=dpp){if(bfsx[i]!=trans[a]) mark[bfsx[i]]=1;}else break;
		long long d=0;
		for(int j=1;j<=cnt;j++) if(mark[j]) d+=dp[j],++ctu;
		for(int a=cnt;a;a--) mark[bg[a]]+=mark[a];
		int p=mark[1];
		for(int a=2;a<=cnt;a++) mark[a]+=mark[bg[a]];
		dsu=1ll*p*dp[trans[b]]+d-2*mark[trans[b]];
		psu=1ll*p*dp[trans[a]]+d-2*mark[trans[a]];
	}
	long long f[maxb][maxb],g[maxb][maxb],h[maxb][maxb];
	void solve(){
		for(int i=1;i<=cnt;i++) ve[i]=pl,pl+=son[i],son[i]=0;
		for(int i=1;i<=ecnt;i++) ve[pr[i].first][son[pr[i].first]++]=pr[i].second,ve[pr[i].second][son[pr[i].second]++]=pr[i].first;
		for(int i=1;i<=cnt;i++) getbfs(i,bfsn[i],cn[i]);
		cct[bloc::trans[b]]=cnt-1;
		bfs(a,dis1),bfs(b,dis2);
		getans(a,a,f),getans(a,b,g),getans(b,b,h);
		init(a,b,daa[bloc::trans[b]],dab[bloc::trans[b]],cta[bloc::trans[b]]);
		init(b,a,dbb[bloc::trans[b]],dba[bloc::trans[b]],ctb[bloc::trans[b]]); 
		w[bloc::trans[b]]=dp[trans[b]]-1;
		for(int i=1;i<=m;i++){
			int p0=::p0[i],d0=::d0[i],p1=::p1[i],d1=::d1[i];
			if(bot[p0]==b||bot[p1]==b){
				if(bot[p1]==b) ::dis1[i][0]=dp[trans[p1]]-1,::dis1[i][1]=cn[trans[b]][trans[p1]],gets(trans[p1],d1,ctu1[i],dsu1[i],psu1[i]);
				if(bot[p0]==b) ::dis0[i][0]=dp[trans[p0]]-1,::dis0[i][1]=cn[trans[b]][trans[p0]],gets(trans[p0],d0,ctu0[i],dsu0[i],psu0[i]),std::swap(p0,p1),std::swap(d0,d1);
				if(bot[p0]!=b){
					if(dis1[p0]<=d0||dis2[p0]<=d0){
						if(dis1[p0]<=dis2[p0]) d0=d0-dis1[p0],p0=a;
						else d0=d0-dis2[p0],p0=b;
					}else continue;
				}
				d0=std::min(d0,cnt-1),d1=std::min(d1,cnt-1);
				int x=trans[p0],y=trans[p1];
				for(int j=1;j<=cnt;j++) mark[j]=0;
				int *bfsx=bfsn[x],*bfsy=bfsn[y],*cx=cn[x],*cy=cn[y];
				for(int j=1;j<=cnt;j++) if(cx[bfsx[j]]<=d0){if(bfsx[j]!=trans[a]) mark[bfsx[j]]=1;}else break;
				long long d=0;
				for(int j=1;j<=cnt;j++) if(mark[j]) d+=dp[j];
				for(int a=cnt;a;a--) mark[bg[a]]+=mark[a];
				int p=mark[1];
				for(int a=2;a<=cnt;a++) mark[a]+=mark[bg[a]];
				for(int j=1;j<=cnt;j++) if(cy[bfsy[j]]<=d1){if(bfsy[j]!=trans[a])ans[i]+=1ll*p*dp[bfsy[j]]+d-2*mark[bfsy[j]];}else break;
			}else if((dis1[p0]<=d0||dis2[p0]<=d0)&&(dis1[p1]<=d1||dis2[p1]<=d1)){
				if(dis1[p0]>=dis2[p0]) std::swap(p0,p1),std::swap(d0,d1);
				if(dis1[p0]<=dis2[p0]&&dis1[p1]<=dis2[p1]){
					int di1=d0-dis1[p0],di2=d1-dis1[p1];
					di1=std::min(di1,cnt-1),di2=std::min(di2,cnt-1);
					ans[i]+=f[di1][di2];
				}else if(dis1[p0]<=dis2[p0]&&dis1[p1]>=dis2[p1]){
					int di1=d0-dis1[p0],di2=d1-dis2[p1];
					di1=std::min(di1,cnt-1),di2=std::min(di2,cnt-1);
					ans[i]+=g[di1][di2];
				}else{
					int di1=d0-dis2[p0],di2=d1-dis2[p1];
					di1=std::min(di1,cnt-1),di2=std::min(di2,cnt-1);
					ans[i]+=h[di1][di2];
				}
			}
		}
	}
}
int main(){
	#ifndef ONLINE_JUDGE 
	freopen("1.in","r",stdin);
	freopen("1.out","w",stdout);
	#endif
	scanf("%d",&n);
	for(int i=2,x;i<=n;i++) scanf("%d",&x),v[x].push_back(i),fa[i]=x;
	dfs(1);
	scanf("%d",&m);
	for(int i=1;i<=m;i++) scanf("%d%d%d%d",p0+i,d0+i,p1+i,d1+i); 
	for(int x=1;x<=n;x++)if(book[x]==x){
		for(int i=v[x].size()-1;~i;){
			int u=bot[v[x][i]];
			get::clear(x);
			while(~i&&bot[v[x][i]]==u) get::insert(v[x][i]),--i;
			get::solve();
		}
	}
	bot[1]=1;
	bloc::solve();
	for(int i=1;i<=m;i++) printf("%lld\n",ans[i]);
	return 0;
}