Tournament[2018-2019 ACM-ICPC, Asia Nanjing Regional Contest, B]

链接

B题

题解

考虑DP,设 $f_{i,j}$ 表示考虑到第 $i$ 个点,$1\sim i$ 中一共放了 $j$ 个的最小答案。

$O(n^3)$ 转移,答案即为 $f_{n,k}$。

看到必须放 $k$ 个,可想到wqs二分,打表可知 $f_{n,*}$ 是凸的。

于是二分一个 $x$,让每次放一个则额外加多 $x$ 的贡献,DP时记录去最小值的时候最多放了多少个,找到放 $k$ 个的答案。

这样做的好处是将DP的第二维的限制给去掉了。

但是这个DP还是 $O(n^2)$ 的。

打表发现他有决策单调性,于是写个决策单调性优化DP就可以了。

时间复杂度是两个log的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())
const int N=3e5+10;

int n,k;

ll a[N],s[N];
ll f[N];
int h[N];
inline ll calc(int l,int r)
{
l--;
if(l>r) return 0;
return s[r]+s[l]-s[(l+r)>>1]-s[(l+r+1)>>1];
}
inline ll g(int l,int r)
{
return f[l]+calc(l+1,r);
}
inline bool check(int x,int y,int r)
{
ll X=g(x,r),Y=g(y,r);
if(X!=Y) return X>Y;
return h[x]>h[y];
}
int le[N],bel[N],top,st;
inline int find(int i)
{
if(st>top) return n;
int l=le[top],r=n,mid;
for(;l<=r;)
{
mid=(l+r)>>1;
if(check(bel[top],i,mid)) r=mid-1;
else l=mid+1;
}
return l;
}
inline ll calc(ll x)
{
f[0]=0; h[0]=0;
le[1]=1; bel[1]=0;
st=top=1;
fo(i,1,n)
{
if(st<top&&i==le[st+1]) st++;
f[i]=g(bel[st],i)+x; h[i]=h[bel[st]]+1;
for(;top>=st&&check(bel[top],i,le[top]);) top--;
int tmp=find(i);
if(tmp<=n) ++top,le[top]=tmp,bel[top]=i;
}
return f[n];
}

int main()
{
n=read(); k=read();
if(n==k) {puts("0"); return 0;}
fo(i,1,n) a[i]=read(),s[i]=s[i-1]+a[i];
ll l=0,r=3e15,mid,tmp,ans=9e18;
int now=n+1;
for(;l+1<r;)
{
mid=(l+r)/2;
tmp=calc(mid);
if(h[n]<=k) r=mid,ans=tmp-1ll*h[n]*mid;
else l=mid;
}
tmp=calc(r);
ans=tmp-1ll*k*r;
printf("%lld\n",ans);
return 0;
}