BOJ 15647 - 로스팅하는 엠마도 바리스타입니다

다음과 같은 두 개의 값을 각 노드마다 관리한다고 하자.

$\quad dp_i=$ 자신의 서브트리에서 자기까지 오는데 총 거리

$\quad size_i=$자신의 서브 트리의 크기

일단 1번 노드를 루트로 이 두 값을 $O(N)$에 구해줄 수 있다.

그럼 1번 노드에서의 $dp_1$ 은 문제에서 요구하는 정답일 것이다.

이제 1번 노드부터 인접 노드들로 DFS를 돌려주는데, DFS로 방문하면서 다음과 같이 Re-Rooting을 적용시켜준다.

Re-Rooting은 트리의 루트를 인접한 노드로 계속 바꿔주는 테크닉으로, 문제에 따라 적절히 $dp$ 값들을 변형해주면서 진행한다. 이 문제의 경우 이걸 아래와 같이 $O(1)$에 할 수 있다.

$x\to y$ 로 간다고 할 때,

$dp_x \coloneqq dp_x - (dp_y + w_{x,\,y} + size_y)$

$size_x \coloneqq size_{x}- size_y$

$size_{y}\coloneqq n$

$dp_{y}\coloneqq dp_y+ size_{x}\cdot w_{x,y}+dp_x$

  1. $dp_x$ 는 $y$를 자식으로 가지기 때문에 더해져있던 값을 뺀다.
  2. $size_x$ 는 $y$를 자식으로 가지기 때문에 더해져있던 값을 뺀다.
  3. $size_y$는 $y$가 루트가 되므로 $n$이 된다.
  4. $dp_y$는 $x$를 새로운 자식으로 갖게 되었기 때문에 증가하는 값을 더해준다.

이렇게 되면 $y$를 기준으로 루트가 되는 서브트리에서의 $dp_y$와 $size_y$ 값이 적절히 변경된다.

이제 DFS를 마치고 다시 돌아왔을땐 저걸 역순으로 다시 적용시켜주면 된다.

따라서 모든 정답을 $O(N)$에 구할 수 있다.

void solve() {  
   int n;  
   cin >> n;  
   vector<vector<pi>> edges(n);  
   for (int i = 0; i < n - 1; i++) {  
      int u, v, w;  
      cin >> u >> v >> w, u--, v--;  
      edges[u].pb({v, w});  
      edges[v].pb({u, w});  
   }  
   vi dp(n), size(n);  
   function<int(int, int)> dfs1 = [&](int i, int p) -> int {  
      size[i] = 1;  
      for (auto&[to, w]: edges[i]) {  
         if (to == p) continue;  
         int s = dfs1(to, i);  
         size[i] += s;  
         dp[i] += w * s + dp[to];  
      }  
      return size[i];  
   };  
   dfs1(0, -1);  
   assert(size[0] == n);  
   vi vis(n), ans(n);  
   function<void(int)> dfs2 = [&](int i) -> void {  
      vis[i] = 1;  
      ans[i] = dp[i];  
      for (auto&[to, w]: edges[i]) {  
         if (vis[to]) continue;  
  
         dp[i] -= dp[to] + w * size[to];  
         size[i] -= size[to];  
         dp[to] += size[i] * w + dp[i];  
         int ps = size[to];  
         size[to] = n;  
         dfs2(to);  
         size[to] = ps;  
         dp[to] -= size[i] * w + dp[i];  
         size[i] += size[to];  
         dp[i] += dp[to] + w * size[to];  
      }  
   };  
   dfs2(0);  
   for (int i = 0; i < n; i++) cout << ans[i] << endl;  
}

Tags:

Categories:

Updated:

Comments