BOJ 1623 - 신년 파티

image.png

단순 트리 DP로, 역추적까지 해주면 된다.

$dp[i][0]$ 을 자기 자신이 포함되지 않았을 때 $i$의 서브트리에서 최댓값으로 하고,

$dp[i][1]$를 자신이 포함되었을 때 최댓값으로 하면 된다.

const int inf = 2e18;
void solve() {
   int n;
   cin >> n;
   vi v(n);
   fv(v);

   vvi edges(n);
   for (int i = 1, p; i < n; i++) {
      cin >> p, p--;
      edges[p].pb(i);
   }

   vvi dp(n, vi(2, -inf));

   function<int(int, int)> fn = [&](int cur, int with) -> int {
      int &ret = dp[cur][with];
      if (ret != -inf) return ret;
      ret = -inf + 1;
      if (sz(edges[cur]) == 0) {
         if (with) ret = v[cur];
         else ret = 0ll;
      } else {
         if (with) {
            ret = v[cur];
            for (int to: edges[cur]) {
               ret += fn(to, 0);
            }
         } else {
            ret = 0;
            for (int to: edges[cur]) {
               ret += max(fn(to, 0), fn(to, 1));
            }
         }
      }
      return ret;
   };
   cout << fn(0, 1) << ' ' << fn(0, 0) << endl;

   vi ans1, ans2;
   vi *ans = &ans1;
   function<void(int, int)> track = [&](int cur, int with) -> void {
      if (sz(edges[cur]) == 0) {
         if (with) ans->pb(cur);
      } else {
         if (with) {
            ans->pb(cur);
            for (int to: edges[cur]) {
               track(to, 0);
            }
         } else {
            for (int to: edges[cur]) {
               if (fn(to, 1) > fn(to, 0)) {
                  track(to, 1);
               } else {
                  track(to, 0);
               }
            }
         }
      }
   };
   track(0, 1);
   ans = &ans2;
   track(0, 0);
   sort(all(ans1));
   sort(all(ans2));
   for (int i: ans1) cout << i + 1 << ' ';
   cout << -1 << endl;
   for (int i: ans2) cout << i + 1 << ' ';
   cout << -1 << endl;
}

Tags:

Categories:

Updated:

Comments