如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
本文为 $O(n\log V)$ 的 wqs 二分+斜率优化做法。
下面这段仅用于凸性的证明,与做法无关:
设 $w(l,r)$ 表示在 $\lfloor\frac{l+r}{2}\rfloor$ 处建立邮局,$l$ 到 $r$ 的村庄全部去往该邮局的距离之和,则有 dp 式子: \(f_{i,j}=\min_{0\le t<i}\{f_{t,j-1}+w(t+1,i)\}\) 由于 $w(l,r)$ 满足四边形不等式,则 $ans_m=f_{n,m}$ 是关于 $m$ 的凸函数,这一经典结论的证明参见此处。
进行 wqs 二分,消去 $m$ 的限制,二分的斜率为 $K$。
关于 $K$ 的二分范围,这个凸包是单调减的下凸壳,考虑 $0\ge ansK=ans_{m+1}-ans_m\ge ans_{n}-ans_m=-ans_m$,由题目的保证知道 $ans_m\le 10^9$,所以二分范围取 $[-10^9,0]$ 即可。
设 $f_i$ 为已经考虑完了前 $i$ 个村庄的最小代价($s$ 为 $a$ 的前缀和),我们有: \(f_i=\min_{1\le j\le i}\{(s_i-s_j)-(i-j)a_j+\min_{0\le k<j}\{(j-k)a_j-(s_j-s_k)+f_k\}\}-K\) 令 $g_j=\min_{0\le k<j}{(j-k)a_j-(s_j-s_k)+f_k}$,$g$ 可以用斜率优化得到。
$f_i=\min_{1\le j\le i}{(s_i-s_j)-(i-j)a_j+g_j}-K$,$f$ 可以用斜率优化得到。
都是经典的斜率优化形式,具体做法就不加赘述了。
另外,提醒一下 wqs 二分时记得处理斜率相同的段。
#include<iostream>
using std::cin,std::cout;
int n,m;
long long a[500010],s[500010],f[500010],g[500010],cntf[500010],cntg[500010];
int check(long long K){
auto x1=[](int x){return x;};
auto y1=[](int x){return s[x]+f[x];};
auto x2=[](int x){return a[x];};
auto y2=[](int x){return -s[x]+x*a[x]+g[x];};
f[0]=0,cntf[0]=0;
static int q1[500010];
int *hd1=q1,*tl1=q1;
*tl1++=0;
static int q2[500010];
int *hd2=q2,*tl2=q2;
for(int i=1;i<=n;i++){
while(tl1-hd1>1){
if(std::pair(y1(*hd1)-a[i]*x1(*hd1),cntf[*hd1])<std::pair(y1(hd1[1])-a[i]*x1(hd1[1]),cntf[hd1[1]])) break;
else ++hd1;
}
cntg[i]=cntf[*hd1];
g[i]=i*a[i]-s[i]+y1(*hd1)-a[i]*x1(*hd1);
while(tl2-hd2>1){
if(std::pair((y2(i)-y2(tl2[-1]))*(x2(tl2[-1])-x2(tl2[-2])),cntg[i])>std::pair((y2(tl2[-1])-y2(tl2[-2]))*(x2(i)-x2(tl2[-1])),cntg[tl2[-1]])) break;
else --tl2;
}
*tl2++=i;
while(tl2-hd2>1){
if(std::pair(y2(*hd2)-i*x2(*hd2),cntg[*hd2])<std::pair(y2(hd2[1])-i*x2(hd2[1]),cntg[hd2[1]])) break;
else ++hd2;
}
cntf[i]=cntg[*hd2]+1;
f[i]=-K+s[i]+y2(*hd2)-i*x2(*hd2);
while(tl1-hd1>1){
if(std::pair((y1(i)-y1(tl1[-1]))*(x1(tl1[-1])-x1(tl1[-2])),cntf[i])>std::pair((y1(tl1[-1])-y1(tl1[-2]))*(x1(i)-x1(tl1[-1])),cntf[tl1[-1]])) break;
else --tl1;
}
*tl1++=i;
}
return cntf[n];
}
int main(){
cin.tie(nullptr)->sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i],s[i]=a[i]+s[i-1];
int l=-1e9,r=0;
long long ans=0;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)<=m) l=mid+1,ans=mid;
else r=mid-1;
}
check(ans);
cout<<f[n]+m*ans<<'\n';
return 0;
}