BOJ 25821 - Palindromic Primes

image.png

팰린드롬의 개수는 단순히 $O(\sqrt{N})$ 이므로 브루트 포스를 하며 밀러라빈으로 찾아주면 된다.

struct _miller {
   bool primei(ll n) {
      if (n <= 1) return 0;
      for (auto &a: {2, 7, 61}) {
         if (n == a) return 1;
         if (!primable(n, a)) return 0;
      }
      return 1;
   }
   bool primell(ll n) {
      for (auto &a: {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
         if (n == a) return 1;
         if (n > 40 && !primable(n, a)) return 0;
      }
      if (n <= 40) return 0;
      return 1;
   }
private:
   inline ll poww(ll a, ll b, ll mod) {
      __int128_t ret = 1;
      while (b) {
         if (b & 1) ret = mul(ret, a, mod);
         a = mul(a, a, mod);
         b >>= 1;
      }
      return ret;
   }
   bool primable(ll n, ll a) {
      if (!(n % a))return 0;
      ll d = (n - 1) >> 1;
      while (1) {
         ll tmp = poww(a, d, n);
         if (tmp == n - 1) return 1;
         if (d & 1) return (tmp == 1 || tmp == n - 1);
         d >>= 1;
      }
   }
protected:
   inline ll mul(ll x, ll y, ll mod) { return (__int128_t) x * y % mod; }
} miller;

int len = 1;

string num;
ll find(ll n, int i) {
   if (i == len) {
      ll ret = stoll(num);
      if (ret <= n && miller.primell(stoll(num))) return 1;
      return 0;
   }

   int is_mid = (len & 1) && len / 2 == i;
   int is_right = !is_mid && i >= len / 2;
   int left_digit = !is_right ? -1 : num[len - 1 - i] - '0';

   ll ret = 0;
   if (is_mid) {
      for (int d = i == 0 ? 1 : 0; d <= 9; d++) {
         num += char(d + '0');
         ret += find(n, i + 1);
         num.pop_back();
      }
   } else if (is_right) {
      num += char(left_digit + '0');
      ret += find(n, i + 1);
      num.pop_back();
   } else {
      for (int d = d = i == 0 ? 1 : 0; d <= 9; d++) {
         num += char(d + '0');
         ret += find(n, i + 1);
         num.pop_back();
      }
   }
   return ret;
}

ll find_all(ll n) {
   int l = sz(to_string(n));
   ll ret = 0;
   for (len = 1; len <= min(l, 11); len++) ret += find(n, 0);
   return ret;
}

int tot[] = {};

void solve() {
   ll L, H;
   cin >> L >> H;

   ll m = 10;
   for (int i = 0; i <= 12; i++) {
      debug(m, find_all(m * 10 - 1) - find_all(m - 1));
      m *= 10;
   }

   ll h = find_all(H);
   ll l = find_all(L - 1);
   cout << h - l;
}

Tags:

Categories:

Updated:

Comments