BOJ 13246 - 행렬 제곱의 합

image.png

특별한 규칙을 찾아야 하는건줄 알았는데, 그냥 정말 분할정복으로 급수를 구하는 문제이다.

행렬의 곱셈은 교환법칙을 제외한 분배법칙과 결합법칙이 성립한다는 점을 명심한다.

$$ S^B=A^1+A^2+\cdots +A^B $$

라고 할 때,

$$ S^B=\begin{cases} (A^{B/2}+I)\cdot S^{B/2} &\text{if}~2 \mid B \\ (A^{\left\lfloor B/2 \right\rfloor+1}+(A^{\left\lfloor B/2 \right\rfloor+1}+I)\cdot S^{\left\lfloor B/2 \right\rfloor} & \text{if}~2 \nmid B \end{cases} $$

이다.

함정입력이 있으니 주의하자

const ll mod = 1000;  
inline ll md(ll x) { return md(mod, x); }  
  
int n;  
typedef vi row;  
typedef vector<row> mat;  
ostream &operator<<(ostream &os, const mat &mat) {  
   for (const auto &vec: mat) {  
      for (const auto &i: vec) cout << i << " ";  
      cout << endl;  
   }  
   cout << endl;  
   return os;  
}  
mat operator*(const mat &a, const mat &b) {  
   int aRow = a.size(), aColumn = a[0].size();  
   int bRow = b.size(), bColumn = b[0].size();  
  
   if (aColumn != bRow)  
      throw std::length_error("column of a is not equal to row of b");  
  
   mat c(aRow, row(bColumn, 0));  
  
   for (int i = 0; i < aRow; ++i)  
      for (int j = 0; j < bColumn; ++j)  
         for (int k = 0; k < aColumn; k++)  
            c[i][j] = md(c[i][j] + a[i][k] * b[k][j]);  
  
   return c;  
}  
mat operator*(const mat &a, int b) {  
   mat c = a;  
   for (int i = 0; i < sz(a); ++i)  
      for (int j = 0; j < sz(a[0]); ++j)  
         c[i][j] = md(c[i][j] * b);  
   return c;  
}  
mat operator+(const mat &a, const mat &b) {  
   int aRow = a.size(), aColumn = a[0].size();  
   int bRow = b.size(), bColumn = b[0].size();  
   if (aRow != bRow || aColumn != bColumn)  
      throw std::length_error("Length Error");  
   mat c(aRow, row(bColumn, 0));  
   for (int i = 0; i < aRow; ++i)  
      for (int j = 0; j < bColumn; ++j)  
         c[i][j] = md(c[i][j] + a[i][j] + b[i][j]);  
   return c;  
}  
mat operator+(const mat &a, int b) {  
   mat c = a;  
   for (int i = 0; i < sz(a); ++i)  
      for (int j = 0; j < sz(a[0]); ++j)  
         c[i][j] = md(c[i][j] + b);  
   return c;  
}  
mat identity(int n) {  
   mat ret(n, row(n));  
   for (int i = 0; i < n; i++) ret[i][i] = 1;  
   return ret;  
}  
mat pow_mat(mat a, ll n) {  
   mat ret = identity(sz(a));  
   while (n) {  
      if (n & 1) ret = ret * a;  
      a = a * a;  
      n >>= 1;  
   }  
   return ret;  
}  
  
void solve() {  
   int b;  
   cin >> n >> b;  
   mat a(n, row(n));  
   fv2(a);  
   for (int y = 0; y < n; y++)  
      for (int x = 0; x < n; x++) a[y][x] = md(a[y][x]);  
   function<mat(int b)> fn = [&](int b) -> mat {  
      if (b == 1) return a;  
      if (b % 2 == 1) {  
         return (pow_mat(a, b / 2 + 1) + identity(n)) * fn(b / 2) + pow_mat(a, b / 2 + 1);  
      } else {  
         return (pow_mat(a, b / 2) + identity(n)) * fn(b / 2);  
      }  
   };  
   cout << fn(b);  
}

Tags:

Categories:

Updated:

Comments