BOJ 13445 - 부분 수열 XOR
Trie를 쓰는건 뻔하지만 풀이를 좀 생각해보자.
아이디어 - XOR Range SumPermalink
까지 순회해가며 각각 가 부분 수열에서 마지막이 되는 수일 때 정답을 쿼리해와야한다.
를 trie에 넣어두었다고 하자.
을 볼 땐 자체만으로 이므로 신경써주지 않아도 된다.
그럼 이제 을 Trie에 넣었다고 생각해보자.
와 해서 미만이 되게 하는 수를 그냥 세면 안된다.
왜냐면 trie에 들어있는 수들은 그냥 과 일 뿐이지 과 이 아니기 때문이다.
그러므로 을 넣을 때 지금까지 있는 것들을 과 한 것으로 바꿔주어야 한다는 것인데 이걸 어떻게 할 수 있을까?
우리가 지금까지 를 거치며 트리에 삽입된 모든 수에 연산을 누적했다고 할 때, 그 누적된 수를 라고 하자.
는 교환법칙과 결합법칙이 성립하기 때문에 가능한 일이다.
이진수 의 자리가 이라는 것은 Trie에서 위치를 가리키는 노드들은 두 자식을 swap해서 가지고 있어야 한다는 뜻이다.
실제로 매 단계마다 수를 집어넣어 Trie의 모든 정점에 대해서 Swap을 직접 해두게 할 수도 있지만 시간복잡도가 이라 시간안에 돌아가지 않는다.
를 유지시키며 insert
와 query
마다 자리 일 때 가 자리가 1이라면, 두 자식 노드를 스위칭하게 만들면 문제가 해결된다.
아이디어 - Path less than KPermalink
현재 넣어야 할 수가 이고 문제에서 제시되는 가 있다고 하자.
Trie를 따라 내려가며 이미 이 path가 보다 작아지는 경로인지를 계속 들고간다.
이미 작아졌다면 해당 Subtree의 모든 개수를 반환해준다.
그렇지 않다면 4가지 경우가 있다.
와 가 자리가 인지가 중요하다.
- 모두 일경우
이 땐, 현재 가 인 것들과 연산을 하면 이 자리가 이 되므로 무조건 이하가 되기 때문에 을 의미하는 자식 노드의 모든 합을 더해준다.
을 의미하는 자식 노드에서도 정답이 나올 수 있기 때문에 이쪽은 계속 재귀적으로 이동한다.
- 는 는 일 경우
이 땐, 현재 가 인 것들과 연산을 하면 이 자리가 이 되므로 무조건 이하가 되기 때문에 을 의미하는 자식 노드의 모든 합을 더해준다.
을 의미하는 자식 노드에서도 정답이 나올 수 있기 때문에 이쪽은 계속 재귀적으로 이동한다.
- 는 는 일 경우
을 의미하는 자식 노드에서만 정답이 나올 수 있기 때문에 계속 재귀적으로 이동한다.
- 는 는 일 경우
을 의미하는 자식 노드에서만 정답이 나올 수 있기 때문에 계속 재귀적으로 이동한다.
이 모든건 아이디어 1에서 Swap과 병행되어야 한다는 점을 고려하자.
구현Permalink
아이디어는 어렵지 않게 떠올렸는데 오랜만에 XOR Trie를 짜니까 실수가 잦아 오래걸렸다.
int X = 0;
struct trie {
int cnt_sum = 0, cnt = 0;
trie *zero = 0, *one = 0;
int insert(int x, int depth) {
if (depth == -1) {
cnt_sum++;
cnt++;
return cnt_sum;
}
if (X & (1 << depth))swap(one, zero);
if ((1 << depth) & x) {
if (!one) one = new trie();
one->insert(x, depth - 1);
} else {
if (!zero) zero = new trie();
zero->insert(x, depth - 1);
}
if (X & (1 << depth))swap(one, zero);
cnt_sum = (one ? one->cnt_sum : 0) + (zero ? zero->cnt_sum : 0);
return cnt_sum;
}
void print(int depth, string &s) {
if (cnt) {
cout << s << endl;
}
if (depth == -1) return;
if (X & (1 << depth))swap(one, zero);
s += '1';
if (one) one->print(depth - 1, s);
s.pop_back();
s += '0';
if (zero)zero->print(depth - 1, s);
s.pop_back();
if (X & (1 << depth))swap(one, zero);
}
int query(int x, int depth, int k, int less) {
if (depth == -1) return less ? cnt_sum : 0;
int x_bit = !!(x & (1 << depth));
int k_bit = !!(k & (1 << depth));
int ret = 0;
if (X & (1 << depth))swap(one, zero);
if (less) {
ret = cnt_sum;
} else if (k_bit == 1 && x_bit == 1) {
ret = (!zero ? 0 : zero->query(x, depth - 1, k, less)) + (!one ? 0 : one->query(x, depth - 1, k, 1));
} else if (k_bit == 1 && x_bit == 0) {
ret = (!zero ? 0 : zero->query(x, depth - 1, k, 1)) + (!one ? 0 : one->query(x, depth - 1, k, less));
} else if (k_bit == 0 && x_bit == 1) {
if (!one) ret = 0;
else ret = one->query(x, depth - 1, k, less);
} else if (k_bit == 0 && x_bit == 0) {
if (!zero) ret = 0;
else ret = zero->query(x, depth - 1, k, less);
}
if (X & (1 << depth))swap(one, zero);
return ret;
}
};
void solve() {
int n, k;
cin >> n >> k;
vi a(n);
fv(a);
trie t;
ll ans = 0;
for (int i: a) debug(bitset<5>(i));
int mx = 30;
for (int i = 0; i < n; i++) {
int cur = a[i];
if (cur < k) ans++;
ans += t.query(cur, mx, k, 0);
X ^= cur;
t.insert(cur, mx);
}
cout << ans;
}
Comments