BOJ 28254 - 트리와 깃발
대회중에 최근에 공부했던 테크닉이 기억나서 쉽게 풀었던 문제이다.
정해는 Virtual Tree라는데, 이 문제는 Sack으로도 풀 수 있다.
Sack테크닉을 써서 현재 정점의 서브트리에서 모든 색의 개수를 관리한다.
하나의 정점은 동일한 색을 여러개 가지지 못하기 때문에, 서브트리에 색 $c$가 $k$개 있고, 색 $c$가 총 $k'$개 있다면
$k(k'-k)$ 개를 정답에 더해주면 된다.
$\displaystyle \sum_{j=1}^{M}C_j \le 500,000$ 이기 때문에 자식 정점 중 가장 색을 많이 가지고 있는 정점을 무거운 간선으로 설정해주면 된다.
void solveE() {
int n, m;
cin >> n >> m;
vector<vector<pi>> edges(n);
for (int i = 0, u, v; i < n - 1; i++) {
cin >> u >> v, u--, v--;
edges[u].pb({v, i}), edges[v].pb({u, i});
}
vi ans(n - 1);
vi color_cnt(m);
vvi colors(n);
for (int i = 0; i < m; i++) {
int c;
cin >> c;
color_cnt[i] = c;
for (int j = 0; j < c; j++) {
int x;
cin >> x, x--;
colors[x].pb(i);
}
}
vi size(n), in(n), out(n), in_rev(n);
int dfsn = 0;
vi p_edge_idx(n, -1);
function<void(int, int)> fn = [&](int cur, int p) -> void {
size[cur] = sz(colors[cur]);
in[cur] = dfsn, in_rev[dfsn] = cur;
dfsn++;
for (auto &[to, idx]: edges[cur]) {
if (to == p) continue;
fn(to, cur);
p_edge_idx[to] = idx;
size[cur] += size[to];
}
out[cur] = dfsn - 1;
};
fn(0, -1);
vi cnt_of_color(m);
int not_all = 0;
auto add = [&](int x) {
for (int c: colors[x]) {
int other = color_cnt[c] - cnt_of_color[c];
not_all -= other * cnt_of_color[c];
cnt_of_color[c]++;
other--;
not_all += other * cnt_of_color[c];
}
};
auto remove = [&](int x) {
for (int c: colors[x]) {
int other = color_cnt[c] - cnt_of_color[c];
not_all -= other * cnt_of_color[c];
cnt_of_color[c]--;
other++;
not_all += other * cnt_of_color[c];
}
};
function<void(int, int, bool)> sack = [&](int cur, int p, bool keep) -> void {
int B = -1;
for (auto &[to, _]: edges[cur]) {
if (to == p) continue;
if (B == -1 || size[B] < size[to]) B = to;
}
for (auto &[to, _]: edges[cur])
if (to != p && to != B)
sack(to, cur, false);
if (B != -1)
sack(B, cur, true);
add(cur);
for (auto &[to, idx]: edges[cur])
if (to != p && to != B)
for (int i = in[to]; i <= out[to]; i++)
add(in_rev[i]);
if (p_edge_idx[cur] != -1) {
ans[p_edge_idx[cur]] = not_all;
}
if (!keep)
for (int i = in[cur]; i <= out[cur]; i++)
remove(in_rev[i]);
};
sack(0, -1, 1);
for (int i = 0; i < n - 1; i++) cout << ans[i] << endl;
}
Comments