1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include <bits/stdc++.h> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cerr<<#x<<"="<<x<<endl #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define lowbit(x) ((x)&-(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second inline int read() { int x=0; char ch=getchar(); bool f=0; for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } #define CASET fo(___,1,read()) const int N=3e5+5; ll l,r,mid; int n,k; struct node { ll v; int k; friend inline bool operator<(const node &A,const node &B) { if(A.v==B.v) return A.k<B.k; return A.v<B.v; } friend inline node operator+(const node &A,const node &B) {return (node){A.v+B.v,A.k+B.k};} friend inline node operator+(const node &A,const ll &B) {return (node){A.v+B,A.k};} }; inline node change(const node &A) {return (node){A.v-mid,A.k+1};} node f[N][3]; int ver[N<<1],val[N<<1],ne[N<<1],head[N],tot=1; inline void add(int z,int y,int x) { ver[++tot]=y; val[tot]=z; ne[tot]=head[x]; head[x]=tot; ver[++tot]=x; val[tot]=z; ne[tot]=head[y]; head[y]=tot; } void dfs(int u,int pre) { f[u][0]=(node){0,0}; f[u][1]=(node){0,0}; f[u][2]=(node){-mid,1}; for(int i=head[u],v;i;i=ne[i]) if((v=ver[i])!=pre) { dfs(v,u); f[u][2]=max(f[u][2]+f[v][0],change(f[u][1]+f[v][1]+val[i])); f[u][1]=max(f[u][1]+f[v][0],f[u][0]+f[v][1]+val[i]); f[u][0]=f[u][0]+f[v][0]; } f[u][0]=max(max(f[u][0],f[u][2]),change(f[u][1])); } int main() { n=read(); k=read()+1; fo(i,2,n) add(read(),read(),read()); r=1000000ll*n; l=-r; for(;l<=r;) { mid=(l+r)>>1; dfs(1,0); if(f[1][0].k>k) l=mid+1; else r=mid-1; } mid=l; dfs(1,0); printf("%lld",f[1][0].v+mid*k); return 0; }
|