如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
做完这道题,感觉自己对计数的掌握又深了一些呢。
$\text{I}$:球互不相同,盒子互不相同。
显然,对于每个球都有 $m$ 种选择,根据乘法原理,答案即 $m^n$。
$\text{II}$:球互不相同,盒子互不相同,每个盒子至多装一个球。
对于每个球,可以从未被选择的盒子中任选一个,即 $\prod\limits_{i=0}^{n-1}(m-i)$。
$\text{III}$:球互不相同,盒子互不相同,每个盒子至少装一个球。
容斥一下枚举几个空盒,$\sum\limits_{i=0}^m(-1)^i\binom{m}{i}(m-i)^n$。
$\text{IV}$:球互不相同,盒子全部相同。
枚举几个盒子有球,\(\sum\limits_{i=0}^m \begin{Bmatrix}n\\i\end{Bmatrix}\),第二类斯特林数·行。
$\text{V}$:球互不相同,盒子全部相同,每个盒子至多装一个球。
由于盒子没有区别,所以答案为 $[n\le m]$。
$\text{VI}$:球互不相同,盒子全部相同,每个盒子至少装一个球。
根据斯特林数定义,\(\begin{Bmatrix}n\\m\end{Bmatrix}\)。
$\text{VII}$:球全部相同,盒子互不相同。
插板法,$\binom{n+m-1}{m-1}$。
$\text{VIII}$:球全部相同,盒子互不相同,每个盒子至多装一个球。
选出 $n$ 个盒子装球,$\binom{m}{n}$
$\text{IX}$:球全部相同,盒子互不相同,每个盒子至少装一个球。
插板法,$\binom{n-1}{m-1}$。
$\text{X}$:球全部相同,盒子全部相同。
问题等价于将 $n+m$ 拆分为 $m$ 个无序的正整数,根据 Ferrers 图的理论,等价于将 $n+m$ 拆分成若干个不超过 $m$ 的正整数,直接生成函数做就行了。
\[[x^{n+m-m}]\prod\limits_{i=1}^m\frac{1}{1-x^i}\]不会求这个的可以去看付公主的背包。
$\text{XI}$:球全部相同,盒子全部相同,每个盒子至多装一个球。
和 $\text{V}$ 相同,盒子没有区别,$[n\le m]$。
$\text{XII}$:球全部相同,盒子全部相同,每个盒子至少装一个球。
同 $\text{X}$:
\[[x^{n-m}]\prod\limits_{i=1}^m\frac{1}{1-x^i}\]代码
#include<cstdio>
#include<algorithm>
int n,m,fac[400010],minv[400010];
int const mod=998244353,g=3,gi=(mod+1)/g;
int C(int x,int y){
if(x<0||y<0||x<y) return 0;
else return 1ll*fac[x]*minv[y]%mod*minv[x-y]%mod;
}
int pow(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
y>>=1;
}
return res;
}
struct NTT{
int r[800010],lim;
NTT():r(),lim(){}
void getr(int lm){
lim=lm;
for(int i=0;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)*(lim>>1));
}
void operator()(int *a,int type){
for(int i=0;i<lim;i++)if(i<r[i])std::swap(a[i],a[r[i]]);
for(int mid=1;mid<lim;mid<<=1){
int rt=pow(type==1?g:gi,(mod-1)/(mid<<1));
for(int j=0,r=mid<<1;j<lim;j+=r){
int p=1;
for(int k=0;k<mid;k++,p=1ll*p*rt%mod){
int x=a[j+k],y=1ll*a[j+mid+k]*p%mod;
a[j+k]=(x+y)%mod,a[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(type==-1)for(int i=0,p=pow(lim,mod-2);i<lim;i++)a[i]=1ll*a[i]*p%mod;
}
}ntt;
void inv(int const *a,int *ans,int n){
static int tmp[800010];
for(int i=0;i<n<<1;i++)tmp[i]=ans[i]=0;
ans[0]=pow(a[0],mod-2);
for(int m=2;m<=n;m<<=1){
int lim=m<<1;
ntt.getr(lim);
for(int i=0;i<m;i++)tmp[i]=a[i];
ntt(tmp,1),ntt(ans,1);
for(int i=0;i<lim;i++)ans[i]=ans[i]*(2-1ll*ans[i]*tmp[i]%mod+mod)%mod,tmp[i]=0;
ntt(ans,-1);
for(int i=m;i<lim;i++)ans[i]=0;
}
}
void inte(int const *a,int *ans,int n){
for(int i=n-1;i;i--)ans[i]=1ll*a[i-1]*pow(i,mod-2)%mod;
ans[0]=0;
}
void der(int const *a,int *ans,int n){
for(int i=1;i<n;i++)ans[i-1]=1ll*i*a[i]%mod;
ans[n-1]=0;
}
void ln(int const *a,int *ans,int n){
static int b[800010];
for(int i=0;i<n<<1;i++)ans[i]=b[i]=0;
inv(a,ans,n);
der(a,b,n);
int lim=n<<1;
ntt.getr(lim);
ntt(b,1),ntt(ans,1);
for(int i=0;i<lim;i++)b[i]=1ll*ans[i]*b[i]%mod,ans[i]=0;
ntt(b,-1);
for(int i=n;i<lim;i++)b[i]=0;
inte(b,ans,n);
}
void exp(int const *a,int *ans,int n){
static int f[800010];
for(int i=0;i<n<<1;i++)ans[i]=f[i]=0;
ans[0]=1;
for(int m=2;m<=n;m<<=1){
int lim=m<<1;
ln(ans,f,m);
f[0]=(a[0]+1-f[0]+mod)%mod;
for(int i=1;i<m;i++)f[i]=(a[i]-f[i]+mod)%mod;
ntt.getr(lim);
ntt(f,1),ntt(ans,1);
for(int i=0;i<lim;i++)ans[i]=1ll*ans[i]*f[i]%mod,f[i]=0;
ntt(ans,-1);
for(int i=m;i<lim;i++)ans[i]=0;
}
}
void solve1(){printf("%d\n",pow(m,n));}
void solve2(){if(m<n)puts("0");else printf("%lld\n",1ll*fac[m]*minv[m-n]%mod);}
void solve3(){
if(n<m)return puts("0"),void();
int ans=0;
for(int i=0;i<=m;i++)
ans=(ans+1ll*pow(mod-1,i)*C(m,i)%mod*pow(m-i,n))%mod;
printf("%d\n",ans);
}
int s[800010];
void solve4(){
static int tmp[800010];
for(int i=0;i<=n;i++)tmp[i]=(i&1?mod-1ll:1ll)*minv[i]%mod,s[i]=1ll*pow(i,n)*minv[i]%mod;
int lim=1;
for(lim=1;lim<=n+n;lim<<=1);
ntt.getr(lim);
ntt(tmp,1),ntt(s,1);
for(int i=0;i<lim;i++)s[i]=1ll*s[i]*tmp[i]%mod;
ntt(s,-1);
for(int i=n+1;i<lim;i++)s[i]=0;
int ans=0;
for(int i=0;i<=m;i++)ans=(ans+s[i])%mod;
printf("%d\n",ans);
}
void solve5(){printf("%d\n",int(m>=n));}
void solve6(){printf("%d\n",s[m]);}
void solve7(){printf("%d\n",C(n+m-1,m-1));}
void solve8(){printf("%d\n",C(m,n));}
void solve9(){printf("%d\n",C(n-1,m-1));}
int ans[800010];
void solve10(){
static int tmp[800010];
for(int i=1;i<=m;i++)
for(int j=1;j*i<=n;j++)
ans[i*j]=(ans[i*j]-1ll*minv[j]*fac[j-1]%mod+mod)%mod;
int lim=1;
for(;lim<=n;lim<<=1);
exp(ans,tmp,lim);
for(int i=0;i<lim;i++)ans[i]=0;
inv(tmp,ans,lim);
printf("%d\n",ans[n]);
}
void solve11(){printf("%d\n",int(m>=n));}
void solve12(){
printf("%d\n",n-m>=0?ans[n-m]:0);
}
int main(){
scanf("%d%d",&n,&m);
fac[0]=1;
for(int i=1;i<=n+m;i++)fac[i]=1ll*fac[i-1]*i%mod;
minv[n+m]=pow(fac[n+m],mod-2);
for(int i=n+m;i;i--)minv[i-1]=1ll*minv[i]*i%mod;
solve1();
solve2();
solve3();
solve4();
solve5();
solve6();
solve7();
solve8();
solve9();
solve10();
solve11();
solve12();
return 0;
}