氪金手游[CTS2019]

题目链接

loj

题解

mdzz…刚开始没看到权值只能在 $[1,3]$ 之间…导致自闭了很久…

这题还挺简单的。

首先如果把题目的有向变成无向,那么就是一棵无根树。

假设权值已经固定了。

再假设它是一个外向树的结构,那么显然可以树形DP一下:

设 $f_u$ 表示子树内答案,有:$f_u=\prod_{v\in son_u}f_v\frac{w_u}{\sum_{v\in tree_u} w_v}$。

但是现在权值 $w_i$ 不固定,那么 $\sum_{v\in tree_u} w_v$ 也不固定,那么在DP的时候还需要记多一维表示子树的 $w_v$ 的和,这个还是 $O(n)$ 级别。那么时间复杂度 $O(n^2)$。

想到这里,你就会获得0分的好成绩。。。

那么当这棵树不是外向树的时候的概率该怎么算呢?

考虑容斥原理,假设你至少有 $i$ 条内向边改成了外向,其他内向边随意,那么对答案的贡献为 $(-1)^i$。

那么随意的外向边相当于断开,那么 $\sum_{v\in tree_u} w_v$ 这里就不需要算这棵断开了的子树。

只需要在树形DP的时候容斥,遇到一条内向边就考虑它是否改成外向,还是断开两种情况就可以了。

时间复杂度 $O(n^2)$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())
const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;}
const int N=3005;
ll inv[N],a[N][4];
ll f[N][N],g[N];
int n,num[N],siz[N];
int ver[N],val[N],ne[N],head[N],tot;
inline void add(int y,int x)
{
ver[++tot]=y; val[tot]=1; ne[tot]=head[x]; head[x]=tot;
ver[++tot]=x; val[tot]=0; ne[tot]=head[y]; head[y]=tot;
}
void dfs(int u,int pre)
{
siz[u]=1;
fo(i,1,3) f[u][i]=a[u][i]*i%mod;
vector<int> V; ll t;
for(int i=head[u],v;i;i=ne[i])
if((v=ver[i])!=pre)
dfs(v,u),V.pb(v),num[v]=val[i];
sort(all(V),[&](const int &x,const int &y){return siz[x]<siz[y];});
for(auto v:V)
{
fo(i,1,siz[u]*3) fo(j,1,siz[v]*3)
{
t=f[u][i]*f[v][j]%mod;
if(num[v]) g[i+j]=Add(g[i+j],t);
else g[i+j]=Dec(g[i+j],t),g[i]=Add(g[i],t);
}
siz[u]+=siz[v];
fo(i,1,siz[u]*3) f[u][i]=g[i],g[i]=0;
}
fo(i,1,siz[u]*3) f[u][i]=Mul(f[u][i],inv[i]);
}
int main()
{
fo(i,1,N-1) inv[i]=Pow(i,mod-2);
n=read();
fo(i,1,n)
{
int s=0;
fo(j,1,3) s+=(a[i][j]=read());
s=Pow(s,mod-2);
fo(j,1,3) a[i][j]=Mul(a[i][j],s);
}
fo(i,2,n) add(read(),read());
dfs(1,0);
ll ans=0;
fo(i,1,n*3) ans=Add(ans,f[1][i]);
printf("%lld",ans);
return 0;
}