密涅瓦的谜题[牛客挑战赛31E]

题目链接

链接

题意

一个长度为 $n$ 的串,$q$ 次询问,给定一个 $m$:每次在 $n$ 的所有子串中(包括空串)选出一个,然后按顺序拼接起来形成一个大字符串。一共进行 $m$ 次,求形成了多少个不同的大字符串。对 $10^9+7$ 取模。

$n,q\leq 10^5,m\leq 10^{10}$。

题解

为了不重复计数,对于一个字符串,我们要让它匹配到尽量远。也就是说,若有 $T=s_1s_2\cdots s_m$,那么对于某个 $s_is_{i+1}$ 这个片段来说,我们要让 $s_i$ 伸展的尽量远。

那么接下来就可以DP了。设字符集为 $\sigma$。

用 $f_{i,j}$ 表示考虑到第 $i$ 个串时,最后的字母为 $j$ 的方案数是多少。为了方便,当 $j=\sigma$ 时表示最后的字母为空的答案。

得出DP方程:$f_{i,j}=\sum f_{i-1,k}\times A_{k,j}$,其中 $A_{k,j}$ 为从字母 $k$ 转移到字母 $j$ 有多少种情况。

初始状态为 $f_{0,\sigma}=1$,答案为 $\sum f_{n,i}$。

这个 $A_{k,j}$ 可以在SAM上用 $O(n\sigma)$ 的时间用一个简单的DP算出来。

那么就得到了一个 $O(nm\sigma)$ 的做法。

这个转移显然是个矩阵乘法的形式。可以用快速幂做到 $O(n\sigma+m\sigma^3)$。仍然过不了。

复杂度瓶颈主要出现在询问上,考虑优化这部分。

显然答案可以表示为:

$Ans=\begin{bmatrix}0 & \cdots & 0 & 1\end{bmatrix}\times A^n\times \begin{bmatrix}1\ 1\ \vdots\ 1\end{bmatrix}$

分块,预处理出当 $n=0,1,2\cdots\sqrt m$ 时前半部分的答案,以及当 $n=0,\sqrt m,2\sqrt m\cdots$ 时后半部分的答案。

那么对于一个询问 $k$,设 $k=x\sqrt m+y$,就可以用这两部分在 $O(\sigma)$ 的时间内合并了。

时间复杂度 $O(n\sigma+\sqrt m\sigma^2+q\sigma)$

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cout<<#x<<"="<<x<<endl;
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline ll read()
{
ll x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())
const ll mod=1e9+7ll;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;}

const int N=200010;
const int M=100000;
const int S=26;
struct matrix{
ll a[S+1][S+1];
matrix() {memset(a,0,sizeof(a));}
ll *operator [](int x){return a[x];}
}A,B;
matrix operator *(matrix A,matrix B)
{
matrix C;
fo(i,0,S) fo(j,0,S) fo(k,0,S) C[i][j]=Add(C[i][j],1ll*A[i][k]*B[k][j]%mod);
return C;
}
matrix Pow(matrix A,int y)
{
matrix C;
fo(i,0,S) C[i][i]=1;
for(;y;y>>=1,A=A*A) if(y&1) C=C*A;
return C;
}
namespace SAM{
int las=1,siz=1,len[N],ne[N][S],fa[N];
inline void init()
{
for(int i=1;i<=siz;i++) memset(ne[i],0,sizeof(ne[i])),fa[i]=len[i]=0;
las=siz=1;
}
inline void extend(int c)
{
int cur=++siz;
len[cur]=len[las]+1;
int p=las;
for(;p&&!ne[p][c];p=fa[p]) ne[p][c]=cur;
if(!p) fa[cur]=1;
else
{
int q=ne[p][c];
if(len[q]==len[p]+1) fa[cur]=q;
else
{
int clone=++siz;
len[clone]=len[p]+1;
memcpy(ne[clone],ne[q],sizeof(ne[q]));
fa[clone]=fa[q];
for(;p&&ne[p][c]==q;p=fa[p]) ne[p][c]=clone;
fa[cur]=fa[q]=clone;
}
}
las=cur;
}
int base[N],a[N],f[N][S+1];
void dp(int n)
{
fo(i,1,siz) ++base[len[i]];
fo(i,1,siz) base[i]+=base[i-1];
fo(i,1,siz) a[base[len[i]]--]=i;
int u;
fd(i,siz,1)
{
u=a[i]; f[u][S]=1;
fo(j,0,S-1)
if(ne[u][j])
{
fo(k,0,S) f[u][k]=Add(f[u][k],f[ne[u][j]][k]);
}
else f[u][j]=Add(f[u][j],1);
}
fo(i,0,S-1) fo(j,0,S) A[j][i]=ne[1][i]?f[ne[1][i]][j]:0;
A[S][S]=1; B=Pow(A,M);
}
}
using namespace SAM;
ll p[M+1][S+1],s[M+1][S+1];
int n; char t[N];
int main()
{
scanf("%s",t+1);
n=strlen(t+1);
SAM::init();
fo(i,1,n) SAM::extend(t[i]-'a');
dp(n);

p[0][S]=1;
fo(i,1,M-1) fo(j,0,S) fo(k,0,S) p[i][k]=Add(p[i][k],p[i-1][j]*A[j][k]%mod);
fo(i,0,S) s[0][i]=1;
fo(i,1,M) fo(j,0,S) fo(k,0,S) s[i][j]=Add(s[i][j],s[i-1][k]*B[j][k]%mod);
int x,y; ll ans,k;
CASET
{
k=read(); x=k/M; y=k%M; ans=0;
fo(i,0,S) ans=Add(ans,Mul(p[y][i],s[x][i]));
printf("%lld\n",ans);
}
return 0;
}