猎人杀[pkuwc2018]

题目链接

loj

题解

一道很不错的思维题。

想来想去想不到什么直接的做法。

先看看 $n\leq 20$ 该怎么做吧。

状态压缩,设 $f_S$ 表示 $S$ 集合先于 $1$ 号猎人死亡的概率。

直接 DP 就可以了。

接着继续想,还是想不到什么直接做的做法。

那么只能试试容斥了。

把 $f_S$ 的状态改一改,变成至少是 $S$ 中的人在 $1$ 之后死的概率。

那么答案即为:$\sum_S (-1)^Sf_S$

而 $f_S$ 显然与不在 $S$ 的且不是 $1$ 的猎人无关。

则 $f_S=\frac{w_1}{w_1+\sum_{i\in S} w_i}$

有一个很重要的数据范围是 $\sum_{i=1}^nw_i\leq 10^5$

那就是说 $f_S$ 的分母不会超过 $10^5$。

那或许可以试着枚举 $j$,然后求出所有满足 $\sum_{i\in S}w_i=j$ 的集合的 $(-1)^S$ 的和?

考虑每个 $w_i$ 的贡献,选相当于乘 $-1$,不会相当于乘 $1$。则可以设以下生成函数:$(1-x^{w_i})$

$j$ 的答案即为 $\prod_{i=2}^n (1-x^{w_i})$ 的第 $j$ 项的系数。

分治FFT 即可。

时间复杂度 $O(n\log ^2n)$

总结

一道思维较好的容斥。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
#define N 210000
#define G 3
#define mod 998244353ll
#define ll long long
inline int read()
{
int x=0,f=0; char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
inline ll Pow(ll x,int y)
{
ll ans=1;
for(;y;y>>=1,x=x*x%mod) if(y&1) ans=ans*x%mod;
return ans;
}
int w[N],n;
int siz[32],m;
int len,L,R[N];
ll a[32][N];
void ntt(ll *a,int n,int opt)
{
for(int i=1;i<n;i++) if(i>R[i]) swap(a[i],a[R[i]]);
for(int i=1;i<n;i<<=1)
{
ll wn=Pow(G,(mod-1)/(i<<1));
if(opt==-1) wn=Pow(wn,mod-2);
for(int j=0;j<n;j+=(i<<1))
{
ll w=1,x,y;
for(int k=0;k<i;k++,w=w*wn%mod)
x=a[j+k],y=a[i+j+k]*w%mod,
a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod;
}
}
if(opt==1) return;
ll invn=Pow(n,mod-2);
for(int i=0;i<n;i++) a[i]=a[i]*invn%mod;
}
inline void pre_ntt(int n)
{
for(L=0,len=1;len<=n;len<<=1,L++); L--;
for(int i=1;i<len;i++) R[i]=(R[i>>1]>>1)|((i&1)<<L);
}
void solve(int l,int r)
{
if(l==r)
{
siz[++m]=w[l];
a[m][0]=1; a[m][w[l]]=mod-1;
for(int i=1;i<w[l];i++) a[m][i]=0;
return;
}
int mid=l+r>>1;
solve(l,mid); solve(mid+1,r);
int ml=m-1,mr=m,s=siz[ml]+siz[mr];
pre_ntt(s);
for(int i=siz[ml]+1;i<len;i++) a[ml][i]=0;
for(int i=siz[mr]+1;i<len;i++) a[mr][i]=0;
ntt(a[ml],len,1); ntt(a[mr],len,1);
for(int i=0;i<len;i++) a[ml][i]=a[ml][i]*a[mr][i]%mod;
ntt(a[ml],len,-1);
siz[--m]=s;
}
int main()
{
n=read();
int sum=0; ll ans=0;
for(int i=1;i<=n;i++) sum+=(w[i]=read());
solve(2,n);
for(int i=0;i<=sum;i++) (ans+=a[1][i]*w[1]%mod*Pow(w[1]+i,mod-2))%=mod;
printf("%d",ans);
return 0;
}