1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| #include <cstdio> #include <iostream> #include <algorithm> #include <cstring> using namespace std; #define N 210000 #define G 3 #define mod 998244353ll #define ll long long inline int read() { int x=0,f=0; char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } inline ll Pow(ll x,int y) { ll ans=1; for(;y;y>>=1,x=x*x%mod) if(y&1) ans=ans*x%mod; return ans; } int w[N],n; int siz[32],m; int len,L,R[N]; ll a[32][N]; void ntt(ll *a,int n,int opt) { for(int i=1;i<n;i++) if(i>R[i]) swap(a[i],a[R[i]]); for(int i=1;i<n;i<<=1) { ll wn=Pow(G,(mod-1)/(i<<1)); if(opt==-1) wn=Pow(wn,mod-2); for(int j=0;j<n;j+=(i<<1)) { ll w=1,x,y; for(int k=0;k<i;k++,w=w*wn%mod) x=a[j+k],y=a[i+j+k]*w%mod, a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod; } } if(opt==1) return; ll invn=Pow(n,mod-2); for(int i=0;i<n;i++) a[i]=a[i]*invn%mod; } inline void pre_ntt(int n) { for(L=0,len=1;len<=n;len<<=1,L++); L--; for(int i=1;i<len;i++) R[i]=(R[i>>1]>>1)|((i&1)<<L); } void solve(int l,int r) { if(l==r) { siz[++m]=w[l]; a[m][0]=1; a[m][w[l]]=mod-1; for(int i=1;i<w[l];i++) a[m][i]=0; return; } int mid=l+r>>1; solve(l,mid); solve(mid+1,r); int ml=m-1,mr=m,s=siz[ml]+siz[mr]; pre_ntt(s); for(int i=siz[ml]+1;i<len;i++) a[ml][i]=0; for(int i=siz[mr]+1;i<len;i++) a[mr][i]=0; ntt(a[ml],len,1); ntt(a[mr],len,1); for(int i=0;i<len;i++) a[ml][i]=a[ml][i]*a[mr][i]%mod; ntt(a[ml],len,-1); siz[--m]=s; } int main() { n=read(); int sum=0; ll ans=0; for(int i=1;i<=n;i++) sum+=(w[i]=read()); solve(2,n); for(int i=0;i<=sum;i++) (ans+=a[1][i]*w[1]%mod*Pow(w[1]+i,mod-2))%=mod; printf("%d",ans); return 0; }
|