1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
| #include <bits/stdc++.h> #define mp make_pair #define pb push_back #define sz(x) (int)x.size() #define all(x) begin(x), end(x) #define fi first #define se second #define debug(x) cerr << #x << " " << x << '\n' using namespace std; using ll = long long; using pii = pair<int,int>; using pli = pair<ll,int>; const int INF = 0x3f3f3f3f, N = 4e5 + 5; const ll LINF = 1e18 + 5; constexpr int mod = 1e9 + 7; constexpr int LOG = 19; int n, k, r, sp[N], step[N], f[N]; vector <int> G[N]; int fa[N][22], dep[N]; void dfs(int u) { for(int i=1; i<=LOG; i++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int v : G[u]) { if(v==fa[u][0]) continue; dep[v] = dep[u] + 1; fa[v][0] = u; dfs(v); } } int LCA(int x,int y) { if(dep[x]>dep[y]) swap(x,y); for(int i=LOG; i>=0; i--) if(dep[fa[y][i]]>=dep[x]) y = fa[y][i]; if(x==y) return x; for(int i=LOG; i>=0; i--) if(fa[x][i]!=fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } int up(int x,int d) { for(int i=LOG; i>=0; i--) if((d>>i)&1) x = fa[x][i]; return x; } int find(int x) { if(x==f[x]) return x; return f[x] = find(f[x]); } void merge(int x, int y) { x = find(x), y = find(y); if(x!=y) f[x] = y; } void bfs() { queue <int> q; for(int i=1; i<=2*n; i++) f[i] = i, step[i] = -1; for(int i=1; i<=r; i++) q.push(sp[i]), step[sp[i]] = 0; while(sz(q)) { int u = q.front(); q.pop(); if(step[u]>=k) break; for(int v : G[u]) { merge(u, v); if(step[v]==-1) { step[v] = step[u] + 1; q.push(v); } } } } void walk(int &u, int v, int w, int k) { if(dep[u]-dep[w]>=k) u = up(u, k); else u = up(v, dep[v]-k+dep[u]-2*dep[w]); } bool ok(int u, int v) { int w = LCA(u, v); if(dep[u]+dep[v]-2*dep[w]<=2*k) return 1; walk(u, v, w, k); walk(v, u, w, k); if(find(u)==find(v)) return 1; return 0; } int main() { scanf("%d%d%d", &n, &k, &r); for(int i=1; i<n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].pb(n+i); G[v].pb(n+i); G[n+i].pb(u); G[n+i].pb(v); } for(int i=1; i<=r; i++) scanf("%d", &sp[i]); dep[1] = 1; dfs(1); bfs(); int q; scanf("%d", &q); while(q--) { int u, v; scanf("%d%d", &u, &v); if(ok(u, v)) puts("YES"); else puts("NO"); } return 0; }
|