泳池[NOI2017]

题目链接

loj

题解

这种毒瘤DP谁想得到。。。

显然恰好为 k 比较难算,

由于 k 比较小,考虑DP,设 gi 表示前 i 列中,第 i 列的第一行格子必须不安全时的答案。

那么显然 gi 只跟前 k 个值有关系。那么有 gi=j=1kgijhj

最终的答案显然为 gn+11q

如果算出了 hj 以及前 k 项,那么就可以用常系数线性齐次递推做了,暴力多项式取模即可做到 O(k2logn)

这个 hj 表示的 j1 列,1001 行时,最大面积不超过 k 的方案数。

考虑一个我如何都想不到的DP:设 fi,j 表示只有 j 列,不考虑前 i 行(即前 i 行全都是安全的),第 i+1 行开始出现不安全的,且最大面积不超过 k 的方案数。设 gi,j=kifk,j

那么当 ij>k 时,fi,j=0

否则,枚举第 i+1 行,从左到右第 1 个出现的不安全的点。

那么有:

fi,j=k=1jqi(1q)gi+1,k1gi,jk

这个DP看起来是 O(k3) 的,但由于 ijk,实际上是 O(k2logk) 的。

那么 hi 就很容易用 f1,x 来表示了。

时间复杂度 O(k2(logk+logn))

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <bitset>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y)
{
y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;
return ans;
}
const int N=2005;
int n,m;
ll p,q,f[N][N],qw[N],g[N],h[N];
ll a[N],b[N],c[N];
inline void mul(ll *a,ll *b,int k)
{
fo(i,0,k) fo(j,0,k) c[i+j]=Add(c[i+j],Mul(a[i],b[j]));
fd(i,k<<1,k) fo(j,0,k) c[i+j-k]=Dec(c[i+j-k],Mul(c[i],h[j]));
fo(i,0,k-1) a[i]=c[i],c[i]=0;
}
inline void Ppow(int n,int m)
{
memset(b,0,sizeof(b));
memset(a,0,sizeof(a));
a[1]=1; b[0]=1;
for(;n;n>>=1,mul(a,a,m)) if(n&1) mul(b,a,m);
}
inline ll work(int m)
{
if(m==0) return Pow(p,n);
memset(g,0,sizeof(g));
memset(h,0,sizeof(h));
memset(f,0,sizeof(f));
f[m+1][0]=1;
int t;
fd(i,m,1)
{
f[i][0]=1; t=min(n,m/i);
fo(j,1,t)
{
fo(k,0,j-1) f[i][j]=Add(f[i][j],Mul(f[i+1][k],f[i][j-1-k]));
f[i][j]=Add(f[i+1][j],Mul(f[i][j],qw[i]*p%mod));
}
}
m++;
g[0]=1;
fo(i,1,m-1) fo(j,0,i-1) g[i]=Add(g[i],Mul(g[j],Mul(p,f[1][i-j-1])));
h[m]=1; fo(i,1,m) h[m-i]=Dec(0,Mul(p,f[1][i-1]));
Ppow(n+1,m);
ll ans=0;
fo(i,0,m-1) ans=Add(ans,b[i]*g[i]%mod);
return Mul(ans,Pow(p,mod-2));
}
int main()
{
n=read(); m=read(); q=read(); q=Mul(q,Pow(read(),mod-2));
qw[0]=1; p=Dec(1,q);
fo(i,1,1000) qw[i]=Mul(qw[i-1],q);
printf("%lld",Dec(work(m),work(m-1)));
return 0;
}