CF479E Riding in a Lift gxy001

本文为 $O(n\log V)$ 的 wqs 二分+斜率优化做法。

下面这段仅用于凸性的证明,与做法无关

设 $w(l,r)$ 表示在 $\lfloor\frac{l+r}{2}\rfloor$ 处建立邮局,$l$ 到 $r$ 的村庄全部去往该邮局的距离之和,则有 dp 式子: \(f_{i,j}=\min_{0\le t<i}\{f_{t,j-1}+w(t+1,i)\}\) 由于 $w(l,r)$ 满足四边形不等式,则 $ans_m=f_{n,m}$ 是关于 $m$ 的凸函数,这一经典结论的证明参见此处

进行 wqs 二分,消去 $m$ 的限制,二分的斜率为 $K$。

关于 $K$ 的二分范围,这个凸包是单调减的下凸壳,考虑 $0\ge ansK=ans_{m+1}-ans_m\ge ans_{n}-ans_m=-ans_m$,由题目的保证知道 $ans_m\le 10^9$,所以二分范围取 $[-10^9,0]$ 即可。

设 $f_i$ 为已经考虑完了前 $i$ 个村庄的最小代价($s$ 为 $a$ 的前缀和),我们有: \(f_i=\min_{1\le j\le i}\{(s_i-s_j)-(i-j)a_j+\min_{0\le k<j}\{(j-k)a_j-(s_j-s_k)+f_k\}\}-K\) 令 $g_j=\min_{0\le k<j}{(j-k)a_j-(s_j-s_k)+f_k}$,$g$ 可以用斜率优化得到。

$f_i=\min_{1\le j\le i}{(s_i-s_j)-(i-j)a_j+g_j}-K$,$f$ 可以用斜率优化得到。

都是经典的斜率优化形式,具体做法就不加赘述了。

另外,提醒一下 wqs 二分时记得处理斜率相同的段。

#include<iostream>
using std::cin,std::cout;
int n,m;
long long a[500010],s[500010],f[500010],g[500010],cntf[500010],cntg[500010];
int check(long long K){
	auto x1=[](int x){return x;};
	auto y1=[](int x){return s[x]+f[x];};
	auto x2=[](int x){return a[x];};
	auto y2=[](int x){return -s[x]+x*a[x]+g[x];};
	f[0]=0,cntf[0]=0;
	static int q1[500010];
	int *hd1=q1,*tl1=q1;
	*tl1++=0;
	static int q2[500010];
	int *hd2=q2,*tl2=q2;
	for(int i=1;i<=n;i++){
		while(tl1-hd1>1){
			if(std::pair(y1(*hd1)-a[i]*x1(*hd1),cntf[*hd1])<std::pair(y1(hd1[1])-a[i]*x1(hd1[1]),cntf[hd1[1]])) break;
			else ++hd1;
		}
		cntg[i]=cntf[*hd1];
		g[i]=i*a[i]-s[i]+y1(*hd1)-a[i]*x1(*hd1);
		while(tl2-hd2>1){
			if(std::pair((y2(i)-y2(tl2[-1]))*(x2(tl2[-1])-x2(tl2[-2])),cntg[i])>std::pair((y2(tl2[-1])-y2(tl2[-2]))*(x2(i)-x2(tl2[-1])),cntg[tl2[-1]])) break;
			else --tl2;
		}
		*tl2++=i;
		while(tl2-hd2>1){
			if(std::pair(y2(*hd2)-i*x2(*hd2),cntg[*hd2])<std::pair(y2(hd2[1])-i*x2(hd2[1]),cntg[hd2[1]])) break;
			else ++hd2;
		}
		cntf[i]=cntg[*hd2]+1;
		f[i]=-K+s[i]+y2(*hd2)-i*x2(*hd2);
		while(tl1-hd1>1){
			if(std::pair((y1(i)-y1(tl1[-1]))*(x1(tl1[-1])-x1(tl1[-2])),cntf[i])>std::pair((y1(tl1[-1])-y1(tl1[-2]))*(x1(i)-x1(tl1[-1])),cntf[tl1[-1]])) break;
			else --tl1;
		}
		*tl1++=i;
	}
	return cntf[n];
}
int main(){
	cin.tie(nullptr)->sync_with_stdio(false);
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>a[i],s[i]=a[i]+s[i-1];
	int l=-1e9,r=0;
	long long ans=0;
	while(l<=r){
		int mid=(l+r)>>1;
		if(check(mid)<=m) l=mid+1,ans=mid;
		else r=mid-1;
	}
	check(ans);
	cout<<f[n]+m*ans<<'\n';
	return 0;
}