常系数线性齐次递推

引入

给定一递推式:$f_n=\sum_{i=1}^{k}a_if_{n-i}$,其中 $a_1,a_2,\cdots,a_k$ 为常实数,并给定 $f_0,f_1,\cdots, f_{k-1}$。

求 $f_n$。

方法

矩阵乘法

设 :

$$V_i=\begin{bmatrix}
f_i\
f_{i+1}\
\vdots\
f_{i+k-1}
\end{bmatrix},M=\begin{bmatrix}
0 & 1 & 0 & \cdots \
0 & 0 & \ddots & \vdots \
0 & 0 & \cdots & 1 \
a_k & a_{k-1} & \cdots & a_1
\end{bmatrix}$$

,则有:$V_{i+1}=M\times V_i$。

若要计算 $f_n$,可计算 $V_n=M^n\times V_0$。

利用矩阵乘法,算出 $M^n$,时间复杂度 $O(k^3\log n)$。

利用特征方程与多项式

算法过程

我们现在是要求 $M^n$。

设 $f(x)=x^n$。

若能找到一个多项式 $g$,使得 $g(M)=0$,那么就可以将 $f(x)$ 表示成 $A(x)g(x)+B(x)$ 了,其中 $B(x)$ 的次数小于 $g(x)$ 的次数。然后求 $f(M)$ 就相当于求 $B(M)$。

$g(x)$ 可取矩阵 $M$ 的特征方程 $\Gamma (x)=\text{det}(Ix-M)$,其中 $I$ 为单位矩阵,$\text{det}(A)$ 表示 $A$ 的行列式。

因为根据Cayley-Hamilton定理, $\Gamma (M)=0$。

而 $M$ 矩阵十分特殊,根据归纳等方法可证明:$\Gamma (x)=\text{det}(Ix-M)=x^k-\sum_{i=1}^ka_ix_{k-i}$。

显然有 $x^nV_0=f(x)V_0$,当 $x=M$ 时,

$x^nV_0=f(x)V_0=B(x)V_0=\sum_{i=0}^{k-1}x^ib_iV_0$。

将 $x$ 换成 $M$:

$V_n=M^nV_0=\sum_{i=0}^{k-1}b_iM^iV_0=\sum_{i=0}^{k-1}b_iV_i$。

取出 $V$ 中的第一行,有 :

$f_n=\sum_{i=0}^{k-1}b_if_i$。

因此,只需求出所有的 $b_i$ 即可,也就是求出 $B(x)$。

由上面的推导,$B(x)=f(x)\bmod \Gamma (x)$。

由于 $f(x)=x^n$,考虑倍增,然后边乘边对 $g$ 取模。

暴力多项式乘法及取模,时间复杂度 $O(k^2\log n)$。

用FFT优化,时间复杂度 $O(k\log k\log n)$。

程序

多项式取模:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())

const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y)
{
ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;
return ans;
}

const int M=1<<17;
ll W[M];
inline void PolyInit()
{
ll w;
for(int i=1;i<M;i<<=1)
{
W[i]=1; w=Pow(3,(mod-1)/2/i);
fo(j,1,i-1) W[i+j]=W[i+j-1]*w%mod;
}
}
typedef vector<ll> Poly;
inline void print(Poly A)
{
ff(i,0,A.size()) printf("%d ",A[i]);
printf("\n");
}
inline void ntt(ll *a,int n,int t)
{
static int R[M];
fo(i,0,n-1)
{
R[i]=(R[i>>1]>>1)|((i&1)*(n>>1));
if(i<R[i]) swap(a[i],a[R[i]]);
}
ll w;
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++)
w=W[i+k]*a[i+j+k]%mod,
a[i+j+k]=Dec(a[j+k],w),
a[j+k]=Add(a[j+k],w);
if(t^1)
{
reverse(a+1,a+n);
w=Pow(n,mod-2);
fo(i,0,n-1) a[i]=w*a[i]%mod;
}
}
inline void ntt(Poly &A,int n,int t){ntt(&A[0],n,t);}
inline Poly operator +(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Add(A[i],B[i]);
return A;
}
inline Poly operator -(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Dec(A[i],B[i]);
return A;
}
inline Poly df(Poly A)
{
fo(i,1,A.size()-1) A[i-1]=A[i]*i%mod;
A.resize(A.size()-1);
return A;
}
inline Poly jf(Poly A)
{
A.pb(0);
fd(i,A.size()-1,1) A[i]=A[i-1]*Pow(i,mod-2)%mod;
A[0]=0; return A;
}
inline Poly operator *(Poly A,ll k)
{
fo(i,0,A.size()-1) A[i]=A[i]*k%mod;
return A;
}
inline Poly operator *(Poly A,Poly B)
{
int n=A.size(),m=B.size(),k=n+m-1,len=1;
for(;len<k;len<<=1);
A.resize(len); ntt(A,len,1);
B.resize(len); ntt(B,len,1);
fo(i,0,len-1) A[i]=A[i]*B[i]%mod;
ntt(A,len,-1);
A.resize(k);
return A;
}
inline Poly operator ~(Poly f)
{
int n=f.size();
Poly g,h;
g.pb(Pow(f[0],mod-2));
int m=2;
for(;m<n;m<<=1)
{
h.resize(m<<1); g.resize(m<<1);
fo(i,0,m-1) h[i]=f[i];
ntt(h,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],h[i]),g[i]);
ntt(g,m<<1,-1); g.resize(m);
}
g.resize(m<<1); f.resize(m<<1);
ntt(f,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],f[i]),g[i]);
ntt(g,m<<1,-1); g.resize(n);
return g;
}
inline Poly Ln(Poly A)
{
int n=A.size();
A=jf((~A)*df(A));
A.resize(n); return A;
}
inline Poly Exp(const Poly &A)
{
int n=1; for(;n<A.size();n<<=1);
Poly B,C,D; B.clear(); B.pb(1);
for(int m=2;m<=n;m<<=1)
{
C=B; C.resize(m); D=A; D.resize(m);
C=D-Ln(C); C[0]=Add(C[0],1);
B=B*C; B.resize(m);
}
B.resize(A.size()); return B;
}
inline Poly operator ^(Poly A,ll k)
{
if(!A.size()) return A;
ll tmp=A[0],w=Pow(tmp,k);
tmp=Pow(tmp,mod-2);
fo(i,0,A.size()-1) A[i]=A[i]*tmp%mod;
A=Exp(Ln(A)*k);
fo(i,0,A.size()-1) A[i]=A[i]*w%mod;
return A;
}
inline Poly Cos(Poly A)
{
static ll w4=Pow(3,(mod-1)/4);
return (Exp(A*w4)+Exp(A*(mod-w4)))*((mod+1)/2);
}
inline Poly Sin(Poly A)
{
static ll w4=Pow(3,(mod-1)/4);
return (Exp(A*w4)-Exp(A*(mod-w4)))*(Pow(w4,mod-2)*((mod+1)/2)%mod);
}
inline Poly Sqrt(Poly A)
{
Poly C,D,B(1,1);
C.clear(); D.clear();
int n=A.size();
for(int l=4;(l>>2)<n;l<<=1)
{
C=A; C.resize(l>>1);
D=B; D.resize(l>>1); D=(~D);
C.resize(l); D.resize(l);
ntt(C,l,1); ntt(D,l,1);
ff(i,0,l) C[i]=C[i]*D[i]%mod;
ntt(C,l,-1);
B.resize(l>>1);
ff(i,0,(l>>1)) B[i]=Add(C[i],B[i])*((mod+1)>>1)%mod;
}
B.resize(n);
return B;
}
inline Poly operator/(Poly A,Poly B)
{
int len=1,deg=A.size()-B.size()+1;
reverse(all(A)); reverse(all(B));
for(;len<=deg;len<<=1);
B.resize(len); B=~B; B.resize(deg);
A=A*B; A.resize(deg);
reverse(all(A));
return A;
}
inline Poly operator%(const Poly &A,const Poly &B)
{
if(A.size()<B.size()) return A;
Poly C=A-(A/B)*B;
C.resize(B.size()-1);
return C;
}

inline Poly Pow(Poly A,ll n,Poly M)
{
Poly B=A; n--;
for(;n;n>>=1,A=(A*A)%M) if(n&1ll) B=(B*A)%M;
return B;
}

ll a[M],f[M];
Poly A,G;
int n,k;
int main()
{
PolyInit();
n=read(); k=read();
fo(i,1,k) a[i]=read();
fo(i,1,k) G.pb((mod+mod-a[k-i+1])%mod);
G.pb(1);
A.pb(0); A.pb(1);
A=Pow(A,n,G);
ll ans=0,x;
fo(i,0,k-1)
{
f[i]=(mod+mod+read())%mod;
ans+=f[i]*A[i]%mod;
}
printf("%lld",ans%mod);
return 0;
}

另一种方法

还有一种时间复杂度也是 $O(k\log k \log n)$ 的做法,不过常数很小。详见