跳转至

最近公共祖先

概述

如果一棵树上的两个结点向上移动,最后交汇的第一个结点,也就是两个结点的lca。

本章是依靠倍增法求出这样一个结点。

Note

倍增法类似于二分,以2倍,4倍...等倍数增加

倍增这个思想可以借鉴这个博客👉【白话系列】倍增算法


思路

求两个结点的最近公共祖先,我们首先分以下两种情况

  • 两结点深度不同
  • 两结点深度相同

我们思考能否将其化成一种情况,对于深度不同的情况,我们发现若将深度大的结点向着它的父节点迭代,直到和深度小的一样,这样它们俩的最近公共祖先并不会变化,因此可以将深度不同的情况转化成深度相同的情况

那如何迭代呢?

其实这里用的就是倍增法,我们预处理了一个\(f[i,j]\)数组表示\(i\)往根节点跳\(2^j\)所达到的结点,因此\(i\)可以跳入到\(f[i,j]\)的位置

这里应该会有人会提问

上面的倍增都是\(2\)的幂次,那若是两者的深度差是其他数呢?

这里只要把一个数看成二进制的话,应该就知道了吧,每个数都能被表示成\(2\)的幂次的和

然后我们思考下一个问题

都转化成深度相同的情况后又该怎么办呢

这里还是用的倍增法,只是和之前的倍增有点不同

  • 此时是两个结点一起迭代
  • 迭代停止的条件不一样

具体哪不一样看例题体会一下


代码模板

bfs求倍增数组

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void bfs() 
{
    memset(dis, 0x3f, sizeof(dis)); // dis数组为深度数组
    dis[0] = 0;
    dis[root] = 1;
    q.push(root);
    while(q.size())
    {
        int pos = q.front();
        q.pop();
        if (vis[pos]) continue;
        vis[pos] = true;
        for(int i = head[pos]; i != -1; i = ne[i])
        {
            int j = e[i];
            if (dis[j] > dis[pos] + 1) 
            {
                dis[j] = dis[pos] + 1;
                f[j][0] = pos;
                q.push(j);
                for(int k = 1; k <= 15; k ++ )
                   f[j][k] = f[f[j][k-1]][k-1]; // f[j][k]是结点j往上跳2^(k-1)步后的结点,倍增思想
            }
        }
    }
}

求lca

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
int lca(int a, int b)
{
    if (dis[a] < dis[b]) swap(a, b); // 保证a在b的下面
    for(int i = 15; i >= 0; i -- ) // 第一次迭代
    {
        if (dis[f[a][i]] >= dis[b]) // a跳的下一个点的深度和b相等时也要跳,所以可以相等
            a = f[a][i];  
    }

    if (a == b) return a; // 跳到相同深度嗷
    for(int i = 15; i >= 0; i -- ) // 第二次迭代
    {
        if (f[a][i] != f[b][i]) // 找公共祖先 终止条件是a和b下一个结点相同
        {
            a = f[a][i];
            b = f[b][i];
        }
    }
    return f[a][0]; // 没有跳到lca上嗷
}

例1.祖孙询问

题目链接

题目描述

给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。

有 m 个询问,每个询问给出了一对节点的编号 x 和 y,询问 x 与 y 的祖孙关系。

输入格式

输入第一行包括一个整数 表示节点个数;

接下来 n 行每行一对整数 a 和 b,表示 a 和 b 之间有一条无向边。如果 b 是 −1,那么 a 就是树的根;

第 n+2 行是一个整数 m 表示询问个数;

接下来 m 行,每行两个不同的正整数 x 和 y,表示一个询问。

输出格式

对于每一个询问,若 x 是 y 的祖先则输出 1,若 y 是 x 的祖先则输出 2,否则输出 0。

数据范围

\(1≤n,m≤4×104 ,\)

\(1≤每个节点的编号≤4×104\)

输入样例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19

输出样例

1
2
3
4
5
1
0
0
0
2

题解

LCA模板题

代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<bits/stdc++.h>
using namespace std;
const int N = 4e4+5, M = 8e4+5;
int head[N], e[M], w[M], ne[M], idx = 0;
void add(int a, int b)
{
    e[idx] = b;
    ne[idx] = head[a];
    head[a] = idx ++ ;
}
int root;
int dis[N];
bool vis[N];
int f[N][17];
queue<int> q;
void bfs() 
{
    memset(dis, 0x3f, sizeof(dis));
    dis[0] = 0;
    dis[root] = 1;
    q.push(root);
    while(q.size())
    {
        int pos = q.front();
        q.pop();
        if (vis[pos]) continue;
        vis[pos] = true;
        for(int i = head[pos]; i != -1; i = ne[i])
        {
            int j = e[i];
            if (dis[j] > dis[pos] + 1) 
            {
                dis[j] = dis[pos] + 1;
                f[j][0] = pos;
                q.push(j);
                for(int k = 1; k <= 15; k ++ )
                   f[j][k] = f[f[j][k-1]][k-1]; // f[j][k]是结点j往上跳2^(k-1)步后的结点,倍增思想
            }
        }
    }
}
int lca(int a, int b)
{
    if (dis[a] < dis[b]) swap(a, b); // 保证a在b的下面
    for(int i = 15; i >= 0; i -- )
    {
        if (dis[f[a][i]] >= dis[b]) // a跳的下一个点的深度和b相等时也要跳,所以可以相等
            a = f[a][i];
    }

    if (a == b) return a;
    for(int i = 15; i >= 0; i -- )
    {
        if (f[a][i] != f[b][i]) // 找公共祖先
        {
            a = f[a][i];
            b = f[b][i];
        }
    }
    return f[a][0];
}
int main()
{
    int n, m;
    cin >> n;
    memset(head,-1,sizeof(head));
    for(int i = 1; i <= n; i ++ ) 
    {
        int a, b;
        cin >> a >> b;
        if (b == -1) root = a;
        else 
        {
            add(a, b);
            add(b, a);
        }
    }
    bfs();
    cin >> m;
    for(int i = 1; i <= m; i ++ )
    {
        int a, b;
        cin >> a >> b;
        int p = lca(a, b);
        if (p == a) cout << "1" << endl;
        else if (p == b) cout << "2" << endl;
        else cout << "0" << endl;
    }
    return 0;
}

例2.距离

题目链接

题目描述

给出 n 个点的一棵树,多次询问两点之间的最短距离。

注意:

  • 边是无向的。
  • 所有节点的编号是 1,2,…,n。

输入格式

第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;

下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;

再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。

树中结点编号从 1 到 n。

输出格式

共 m 行,对于每次询问,输出一行询问结果。

数据范围

\(2≤n≤10^4,\)

\(1≤m≤2×10^4,\)

\(0<k≤100,\)

\(1≤x,y≤n\)

输入样例

1
2
3
4
2 2 
1 2 100 
1 2 
2 1

输出样例

1
2
100
100

题解

对于树上任意两点的距离,我们求出它的lca,则\(dis{x,y}=depth[y]+depth[x]-2\times depth[lca]\)

于是只要能求出lca,并处理出两个结点到根节点的距离就行了

代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e4 + 5, M = 2 * N;
int head[N], e[M], ne[M], w[M], idx = 0;
void add(int a, int b, int c) {
    e[idx] = b;
    w[idx] = c;
    ne[idx] = head[a];
    head[a] = idx ++ ;
}
int n, m;

int dis[N];
int depth[N];
int f[N][17];
void bfs(int s) {
    memset(depth, 0x3f, sizeof(depth));
    depth[s] = 0;
    depth[0] = 0;
    queue<int> q;
    q.push(s);
    while(q.size()) {
        int pos = q.front();
        q.pop();
        for(int i = head[pos]; i != -1; i = ne[i]) {
            int j = e[i];
            if(depth[j] > depth[pos] + 1) {
                depth[j] = depth[pos] + 1;
                dis[j] = dis[pos] + w[i];
                q.push(j);
                f[j][0] = pos;
                for(int k = 1; k <= 15; k ++ ) {
                    f[j][k] = f[f[j][k - 1]][k - 1];
                }
            }
        }
    }
}

int lca(int a, int b) {
    if(depth[a] < depth[b]) {
        int t = a;
        a = b;
        b = t;
    }

    for(int i = 15; i >= 0; i -- ) {
        if(depth[f[a][i]] >= depth[b]) {
            a = f[a][i]; 
        }
    }

    if(a == b) return a;

    for(int i = 15; i >= 0; i -- ) {
        if(f[a][i] != f[b][i]) {
            a = f[a][i];
            b = f[b][i];
        }
    }
    return f[a][0];
}
signed main() {
    cin >> n >> m;
    memset(head, -1, sizeof(head));
    for(int i = 1; i <= n - 1; i ++ ) {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    bfs(1);

    while(m -- ) {
        int a, b;
        cin >> a >> b;
        int p = lca(a, b);
        cout << dis[a] + dis[b] - 2 * dis[p] << '\n';
    }
    return 0;
}