1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include <map> #include <set> #include <cmath> #include <queue> #include <bitset> #include <cstdio> #include <vector> #include <cstring> #include <iostream> #include <algorithm> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cerr<<#x<<"="<<x<<endl #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define lowbit(x) ((x)&-(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second const ll mod=65537; inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;} inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;} inline ll Mul(ll x,ll y){return x*y%mod;} inline ll Pow(ll x,ll y) { y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod; return ans; } ll n,m;
namespace CM{ const int N=5010; ll f[N],g[N],h[N],b[N],c[N]; inline void mul(ll *a,ll *b,int k) { fo(i,0,k) fo(j,0,k) c[i+j]=Add(c[i+j],a[i]*b[j]%mod); fd(i,k<<1,k) fo(j,0,k) c[i+j-k]=Dec(c[i+j-k],c[i]*g[j]%mod); fo(i,0,k-1) a[i]=c[i],c[i]=0; } inline void Ppow(ll *a,ll n,int m) { a[0]=1; b[1]=1; for(;n;n>>=1,mul(b,b,m)) if(n&1) mul(a,b,m); } inline ll solve() { h[0]=1; fo(i,1,m) h[i]=(h[i-1]*2)%mod; if(n<m) return h[n]; g[m]=1; ff(i,0,m) g[i]=mod-1; Ppow(f,n,m); ll ans=0; fo(i,0,m-1) ans=Add(ans,f[i]*h[i]%mod); return ans; } } namespace Lucas{ ll fac[mod],inv[mod]; inline ll C(ll n,ll m) { if(n<m) return 0; if(n<mod) return fac[n]*inv[m]%mod*inv[n-m]%mod; return C(n%mod,m%mod)*C(n/mod,m/mod)%mod; } inline ll work(ll n) { ll tmp=Pow(Pow(2,m+1),mod-2),now=Pow(2,n),ans=0; fo(i,0,n/(m+1)) { if(i&1) ans-=C(n-m*i,i)*now%mod; else ans+=C(n-m*i,i)*now%mod; now=now*tmp%mod; } return (ans%mod+mod)%mod; } inline ll solve() { fac[0]=1; fo(i,1,mod-1) fac[i]=fac[i-1]*i%mod; inv[mod-1]=Pow(fac[mod-1],mod-2); fd(i,mod-1,1) inv[i-1]=inv[i]*i%mod; return (work(n+1)-work(n)+mod)%mod; } }
int main() { cin>>n>>m; if(m==1) printf("1"); else printf("%lld",(m<=2500)?CM::solve():Lucas::solve()); return 0; }
|