挖宝[FJWC2020Day3]

题意

有一个挖宝游戏,它在一棵 n 个点的树上进行,宝藏埋在某个未知的点 $𝑥$ 。每次挖掘一个点 $u$,玩家得到的反馈信息是一个数值 $d$,表示 $u$ 号点到 $𝑥$ 号点简单路径上的边数。这个游戏会进行 $q$ 次,每次游戏藏宝的位置不一定相同。

你作为一名优秀的 OIer,对自己无比自信。你希望用最少的挖掘次数来找出宝藏。于是你挑了两个不同的点 $a,b$ 进行挖掘,并得到了反馈信息,分别为 $d_a,d_b$。接下来的第三次挖掘中,你想要直接奔着一个可能的 $𝑥$ 进行挖掘。由于树太大了,凭借人眼无法找出 $𝑥$ 的确切位置,你便转向了电脑,开始写一个程序,帮助你解决这个问题。

$n,q\leq 10^6,1\leq d \leq n$

题解

一道这么简单的lca写完后都调了1个多小时。。。

主要还是有一些细节没有考虑到。

首先通过树上倍增,我们可以在 $O(\log n)$ 的时间内,求出两点的 $lca$,某个点到根节点路径中距离为 $d$ 的点,某个点到另外一点的路径中距离为 $d$ 的点。

设 $k=\frac{d_a+d_b-dist(a,b)}{2}$,这个 $k$ 表示的是,$a$ 到 $x$ 和 $b$ 到 $x$ 这两条路径在最后面的重合的长度。

显然如果 $k$ 不为整数或者 $<0$ 就一定无解。

那么就可以找到从 $a$ 出发到 $b$ 的路径中第 $d_a-k$ 个点,设为 $z$。(显然如果 $d_a-k<0$ 或者 $d_a-k>dist(a,b)$ 也无解。)

转换成求任意一点 $x$,满足从 $z$ 出发,不经过 $a,b$ 两点的方向到达 $x$,且距离为 $k$。

那么就可以先判断满足条件的最远的点到 $z$ 的距离是否大于等于 $k$,如果存在,那么就从 $z$ 往这个点跳 $k$ 步即可。

假定 $1$ 为根,然后dfs找出子树中路径前 $3$ 长的路径,以及往根方向走的最长的路径。这样就可以枚举这 $4$ 个方向,然后判断这四个方向到 $z$ 的路径是否符合条件,分类讨论一下即可。

时间复杂度 $O(n\log n)$,常数有点大。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <cstring>
#include <queue>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())
const int N=1e6+5;
int n,dep[N],f[N][21];
vector<int> adj[N];
inline void add(int x,int y)
{
adj[x].pb(y); adj[y].pb(x);
}
struct node{
int u,d;
friend inline bool operator>=(const node &A,const node &B)
{
return A.d>=B.d;
}
friend inline bool operator<(const node &A,const node &B)
{
if(!B.u) return 0;
if(!A.u) return 1;
return A.d<B.d;
}
friend inline node operator+(const node &A,const int &B)
{
return (node){A.u,A.d+B};
}
};
struct P{
node a[3];
inline void ins(node b)
{
b.d++;
fo(i,0,2)
if(b>=a[i])
{
fd(j,1,i) a[j+1]=a[j];
a[i]=b;
return;
}
}
}mx[N];
node h[N];
int tim,l[N],r[N];
void dfs(int u,int pre)
{
l[u]=++tim;
dep[u]=dep[pre]+1;
f[u][0]=pre;
fo(i,1,20) f[u][i]=f[f[u][i-1]][i-1];
for(auto v:adj[u]) if(v!=pre)
{
dfs(v,u);
mx[u].ins(mx[v].a[0]);
}
mx[u].ins((node){u,-1});
r[u]=tim;
}
inline int lca(int x,int y)
{
if(dep[x]>dep[y]) swap(x,y);
fd(i,20,0) if(dep[f[y][i]]>=dep[x]) y=f[y][i];
if(x==y) return x;
fd(i,20,0) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
inline int dist(int x,int y)
{
return dep[x]+dep[y]-(dep[lca(x,y)]<<1);
}
inline int jump(int x,int d)
{
fo(i,0,20) if((1<<i)&d) x=f[x][i];
return x;
}

void dfs2(int u,int pre,node g)
{
h[u]=g;
for(auto v:adj[u]) if(v!=pre)
{
if(mx[u].a[0].d==dist(v,mx[u].a[0].u)+1) dfs2(v,u,max(g+1,mx[u].a[1]+1));
else dfs2(v,u,max(g+1,mx[u].a[0]+1));
}
}
inline int jump(int x,int y,int d)
{
int z=lca(x,y);
if(dep[x]-dep[z]>=d) return jump(x,d);
else return jump(y,(dep[x]+dep[y]-(dep[z]<<1))-d);
}
inline int calc(int x,int y,int z)
{
return dist(x,y)+dist(y,z);
}
inline bool check(int x,int a)
{
return l[x]<=l[a]&&l[a]<=r[x];
}
inline int work(int x,int k,int a,int b)
{
node u;
fo(i,0,2)
if(mx[x].a[i].d>=k)
{
u=mx[x].a[i];
//DEBUG(u.u);
if((!check(x,a)||lca(u.u,a)==x)&&(!check(x,b)||lca(u.u,b)==x)) return jump(x,u.u,k);
}
//cerr<<"FUCK"<<h[x].d;
if(h[x].d>=k)
{
u=h[x];
//DEBUG(u.u);
if(check(x,a)&&check(x,b))
{
//DEBUG(k);
return jump(x,u.u,k);
}
}
return -1;
}

int main()
{
FO(hunting);
n=read(); int q=read();
fo(i,2,n) add(read(),read());
dfs(1,0); dfs2(1,0,(node){0,0});
for(int x,y,dx,dy,dis,k,z;q--;)
{
x=read(); dx=read(); y=read(); dy=read();
z=lca(x,y); dis=dep[x]+dep[y]-(dep[z]<<1);
k=dx+dy-dis;
if(k<0||(k&1)) {puts("-1"); continue;}
k/=2;
if(dx-k<0||dx-k>dis) {puts("-1"); continue;}
if(dep[x]-dep[z]>=dx-k) z=jump(x,dx-k);
else z=jump(y,dy-k);
//DEBUG(z); DEBUG(k);
if(k==0) printf("%d\n",z);
else printf("%d\n",work(z,k,x,y));
}
return 0;
}