tree[jzoj6511]

题意

problem

题解

显然直接算是非常不好算的。因为有 $3$ 个数,如果能减少到枚举两个数或者一个就更好了。

对于一条有向的路径 $(u,v)$,我们定义 $W_{u,v}$ 为该路径权值模 $P$ 后是否不为 $0$。

那么一共有 $2^3=8$ 种情况,我们需要算的只有两种。

无论如何,先列个表看看再说:

$W_{u,v}$ $W_{u,t}$ $W_{t,v}$ 是否计算
$0$ $0$ $0$ Yes
$0$ $0$ $1$ No
$0$ $1$ $0$ No
$0$ $1$ $1$ No
$1$ $0$ $0$ No
$1$ $0$ $1$ No
$1$ $1$ $0$ No
$1$ $1$ $1$ Yes

考虑容斥,看看不计算的能不能算,最后用 $n^3$ 减去就好了。

我们发现,不需要计算的情况都包含两组 $W$ ,满足这两种 $W$ 的值是不相等的。

那么可以枚举这两组 $W$ 是哪两组,剩下那组不管它,然后把方案数加起来后,除以 $2$ 就是答案了。

分三类讨论:

  1. $W_{u,v}\not = W_{u,t}$
  2. $W_{u,v}\not = W_{t,v}$
  3. $W_{u,t}\not = W_{t,v}$

这样来看,我们只需要对于每个点 $u$,都算出 $W_{u,i}=0$ 和 $W_{i,u}=0$ 有多少种,这样就可以算出答案了。

至此,我们把枚举三个点变成了枚举两个点的路径问题。

跟树上所有的路径相关的,考虑点分治。

对于点分治过程中的一个分治结构,我们考虑用容斥的方法减去子树自己对自己的贡献。

那么只需要考虑一条经过分治中心的一条链的权值怎么算就好了。

先来看 $W_{i,u}$ 如何计算。

tree

如图,$w_i,l_i$ 表示为该链(有方向)的权值以及长度。

那么将这两条链合起来的权值就是:$w_2\times K^{l_1}+w_1$

要使得这个权值为 $0$,则有:$w_2=\frac{-w_1}{K^{l_1}}$。

$W_{u,i}$ 同理。

那么先来一次dfs,用个 map 或者哈希记录 $\frac{-w_1}{K^{l_1}}$ 的出现次数,然后再来一次dfs统计即可。

时间复杂度 $O(n\log n)$ 或 $O(n\log ^2n)$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <bitset>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())
ll base,mod;
inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;}
inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;}
inline ll Mul(ll x,ll y){return x*y%mod;}
inline ll Pow(ll x,ll y)
{
y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;
return ans;
}
const int N=1e5+5;
vector<int> adj[N];
int siz[N],mx[N],rt;
bool vis[N];
int f[N][2][2];
int n;
ll pw[N],iw[N],w[N];
inline void add(int x,int y) {adj[x].pb(y); adj[y].pb(x);}
void getroot(int u,int pre,int S)
{
siz[u]=1; mx[u]=0;
for(auto v:adj[u]) if(v!=pre&&!vis[v])
{
getroot(v,u,S);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],S-siz[u]);
if(mx[rt]>mx[u]) rt=u;
}
map<ll,int> g[2];
int now_siz;
void dfs(int u,int pre,int len,ll w1,ll w2)
{
g[0][(mod-w1)*iw[len]%mod]++; g[1][w2]++;
for(auto v:adj[u]) if(!vis[v]&&(v!=pre))
dfs(v,u,len+1,(w1+w[v])*base%mod,(w2+w[v]*pw[len+1])%mod);
}
void dfs2(int u,int pre,int opt,int len,ll w1,ll w2)
{
int tmp;
tmp=(mod-w2)*iw[len]%mod;
if(g[1].count(tmp)) f[u][0][0]+=opt*g[1][tmp];
tmp=w1;
if(g[0].count(tmp)) f[u][1][0]+=opt*g[0][tmp];
for(auto v:adj[u]) if(!vis[v]&&(v!=pre))
dfs2(v,u,opt,len+1,(w1+w[v]*pw[len+1])%mod,(w2+w[v])*base%mod);
}
inline void calc(int u,int pre,int opt)
{
g[0].clear(); g[1].clear();
if(opt==1)
{
dfs(u,pre,1,w[u]*base%mod,w[u]*base%mod);
dfs2(u,pre,opt,0,0,0);
}
else
{
dfs(u,pre,2,(w[pre]*base%mod+w[u])%mod*base%mod,(w[u]*base%mod+w[pre])%mod*base%mod);
dfs2(u,pre,opt,1,w[u]*base%mod,w[u]*base%mod);
}
}
void divide(int u,int S)
{
now_siz=S; calc(u,0,1);
vis[u]=1;
for(auto v:adj[u]) if(!vis[v])
{
now_siz=(siz[u]<siz[v])?S-siz[u]:siz[v];
calc(v,u,-1);
}
for(auto v:adj[u]) if(!vis[v])
{
int Si=(siz[u]<siz[v])?S-siz[u]:siz[v];
rt=0; getroot(v,u,Si); divide(rt,Si);
}
}
inline ll work(int u,int x,int y)
{
return 1ll*f[u][x][0]*f[u][y][1]+1ll*f[u][x][1]*f[u][y][0];
}
int main()
{
FO(tree);
n=read(); base=read(); mod=read();
pw[0]=1; iw[0]=1;
pw[1]=base; iw[1]=Pow(base,mod-2);
fo(i,2,n) pw[i]=pw[i-1]*pw[1]%mod,iw[i]=iw[i-1]*iw[1]%mod;
fo(i,1,n) w[i]=read()%mod;
fo(i,2,n) add(read(),read());
mx[0]=1e9; rt=0; getroot(1,0,n); divide(rt,n);
ll sum=0;
fo(i,1,n) fo(j,0,1) f[i][j][1]=n-f[i][j][0];
fo(i,1,n) sum+=work(i,1,1)+work(i,0,0)+work(i,0,1);
printf("%lld",1ll*n*n*n-(sum/2));
return 0;
}