BOJ 25198 - 곰곰이의 심부름
대충 $O(n)$에 $c$를 루트로 한 트리에서 $s, h$의 LCA를 찾는다.
이제 해당 LCA를 기준으로 $s \to c \to u$ 까지 가는 경로에 가능한 것이 몇개가 있는지를 적절히 계산해주면 된다.
중복되게 세거나 빠뜨리는 경우가 있지 않게 조심하자.
void solve() {
int n, s, c, h, u, v;
cin >> n >> s >> c >> h, s--, c--, h--;
vvi e(n);
for (int i = 0; i < n - 1; i++) cin >> u >> v, u--, v--, e[u].pb(v), e[v].pb(u);
if (s == c && c == h) {
cout << 0;
return;
}
vi level(n), par(n);
function<void(int, int)> fn = [&](int cur, int p) -> void {
par[cur] = p;
for (int to: e[cur]) if (to ^ p) level[to] = level[cur] + 1, fn(to, cur);
};
level[c] = 0;
fn(c, -1);
int a = s, b = h;
if (level[a] < level[b]) swap(a, b);
while (level[a] ^ level[b]) a = par[a];
while (a ^ b) {
a = par[a];
b = par[b];
}
int lca = a;
ll A = level[c] - level[lca];
ll B = level[lca] - level[s];
ll C = level[lca] - level[h];
A *= -1, B *= -1, C *= -1;
debug(A, B, C);
ll ans = 0;
ans += A * (A - 1);
ans += B * (B - 1) / 2;
ans += C * (C - 1) / 2;
ans += B * A;
ans += B * C;
ans += A * C;
ans += A * 2;
ans += B;
ans += C;
cout << ans;
}
Comments