08 Mar 2021 2892字 10分 次 网络流, 费用流, 线段树, and 数据结构
如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外) 首先,这题有一个显然的费用流做法,连边如下:
- S 向 i 连边,费用为 ai,流量为 1。(1≤i≤n)
- i 向 T’ 连边,费用为 bi,流量为 1。(1≤i≤n)
- i 向 i+1 连边,费用为 0,流量为 inf。(1≤i<n)
- T’ 向 T 连边,费用为 0,流量为 k。
S 向 T 跑最小费用最大流,费用即为答案,时间复杂度 O(knm)。
我们发现复杂度瓶颈在于求 k 次最短路,考虑优化最短路,我们注意到,可以使用线段树优化这个过程,这样找最短路的复杂度就被优化为 O(logn),增广的过程也可以修改线段树复杂度也是 O(logn)。总复杂度 O(klogn)。
接下来讲解如何用线段树维护,我们维护从左向右的最短路,从右向左的最短路,从右向左的最短可流路,ai 的最小值,bi 的最小值,可以接受来自右边流量的最小 bi,可以向左流的最小 ai,从右向左的流量。
找最短路可以直接访问根节点的信息,得到最短路 S→u→v→T’→T,增广可以把 au 和 bv 设为 inf,区间修改从右向左的流量。
这样我们就获得了 O(klogn) 的做法。
#include<cstdio>
int n,k,a[500010],b[500010];
struct node{
struct path{
int x,y;
path():x(),y(){}
path(int const &a,int const &b):x(a),y(b){}
path operator +(path const &k)const{
return a[x]+b[y]<a[k.x]+b[k.y]?*this:k;
}
}va,vb,vc;
int aa,ab,ba,bb,vm,tag;
node():va(),vb(),vc(),aa(),ab(),ba(),bb(),vm(),tag(){}
friend node operator + (node const &l,node const &r){
node x;
x.va=l.va+r.va+path(l.aa,r.ab);
x.vc=l.vc+r.vc+path(r.aa,l.ab);
x.aa=a[l.aa]<a[r.aa]?l.aa:r.aa;
x.ab=b[l.ab]<b[r.ab]?l.ab:r.ab;
x.vm=l.vm>r.vm?r.vm:l.vm;
if(l.vm<r.vm){
x.vb=l.vb+r.vb+r.vc+path(r.aa,l.bb);
x.ba=l.ba;
x.bb=b[r.ab]<b[l.bb]?r.ab:l.bb;
}else if(l.vm>r.vm){
x.vb=l.vb+r.vb+l.vc+path(r.ba,l.ab);
x.ba=a[r.ba]<a[l.aa]?r.ba:l.aa;
x.bb=r.bb;
}else{
x.vb=l.vb+r.vb+path(r.ba,l.bb);
x.ba=l.ba;
x.bb=r.bb;
}
return x;
}
}tr[2000010];
inline void add(int const &x,int const &p){tr[x].tag+=p,tr[x].vm+=p;}
inline void pushdown(int const &x){if(tr[x].tag)add(x<<1,tr[x].tag),add(x<<1|1,tr[x].tag),tr[x].tag=0;}
void build(int const &x=1,int const &l=0,int const &r=n){
if(l==r) return tr[x].va=tr[x].vc=node::path(l,l),tr[x].aa=tr[x].ab=tr[x].ba=l,void();
int mid=(l+r)>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
tr[x]=tr[x<<1]+tr[x<<1|1];
}
void update(int const &pl,int const &pr,int const &v,int const &x=1,int const &l=0,int const &r=n){
if(l==pl&&r==pr) return add(x,v);
pushdown(x);
int mid=(l+r)>>1;
if(pr<=mid) update(pl,pr,v,x<<1,l,mid);
else if(pl>mid) update(pl,pr,v,x<<1|1,mid+1,r);
else update(pl,mid,v,x<<1,l,mid),update(mid+1,pr,v,x<<1|1,mid+1,r);
tr[x]=tr[x<<1]+tr[x<<1|1];
}
void upd(int const &p,int const &x=1,int const &l=0,int const &r=n){
if(l==r) return;
pushdown(x);
int mid=(l+r)>>1;
if(p<=mid) upd(p,x<<1,l,mid);
else upd(p,x<<1|1,mid+1,r);
tr[x]=tr[x<<1]+tr[x<<1|1];
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)scanf("%d",a+i);
for(int i=1;i<=n;i++)scanf("%d",b+i);
a[0]=b[0]=0x3f3f3f3f;
build();
long long ans=0;
while(k--){
node::path t=tr[1].va+tr[1].vb;
int i=t.x,j=t.y;
ans+=a[i]+b[j];
if(i<j)update(i,j-1,1);
if(i>j)update(j,i-1,-1);
a[i]=0x3f3f3f3f,upd(i);
b[j]=0x3f3f3f3f,upd(j);
}
printf("%lld\n",ans);
return 0;
}