P6623 [省选联考 2020 A 卷] 树 gxy001

看到题目,首先思考怎么从子节点转移到父节点,假设在子节点有一个桶,存着异或起来得到子节点答案的信息,我们需要的是其内部每一个数加一后的异或值,然后就是经典套路 ,不讲了

简单说一下怎么维护,从低位往高位建 01trie,加一操作实际上就是交换左右子树,进入交换后的 0 子树递归,记得进位;开桶维护每一位上 1 的数量,用来求异或值。

从子节点转移到父节点时先给子节点加一,再 01trie 合并,方法和线段树合并完全一致。

时间复杂度 $O(n\log v)$,证明:每个节点只会进行一次插入操作,一次加一操作,一次查询操作,这部分是严格 $O(n\log v)$,被合并一次,合并的复杂度与线段树合并一致,也是 $O(n\log v)$。

#include<cstdio>
#include<algorithm>
#include<utility>
namespace IO{
	#define BUFSIZE 10000000
	struct read{
		char buf[BUFSIZE],*p1,*p2,c,f;
		read():p1(buf),p2(buf){}
		char gc(void){
			return (p1==p2)&&(p2=buf+fread(p1=buf,1,BUFSIZE,stdin),p1==p2)?EOF:*p1++;
		}
		read& operator >>(int& x){
			c=gc(),f=1,x=0;
			for(;~c&&(c<'0'||c>'9');c=gc())if(c=='-')f=-1;
			for(;c>='0'&&c<='9';c=gc())x=x*10+c-'0';
			x*=f;
			return *this;
		}
	};
	read in;
}
using IO::in;
int const maxn=525020;
struct trie{
	static int cnt,s[maxn*32][2],sz[maxn*32];
	int rt,a[24];
	void insert(int x){
		if(!rt) rt=++cnt;
		int p=rt,i=0;
		while(x){
			++sz[p];
			a[i++]+=x&1;
			if(!s[p][x&1]) s[p][x&1]=++cnt;
			p=s[p][x&1];
			x>>=1;
		}
		++sz[p];
	}
	void add(){
		int p=rt,i=0,k=0;
		while(p){
			std::swap(s[p][0],s[p][1]);
			a[i]=a[i]-sz[s[p][0]]+sz[s[p][1]];
			if((k=sz[p]-sz[s[p][0]])-sz[s[p][1]]){
				if(!s[p][1]) s[p][1]=++cnt;
				a[i]+=k-sz[s[p][1]];
				sz[s[p][1]]=k;
			}
			++i;
			p=s[p][0];
		}
	}
	int getxor(){
		int ans=0;
		for(int i=0;i<21;i++) ans|=(a[i]&1)<<i;
		return ans;
	}
	static void pmerge(int &x,int const &px,int const &i);
	void merge(const trie& x){
		for(int i=0;i<21;i++) a[i]+=x.a[i];
		return pmerge(rt,x.rt,0);
	}
}tr[maxn];
void trie::pmerge(int &x,int const &px,int const &i){
	if(!x) return x=px,void();
	if(!px) return;
	sz[x]+=sz[px];
	pmerge(s[x][0],s[px][0],i+1),pmerge(s[x][1],s[px][1],i+1);
}
int trie::cnt,trie::s[maxn*32][2],trie::sz[maxn*32];
int n,v[maxn],head[maxn],to[maxn],nxt[maxn];
long long ans;
inline void add(int const &x,int const &y){
	static int cnt=0;
	to[++cnt]=y,nxt[cnt]=head[x],head[x]=cnt;
}
void dfs(int const &x){
	tr[x].insert(v[x]);
	for(int i=head[x];i;i=nxt[i])
		dfs(to[i]),tr[to[i]].add(),tr[x].merge(tr[to[i]]);
	ans+=tr[x].getxor();
}
int main(){
	in>>n;
	for(int i=1;i<=n;i++)in>>v[i];
	for(int i=2,x;i<=n;i++)in>>x,add(x,i);
	dfs(1);
	printf("%lld\n",ans);
//	fprintf(stderr,"%d\n",trie::cnt);
	return 0;
}