BOJ 21759 - 두 개의 팀

image.png

DP 3개로 풀 수 있다.

$dp_\text{valid max}(i):$ 자기 자신을 포함하고(팀장이나 상사로 하고) 조건 3에 위배되지 않게 자식 팀원들을 구성했을 때 최대합

이건 양수일 때만 포함시켜 더해주면 된다.

function<int(int)> dfs_valid_max = [&](int i) -> int {  
   dp_valid_max[i] = x[i];  
   for (int to: children[i]) {  
      int child_ret = dfs_valid_max(to);  
      dp_valid_max[i] += max(0ll, child_ret);  
   }  
   return dp_valid_max[i];  
};  
dfs_valid_max(0);

$dp_\text{all max}(i):$ 서브트리에서 아무 자식이나 팀장으로 했을 때 팀들 중 최대합

function<int(int)> dfs_all_max = [&](int i) -> int {  
   dp_all_max[i] = dp_valid_max[i];  
   for (int to: children[i]) {  
      int child_ret = dfs_all_max(to);  
      dp_all_max[i] = max(dp_all_max[i], child_ret);  
   }  
   return dp_all_max[i];  
};  
dfs_all_max(0);

$dp_{\text{without valid}}(i):$ 자기 자신을 팀장이나 상사로 했을 때 조건 3에 위배되지 않게 팀원들을 모두 가져갔을 때 거기에 포함이 되지 않아도 되는 어떤 정점을 팀장으로 하는 팀들 중 최대합

이건 양수여서 무조건 포함해야 될 때는 재귀적으로 더 파고들어가서 값을 꺼내오고, 포함하지 않아도 될 때는 $dp_{\text{all max}}$ 값을 써주면 된다.

function<int(int)> dfs_without_valid = [&](int cur) -> int {  
   for (int i = 0; i < sz(children[cur]); i++) {  
      int c = children[cur][i];  
      int k = dfs_without_valid(c);  
      if (dp_valid_max[c] <= 0) {  
         dp_without_valid[cur] = max(dp_without_valid[cur], dp_all_max[c]);  
      } else {  
         dp_without_valid[cur] = max(dp_without_valid[cur], k);  
      }  
   }  
   return dp_without_valid[cur];  
};  
dfs_without_valid(0);

이 $dp$ 들이 모두 준비되었다면, 모든 정점에 대해

  1. 자신을 팀장으로 하거나
  2. 자신의 다른 두 자식에 각각 팀장이 있거나

를 검사하여 최대값을 찾아내주면 된다.

int ans = -inf;  
for (int i = 0; i < n; i++) {  
   ans = max(ans, dp_valid_max[i] + dp_without_valid[i]);  
   if (sz(children[i]) >= 2) {  
      sort(all(children[i]), [&](int i, int j) { return dp_all_max[i] > dp_all_max[j]; });  
      ans = max(ans, dp_all_max[children[i][0]] + dp_all_max[children[i][1]]);  
   }  
}  
cout << ans;

전체 코드

const int inf = 1e17;  
  
void solve() {  
   int n;  
   cin >> n;  
   vvi children(n);  
   vi x(n), par(n);  
   for (int i = 0; i < n; i++) {  
      cin >> x[i] >> par[i];  
      par[i]--;  
      if (par[i] != -2) children[par[i]].pb(i);  
   }  
  
   vi dp_valid_max(n, -inf), dp_all_max(n, -inf), dp_without_valid(n, -inf);  
   function<int(int)> dfs_valid_max = [&](int i) -> int {  
      dp_valid_max[i] = x[i];  
      for (int to: children[i]) {  
         int child_ret = dfs_valid_max(to);  
         dp_valid_max[i] += max(0ll, child_ret);  
      }  
      return dp_valid_max[i];  
   };  
   dfs_valid_max(0);  
  
   function<int(int)> dfs_all_max = [&](int i) -> int {  
      dp_all_max[i] = dp_valid_max[i];  
      for (int to: children[i]) {  
         int child_ret = dfs_all_max(to);  
         dp_all_max[i] = max(dp_all_max[i], child_ret);  
      }  
      return dp_all_max[i];  
   };  
   dfs_all_max(0);  
  
   function<int(int)> dfs_without_valid = [&](int cur) -> int {  
      for (int i = 0; i < sz(children[cur]); i++) {  
         int c = children[cur][i];  
         int k = dfs_without_valid(c);  
         if (dp_valid_max[c] <= 0) {  
            dp_without_valid[cur] = max(dp_without_valid[cur], dp_all_max[c]);  
         } else {  
            dp_without_valid[cur] = max(dp_without_valid[cur], k);  
         }  
      }  
      return dp_without_valid[cur];  
   };  
   dfs_without_valid(0);  
  
   int ans = -inf;  
   for (int i = 0; i < n; i++) {  
      ans = max(ans, dp_valid_max[i] + dp_without_valid[i]);  
      if (sz(children[i]) >= 2) {  
         sort(all(children[i]), [&](int i, int j) { return dp_all_max[i] > dp_all_max[j]; });  
         ans = max(ans, dp_all_max[children[i][0]] + dp_all_max[children[i][1]]);  
      }  
   }  
   cout << ans;  
}

Tags:

Categories:

Updated:

Comments