最近公共祖先
概述
如果一棵树上的两个结点向上移动,最后交汇的第一个结点,也就是两个结点的lca。
本章是依靠倍增法求出这样一个结点。
Note
倍增法类似于二分,以2倍,4倍...等倍数增加
倍增这个思想可以借鉴这个博客👉【白话系列】倍增算法
思路
求两个结点的最近公共祖先,我们首先分以下两种情况
我们思考能否将其化成一种情况,对于深度不同的情况,我们发现若将深度大的结点向着它的父节点迭代,直到和深度小的一样,这样它们俩的最近公共祖先并不会变化,因此可以将深度不同的情况转化成深度相同的情况
那如何迭代呢?
其实这里用的就是倍增法,我们预处理了一个\(f[i,j]\)数组表示\(i\)往根节点跳\(2^j\)所达到的结点,因此\(i\)可以跳入到\(f[i,j]\)的位置
这里应该会有人会提问
上面的倍增都是\(2\)的幂次,那若是两者的深度差是其他数呢?
这里只要把一个数看成二进制的话,应该就知道了吧,每个数都能被表示成\(2\)的幂次的和
然后我们思考下一个问题
都转化成深度相同的情况后又该怎么办呢
这里还是用的倍增法,只是和之前的倍增有点不同
具体哪不一样看例题体会一下
代码模板
bfs求倍增数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 | void bfs()
{
memset(dis, 0x3f, sizeof(dis)); // dis数组为深度数组
dis[0] = 0;
dis[root] = 1;
q.push(root);
while(q.size())
{
int pos = q.front();
q.pop();
if (vis[pos]) continue;
vis[pos] = true;
for(int i = head[pos]; i != -1; i = ne[i])
{
int j = e[i];
if (dis[j] > dis[pos] + 1)
{
dis[j] = dis[pos] + 1;
f[j][0] = pos;
q.push(j);
for(int k = 1; k <= 15; k ++ )
f[j][k] = f[f[j][k-1]][k-1]; // f[j][k]是结点j往上跳2^(k-1)步后的结点,倍增思想
}
}
}
}
|
求lca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 | int lca(int a, int b)
{
if (dis[a] < dis[b]) swap(a, b); // 保证a在b的下面
for(int i = 15; i >= 0; i -- ) // 第一次迭代
{
if (dis[f[a][i]] >= dis[b]) // a跳的下一个点的深度和b相等时也要跳,所以可以相等
a = f[a][i];
}
if (a == b) return a; // 跳到相同深度嗷
for(int i = 15; i >= 0; i -- ) // 第二次迭代
{
if (f[a][i] != f[b][i]) // 找公共祖先 终止条件是a和b下一个结点相同
{
a = f[a][i];
b = f[b][i];
}
}
return f[a][0]; // 没有跳到lca上嗷
}
|
例1.祖孙询问
题目链接
题目描述
给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。
有 m 个询问,每个询问给出了一对节点的编号 x 和 y,询问 x 与 y 的祖孙关系。
输入格式
输入第一行包括一个整数 表示节点个数;
接下来 n 行每行一对整数 a 和 b,表示 a 和 b 之间有一条无向边。如果 b 是 −1,那么 a 就是树的根;
第 n+2 行是一个整数 m 表示询问个数;
接下来 m 行,每行两个不同的正整数 x 和 y,表示一个询问。
输出格式
对于每一个询问,若 x 是 y 的祖先则输出 1,若 y 是 x 的祖先则输出 2,否则输出 0。
数据范围
\(1≤n,m≤4×104 ,\)
\(1≤每个节点的编号≤4×104\)
输入样例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 | 10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
|
输出样例
题解
LCA模板题
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 | #include<bits/stdc++.h>
using namespace std;
const int N = 4e4+5, M = 8e4+5;
int head[N], e[M], w[M], ne[M], idx = 0;
void add(int a, int b)
{
e[idx] = b;
ne[idx] = head[a];
head[a] = idx ++ ;
}
int root;
int dis[N];
bool vis[N];
int f[N][17];
queue<int> q;
void bfs()
{
memset(dis, 0x3f, sizeof(dis));
dis[0] = 0;
dis[root] = 1;
q.push(root);
while(q.size())
{
int pos = q.front();
q.pop();
if (vis[pos]) continue;
vis[pos] = true;
for(int i = head[pos]; i != -1; i = ne[i])
{
int j = e[i];
if (dis[j] > dis[pos] + 1)
{
dis[j] = dis[pos] + 1;
f[j][0] = pos;
q.push(j);
for(int k = 1; k <= 15; k ++ )
f[j][k] = f[f[j][k-1]][k-1]; // f[j][k]是结点j往上跳2^(k-1)步后的结点,倍增思想
}
}
}
}
int lca(int a, int b)
{
if (dis[a] < dis[b]) swap(a, b); // 保证a在b的下面
for(int i = 15; i >= 0; i -- )
{
if (dis[f[a][i]] >= dis[b]) // a跳的下一个点的深度和b相等时也要跳,所以可以相等
a = f[a][i];
}
if (a == b) return a;
for(int i = 15; i >= 0; i -- )
{
if (f[a][i] != f[b][i]) // 找公共祖先
{
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
int main()
{
int n, m;
cin >> n;
memset(head,-1,sizeof(head));
for(int i = 1; i <= n; i ++ )
{
int a, b;
cin >> a >> b;
if (b == -1) root = a;
else
{
add(a, b);
add(b, a);
}
}
bfs();
cin >> m;
for(int i = 1; i <= m; i ++ )
{
int a, b;
cin >> a >> b;
int p = lca(a, b);
if (p == a) cout << "1" << endl;
else if (p == b) cout << "2" << endl;
else cout << "0" << endl;
}
return 0;
}
|
例2.距离
题目链接
题目描述
给出 n 个点的一棵树,多次询问两点之间的最短距离。
注意:
输入格式
第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;
下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;
再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。
树中结点编号从 1 到 n。
输出格式
共 m 行,对于每次询问,输出一行询问结果。
数据范围
\(2≤n≤10^4,\)
\(1≤m≤2×10^4,\)
\(0<k≤100,\)
\(1≤x,y≤n\)
输入样例
输出样例
题解
对于树上任意两点的距离,我们求出它的lca,则\(dis{x,y}=depth[y]+depth[x]-2\times depth[lca]\)
于是只要能求出lca,并处理出两个结点到根节点的距离就行了
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 | #include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e4 + 5, M = 2 * N;
int head[N], e[M], ne[M], w[M], idx = 0;
void add(int a, int b, int c) {
e[idx] = b;
w[idx] = c;
ne[idx] = head[a];
head[a] = idx ++ ;
}
int n, m;
int dis[N];
int depth[N];
int f[N][17];
void bfs(int s) {
memset(depth, 0x3f, sizeof(depth));
depth[s] = 0;
depth[0] = 0;
queue<int> q;
q.push(s);
while(q.size()) {
int pos = q.front();
q.pop();
for(int i = head[pos]; i != -1; i = ne[i]) {
int j = e[i];
if(depth[j] > depth[pos] + 1) {
depth[j] = depth[pos] + 1;
dis[j] = dis[pos] + w[i];
q.push(j);
f[j][0] = pos;
for(int k = 1; k <= 15; k ++ ) {
f[j][k] = f[f[j][k - 1]][k - 1];
}
}
}
}
}
int lca(int a, int b) {
if(depth[a] < depth[b]) {
int t = a;
a = b;
b = t;
}
for(int i = 15; i >= 0; i -- ) {
if(depth[f[a][i]] >= depth[b]) {
a = f[a][i];
}
}
if(a == b) return a;
for(int i = 15; i >= 0; i -- ) {
if(f[a][i] != f[b][i]) {
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
signed main() {
cin >> n >> m;
memset(head, -1, sizeof(head));
for(int i = 1; i <= n - 1; i ++ ) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c);
add(b, a, c);
}
bfs(1);
while(m -- ) {
int a, b;
cin >> a >> b;
int p = lca(a, b);
cout << dis[a] + dis[b] - 2 * dis[p] << '\n';
}
return 0;
}
|