0%

HDU 6031: Innumerable Ancestors

题意

给一棵树,询问两个点集间的LCA最大深度。

分析

首先,我们通过dfs+ST表进行预处理,让每次LCA的查询只要O(1)就可以完成。接着,利用以下这个性质:

根据 DFS 序,若两个点的 DFS 序越接近,则两个点的 LCA 的深度越大。

我们可以先将其中一个点集B中的元素按dfs序排序,然后枚举另一个点集A的元素,不妨设为a,利用二分查找在点集B中查找到与a的dfs序最相近的两个元素,不妨设为b1,b2,最后LCA(a, b1) 的深度和 LCA(a, b2)的深度就是可能的答案,取max即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

#include<bits/stdc++.h>
#define x first
#define y second
#define ok cout << "ok" << endl;
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const double PI = acos(-1.0);
const int INF=0x3f3f3f3f;
const ll LINF=0x3f3f3f3f3f3f3f3f;
const int N=1e5+9;
const int shift=1e3+9;
const double Eps=1e-7;


/*************************************************************/
// LCA在线算法--dfs+ST算法

int F[2*N]; //欧拉序列, 长度为2*n-1, 下标从1开始
int rmq[2*N]; //欧拉序列对应的深度序列
int P[N]; //P[i]表示点i在F中第一次出现的位置
int tot, head[N], cnt, root, n;

struct Edge{
int to, next;
}edge[2*N];

//加边, 无向边需要加两次
void addedge(int u, int v) {
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}

//建树前的初始化
void init() {
tot = 0;
memset(head, -1, sizeof head);
}

//辅助函数
void dfs(int u, int pre, int dep) {
F[++cnt] = u;
rmq[cnt] = dep;
P[u] = cnt;
for(int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if(v == pre) continue;
dfs(v, u, dep+1);
F[++cnt] = u;
rmq[cnt] = dep;
}
}

//构建ST表
struct ST{
int mm[2*N], dp[2*N][20];
void build(int root, int n) {
cnt = 0;
dfs(root, root, 0);
mm[0] = -1;
for(int i = 1; i <= 2 * n - 1; i++) {
mm[i] = (i&(i-1)) == 0 ? mm[i-1]+1 : mm[i-1];
dp[i][0] = i;
}
for(int j = 1; j <= mm[2 * n - 1]; j++)
for(int i = 1; i + (1 << (j-1)) <= 2 * n - 1; i++)
dp[i][j] = rmq[dp[i][j-1]] < rmq[dp[i+(1<<(j-1))][j-1]] ?
dp[i][j-1] : dp[i+(1<<(j-1))][j-1];
}
int query(int a, int b) {
a = P[a], b = P[b];
if(a>b) swap(a, b);
int k = mm[b-a+1];
return F[rmq[dp[a][k]] <= rmq[dp[b-(1<<k)+1][k]] ?
dp[a][k] : dp[b-(1<<k)+1][k]];
}
}st;

/*************************************************************/

int m, u, v, ta, tb, a[N], b[N];

bool cmp(int a, int b) {
return P[a] < P[b];
}

bool check(int mid, int val) {
if(P[b[mid]] >= P[val])
return true;
return false;
}

int main(void) {
if(fopen("in", "r")!=NULL) {freopen("in", "r", stdin); freopen("out", "w", stdout);}
while(~scanf("%d%d", &n, &m)) {
init();
for(int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
st.build(1, n);
while(m--) {
scanf("%d", &ta);
for(int i = 0; i < ta; i++)
scanf("%d", a+i);
scanf("%d", &tb);
for(int i = 0; i < tb; i++)
scanf("%d", b+i);
sort(b, b + tb, cmp);
int ans = -INF;
for(int i = 0; i < ta; i++) {
int l = -1, r = tb - 1;
while(l + 1 != r) {
int mid = (l + r) >> 1;
if(check(mid, a[i]))
r = mid;
else
l = mid;
}
int rr = r;
l = 0, r = tb;
while(l + 1 != r) {
int mid = (l + r) >> 1;
if(check(mid, a[i]))
r = mid;
else
l = mid;
}
int ll = l;
ans = max(ans, rmq[P[st.query(a[i], b[ll])]]);
ans = max(ans, rmq[P[st.query(a[i], b[rr])]]);
//printf("t: %d\n", ans);
}
printf("%d\n", ans + 1);
}
}
return 0;
}