P6624 [省选联考 2020 A 卷] 作业题 gxy001

根据套路,先去掉烦人的 $\gcd$。 \(\begin{aligned} val(T)&=\left(\sum\limits_{i=1}^{n-1} w_{e_i}\right) \times \gcd(w_{e_1},w_{e_2},\dots,w_{e_{n-1}})\\&=\left(\sum\limits_{i=1}^{n-1} w_{e_i}\right) \times \sum\limits_{d\mid w_{e_1}\land\dots\land d\mid w_{e_{n-1}}}\varphi(d)\\&=\left(\sum\limits_{i=1}^{n-1} w_{e_i}\right)\sum\limits_{d=1}^{max_v}\varphi(d)[d\mid w_{e_1}\land\dots\land d\mid w_{e_{n-1}}] \end{aligned}\)

我们可以对于每一个 $d$,将边权为 $d$ 的倍数的边加入图中,然后只需要求所有生成树的边权和即可 ,然后就不会了

我们可以利用矩阵树定理求得所有边权和。

我们知道,矩阵树定理求的是生成树上所有边乘积的和,我们可以将边权为 $w$ 的边的边权变为 $wx+1$,最后求得的结果的一次项系数就是边权和。

时间复杂度 $O(n^3max_v)$,不知道能不能过,我们发现实际上并不需要对每种 $d$ 都跑一遍矩阵树,只对边数大于等于 $n-1$ 的 $d$ 跑即可,然后就能过了。

更多细节都在代码里了。

#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#include<utility>
#define int long long
long long const mod=998244353;
long long pow(long long x,long long y){
	long long res=1;
	while(y){
		if(y&1) res=res*x%mod;
		x=x*x%mod; 
		y>>=1;
	}
	return res;
}
struct node{
	long long a,b;
	node(long long const &_a,long long const &_b):a(_a),b(_b){}
	node():a(),b(){}
	friend node operator +(node const &x,node const &y){
		return node((x.a+y.a)%mod,(x.b+y.b)%mod);
	}
	friend node operator -(node const &x,node const &y){
		return node((x.a-y.a+mod)%mod,(x.b-y.b+mod)%mod);
	}
	friend node operator *(node const &x,node const &y){
		return node((x.a*y.b%mod+y.a*x.b%mod)%mod,x.b*y.b%mod);
	}
	friend node operator /(node const &x,node const &y){
		long long inv=pow(y.b,mod-2);
		return node((x.a*y.b%mod-x.b*y.a%mod+mod)*inv%mod*inv%mod,x.b*inv%mod);
	}
	node& operator +=(node const &x){
		*this=*this+x;
		return *this;
	}
	node& operator -=(node const &x){
		*this=*this-x;
		return *this;
	}
	node& operator *=(node const &x){
		*this=*this*x;
		return *this;
	}
	node& operator /=(node const &x){
		*this=*this/x;
		return *this;
	}
}f[50][50];
int n,m,u[1010],v[1010],w[1010],iinv[152502];
std::vector<int> p[152502];
long long ans;
long long phi(long long x){
	long long res=x;
	for(int i=2;i*i<=x;i++)
		if(x%i==0){
			res=res*(i-1)%mod;
			if(!iinv[i]) iinv[i]=pow(i,mod-2);
			res=res*iinv[i]%mod;
			while(x%i==0) x/=i;
		} 
	if(x!=1){
		res=res*(x-1)%mod;
		if(!iinv[x]) iinv[x]=pow(x,mod-2);
		res=res*iinv[x]%mod;
	}
	return res;
}
signed main(){
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=m;i++){
		scanf("%lld%lld%lld",u+i,v+i,w+i);
		for(int j=1;j*j<=w[i];j++)
			if(w[i]%j==0){
				p[j].push_back(i);
				if(j*j!=w[i]) p[w[i]/j].push_back(i);
			}
	}
	for(int i=1;i<=152501;i++)
		if((int)(p[i].size())>=n-1){
			memset(f,0,sizeof(f));
			for(std::vector<int>::iterator it=p[i].begin();it!=p[i].end();++it){
				int u=::u[*it],v=::v[*it],w=::w[*it];
				f[u][v]-=node(w,1);
				f[v][u]-=node(w,1);
				f[u][u]+=node(w,1);
				f[v][v]+=node(w,1);
			}
			node ans(0,1);
			int rev=0;
			for(int i=1;i<n;i++){
				if(!f[i][i].b)
					for(int j=i+1;j<n;++j)
						if(f[j][i].b){
							rev^=1;
							std::swap(f[i],f[j]);
							break;
						}
				node inv=node(0,1)/f[i][i];
				ans=ans*f[i][i];
				for(int j=i;j<n;j++) f[i][j]=f[i][j]*inv;
				for(int j=1;j<n;j++)
					if(j!=i){
						node tmp=f[j][i];
						for(int k=i;k<n;k++)
							f[j][k]=f[j][k]-f[i][k]*tmp;
					}
			}
			if(rev) ans=node(0,0)-ans;
			::ans=(::ans+ans.a*phi(i)%mod)%mod;
		}
	printf("%lld\n",ans);
	return 0;
}