BOJ 13445 - 부분 수열 XOR
Trie를 쓰는건 뻔하지만 풀이를 좀 생각해보자.
아이디어 - XOR Range Sum
$i=0 \to i=n-1$ 까지 순회해가며 각각 $i$가 부분 수열에서 마지막이 되는 수일 때 정답을 쿼리해와야한다.
$a_0$ 를 trie에 넣어두었다고 하자.
$a_1$을 볼 땐 $a_0$ 자체만으로 $a_0 \oplus a_1$ 이므로 신경써주지 않아도 된다.
그럼 이제 $a_1$ 을 Trie에 넣었다고 생각해보자.
$a_2$와 $\oplus$ 해서 $k$ 미만이 되게 하는 수를 그냥 세면 안된다.
왜냐면 trie에 들어있는 수들은 그냥 $a_0$ 과 $a_1$일 뿐이지 $a_0 \oplus a_1$ 과 $a_1$ 이 아니기 때문이다.
그러므로 $a_1$ 을 넣을 때 지금까지 있는 것들을 $a_1$ 과 $\oplus$ 한 것으로 바꿔주어야 한다는 것인데 이걸 어떻게 할 수 있을까?
우리가 지금까지 $i$ 를 거치며 트리에 삽입된 모든 수에 $\oplus$ 연산을 누적했다고 할 때, 그 누적된 수를 $X$ 라고 하자.
$\oplus$ 는 교환법칙과 결합법칙이 성립하기 때문에 가능한 일이다.
이진수 $X$의 $2^k$ 자리가 $1$이라는 것은 Trie에서 $2^k$ 위치를 가리키는 노드들은 두 자식을 swap해서 가지고 있어야 한다는 뜻이다.
실제로 매 단계마다 수를 집어넣어 Trie의 모든 정점에 대해서 Swap을 직접 해두게 할 수도 있지만 시간복잡도가 $O(N)$ 이라 시간안에 돌아가지 않는다.
$X$ 를 유지시키며 insert
와 query
마다 $2^k$ 자리 일 때 $X$가 $2^k$자리가 1이라면, 두 자식 노드를 스위칭하게 만들면 문제가 해결된다.
아이디어 - Path less than K
현재 넣어야 할 수가 $x$ 이고 문제에서 제시되는 $K$ 가 있다고 하자.
Trie를 따라 내려가며 이미 이 path가 $K$ 보다 작아지는 경로인지를 계속 들고간다.
이미 작아졌다면 해당 Subtree의 모든 개수를 반환해준다.
그렇지 않다면 4가지 경우가 있다.
$x$와 $K$가 $2^k$ 자리가 $1$인지가 중요하다.
- 모두 $1$일경우
이 땐, 현재 $2^k$ 가 $1$ 인 것들과 $\oplus$ 연산을 하면 이 자리가 $0$이 되므로 무조건 $K$ 이하가 되기 때문에 $1$ 을 의미하는 자식 노드의 모든 합을 더해준다.
$0$을 의미하는 자식 노드에서도 정답이 나올 수 있기 때문에 이쪽은 계속 재귀적으로 이동한다.
- $x$ 는 $0$ $K$는 $1$일 경우
이 땐, 현재 $2^k$ 가 $0$ 인 것들과 $\oplus$ 연산을 하면 이 자리가 $0$이 되므로 무조건 $K$ 이하가 되기 때문에 $0$ 을 의미하는 자식 노드의 모든 합을 더해준다.
$1$을 의미하는 자식 노드에서도 정답이 나올 수 있기 때문에 이쪽은 계속 재귀적으로 이동한다.
- $x$ 는 $0$ $K$는 $0$일 경우
$0$을 의미하는 자식 노드에서만 정답이 나올 수 있기 때문에 계속 재귀적으로 이동한다.
- $x$ 는 $1$ $K$는 $0$일 경우
$1$을 의미하는 자식 노드에서만 정답이 나올 수 있기 때문에 계속 재귀적으로 이동한다.
이 모든건 아이디어 1에서 Swap과 병행되어야 한다는 점을 고려하자.
구현
아이디어는 어렵지 않게 떠올렸는데 오랜만에 XOR Trie를 짜니까 실수가 잦아 오래걸렸다.
int X = 0;
struct trie {
int cnt_sum = 0, cnt = 0;
trie *zero = 0, *one = 0;
int insert(int x, int depth) {
if (depth == -1) {
cnt_sum++;
cnt++;
return cnt_sum;
}
if (X & (1 << depth))swap(one, zero);
if ((1 << depth) & x) {
if (!one) one = new trie();
one->insert(x, depth - 1);
} else {
if (!zero) zero = new trie();
zero->insert(x, depth - 1);
}
if (X & (1 << depth))swap(one, zero);
cnt_sum = (one ? one->cnt_sum : 0) + (zero ? zero->cnt_sum : 0);
return cnt_sum;
}
void print(int depth, string &s) {
if (cnt) {
cout << s << endl;
}
if (depth == -1) return;
if (X & (1 << depth))swap(one, zero);
s += '1';
if (one) one->print(depth - 1, s);
s.pop_back();
s += '0';
if (zero)zero->print(depth - 1, s);
s.pop_back();
if (X & (1 << depth))swap(one, zero);
}
int query(int x, int depth, int k, int less) {
if (depth == -1) return less ? cnt_sum : 0;
int x_bit = !!(x & (1 << depth));
int k_bit = !!(k & (1 << depth));
int ret = 0;
if (X & (1 << depth))swap(one, zero);
if (less) {
ret = cnt_sum;
} else if (k_bit == 1 && x_bit == 1) {
ret = (!zero ? 0 : zero->query(x, depth - 1, k, less)) + (!one ? 0 : one->query(x, depth - 1, k, 1));
} else if (k_bit == 1 && x_bit == 0) {
ret = (!zero ? 0 : zero->query(x, depth - 1, k, 1)) + (!one ? 0 : one->query(x, depth - 1, k, less));
} else if (k_bit == 0 && x_bit == 1) {
if (!one) ret = 0;
else ret = one->query(x, depth - 1, k, less);
} else if (k_bit == 0 && x_bit == 0) {
if (!zero) ret = 0;
else ret = zero->query(x, depth - 1, k, less);
}
if (X & (1 << depth))swap(one, zero);
return ret;
}
};
void solve() {
int n, k;
cin >> n >> k;
vi a(n);
fv(a);
trie t;
ll ans = 0;
for (int i: a) debug(bitset<5>(i));
int mx = 30;
for (int i = 0; i < n; i++) {
int cur = a[i];
if (cur < k) ans++;
ans += t.query(cur, mx, k, 0);
X ^= cur;
t.insert(cur, mx);
}
cout << ans;
}
Comments