Tree[CF1010F]

题目链接

链接

题意

给你一棵有根二叉树,现在你要在上面填数。

根节点为 $1$,且已经填上了数字 $x$。

树边可以任意断掉,最后保留存在能到根节点路径的节点所形成的子树。

对于子树中的这些节点 $u$,满足它所有儿子所填的数之和小于等于 $u$ 填的数。

问最后有多少种形态。两种形态不同当且仅当子树的节点集合不同;或者集合相同,但存在一个节点 $u$,使得在两种方案中所填之数不同。

模 $998244353$。

$n\leq 10^5$,时限7s。

题解

这题也是够毒瘤了QwQ。

先来看看不能断树边该怎么算。

设 $a_u$ 为节点 $u$ 所填的数字。对于所有的 $u$ ,要使得满足它所有儿子所填的数之和小于等于 $u$ 填的数,我们转换条件,设:$b_u=a_u-\sum_{v\in son_u}a_v$。

那么就会有 $b_u\geq 0$,且可以发现,每个不同的 $b$ 一一对应这不同的 $a$。

并且还有 $\sum_{i=1}^nb_i=n$。

那么有插板法可得总方案数为 $\binom{n+x-1}{n}$。这个 $n$ 表示树的节点个数。

现在题目转换成,对于每个 $i$,求有多少个包含 $1$ 号点的连通子图,且大小为 $i$。

显然我们可以树形DP,设 $f_{u,i}$ 表示子树 $u$ 中节点个数为 $i$ 的方案数。

当有两个儿子的时候,有:$f_{u,i}=\sum_{j+k=i-1}f_{v,j}\times f_{w,k},f_{u,0}=1$。

设 $f_u$ 的OGF为 $F_u$,上式写成生成函数的形式是:$F_u=xF_vF_w+1$。

当 $u$ 为空的时候,$F_u=1$。

显然上式用FFT优化,但显然还是没有用的。

那就先考虑链吧?这个很简单。

链中每个点插多一个子树呢?

那么我对于每个点 $u$,设插进去的这个子树算出来的生成函数乘以 $x$ 后的式子为 $g_u$。

对于这条链,考虑从上往下计算,则有:$F_u=F_{son_u}g_u+1$

我们不妨将这条链从上往下标号为 $1,2\cdots k$。那么最终的生成函数就是:

$$F=(g_1(g_2(g_3\cdots)+1))+1\=\sum_{i=0}^k\prod_{j=1}^ig_j$$

考虑分治计算这个东西,设当前的分治结构为 $[l,r]$,$s=\prod_{i=l}^rg_i$,$F=\sum_{i=l-1}^r\prod_{j=l}^ig_j$,那么有:

$$s_{[l,r]}=s_{[l,mid]}\times s_{[mid+1,r]}\F_{[l,r]}=(F_{[l,mid]}-1)\times s_{[mid+1,r]}+F_{[mid+1,r]}$$

可以用ntt优化上面的乘法。

那么可以先树链剖分进行链分治,对于每条链,我们按照上面的方法算。设链顶的节点的子树大小为 $siz$ ,那么上面的方法的复杂度就是 $O(siz\log ^2siz)$。

由于每个节点往上跳只会统计最多 $\log n$ 次,那么总的复杂度就是 $O(n\log ^3n)$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","W",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;}
const int N=100005;
const int M=1<<18;
typedef vector<ll> Poly;
namespace P{
int R[M]; ll W[M];
inline void PolyInit()
{
ll wn;
for(int i=1;i<M;i<<=1)
{
W[i]=1; wn=Pow(3,(mod-1)/2/i);
fo(j,1,i-1) W[i+j]=W[i+j-1]*wn%mod;
}
}
inline void ntt(ll *a,int n,int opt)
{
ll w;
for(int i=0;i<n;i++)
{
R[i]=(R[i>>1]>>1)|((i&1)*(n>>1));
if(i<R[i]) swap(a[i],a[R[i]]);
}
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++)
w=W[i+k]*a[i+j+k]%mod,
a[i+j+k]=Dec(a[j+k],w),
a[j+k]=Add(a[j+k],w);
if(opt==1) return;
reverse(a+1,a+n);
w=Pow(n,mod-2);
fo(i,0,n-1) a[i]=w*a[i]%mod;
}
inline void ntt(Poly &A,int n,int opt) {ntt(&A[0],n,opt);}
inline Poly operator*(Poly A,Poly B)
{
int n=A.size(),m=B.size(),k=n+m-1,len=1;
for(;len<=k;len<<=1);
A.resize(len); B.resize(len);
ntt(A,len,1); ntt(B,len,1);
fo(i,0,len-1) A[i]=1ll*A[i]*B[i]%mod;
ntt(A,len,-1);
A.resize(k);
return A;
}
inline Poly operator+(const Poly &A,const Poly &B)
{
Poly C=A;
C.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) C[i]=Add(C[i],B[i]);
return C;
}
}
using namespace P;
int n;
namespace Tree{
int son[N],siz[N],oth[N];
vector<int> adj[N];
inline void add(int x,int y)
{
adj[x].pb(y); adj[y].pb(x);
}
void dfs1(int u,int pre)
{
siz[u]=1;
for(auto v:adj[u]) if(v!=pre)
{
dfs1(v,u);
oth[u]^=v;
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
oth[u]^=son[u];
}
Poly p[N],w[N];
inline void solve(int l,int r,Poly &f,Poly &g)
{
if(l==r){f=g=w[l]; return;}
int mid=l+r>>1;
Poly lf,lg,rf,rg;
solve(l,mid,lf,lg); solve(mid+1,r,rf,rg);
g=lg*rg; f=rf*lg+lf;
}
Poly dfs2(int u)
{
int m=0;
for(int v=u;v;v=son[v])
{
p[v].clear();
if(oth[v]) p[v]=dfs2(oth[v]);
if(p[v].empty()) p[v].pb(0);
p[v][0]++; p[v].insert(p[v].begin(),0);
w[++m].swap(p[v]);
}
Poly f,g;
solve(1,m,f,g);
return f;
}
Poly f;
inline ll work(ll x)
{
dfs1(1,0);
f=dfs2(1);
ll ans=0,now=1;
fo(i,1,n)
{
ans=Add(ans,Mul(now,f[i]));
now=Mul(now,Mul(Add(x,i),Pow(i,mod-2)));
}
return ans;
}
}
int main()
{
PolyInit();
n=read();
ll x; scanf("%lld",&x); x%=mod;
fo(i,2,n) Tree::add(read(),read());
printf("%lld",Tree::work(x));
return 0;
}