1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| #include <bits/stdc++.h> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cerr<<#x<<"="<<x<<endl #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define lowbit(x) ((x)&-(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second inline int read() { int x=0; char ch=getchar(); bool f=0; for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } #define CASET fo(___,1,read()) const int N=3e5+10;
int n,k;
ll a[N],s[N]; ll f[N]; int h[N]; inline ll calc(int l,int r) { l--; if(l>r) return 0; return s[r]+s[l]-s[(l+r)>>1]-s[(l+r+1)>>1]; } inline ll g(int l,int r) { return f[l]+calc(l+1,r); } inline bool check(int x,int y,int r) { ll X=g(x,r),Y=g(y,r); if(X!=Y) return X>Y; return h[x]>h[y]; } int le[N],bel[N],top,st; inline int find(int i) { if(st>top) return n; int l=le[top],r=n,mid; for(;l<=r;) { mid=(l+r)>>1; if(check(bel[top],i,mid)) r=mid-1; else l=mid+1; } return l; } inline ll calc(ll x) { f[0]=0; h[0]=0; le[1]=1; bel[1]=0; st=top=1; fo(i,1,n) { if(st<top&&i==le[st+1]) st++; f[i]=g(bel[st],i)+x; h[i]=h[bel[st]]+1; for(;top>=st&&check(bel[top],i,le[top]);) top--; int tmp=find(i); if(tmp<=n) ++top,le[top]=tmp,bel[top]=i; } return f[n]; }
int main() { n=read(); k=read(); if(n==k) {puts("0"); return 0;} fo(i,1,n) a[i]=read(),s[i]=s[i-1]+a[i]; ll l=0,r=3e15,mid,tmp,ans=9e18; int now=n+1; for(;l+1<r;) { mid=(l+r)/2; tmp=calc(mid); if(h[n]<=k) r=mid,ans=tmp-1ll*h[n]*mid; else l=mid; } tmp=calc(r); ans=tmp-1ll*k*r; printf("%lld\n",ans); return 0; }
|