1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
| #include <bits/stdc++.h> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cerr<<#x<<"="<<x<<endl #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define lowbit(x) ((x)&-(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second const ll mod=998244353; inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;} inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;} inline ll Mul(ll x,ll y){return x*y%mod;} inline ll Pow(ll x,ll y) { y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod; return ans; } const int N=5e5+5; ll fac[N],inv[N]; int n,k,R,B; inline ll C(int n,int m) { if(n<0||m<0||n-m<0) return 0; return fac[n]*inv[m]%mod*inv[n-m]%mod; } void init(int n) { fac[0]=1; fo(i,1,n) fac[i]=fac[i-1]*i%mod; inv[n]=Pow(fac[n],mod-2); fd(i,n,1) inv[i-1]=inv[i]*i%mod; } ll ans; int main() { cin>>n>>k; if(n>k) return puts("0")&0; init(k); fo(R,0,k) { B=k-R; if(R<B) continue; if(R>=B+n) {ans=Add(ans,C(R+B,R)); continue;} if(R==B) B--; ans=Add(ans,Dec(C(R+B,R),C(R+B,2*R-n+1))); } printf("%lld",ans); return 0; }
|