题意
链接
给三年整数 $x,y,z$,然后有 $n$ 个数组,第 $i$ 个数组有 $x$ 个 $a_i$,$y$ 个 $b_i$,$z$ 个 $c_i$。
一种方案为:从 $n$ 个数组中各选择 $1$ 个数。
对于每个 $t\in[0,2^k)$,求出有多少种方案,使得该方案中的数的异或和为 $t$,对 $998244353$ 取模。
$1\leq n \leq 10^5,1\leq k \leq 17$
题解
一个显然的做法是每个数组对应一个生成函数,都做一次 $\mbox{FWT}$。时间复杂度 $O(nk2^k)$,显然不行。
但这个生成函数只有三个位置有数,比较特殊。
考虑将 ${ a_i,b_i,c_i}$ 变为 ${0,a_i\bigoplus b_i,a_i\bigoplus c_i}$,最后异或上 $\bigoplus_{i=1}^n a_i$ 就是答案。
那么 $\mbox{FWT}$ 以后就只剩下四种答案:$x+y+z$,$x+y-z$,$x-y+z$,$x-y-z$。
最后乘起来的结果就是 $(x+y+z)^{a_1}(x+y-z)^{a_2}(x-y+z)^{a_3}(x-y-z)^{a_4}$
求出了 $a_1,a_2,a_3,a_4$ 就可以了。
首先有 $a_1+a_2+a_3+a_4=n$
我们发现当 $x=0,y=1,z=0$ 的时候答案只跟 $y$ 有关,由于$FWT(A+B)=FWT(A)+FWT(B)$,那么对于每个 $i$ ,将 $f[a_i\bigoplus b_i]$ 加 $1$ 后,$FWT$ 一下,求出的 $f_i$ 就有: $a_1+a_2-a_3-a_4=f_i$。
同理对 $z$ 进行同样的处理,有 $a_1-a_2+a_3-a_4=f_i$。
还差一个方程就可以解出来了。
可以发现,我们上面处理的是 $a$^$b$ 和 $a$^$c$ 的,还有一个 $b$^$c$ 没有处理。
考虑变为 ${a_i\bigoplus b_i,0,c_i\bigoplus b_i}$,这时候会有四种答案:$x+y+z,x+y-z,-x+y+z,-x+y-z$
将 $f[b_i\bigoplus c_i]$ 加 $1$ 后 $FWT$,就相当于 $a_1-a_2+a_4-a_3=f_i$
于是便可以解出 $a_1,a_2,a_3,a_4$。
然后快速幂,再 $UFWT$ 一下就好了。
时间复杂度 $O((n+2^k)k)$
程序
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| #include <cstring> #include <iostream> #include <cstring> #include <algorithm> #include <cmath> using namespace std; #define ll long long const int N=1<<20; const ll mod=998244353; const ll inv2=(mod+1)/2; const ll inv4=inv2*inv2%mod; inline int read() { int x=0; char ch=getchar(); bool f=0; for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } inline ll Pow(ll x,int y) { ll ans=1; for(;y;y>>=1,x=x*x%mod) if(y&1) ans=ans*x%mod; return ans; } inline void fwt(int n,ll *a,int t) { ll x,y; for(int i=1;i<n;i<<=1) for(int j=0;j<n;j+=(i<<1)) for(int k=0;k<i;k++) { x=a[j+k],y=a[i+j+k]; a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod; if(t!=1) (a[j+k]*=inv2)%=mod,(a[i+j+k]*=inv2)%=mod; } } int n,k,m,a,b,c,s; ll x,y,z,d1,d2,d3,d4,f1[N],f2[N],f3[N],g[N]; int main() { n=read(); k=read(); m=1<<k; x=read(),y=read(),z=read(); for(int i=1;i<=n;i++) { a=read(),b=read(),c=read(); f1[a^b]++,f2[a^c]++,f3[b^c]++; s^=a; } fwt(m,f1,1); fwt(m,f2,1); fwt(m,f3,1); d1=(x+y+z)%mod; d2=(mod+x+y-z)%mod; d3=(x-y+z+mod)%mod; d4=(x-y-z+mod+mod)%mod; for(int i=0;i<m;i++) g[i]= Pow(d1,((ll)n+f1[i]+f2[i]+f3[i])%mod*inv4%mod)* Pow(d2,((ll)n+f1[i]-f2[i]-f3[i]+mod*2)%mod*inv4%mod)%mod* Pow(d3,((ll)n-f1[i]+f2[i]-f3[i]+mod*2)%mod*inv4%mod)%mod* Pow(d4,((ll)n-f1[i]-f2[i]+f3[i]+mod*2)%mod*inv4%mod)%mod; fwt(m,g,-1); for(int i=0;i<m;i++) printf("%lld ",g[i^s]); return 0; }
|