Number of Binominal Coefficients[CF582D]

链接

luogu

题解

一个组合数能否整除 $p^k$(其中 $p$ 是质数),这个可以算出组合数的含 $p$ 的次幂,然后判断是否大于等于 $k$ 即可。

由kummer定理,$\binom{n+m}{m}$ 的含 $p$ 的次幂等于 $n+m$ 在 $p$ 进制意义下的进位次数。

由于 $\binom{n}{k}=\binom{(n-k)+k}{k}$,那么题目转换成,问有多少对 $0\le a,b\leq A$,且 $a+b$ 进位次数大于等于 $k$。

可以注意到,$\alpha \le 10^9$ 是吓人的。当极限数据 $p=2,A=10^{1000}$ 时,算出最小的 $\alpha$,满足 $p^{\alpha}\geq A$,发现不超过 $3400$。

那么就可以进行数位DP了。

设 $f_{i,j,0/1,0/1}$ 表示从高位到低位考虑到第 $i$ 位,此时(包括 $i-1$ 位进上来的)已经有 $j$ 个进位,第 $i-1$ 位是否进上来第 $i$ 位,数位是否一直取最大值的方案数。

转移分四类大力讨论,算一堆东西即可。

具体见程序。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
const int N=3405;
const ll mod=1e9+7;
ll p;
int alpha,x[N],m;
int f[N][N][2][2];
inline ll solve()
{
f[m+1][0][0][1]=1;
ll p1=1ll*(p+1)*p/2%mod,p2=1ll*p*(p-1)/2%mod;
ll g1,g2,g3,g4;
fd(i,m,1)
{
ll x1=1ll*(x[i]+1)*x[i]/2%mod,x2=1ll*x[i]*(x[i]-1)/2%mod;
ll t1=1ll*x[i]*(p*2-x[i]-1)/2%mod,t2=1ll*x[i]*(p*2-x[i]+1)/2%mod;
fo(j,0,m)
{
g1=f[i+1][j][0][0]; g2=f[i+1][j][0][1]; g3=f[i+1][j][1][0]; g4=f[i+1][j][1][1];
f[i][j][0][0]=(p1*g1+x1*g2+p2*g3+t1*g4)%mod;
f[i][j][0][1]=(g2*(x[i]+1)+g4*(p-x[i]-1))%mod;
f[i][j+1][1][0]=(p2*g1+x2*g2+p1*g3+t2*g4)%mod;
f[i][j+1][1][1]=(g2*x[i]+g4*(p-x[i]))%mod;
}
}
ll ans=0;
fo(i,alpha,m) (ans+=f[1][i][0][0]+f[1][i][0][1])%=mod;
return ans;
}

int main()
{
static char s[N];
scanf("%d%d\n%s",&p,&alpha,s+1);
if(alpha>=3400) return printf("0"),0;
static int n=strlen(s+1),a[N];
fo(i,1,n) a[i]=s[n-i+1]-48;
for(;n;)
{
ll tmp=0;
fd(i,n,1)
{
tmp=tmp*10+a[i]; a[i]=tmp/p; tmp%=p;
if(i==n&&!a[i]) n--;
}
x[++m]=tmp;
}
printf("%lld",solve());
return 0;
}