如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
看到题目,首先思考怎么从子节点转移到父节点,假设在子节点有一个桶,存着异或起来得到子节点答案的信息,我们需要的是其内部每一个数加一后的异或值,然后就是经典套路 ,不讲了。
简单说一下怎么维护,从低位往高位建 01trie,加一操作实际上就是交换左右子树,进入交换后的 0 子树递归,记得进位;开桶维护每一位上 1 的数量,用来求异或值。
从子节点转移到父节点时先给子节点加一,再 01trie 合并,方法和线段树合并完全一致。
时间复杂度 $O(n\log v)$,证明:每个节点只会进行一次插入操作,一次加一操作,一次查询操作,这部分是严格 $O(n\log v)$,被合并一次,合并的复杂度与线段树合并一致,也是 $O(n\log v)$。
#include<cstdio>
#include<algorithm>
#include<utility>
namespace IO{
#define BUFSIZE 10000000
struct read{
char buf[BUFSIZE],*p1,*p2,c,f;
read():p1(buf),p2(buf){}
char gc(void){
return (p1==p2)&&(p2=buf+fread(p1=buf,1,BUFSIZE,stdin),p1==p2)?EOF:*p1++;
}
read& operator >>(int& x){
c=gc(),f=1,x=0;
for(;~c&&(c<'0'||c>'9');c=gc())if(c=='-')f=-1;
for(;c>='0'&&c<='9';c=gc())x=x*10+c-'0';
x*=f;
return *this;
}
};
read in;
}
using IO::in;
int const maxn=525020;
struct trie{
static int cnt,s[maxn*32][2],sz[maxn*32];
int rt,a[24];
void insert(int x){
if(!rt) rt=++cnt;
int p=rt,i=0;
while(x){
++sz[p];
a[i++]+=x&1;
if(!s[p][x&1]) s[p][x&1]=++cnt;
p=s[p][x&1];
x>>=1;
}
++sz[p];
}
void add(){
int p=rt,i=0,k=0;
while(p){
std::swap(s[p][0],s[p][1]);
a[i]=a[i]-sz[s[p][0]]+sz[s[p][1]];
if((k=sz[p]-sz[s[p][0]])-sz[s[p][1]]){
if(!s[p][1]) s[p][1]=++cnt;
a[i]+=k-sz[s[p][1]];
sz[s[p][1]]=k;
}
++i;
p=s[p][0];
}
}
int getxor(){
int ans=0;
for(int i=0;i<21;i++) ans|=(a[i]&1)<<i;
return ans;
}
static void pmerge(int &x,int const &px,int const &i);
void merge(const trie& x){
for(int i=0;i<21;i++) a[i]+=x.a[i];
return pmerge(rt,x.rt,0);
}
}tr[maxn];
void trie::pmerge(int &x,int const &px,int const &i){
if(!x) return x=px,void();
if(!px) return;
sz[x]+=sz[px];
pmerge(s[x][0],s[px][0],i+1),pmerge(s[x][1],s[px][1],i+1);
}
int trie::cnt,trie::s[maxn*32][2],trie::sz[maxn*32];
int n,v[maxn],head[maxn],to[maxn],nxt[maxn];
long long ans;
inline void add(int const &x,int const &y){
static int cnt=0;
to[++cnt]=y,nxt[cnt]=head[x],head[x]=cnt;
}
void dfs(int const &x){
tr[x].insert(v[x]);
for(int i=head[x];i;i=nxt[i])
dfs(to[i]),tr[to[i]].add(),tr[x].merge(tr[to[i]]);
ans+=tr[x].getxor();
}
int main(){
in>>n;
for(int i=1;i<=n;i++)in>>v[i];
for(int i=2,x;i<=n;i++)in>>x,add(x,i);
dfs(1);
printf("%lld\n",ans);
// fprintf(stderr,"%d\n",trie::cnt);
return 0;
}