泳池[NOI2017]

题目链接

loj

题解

这种毒瘤DP谁想得到。。。

显然恰好为 $k$ 比较难算,

由于 $k$ 比较小,考虑DP,设 $g_i$ 表示前 $i$ 列中,第 $i$ 列的第一行格子必须不安全时的答案。

那么显然 $g_i$ 只跟前 $k$ 个值有关系。那么有 $g_i=\sum_{j=1}^kg_{i-j}h_j$。

最终的答案显然为 $\frac{g_{n+1}}{1-q}$。

如果算出了 $h_j$ 以及前 $k$ 项,那么就可以用常系数线性齐次递推做了,暴力多项式取模即可做到 $O(k^2\log n)$。

这个 $h_j$ 表示的 $j-1$ 列,$1001$ 行时,最大面积不超过 $k$ 的方案数。

考虑一个我如何都想不到的DP:设 $f_{i,j}$ 表示只有 $j$ 列,不考虑前 $i$ 行(即前 $i$ 行全都是安全的),第 $i+1$ 行开始出现不安全的,且最大面积不超过 $k$ 的方案数。设 $g_{i,j}=\sum_{k\geq i} f_{k,j}$。

那么当 $ij>k$ 时,$f_{i,j}=0$。

否则,枚举第 $i+1$ 行,从左到右第 $1$ 个出现的不安全的点。

那么有:

$$f_{i,j}=\sum_{k=1}^jq^i(1-q)g_{i+1,k-1}g_{i,j-k}$$

这个DP看起来是 $O(k^3)$ 的,但由于 $ij\leq k$,实际上是 $O(k^2\log k)$ 的。

那么 $h_i$ 就很容易用 $f_{1,x}$ 来表示了。

时间复杂度 $O(k^2(\log k+\log n))$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <bitset>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y)
{
y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;
return ans;
}
const int N=2005;
int n,m;
ll p,q,f[N][N],qw[N],g[N],h[N];
ll a[N],b[N],c[N];
inline void mul(ll *a,ll *b,int k)
{
fo(i,0,k) fo(j,0,k) c[i+j]=Add(c[i+j],Mul(a[i],b[j]));
fd(i,k<<1,k) fo(j,0,k) c[i+j-k]=Dec(c[i+j-k],Mul(c[i],h[j]));
fo(i,0,k-1) a[i]=c[i],c[i]=0;
}
inline void Ppow(int n,int m)
{
memset(b,0,sizeof(b));
memset(a,0,sizeof(a));
a[1]=1; b[0]=1;
for(;n;n>>=1,mul(a,a,m)) if(n&1) mul(b,a,m);
}
inline ll work(int m)
{
if(m==0) return Pow(p,n);
memset(g,0,sizeof(g));
memset(h,0,sizeof(h));
memset(f,0,sizeof(f));
f[m+1][0]=1;
int t;
fd(i,m,1)
{
f[i][0]=1; t=min(n,m/i);
fo(j,1,t)
{
fo(k,0,j-1) f[i][j]=Add(f[i][j],Mul(f[i+1][k],f[i][j-1-k]));
f[i][j]=Add(f[i+1][j],Mul(f[i][j],qw[i]*p%mod));
}
}
m++;
g[0]=1;
fo(i,1,m-1) fo(j,0,i-1) g[i]=Add(g[i],Mul(g[j],Mul(p,f[1][i-j-1])));
h[m]=1; fo(i,1,m) h[m-i]=Dec(0,Mul(p,f[1][i-1]));
Ppow(n+1,m);
ll ans=0;
fo(i,0,m-1) ans=Add(ans,b[i]*g[i]%mod);
return Mul(ans,Pow(p,mod-2));
}
int main()
{
n=read(); m=read(); q=read(); q=Mul(q,Pow(read(),mod-2));
qw[0]=1; p=Dec(1,q);
fo(i,1,1000) qw[i]=Mul(qw[i-1],q);
printf("%lld",Dec(work(m),work(m-1)));
return 0;
}