BOJ 25198 - 곰곰이의 심부름

image.png

대충 $O(n)$에 $c$를 루트로 한 트리에서 $s, h$의 LCA를 찾는다.

이제 해당 LCA를 기준으로 $s \to c \to u$ 까지 가는 경로에 가능한 것이 몇개가 있는지를 적절히 계산해주면 된다.

중복되게 세거나 빠뜨리는 경우가 있지 않게 조심하자.

void solve() {
   int n, s, c, h, u, v;
   cin >> n >> s >> c >> h, s--, c--, h--;
   vvi e(n);
   for (int i = 0; i < n - 1; i++) cin >> u >> v, u--, v--, e[u].pb(v), e[v].pb(u);
   if (s == c && c == h) {
      cout << 0;
      return;
   }
   vi level(n), par(n);
   function<void(int, int)> fn = [&](int cur, int p) -> void {
      par[cur] = p;
      for (int to: e[cur]) if (to ^ p) level[to] = level[cur] + 1, fn(to, cur);
   };
   level[c] = 0;
   fn(c, -1);
   int a = s, b = h;
   if (level[a] < level[b]) swap(a, b);
   while (level[a] ^ level[b]) a = par[a];
   while (a ^ b) {
      a = par[a];
      b = par[b];
   }
   int lca = a;

   ll A = level[c] - level[lca];
   ll B = level[lca] - level[s];
   ll C = level[lca] - level[h];
   A *= -1, B *= -1, C *= -1;
   debug(A, B, C);
   ll ans = 0;
   ans += A * (A - 1);
   ans += B * (B - 1) / 2;
   ans += C * (C - 1) / 2;
   ans += B * A;
   ans += B * C;
   ans += A * C;
   ans += A * 2;
   ans += B;
   ans += C;
   cout << ans;
}

Tags:

Categories:

Updated:

Comments