AtCoder ABC294 G - Distance Queries on a Tree
이걸로 ABC294 마무리!
단순히 LCA를 쓰거나 HLD를 쓴다면 그냥 풀 수 있는 문제이다.
문제는 필자가 HLD가 기억이 안나서 그냥 이전 템플릿으로 풀었다는 것인데,
복습할 때 HLD에 대해서 다시 잘 알아보자.
template<class T>
struct seg_tree {
private:
const T INF = numeric_limits<T>::max() >> 1;
struct node {
T sum = 0, min = 0, max = 0;
node(T v) : sum(v), min(v), max(v) {}
node(T sum, T min, T max) : sum(sum), min(min), max(max) {}
};
const node identity{0, INF, -INF};
int N;
vector<node> tree;
node merge(node l, node r) {
node ret = {l.sum + r.sum, min(l.min, r.min), max(l.max, r.max)};
return ret;
}
void update(int n, int nl, int nr, int i, T v) {
if (nl > i || nr < i) return;
if (nl == nr) {
tree[n] = node(v); // diff or assign?
return;
}
int m = (nl + nr) >> 1;
update(n * 2, nl, m, i, v);
update(n * 2 + 1, m + 1, nr, i, v);
tree[n] = merge(tree[n * 2], tree[n * 2 + 1]);
}
node query(int n, int nl, int nr, int l, int r) {
if (nl > r || nr < l) return identity;
if (nl >= l && nr <= r) return tree[n];
int m = (nl + nr) >> 1;
return merge(query(n * 2, nl, m, l, r), query(n * 2 + 1, m + 1, nr, l, r));
}
public:
seg_tree(int N) : N(N) {
int tree_size = 1 << ((int) ceil(log2(N)) + 1);
tree = vector<node>(tree_size, identity);
}
void update(int i, T v) { update(1, 0, N - 1, i, v); }
node query(int l, int r) { return query(1, 0, N - 1, l, r); };
};
template<class T>
struct HLD {
struct Edge { T to, cost; };
int N, next_dfsn = 0, for_edge;
vi subsize, par, depth, head, in, out;
vector<vector<Edge>> edges;
seg_tree<T> seg;
HLD(int N, int for_edge = 1)
: N(N), for_edge(for_edge), par(N), subsize(N), depth(N), edges(N), head(N), in(N), out(N),
seg(seg_tree<T>(N)) {}
void add_edge(int u, int v, T cost = 0) { edges[u].pb({v, cost}); }
void init(int root = 0) {
dfs1(root, -1);
dfs2(root, -1);
for (int cur = 0; cur < N; cur++)
for (const Edge &e: edges[cur])
if (par[e.to] == cur)
update_node(e.to, e.cost);
}
void dfs1(int cur, int p) {
subsize[cur] = 1;
for (Edge &e: edges[cur]) {
if (e.to == p) continue;
depth[e.to] = depth[cur] + 1;
dfs1(e.to, cur);
subsize[cur] += subsize[e.to];
if (subsize[e.to] > subsize[edges[cur][0].to])swap(edges[cur][0], e);
}
}
void dfs2(int cur, int p) {
par[cur] = p;
in[cur] = next_dfsn++;
for (const Edge &e: edges[cur]) {
if (e.to == par[cur]) continue;
head[e.to] = (e.to == edges[cur][0].to) ? head[cur] : e.to;
dfs2(e.to, cur);
}
out[cur] = next_dfsn - 1;
}
void update_node(int i, T v) { seg.update(in[i], v); }
void update_path(int a, int b, T v) {
if (depth[a] > depth[b]) swap(a, b);
seg.update(in[b], v);
}
T query_sum(int a, int b) {
T ret = 0;
for (; head[a] ^ head[b]; a = par[head[a]]) {
if (depth[head[a]] < depth[head[b]]) swap(a, b);
ret += seg.query(in[head[a]], in[a]).sum;
}
if (depth[a] > depth[b]) swap(a, b);
return ret + seg.query(in[a] + for_edge, in[b]).sum;
}
};
void solve() {
int n;
cin >> n;
HLD<int> hld(n, 1);
vector<pi> edges;
for (int i = 0; i < n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
u--, v--;
hld.add_edge(u, v, w);
hld.add_edge(v, u, w);
edges.pb({u, v});
}
hld.init(0);
int q;
cin >> q;
while (q--) {
int cmd, x, y;
cin >> cmd >> x >> y;
if (cmd == 1) {
x--;
hld.update_path(edges[x].fi, edges[x].se, y);
} else {
x--, y--;
cout << hld.query_sum(x, y) << endl;
}
}
}
Comments