如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
首先,我们可以在 $O(m)$ 的时间复杂度内求出每个点的 SG 值,直接按照 SG 函数的定义求解即可。
题目转化为给定一个数组 $a$,长度为 $n$,和一个初值为 $0$ 的变量 $v$,执行如下操作:
- 在 $[1,n+1]$ 中等概率随机一个整数 $p$。若 $p\le n$,进行操作 2;否则,进行操作 3。
- $v\larr v\oplus a_p$,进行操作 1。
- 游戏结束,若 $v=0$ 则输,否则赢。
求赢的概率。
设 $f_i$ 为一次随机过程使得 $v\larr v\oplus i$ 的概率,设其集合幂级数为 $F$,设游戏结束时 $v=i$ 的概率为 $g_i$,$G$ 为其集合幂级数,乘法为异或卷积,则有等式
\[\begin{aligned} G=&\frac{\sum\limits_{i=0}^{\infin}F^i}{n+1}\\ =&\frac{1}{(n+1)(1-F)} \end{aligned}\]这里我们要证明 $1-F$ 有逆,有逆的充要条件是 FWT 后每一项都不为零。
证明需要先了解 FWT 的一个性质, FWT 是线性变换且每一项系数都为 $\pm1$,且常数项对 FWT 后每一项的系数贡献都为 $1$。
因为 $F$ 的和为 $\frac n{n+1}$ 且每一项都为非负数,所以对 $F$ 进行 FWT 后每一项取值应该在 $-\frac n{n+1}$ 到 $\frac n{n+1}$ 之间,$-F$ 进行 FWT 后的结果为 $F$ 的结果取反,所以每一项取值也在 $-\frac n{n+1}$ 到 $\frac n{n+1}$ 之间,给原集合幂级数常数项加 $1$,就是给 FWT 后每一项都加 $1$,所以 $1-F$ 在 FWT 后每一项的取值应该在 $1-\frac n{n+1}$ 到 $1+\frac n{n+1}$ 之间,不包括零,所以 $1-F$ 有逆。
知道这点就可以直接做了,求逆的方式就是点值每一项都求逆,时间复杂度就是 FWT 的复杂度。
不过 SG 值的值域是 $\sqrt m$,所以复杂度是 $O(m+\sqrt m \log m)$。
#include<cstdio>
#include<algorithm>
#include<vector>
int const mod=998244353,inv2=(mod+1)/2;
std::vector<int> v[100010],p[100010];
int n,m,d[100010],a[100010],f[600],lim;
int pow(int x,int y){
int res=1;
while(y){
if(y&1) res=1ll*res*x%mod;
x=1ll*x*x%mod,y>>=1;
}
return res;
}
void fwt(int *a){
for(int mid=1;mid<lim;mid<<=1)
for(int r=mid<<1,j=0;j<lim;j+=r)
for(int k=0;k<mid;k++){
int x=a[k+j],y=a[mid+j+k];
a[k+j]=(x+y)%mod,a[mid+j+k]=(x-y+mod)%mod;
}
}
void ifwt(int *a){
for(int mid=1;mid<lim;mid<<=1)
for(int r=mid<<1,j=0;j<lim;j+=r)
for(int k=0;k<mid;k++){
int x=a[k+j],y=a[mid+j+k];
a[k+j]=1ll*(x+y)*inv2%mod,a[mid+j+k]=1ll*inv2*(x-y+mod)%mod;
}
}
int main(){
scanf("%d%d",&n,&m);
int lm=pow(n+1,mod-2);
for(int i=1,x,y;i<=m;i++)scanf("%d%d",&x,&y),v[y].push_back(x),++d[x];
static int que[100010];
int *hd=que,*tl=que;
for(int i=1;i<=n;i++) if(!d[i]) *tl++=i;
while(hd!=tl){
int x=*hd++;
std::sort(p[x].begin(),p[x].end());
p[x].erase(std::unique(p[x].begin(),p[x].end()),p[x].end());
for(a[x]=0;a[x]<(int)p[x].size()&&p[x][a[x]]==a[x];a[x]++);
lim=std::max(a[x],lim);
f[a[x]]=(f[a[x]]-lm+mod)%mod;
for(auto u:v[x]){
p[u].push_back(a[x]);
if(!--d[u]) *tl++=u;
}
}
if(lim==0) lim=2;
else do lim+=lim&-lim; while((lim&-lim)!=lim);
f[0]=(f[0]+1)%mod;
for(int i=0;i<lim;i++) f[i]=1ll*f[i]*(n+1)%mod;
fwt(f);
for(int i=0;i<lim;i++) f[i]=pow(f[i],mod-2);
ifwt(f);
printf("%d\n",(1-f[0]+mod)%mod);
return 0;
}