如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
建出 fail 树,称原字典树为 $T1$,fail 树为 $T2$,则所求为,对于每个点 $u$: \(\sum\limits_{x\in path_{T1}(rt,u)}\max\limits_{dep_{T1}(x)\ge a_v\land v\in subtree_{T2}(u)\land v\in subtree_{T1}(x)}val_v\) 注意到 $dep_{T1}(x)\ge a_v\land v\in subtree_{T1}(x)$ 等价于 $x\in path_{T1}(p_v,v)$($p_v$ 为 $v$ 的深度为 $a_v$ 的祖先)。
因此,所求可转化为: \(\sum\limits_{x\in path_{T1}(rt,u)}\max\limits_{x\in path_{T1}(p_v,v)\land v\in subtree_{T2}(u)}val_v\) 我们使用数据结构维护 $T1$,对 $subtree_{T2}(u)$ 内的每个点 $v$,对链 $path_{T1}(p_v,v)$ 上每个点与 $val_v$ 取 $\max$,然后再求和 $path_{T1}(rt,u)$,即为答案。
链取 $max$,链和,可以使用静态 top tree beats 完成。而我们要对 $u$ 子树内的每个点 $v$ 都执行一个只和 $v$ 有关的修改,不难想到可以使用静态 top tree 合并来维护。
另外,本题为二叉树,这使得每个点最多只有一个虚儿子,利用这点可以使维护更简单。
时间复杂度为 $O(n\log n)$。
#include<iostream>
#include<algorithm>
#include<vector>
#define BUFSIZE 10000000
struct read{
char buf[BUFSIZE],*p1,*p2,c;
read():p1(buf),p2(buf){}
char gc(void){
return p1==p2&&(p2=buf+fread(p1=buf,1,BUFSIZE,stdin),p1==p2)?EOF:*p1++;
}
read& operator >>(int& x){
for(c=gc(),x=0;c<'0'||c>'9';c=gc());
for(;c>='0'&&c<='9';c=gc())x=x*10+(c-'0');
return *this;
}
}cin;
struct write{
char buf[BUFSIZE],*p1,*p2;
write():p1(buf),p2(buf+BUFSIZE){}
~write(){fwrite(buf,1,p1-buf,stdout);}
void pc(char c){p1==p2&&(fwrite(buf,1,p1-buf,stdout),p1=buf),*p1++=c;}
write& operator <<(long long x){
static int stk[30],tp;
do stk[tp++]='0'+x%10,x/=10; while(x);
while(tp) pc(stk[--tp]);
return *this;
}
write& operator <<(char c){return pc(c),*this;}
}cout;
int n,s[500010][2],fail[500010];
int val[500010],a[500010],p[500010];
namespace TT{
int rt,cnt,sz[1000010],csz[1000010],A[1000010],B[1000010],ls[500010],rs[500010],tp[500010];
int s[500010],fa[500010],son[500010];
int newnode(int x){
int u=x+n;
csz[u]=sz[u]=1,A[u]=B[u]=x;
return u;
}
int bd(int x,int y,int p){
int u=++cnt;
ls[u]=x,rs[u]=y,sz[u]=sz[x]+sz[y],tp[u]=p;
if(p==0) csz[u]=csz[x]+csz[y],A[u]=A[x],B[u]=B[y];
else csz[u]=csz[x],A[u]=A[x],B[u]=B[x];
return u;
}
int c[500010],cs[500010],ccnt;
int cbuild(int l,int r,int p){
if(l==r) return c[l];
int u=std::lower_bound(cs+l,cs+r,cs[l-1]+(cs[r]-cs[l-1])/2)-cs;
if(u==r) return bd(cbuild(l,r-1,p),c[r],p);
else return bd(cbuild(l,u,p),cbuild(u+1,r,p),p);
}
int build(int x){
int u=x;
static int tmp[500010];
while(u){
for(auto w: ::s[u])if(w&&w!=son[u]) tmp[w]=build(w);
ccnt=0,c[++ccnt]=newnode(u),cs[ccnt]=1;
for(auto w: ::s[u])if(w&&w!=son[u]) c[++ccnt]=tmp[w],cs[ccnt]=cs[ccnt-1]+sz[tmp[w]];
tmp[u]=cbuild(1,ccnt,1),u=son[u];
}
u=x,ccnt=0;
while(u) c[++ccnt]=tmp[u],cs[ccnt]=cs[ccnt-1]+sz[tmp[u]],u=son[u];
return cbuild(1,ccnt,0);
}
int R[1000010],ct;
void ini(int x){
if(x>n) return A[x]=B[x]=R[x]=++ct,void();
ini(ls[x]),ini(rs[x]);
A[x]=A[A[x]+n],B[x]=B[B[x]+n];
R[x]=ct;
}
void dfs(int x){
static int pt,f[500010];
f[pt++]=x,s[x]=1,son[x]=0;
if(a[x]>=pt) p[x]=0;
else p[x]=f[a[x]];
for(int u: ::s[x]) if(u) fa[u]=x,dfs(u),s[x]+=s[u],son[x]=(s[son[x]]>s[u]?son[x]:u);
--pt;
}
void init(){
cnt=0;
dfs(1);
rt=build(1);
ini(rt);
}
}
std::vector<int> vec[500010];
int sz[500010];
void dfs(int x){
sz[x]=1;
for(auto u:vec[x]) dfs(u),sz[x]+=sz[u];
}
long long ans[500010];
#define MAXN 2100010
int cct,ls[MAXN],rs[MAXN],mn[MAXN],lmn[MAXN],cnt[MAXN],tag[MAXN];
long long sum[MAXN],stk[MAXN],tp;
inline int newnode(int pos){
using TT::csz;
int u=tp?stk[--tp]:++cct;
ls[u]=rs[u]=0,tag[u]=sum[u]=mn[u]=0,cnt[u]=csz[pos],lmn[u]=2e9;
return u;
}
inline void addtag(int x,int v){
sum[x]+=1ll*(v-mn[x])*cnt[x];
mn[x]=tag[x]=v;
}
inline void pushdown(int x,int pos){
if(!tag[x]) return;
if(!ls[x]) ls[x]=newnode(TT::ls[pos]);
if(TT::tp[pos]==0){
if(!rs[x]) rs[x]=newnode(TT::rs[pos]);
if(mn[ls[x]]==mn[rs[x]]) addtag(ls[x],tag[x]),addtag(rs[x],tag[x]);
else if(mn[ls[x]]<mn[rs[x]]) addtag(ls[x],tag[x]);
else addtag(rs[x],tag[x]);
}else addtag(ls[x],tag[x]);
tag[x]=0;
}
inline void pushup(int x,int pos){
using TT::csz;
if(TT::tp[pos]==0){
if(!ls[x]&&!rs[x]) sum[x]=mn[x]=0,cnt[x]=csz[pos],lmn[x]=2e9;
else if(!ls[x]) sum[x]=sum[rs[x]],mn[x]=0,lmn[x]=(mn[rs[x]]==0?lmn[rs[x]]:mn[rs[x]]),cnt[x]=csz[TT::ls[pos]]+(mn[rs[x]]==0?cnt[rs[x]]:0);
else if(!rs[x]) sum[x]=sum[ls[x]],mn[x]=0,lmn[x]=(mn[ls[x]]==0?lmn[ls[x]]:mn[ls[x]]),cnt[x]=csz[TT::rs[pos]]+(mn[ls[x]]==0?cnt[ls[x]]:0);
else{
if(mn[ls[x]]==mn[rs[x]]) sum[x]=sum[ls[x]]+sum[rs[x]],mn[x]=mn[ls[x]],lmn[x]=std::min(lmn[ls[x]],lmn[rs[x]]),cnt[x]=cnt[ls[x]]+cnt[rs[x]];
else if(mn[ls[x]]<mn[rs[x]]) sum[x]=sum[ls[x]]+sum[rs[x]],mn[x]=mn[ls[x]],lmn[x]=std::min(lmn[ls[x]],mn[rs[x]]),cnt[x]=cnt[ls[x]];
else sum[x]=sum[ls[x]]+sum[rs[x]],mn[x]=mn[rs[x]],lmn[x]=std::min(mn[ls[x]],lmn[rs[x]]),cnt[x]=cnt[rs[x]];
}
}else
if(!ls[x]) sum[x]=mn[x]=0,cnt[x]=csz[pos],lmn[x]=2e9;
else sum[x]=sum[ls[x]],mn[x]=mn[ls[x]],cnt[x]=cnt[ls[x]],lmn[x]=lmn[ls[x]];
}
inline void beats(int &rt,int v,int pos){
if(!rt) rt=newnode(pos);
if(v<=mn[rt]) return;
if(v<lmn[rt]) return addtag(rt,v);
tag[rt]=0;
beats(ls[rt],v,TT::ls[pos]),beats(rs[rt],v,TT::rs[pos]);
pushup(rt,pos);
}
inline void lupdate(int b,int v,int &rt,int pos=TT::rt){
using TT::B,TT::R;
if(!rt) rt=newnode(pos);
if(B[pos]==b) return beats(rt,v,pos);
pushdown(rt,pos);
if(b<=R[TT::ls[pos]]) lupdate(b,v,ls[rt],TT::ls[pos]);
else beats(ls[rt],v,TT::ls[pos]),lupdate(b,v,rs[rt],TT::rs[pos]);
pushup(rt,pos);
}
inline void rupdate(int a,int v,int &rt,int pos=TT::rt){
using TT::A,TT::R;
if(!rt) rt=newnode(pos);
if(A[pos]==a) return beats(rt,v,pos);
pushdown(rt,pos);
if(a>R[TT::ls[pos]]) rupdate(a,v,rs[rt],TT::rs[pos]);
else rupdate(a,v,ls[rt],TT::ls[pos]),beats(rs[rt],v,TT::rs[pos]);
pushup(rt,pos);
}
inline void update(int a,int b,int v,int &rt,int pos=TT::rt){
using TT::A,TT::B,TT::R;
if(!rt) rt=newnode(pos);
if(A[pos]==a&&B[pos]==b) return beats(rt,v,pos);
pushdown(rt,pos);
if(b<=R[TT::ls[pos]]) update(a,b,v,ls[rt],TT::ls[pos]);
else if(a>R[TT::ls[pos]]) update(a,b,v,rs[rt],TT::rs[pos]);
else rupdate(a,v,ls[rt],TT::ls[pos]),lupdate(b,v,rs[rt],TT::rs[pos]);
pushup(rt,pos);
}
inline long long query(int b,int rt,int pos=TT::rt){
using TT::B,TT::R;
if(!rt) return 0;
if(B[pos]==b)return sum[rt];
pushdown(rt,pos);
if(b<=R[TT::ls[pos]]) return query(b,ls[rt],TT::ls[pos]);
else return sum[ls[rt]]+query(b,rs[rt],TT::rs[pos]);
}
void merge(int &x,int y,int pos=TT::rt){
using TT::A,TT::B;
if(!x||!y) return x|=y,void();
if(pos>n) return sum[x]=mn[x]=std::max(mn[x],mn[y]),stk[tp++]=y,void();
int u=std::max(tag[x],tag[y]);
tag[x]=0;
merge(ls[x],ls[y],TT::ls[pos]);
merge(rs[x],rs[y],TT::rs[pos]);
pushup(x,pos);
if(u) beats(x,u,pos);
stk[tp++]=y;
}
int solve(int x){
int rt=0;
std::sort(vec[x].begin(),vec[x].end(),[&](int x,int y){return sz[x]>sz[y];});
for(auto u:vec[x]) merge(rt,solve(u));
if(p[x]==1) lupdate(TT::A[x+n],val[x],rt);
else if(p[x]) update(TT::A[p[x]+n],TT::A[x+n],val[x],rt);
ans[x]=query(TT::A[x+n],rt);
return rt;
}
#include<chrono>
using namespace std::chrono_literals;
void solve(){
auto u=std::chrono::steady_clock().now();
cin>>n;
for(int i=1;i<=n;i++) s[i][0]=s[i][1]=fail[i]=0,vec[i].clear(),vec[i].shrink_to_fit();
for(int i=2,x,v;i<=n;i++) cin>>x>>v,s[x][v]=i;
for(int i=1;i<=n;i++) cin>>val[i];
for(int i=1;i<=n;i++) cin>>a[i];
TT::init();
static int que[500010];
int *hd=que,*tl=que;
for(int i=0;i<2;i++) if(s[1][i]) *tl++=s[1][i],fail[s[1][i]]=1;else s[1][i]=1;
while(hd!=tl){
int p=*hd++;
for(int i=0;i<2;i++)
if(s[p][i]) fail[s[p][i]]=s[fail[p]][i],*tl++=s[p][i];
else s[p][i]=s[fail[p]][i];
}
for(int i=2;i<=n;i++) vec[fail[i]].push_back(i);
dfs(1);
cct=tp=0,lmn[0]=2e9;
solve(1);
for(int i=1;i<=n;i++) cout<<ans[i]<<' ';
cout<<'\n';
std::cerr<<cct<<' '<<(std::chrono::steady_clock().now()-u)/1.0s<<'\n';
}
int main(){
int T;
cin>>T;
while(T--) solve();
return 0;
}