yww与树上的回文串[loj6681]

题目链接

loj

题解

好神仙的一道题啊qwq…

首先统计路径嘛,那么点分治。

考虑经过重心的合法路径的方案数。先把所有子树合在一起算,最后容斥减掉各子树的贡献。

那么剩下就是如何计算的问题。考虑经过重心的字符串只有两种情况:

  • 被分成一样的两部分。建个Trie统计一下即可。
  • 被分成不同的两部分,即:

1

其中 $S$ 是一个字符串,$T$ 为非空回文串。

那么你建一个AC自动机,上面的形式相当于统计AC自动机上的某个节点的fail树的祖先中,减去该祖先所代表的字符串的长度后的前缀是回文串的个数。

我们知道,一个字符串的前缀回文串可以看成 $O(\log n)$ 个等差数列。

那么对于每个字符串记录下这些等差数列以后,相当于统计该节点跳fail链中的一些等差数列的和。

根据套路:

  • 当公差 $\geq \sqrt{n}$ 的时候暴力往上跳。

  • 当公差 $<\sqrt{n}$ 的时候,开个数组,记录下当前节点中,祖先字符串的长度模 $i$ 后为 $k$ 的有多少个,然后进行离线统计。

时间复杂度 $T(n)=2T(\frac{n}{2})+O(n\log n+n\sqrt{n})=O(n\sqrt{n})$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#include <bits/stdc++.h>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
const int N=50005;
const int inf=1e9;
const int base=43;
const ll mod=1e9+9;
const int B=100;
ll pw[N];
int n;
namespace AC{
struct node{
int l,r,d;
};
vector<node> s[N],q[N];
vector<int> adj[N];
int ne[N][2],fail[N],siz[N],len[N],cnt;
ll ans,h1[N],h2[N];
inline void init()
{
fo(i,0,cnt) len[i]=h1[i]=h2[i]=ne[i][0]=ne[i][1]=fail[i]=siz[i]=0,s[i].clear(),q[i].clear(),adj[i].clear();
cnt=0; ans=0;
}
inline int insert(int u,int c)
{
int &p=ne[u][c];
if(!p)
{
p=++cnt; len[p]=len[u]+1;
h1[p]=(h1[u]*base+c)%mod;
h2[p]=(h2[u]+pw[len[u]]*c)%mod;
s[p]=s[u];
if(h1[p]==h2[p])//is a plalindrome string
{
if(s[u].empty()) s[p].pb((node){len[p],len[p],inf});
else
{
auto las=s[u].back();
if(las.d==inf) s[p].back()=(node){las.l,len[p],len[p]-las.l};
else if(las.d==len[p]-las.r) s[p].back().r=len[p];
else s[p].pb((node){len[p],len[p],inf});
}
}
}
siz[p]++; return p;
}
int f[N],g[N],m;
inline void getfail()
{
queue<int> q;
q.push(0);
for(int u,v;!q.empty();)
{
u=q.front(); q.pop();
fo(i,0,1)
if(v=ne[u][i])
{
if(!u) fail[v]=0;
else fail[v]=ne[fail[u]][i];
q.push(v);
}
else ne[u][i]=ne[fail[u]][i];
}
fo(i,1,cnt) adj[fail[i]].pb(i),f[i]=-1;
}
void dfs1(int u)
{
f[len[u]]=u;
g[++m]=len[u];
int l,r,L,R;
for(auto v:s[u])
{
l=len[u]-v.l; r=len[u]-v.r;
swap(l,r);
if(v.d>B)
{
for(int i=l;i<=r;i+=v.d) if(f[i]!=-1) ans+=1ll*siz[f[i]]*siz[u];
}
else
{
L=lower_bound(g+1,g+m+1,l)-g-1;//<
R=upper_bound(g+1,g+m+1,r)-g-1;//<=
if(L==R) continue;
q[f[g[R]]].pb((node){v.d,r%v.d, siz[u]});
if(L) q[f[g[L]]].pb((node){v.d,r%v.d,-siz[u]});
}
}
for(int v:adj[u]) dfs1(v);
f[len[u]]=-1;
g[m--]=0;
}
ll sum[B+3][N/B+5];
void dfs2(int u)
{
fo(i,1,B) sum[i][len[u]%i]+=siz[u];
for(auto p:q[u]) ans+=1ll*p.d*sum[p.l][p.r];
for(int v:adj[u]) dfs2(v);
fo(i,1,B) sum[i][len[u]%i]-=siz[u];
}
inline void tle(int u)
{
f[len[u]]=u;
for(int i=u;i;i=fail[i])
{
int l=len[u]=len[i];
if(f[l]!=-1&&h1[f[l]]==h2[f[l]]) ans+=1ll*siz[i]*siz[u];
}
for(auto v:adj[u]) tle(v);
f[len[u]]=-1;
}
inline ll work()
{
ans=0;
getfail();
//tle(0); return ans;
fo(i,1,cnt) ans+=1ll*siz[i]*(siz[i]-1)/2;
dfs1(0); dfs2(0);
return ans;
}
}
namespace Tree{
ll ans;
int ver[N<<1],val[N<<1],ne[N<<1],head[N],tot=1;
inline void add(int z,int y,int x)
{
ver[++tot]=y; val[tot]=z; ne[tot]=head[x]; head[x]=tot;
ver[++tot]=x; val[tot]=z; ne[tot]=head[y]; head[y]=tot;
}
int rt,siz[N],mx[N]; bool vis[N];
void getroot(int u,int pre,int S)
{
siz[u]=1; mx[u]=0;
for(int i=head[u],v;i;i=ne[i])
if((v=ver[i])!=pre&&!vis[ver[i]])
{
getroot(v,u,S);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],S-siz[u]);
if(mx[rt]>mx[u]) rt=u;
}
void build(int u,int pre,int rt)
{
for(int i=head[u],v;i;i=ne[i])
if(!vis[ver[i]]&&(v=ver[i])!=pre)
build(v,u,AC::insert(rt,val[i]));
}
inline void calc(int u,int opt,int val=-1)
{
AC::init();
int v=(opt==1)?0:(AC::insert(0,val));
AC::siz[v]=1;
build(u,0,v);
ans+=AC::work()*opt;
}
void divide(int u,int Siz)
{
if(Siz==1) return;
//DEBUG(u);
calc(u,1);
vis[u]=1;
for(int i=head[u],v;i;i=ne[i])
if(!vis[v=ver[i]])
calc(v,-1,val[i]);
for(int i=head[u],v,S;i;i=ne[i])
if(!vis[v=ver[i]])
{
S=siz[v]>siz[u]?Siz-siz[u]:siz[v];
rt=0; getroot(v,u,S);
divide(rt,S);
}
}
inline ll work()
{
ans=0;mx[0]=inf;
getroot(1,0,n);
divide(rt,n);
return ans;
}
}
inline void init(int n)
{
pw[0]=1;
fo(i,1,n) pw[i]=pw[i-1]*base%mod;
}
int main()
{
n=read(); init(n);
fo(i,2,n) Tree::add(read(),read(),read());
printf("%lld\n",Tree::work());
return 0;
}