经典题[20200415模拟]

题意

$m$ 个整数变量,求满足一下条件的正整数解的个数对 $998244353$ 取模后的结果。

  • $\forall i\in [1,n],1\leq x_i\leq T$
  • $\sum_{i=1}^m\leq S$。

$n\leq 10^{18},n\leq 10^9,nT\leq S,m-n\leq 10^5$

题解

考场上只能做到 $m-n=1$ 的情况。

这题太毒瘤了…

由于 $nT\leq S$,那么考虑暴力枚举前 $n$ 个数选了 $x_i$,那么这时,后面 $m-n$ 个数就有 $\binom{S-\sum_{i=1}^nx_i}{m-n}$ 种情况。

这个东西显然没法直接求,考虑这个组合数,我们把它转换成上升幂的形式。

$$ans=\frac{1}{(m-n)!}\sum_{x_1,x_2,\cdots,x_n}(S-(m-n)+1-\sum_{x=1}^mx_i)^{\overline{m-n}}$$

然后由第一类斯特林数的性质得到:

$$ans=\frac{1}{(m-n)!}\sum_{i=0}^{m-n}\left ^{m-n} _{\ \ \ i}\right ^{i}$$

而第一类斯特林数可以 $O(n\log n)$ 求出来。

以下为了方便,设 $k=m-n$,用 $S$ 代表原来的 $S-(m-n)+1$。

那么原式就变为:

$$\frac{1}{k!}\sum_{i=0}^k\left ^{k} _{i} \right ^{i}$$

由多项式定理暴力展开:

$$(S-\sum_{j=1}^nx_j)^i=\sum_{a_0+a_1+\cdots+a_n=i}\frac{i!}{a_0!a_1!\cdots a_n!}S^{a_0}(-x_1)^{a_1}\cdots (-x_n)^{a_n}$$

那么就可以用EGF来算一下。

设 $G(x)=\sum \frac{S^i}{i!}x^i,F(x)=\sum_{i=0}^{\infty}\frac{(-1)^i\sum_{j=1}^{T}j^i}{i!}x^i$

那么 $i$ 次方的答案就是 $[x^i] i!G(x)F^n(x)$。

也就是说,如果搞出了 $F(x)$,然后多项式快速幂就可以了。

剩下的问题是,对于所有的 $k$,计算 $\sum_{i=1}^Ti^k$。

可以用伯努利数来算。

由伯努利数的性质可以得到:

$$\sum_{i=0}^{n-1}i^k=\frac{1}{k+1}\sum_{i=0}^k\binom{k+1}{i}B_in^{k+1-i}$$

伯努利数可以用多项式求逆 $O(n\log n)$ 求出前几项。

上式也是一个卷积形式,很容易算出。

所以你只需要,求出第一类斯特林数的其中一列,多项式快速幂,以及伯努利数前 $n$ 项。

多项式快速幂用ln+exp搞,那么总的时间复杂度就是 $O(n\log n)$。

代码6k…

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <queue>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define ll long long
#define pb push_back
const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;}
const int M=1<<20;
ll W[M]; int R[M];
inline void PolyInit()
{
ll w;
for(int i=1;i<M;i<<=1)
{
W[i]=1; w=Pow(3,(mod-1)/2/i);
fo(j,1,i-1) W[i+j]=W[i+j-1]*w%mod;
}
}
typedef vector<ll> Poly;
inline void ntt(ll *a,int n,int t)
{
fo(i,0,n-1)
{
R[i]=(R[i>>1]>>1)|((i&1)*(n>>1));
if(i<R[i]) swap(a[i],a[R[i]]);
}
ll w;
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++)
w=W[i+k]*a[i+j+k]%mod,
a[i+j+k]=Dec(a[j+k],w),
a[j+k]=Add(a[j+k],w);
if(t==1) return;
reverse(a+1,a+n);
w=Pow(n,mod-2);
fo(i,0,n-1) a[i]=w*a[i]%mod;
}
inline void ntt(Poly &A,int n,int t){ntt(&A[0],n,t);}
inline Poly operator +(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Add(A[i],B[i]);
return A;
}
inline Poly operator -(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Dec(A[i],B[i]);
return A;
}
inline Poly df(Poly A)
{
fo(i,1,A.size()-1) A[i-1]=A[i]*i%mod;
A.resize(A.size()-1);
return A;
}
inline Poly jf(Poly A)
{
A.pb(0);
fd(i,A.size()-1,1) A[i]=A[i-1]*Pow(i,mod-2)%mod;
A[0]=0; return A;
}
inline Poly operator *(Poly A,ll k)
{
fo(i,0,A.size()-1) A[i]=Mul(A[i],k);
return A;
}
inline Poly operator *(Poly A,Poly B)
{
int n=A.size(),m=B.size(),k=n+m-1,len=1;
for(;len<k;len<<=1);
A.resize(len); ntt(A,len,1);
B.resize(len); ntt(B,len,1);
fo(i,0,len-1) A[i]=A[i]*B[i]%mod;
ntt(A,len,-1);
A.resize(k);
return A;
}
inline Poly operator ~(Poly f)
{
int n=f.size();
Poly g,h;
g.pb(Pow(f[0],mod-2));
int m=2;
for(;m<n;m<<=1)
{
h.resize(m<<1); g.resize(m<<1);
fo(i,0,m-1) h[i]=f[i];
ntt(h,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],h[i]),g[i]);
ntt(g,m<<1,-1); g.resize(m);
}
g.resize(m<<1); f.resize(m<<1);
ntt(f,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],f[i]),g[i]);
ntt(g,m<<1,-1); g.resize(n);
return g;
}
inline Poly ln(Poly A)
{
int n=A.size();
A=jf((~A)*df(A));
A.resize(n); return A;
}
inline Poly exp(Poly A)
{
int n=1; for(;n<A.size();n<<=1);
Poly B,C,D; B.clear(); B.pb(1);
for(int m=2;m<=n;m<<=1)
{
C=B; C.resize(m); D=A; D.resize(m);
C=D-ln(C); C[0]=Add(C[0],1);
B=B*C; B.resize(m);
}
B.resize(A.size()); return B;
}
inline Poly operator ^(Poly A,ll k)
{
if(!A.size()) return A;
ll tmp=A[0],w=Pow(tmp,k);
tmp=Pow(tmp,mod-2);
fo(i,0,A.size()-1) A[i]=Mul(A[i],tmp);
A=exp(ln(A)*k);
fo(i,0,A.size()-1) A[i]=Mul(A[i],w);
return A;
}

const int N=1e5+5;
ll fc[N],fv[N],iv[N];
inline void init(int n)
{
PolyInit();
fc[0]=1;
fo(i,1,n) fc[i]=fc[i-1]*i%mod;
fv[n]=Pow(fc[n],mod-2);
fd(i,n,1) fv[i-1]=fv[i]*i%mod;
iv[1]=1;
fo(i,2,n) iv[i]=(mod-mod/i)*iv[mod%i]%mod;
}
ll S,T,n,m,s[N];
Poly G,F,B;

ll d[M],c[M];
inline void calcS(int n)
{
if(!n) {s[0]=1; return;}
if(n==1) {s[1]=1; return;}
if(n&1)
{
calcS(n-1);
fd(i,n,1) s[i]=Add(s[i-1],Mul(s[i],n-1));
return;
}
else
{
calcS(n>>1); int l=n>>1,len;
for(len=1;len<=n;len<<=1);
d[0]=1; fo(i,1,l) d[i]=d[i-1]*l%mod;
fo(i,0,l) d[i]=d[i]*fv[i]%mod,c[i]=s[i]*fc[i]%mod;
reverse(&d[0],&d[l+1]);
ntt(d,len,1); ntt(c,len,1);
fo(i,0,len-1) d[i]=d[i]*c[i]%mod;
ntt(d,len,-1);
fo(i,0,l) d[i]=d[i+l]*fv[i]%mod;
fo(i,l+1,len-1) d[i]=c[i]=0;
fo(i,0,l) c[i]=s[i];
ntt(d,len,1); ntt(c,len,1);
fo(i,0,len-1) d[i]=d[i]*c[i]%mod;
ntt(d,len,-1);
fo(i,0,n) s[i]=d[i];
fo(i,0,len-1) d[i]=c[i]=0;
}
}
inline void calcB(int n)
{
n+=2;
B.resize(n+1);
fo(i,0,n-1) B[i]=fv[i+1];
B=~B;
fo(i,0,n) B[i]=B[i]*fc[i]%mod;
B.resize(n);
}
inline void calcF(int n)
{
T=(T+1)%mod;
ll tmp=T;
fo(i,1,n+1) d[i]=Mul(tmp,fv[i]),tmp=Mul(tmp,T);
fo(i,0,n) c[i]=Mul(B[i],fv[i]);
int len=1; for(;len<=n+n+1;len<<=1);
ntt(d,len,1); ntt(c,len,1);
fo(i,0,len-1) d[i]=d[i]*c[i]%mod;
ntt(d,len,-1);
F.resize(n+1);
fo(j,0,n) F[j]=Mul(d[j+1],fc[j]);
F[0]=Dec(F[0],1);
fo(j,0,n) F[j]=Mul(F[j],(j&1)?(mod-fv[j]):fv[j]);
}
inline ll solve()
{
if(S<m) return 0;
S=(S-(m-n)+1)%mod;
n=m-n;
calcS(n); calcB(n+1);
G.resize(n+1);
ll tmp=1,ans=0;
fo(i,0,n) G[i]=Mul(tmp,fv[i]),tmp=Mul(tmp,S);
calcF(n);
F=(F^(m-n))*G;
F.resize(n+1);
fo(i,0,n) ans=Add(ans,Mul(F[i],Mul(fc[i],s[i])));
return Mul(ans,fv[n]);
}
int main()
{
FO(count);
cin>>S>>T>>n>>m;
init(m-n+3);
printf("%lld",solve());
return 0;
}