BOJ 1215 - 잘못 작성한 요세푸스 코드

image.png

문제 설명Permalink

코드를 읽어보면, 이 문제는 1n,k1091 \le n,k \le 10^9 인 제한에서 kk11부터 nn까지 나눈 나머지들의 합을 구하는 문제이다.

그러나 O(n)O(n) 으로 풀 수 없고 좀 더 빠른 방법으로 풀어야 한다.

풀이Permalink

1- 일반성을 잃지 않고 n<kn < k 이다.

그렇지 않다면 kkkk로 나눈 나머지는 00이고 n>kn>k 인 경우엔 k+1nk+1 \sim n 으로 kk 를 나눈 나머지는 모두 kk 그 자체가 되어서 k(nk)k \cdot (n - k) 만큼 정답에 더해주면 된다.


k1k-1 라는 수를 보자.

kk 가 충분히 크다고 한다면, kkk1k-1 로 나눈 나머지는 k(k1)k-(k-1) 일 것이다.

또한 k2k-2 로 나눈 나머지는 k(k2)k-(k-2) 이다.

즉, 나누는 수를 u (uN)u~(u \in \N) 라 할 때,

ku k-u

가 나머지가 된다.

그런데 항상 나머지가 kuk-u 라는 것은 말이 안된다. 그럼 어떤 uu 들에 대해 나머지가 kuk-u 가 될까?

kk 로 나눈 수가 kuk-u 가 되기 위한 정수 uu 의 범위는 k2+1uk1\left\lfloor \dfrac k2 \right\rfloor+1 \le u \le k-1 이다.

예시로 k=1,000k=1,000 라면 [501,999][501, 999] 이다.

따라서 이 구간에서 kkuu로 나눈 나머지의 합은

k(k1k2)i=k2+1k1i k \cdot(k-1-\left\lfloor \frac k2 \right\rfloor)-\displaystyle \sum_{i=\left\lfloor \frac k2 \right\rfloor+1}^{k-1} i

이고 이는 나눗셈 22번으로 구해줄 수 있다.

  i=abi=(b+a)(ba+1)2\because ~~ \displaystyle \sum_{i=a}^b i=\dfrac {(b+a)(b-a+1)}2

그럼 k2+1\left\lfloor \dfrac k2 \right\rfloor+1 보다 작은 수들로 kk 를 나누면 어떻게 될까?

그 다음 살펴야 할 구간은 kkuu로 나눈 나머지가 k2uk-2u 인 구간이다.

이 구간은 k3+1uk21\left\lfloor \dfrac k3 \right\rfloor+1 \le u \le \left\lceil \dfrac k2 \right\rceil-1 이다.

이렇게 계속 k3u,k4u,k-3u, k-4u, \cdots 처럼 진행하다가 u=1u=1 이 되는 시점까지 진행해주면 된다.

이를 빠르게 진행하는 법은 현재 범위의 가장 작은 uu에서 uku \nmid kuu가 나올 때 까지 줄여보는 것이다. Harmonic Lemma 의 구현과 비슷하다.

결국 대략 나눗셈 연산의 횟수는 구간의 개수가 CC라 할 때, 구간을 특정하는데 2번, 나머지의 합을 계산하는데 2번만 쓰이면 되어서 4C4C 정도로 국한된다.

구간의 개수Permalink

서로 다른 ku\left\lfloor \dfrac ku \right\rfloor의 개수가 CC 에 비례한다.

u<ku < \sqrt{k} 라면 uu 의 개수 자체가 k\sqrt{k} 개 이하이고, uku \ge \sqrt{k} 라면 ku\left\lfloor \dfrac ku \right\rfloor 의 서로 다른 가지수가 k\sqrt{k}개의 상한을 갖기 때문에 CC 는 아무리 많아봤자 2k2\sqrt{k} 개이기 때문에,

결론적으로 42k4 \cdot 2 \sqrt{k} 번의 나눗셈 연산이면 정답을 구할 수 있다.

문제 풀이 구현Permalink

void solve() {  
   int n, k;  
   cin >> n >> k;  
   int ans = 0;  
   if (n > k) {  
      ans += k * (n - k);  
      n = k;  
   }  
   vi d;  
   for (int i = 1; i * i <= k; i++) {  
      if (k % i == 0) {  
         d.pb(i);  
         if (k / i != i) {  
            d.pb(k / i);  
         }  
      }  
   }  
   uniq(d);  
   auto sum = [&](int l, int r) {  
      return (r + l) * (r - l + 1) / 2;  
   };  
   auto get_ans = [&](int l, int r, int x) {  
      int hi = ubi(d, r);  
      int lo = lbi(d, l);  
      int ret = (r - l + 1) * k - sum(l, r) * x;  
      for (int i = lo; i < hi; i++) {  
         ret -= k - x * d[i];  
      }  
      return ret;  
   };  
   for (int r = min(k, n), l; r > 0;) {  
      if (k % r == 0) {  
         r--;  
         continue;  
      }  
      int x = k / r;  
      l = k / (x + 1) + 1;  
      if (l <= r)  
         ans += get_ans(l, r, x);  
      r = min(l, r) - 1;  
   }  
   cout << ans;  
}

Comments