NIM游戏[牛客挑战赛36G]

题目链接

链接

题面描述

$n$ 个栈,第 $i$ 个栈有 $a_i$ 个石子,一共有 $b_i$ 中排列方式。选择一个栈后可任意决定其排列方式。

玩$\mbox{NIM}$游戏,每个人每次在某个栈中取走若干石子,不能取的人输。

求有多少种选栈的方案使得先手必败。

两个方案不同当且仅当存在一个栈仅出现在其中一个方案中或在两个方案中栈中物品的排列方式不同。

$n\leq 10^7,a_i\leq 10^5$

时限3s。

题解

设 $m=\max{a_i}$

要先手必败,那么异或和为0。

显而易见的,答案就是所有的 $(b_ix^{a_i}+1)$ 的异或卷积的常数项。

暴力$\mbox{FWT}$,复杂度 $O(m^2\log m)$。

但是我们发现每个函数都只有两项有数。

分析下$\mbox{FWT}$的性质发现自己想不到…

如果是算卷积,可以分治$\mbox{FFT}$。

那不妨试试分治$\mbox{FWT}$ ?

我们同时处理次数连续的区间:

分治的过程中,二进制表示下,前面若干位是相同的。

所以,当 $x$ 的次数的范围是 $[l,l+2^k)$ 的时候,做完$\mbox{FWT}$以后,$x$ 的次数在 $[0,2^k)\bigcup[l,l+2^k)$ 这个范围内有数。

那么就设 $fa$ 为次数是 $[0,2^k)$ 的答案,$fb$ 为 $[l,l+2^k)$ 的答案。

分治时,若一个分支是 $[l,l+2^k)$,则求的是 $[l+2^{k-1})$ 和 $[l+2^{k-1},l+2^k)$ 的答案。

分四类讨论一下就可以得到新的 $fa,fb$ 的值。

每次分治,$\mbox{FWT}$的次数减少一半。

时间复杂度 $T(m)=2T(m/2)+O(m\log m)$,即 $O(m\log ^2 m)$

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <cstdio>
using namespace std;
#define ll long long
#define fo(i,j,k) for(int i=(j);i<(k);i++)
const int N=1<<17;
const ll mod=998244353;
const ll inv2=(mod+1)/2;
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
ll x,y;
inline void fwt(int n,ll *a,int t)
{
for(int i=1;i<n;i<<=1)
for(int j=0,p=(i<<1);j<n;j+=p)
for(int k=0,l=j+k,r=l+i;k<i;k++,l++,r++)
{
x=a[l],y=a[r]; a[l]=x+y,a[r]=x-y;
a[l]>=mod?(a[l]-=mod):0;
a[r]<0?(a[r]+=mod):0;
if(t!=1) (a[l]*=inv2)%=mod,(a[r]*=inv2)%=mod;
}
}

ll fa[N],fb[N],la[N],lb[N],ra[N],rb[N],g1[N],g2[N],g3[N],g4[N];
ll a[N],b[N];
int n;
void solve(int l,int m)
{
if(!m) {fa[l]=a[l]; fb[l]=b[l]; return;}
int r=l+m;
solve(l,m>>1); solve(r,m>>1);
for(int i=0,j=i+l,k=i+r;i<m;i++,j++,k++) la[i]=fa[j],ra[i]=fb[j],lb[i]=fa[k],rb[i]=fb[k];
fwt(m,la,1); fwt(m,ra,1); fwt(m,lb,1); fwt(m,rb,1);
fo(i,0,m) g1[i]=la[i]*lb[i]%mod,g2[i]=la[i]*rb[i]%mod,g3[i]=ra[i]*lb[i]%mod,g4[i]=ra[i]*rb[i]%mod;
fwt(m,g1,-1); fwt(m,g2,-1); fwt(m,g3,-1); fwt(m,g4,-1);
for(int i=0,j=i+l,k=i+r;i<m;i++,j++,k++) fa[j]=g1[i],fb[j]=g3[i],fb[k]=g2[i],fa[k]=g4[i];
}

int main()
{
n=read();
fo(i,0,N) a[i]=1;
int q,m,x,y;
fo(i,0,n)
{
q=read(); m=read();
x=(a[q]+b[q]*m)%mod;
y=(a[q]*m+b[q])%mod;
a[q]=x; b[q]=y;
}
solve(0,N/2);
printf("%lld",(fa[0]+fb[0]+mod-1)%mod);
return 0;
}