1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| #include <bits/stdc++.h> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cerr<<#x<<"="<<x<<endl #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define lowbit(x) ((x)&-(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second inline int read() { int x=0; char ch=getchar(); bool f=0; for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } #define CASET fo(___,1,read()) const ll mod=998244353; inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;} inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;} inline ll Mul(ll x,ll y){return x*y%mod;} inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;} const int N=3005; ll inv[N],a[N][4]; ll f[N][N],g[N]; int n,num[N],siz[N]; int ver[N],val[N],ne[N],head[N],tot; inline void add(int y,int x) { ver[++tot]=y; val[tot]=1; ne[tot]=head[x]; head[x]=tot; ver[++tot]=x; val[tot]=0; ne[tot]=head[y]; head[y]=tot; } void dfs(int u,int pre) { siz[u]=1; fo(i,1,3) f[u][i]=a[u][i]*i%mod; vector<int> V; ll t; for(int i=head[u],v;i;i=ne[i]) if((v=ver[i])!=pre) dfs(v,u),V.pb(v),num[v]=val[i]; sort(all(V),[&](const int &x,const int &y){return siz[x]<siz[y];}); for(auto v:V) { fo(i,1,siz[u]*3) fo(j,1,siz[v]*3) { t=f[u][i]*f[v][j]%mod; if(num[v]) g[i+j]=Add(g[i+j],t); else g[i+j]=Dec(g[i+j],t),g[i]=Add(g[i],t); } siz[u]+=siz[v]; fo(i,1,siz[u]*3) f[u][i]=g[i],g[i]=0; } fo(i,1,siz[u]*3) f[u][i]=Mul(f[u][i],inv[i]); } int main() { fo(i,1,N-1) inv[i]=Pow(i,mod-2); n=read(); fo(i,1,n) { int s=0; fo(j,1,3) s+=(a[i][j]=read()); s=Pow(s,mod-2); fo(j,1,3) a[i][j]=Mul(a[i][j],s); } fo(i,2,n) add(read(),read()); dfs(1,0); ll ans=0; fo(i,1,n*3) ans=Add(ans,f[1][i]); printf("%lld",ans); return 0; }
|