多点求值

题目

给定一 $n$ 次多项式 $f(x)$,$m$ 个值 $x_1,x_2,\cdots,x_m$,求出 $f(x_i)$ 的值对 $998244353$ 取模后的结果。

$n,m\le 64000$。

解法

算法一

最经典的多点求值。

设 $P_{l,r}(x)=\prod_{i=l}^r(x-x_i)$。

设 $k=\lfloor \frac{m}{2} \rfloor$,显然有 $\forall i\in [1,k],P_{1,k}(x_i)=0,\forall i\in[k+1,m],P_{m+1,n}(x_i)=0$。

对于 $i\in [1,k]$,考虑多项式 $f(x)$ 表示成 $D(x)P_{1,k}(x)+R(x)$ 的形式(其中 $deg_R<deg_{P_{1,k}}$),那么有 $\forall i\in [1,k],f(x_i)=R(x_i)$。

于是,经过一次多项式取模,将其转换为 $\frac{k}{2}$ 的子问题。

那么先分治,预处理出所有分治结构中的 $P(x)$,然后再来一次分治,每次递归时将当前多项式对 $P_{l,r}$ 取模。

到 $l=r$ 时,$[x^0]P_{l,l}(x)$ 就是答案了。

时间复杂度 $T(n)=2T(\frac{n}{2})+O(n\log n)$,由主定理得,$T(n)=O(n\log ^2n)$。

空间复杂度 $O(n\log n)$。

这样做不仅需要写多项式取模,而且常数较大。

算法二

定义差卷积 $\text{Mul}^T(A,B)i=\sum{j}A_{i+j}\times B_j$。

显然有如下性质:

$\text{Mul}^T(A,B\times C)=\text{Mul}^T(\text{Mul}^T(A,B),C)$。

可以发现, $f(x_0)=\sum_{i=0}^nf_ix_0^i=[x^0]\text{Mul}^T(f,\frac{1}{1-x_0x})$。

于是,考虑算出 $P_{l,r}=\prod_{i=l}^r(1-x_ix)$,先计算 $G=\text{Mul}^T(f,\frac{1}{P_{1,m}})$。

然后分治,递归时维护当前的 $G$,下传到 $[l,mid]$ 时,新的 $G$ 则为 $\text{Mul}^T(G,P_{mid+1,r})$。下传到 $[mid+1,r]$ 同理。

于是就做完了,时间复杂度也为 $O(n\log ^2n)$。

但是这个只需要写到多项式求逆,不需要再写多项式取模,常数少了 $\frac{1}{2}$。

发现计算差卷积时可以运用FFT计算循环卷积,这样常数又减少了约 $\frac{1}{3}$。

最后大概的时间是算法一的 $\frac{1}{3}$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())

const ll mod=998244353;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y)
{
ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;
return ans;
}

const int M=1<<19;
ll W[M];
inline void PolyInit()
{
ll w;
for(int i=1;i<M;i<<=1)
{
W[i]=1; w=Pow(3,(mod-1)/2/i);
fo(j,1,i-1) W[i+j]=W[i+j-1]*w%mod;
}
}
typedef vector<ll> Poly;
inline void print(Poly A)
{
ff(i,0,A.size()) printf("%d ",A[i]);
printf("\n");
}
inline void ntt(ll *a,int n,int t)
{
static int R[M];
fo(i,0,n-1)
{
R[i]=(R[i>>1]>>1)|((i&1)*(n>>1));
if(i<R[i]) swap(a[i],a[R[i]]);
}
ll w;
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++)
w=W[i+k]*a[i+j+k]%mod,
a[i+j+k]=Dec(a[j+k],w),
a[j+k]=Add(a[j+k],w);
if(t^1)
{
reverse(a+1,a+n);
w=Pow(n,mod-2);
fo(i,0,n-1) a[i]=w*a[i]%mod;
}
}
inline void ntt(Poly &A,int n,int t){ntt(&A[0],n,t);}
inline Poly operator +(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Add(A[i],B[i]);
return A;
}
inline Poly operator -(Poly A,Poly B)
{
A.resize(max(A.size(),B.size()));
fo(i,0,B.size()-1) A[i]=Dec(A[i],B[i]);
return A;
}
inline Poly df(Poly A)
{
fo(i,1,A.size()-1) A[i-1]=A[i]*i%mod;
A.resize(A.size()-1);
return A;
}
inline Poly jf(Poly A)
{
A.pb(0);
fd(i,A.size()-1,1) A[i]=A[i-1]*Pow(i,mod-2)%mod;
A[0]=0; return A;
}
inline Poly operator *(Poly A,ll k)
{
fo(i,0,A.size()-1) A[i]=A[i]*k%mod;
return A;
}
inline Poly operator *(Poly A,Poly B)
{
int n=A.size(),m=B.size(),k=n+m-1,len=1;
for(;len<k;len<<=1);
A.resize(len); ntt(A,len,1);
B.resize(len); ntt(B,len,1);
fo(i,0,len-1) A[i]=A[i]*B[i]%mod;
ntt(A,len,-1);
A.resize(k);
return A;
}
inline Poly operator ~(Poly f)
{
int n=f.size();
Poly g,h;
g.pb(Pow(f[0],mod-2));
int m=2;
for(;m<n;m<<=1)
{
h.resize(m<<1); g.resize(m<<1);
fo(i,0,m-1) h[i]=f[i];
ntt(h,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],h[i]),g[i]);
ntt(g,m<<1,-1); g.resize(m);
}
g.resize(m<<1); f.resize(m<<1);
ntt(f,m<<1,1); ntt(g,m<<1,1);
fo(i,0,(m<<1)-1) g[i]=Mul(2+mod-Mul(g[i],f[i]),g[i]);
ntt(g,m<<1,-1); g.resize(n);
return g;
}
const int N=64005;
#define lc (u<<1)
#define rc ((u<<1)|1)
Poly P[N<<2];
ll ans[N],a[N];
inline Poly MulT(Poly A,Poly B)
{
int n=A.size(),m=B.size();
reverse(all(B));
int len=1;
for(;len<n;len<<=1);
A.resize(len); B.resize(len);
ntt(A,len,1); ntt(B,len,1);
ff(i,0,len) A[i]=A[i]*B[i]%mod;
ntt(A,len,-1);
B.clear();
len--;
fo(i,m-1,n+m-2) B.pb(A[i&len]);
return B;
}
void solve(int u,int l,int r)
{
if(l==r)
{
P[u].pb(1); P[u].pb(Dec(0,a[l]));
return;
}
int mid=(l+r)>>1;
solve(lc,l,mid); solve(rc,mid+1,r);
P[u]=P[lc]*P[rc];
}
void solve(int u,int l,int r,Poly A)
{
A.resize(r-l+1);
if(l==r) return (void)(ans[l]=A[0]);
int mid=(l+r)>>1;
solve(lc,l,mid,MulT(A,P[rc])); solve(rc,mid+1,r,MulT(A,P[lc]));
}
int n,m,k;
Poly F,G;
int main()
{
PolyInit();
n=read(); m=read(); k=max(n,m);
fo(i,0,n) F.pb(read());
F.resize(n+k+1);
fo(i,1,m) a[i]=read();
solve(1,1,k);
F=MulT(F,(~P[1]));
solve(1,1,k,F);
fo(i,1,m) printf("%lld\n",ans[i]);
return 0;
}