BOJ 1215 - 잘못 작성한 요세푸스 코드

image.png

문제 설명

코드를 읽어보면, 이 문제는 $1 \le n,k \le 10^9$ 인 제한에서 $k$를 $1$부터 $n$까지 나눈 나머지들의 합을 구하는 문제이다.

그러나 $O(n)$ 으로 풀 수 없고 좀 더 빠른 방법으로 풀어야 한다.

풀이

1- 일반성을 잃지 않고 $n < k$ 이다.

그렇지 않다면 $k$를 $k$로 나눈 나머지는 $0$이고 $n>k$ 인 경우엔 $k+1 \sim n$ 으로 $k$ 를 나눈 나머지는 모두 $k$ 그 자체가 되어서 $k \cdot (n - k)$ 만큼 정답에 더해주면 된다.


$k-1$ 라는 수를 보자.

$k$ 가 충분히 크다고 한다면, $k$ 를 $k-1$ 로 나눈 나머지는 $k-(k-1)$ 일 것이다.

또한 $k-2$ 로 나눈 나머지는 $k-(k-2)$ 이다.

즉, 나누는 수를 $u~(u \in \N)$ 라 할 때,

$$ k-u $$

가 나머지가 된다.

그런데 항상 나머지가 $k-u$ 라는 것은 말이 안된다. 그럼 어떤 $u$ 들에 대해 나머지가 $k-u$ 가 될까?

$k$ 로 나눈 수가 $k-u$ 가 되기 위한 정수 $u$ 의 범위는 $\left\lfloor \dfrac k2 \right\rfloor+1 \le u \le k-1$ 이다.

예시로 $k=1,000$ 라면 $[501, 999]$ 이다.

따라서 이 구간에서 $k$를 $u$로 나눈 나머지의 합은

$$ k \cdot(k-1-\left\lfloor \frac k2 \right\rfloor)-\displaystyle \sum_{i=\left\lfloor \frac k2 \right\rfloor+1}^{k-1} i $$

이고 이는 나눗셈 $2$번으로 구해줄 수 있다.

$\because ~~ \displaystyle \sum_{i=a}^b i=\dfrac {(b+a)(b-a+1)}2$

그럼 $\left\lfloor \dfrac k2 \right\rfloor+1$ 보다 작은 수들로 $k$ 를 나누면 어떻게 될까?

그 다음 살펴야 할 구간은 $k$를 $u$로 나눈 나머지가 $k-2u$ 인 구간이다.

이 구간은 $\left\lfloor \dfrac k3 \right\rfloor+1 \le u \le \left\lceil \dfrac k2 \right\rceil-1$ 이다.

이렇게 계속 $k-3u, k-4u, \cdots$ 처럼 진행하다가 $u=1$ 이 되는 시점까지 진행해주면 된다.

이를 빠르게 진행하는 법은 현재 범위의 가장 작은 $u$에서 $u \nmid k$ 인 $u$가 나올 때 까지 줄여보는 것이다. Harmonic Lemma 의 구현과 비슷하다.

결국 대략 나눗셈 연산의 횟수는 구간의 개수가 $C$라 할 때, 구간을 특정하는데 2번, 나머지의 합을 계산하는데 2번만 쓰이면 되어서 $4C$ 정도로 국한된다.

구간의 개수

서로 다른 $\left\lfloor \dfrac ku \right\rfloor$의 개수가 $C$ 에 비례한다.

$u < \sqrt{k}$ 라면 $u$ 의 개수 자체가 $\sqrt{k}$ 개 이하이고, $u \ge \sqrt{k}$ 라면 $\left\lfloor \dfrac ku \right\rfloor$ 의 서로 다른 가지수가 $\sqrt{k}$개의 상한을 갖기 때문에 $C$ 는 아무리 많아봤자 $2\sqrt{k}$ 개이기 때문에,

결론적으로 $4 \cdot 2 \sqrt{k}$ 번의 나눗셈 연산이면 정답을 구할 수 있다.

문제 풀이 구현

void solve() {  
   int n, k;  
   cin >> n >> k;  
   int ans = 0;  
   if (n > k) {  
      ans += k * (n - k);  
      n = k;  
   }  
   vi d;  
   for (int i = 1; i * i <= k; i++) {  
      if (k % i == 0) {  
         d.pb(i);  
         if (k / i != i) {  
            d.pb(k / i);  
         }  
      }  
   }  
   uniq(d);  
   auto sum = [&](int l, int r) {  
      return (r + l) * (r - l + 1) / 2;  
   };  
   auto get_ans = [&](int l, int r, int x) {  
      int hi = ubi(d, r);  
      int lo = lbi(d, l);  
      int ret = (r - l + 1) * k - sum(l, r) * x;  
      for (int i = lo; i < hi; i++) {  
         ret -= k - x * d[i];  
      }  
      return ret;  
   };  
   for (int r = min(k, n), l; r > 0;) {  
      if (k % r == 0) {  
         r--;  
         continue;  
      }  
      int x = k / r;  
      l = k / (x + 1) + 1;  
      if (l <= r)  
         ans += get_ans(l, r, x);  
      r = min(l, r) - 1;  
   }  
   cout << ans;  
}

Comments