Tuesday 22 October 2013

Matrix exponentiation

The Tribonacci Numbers:

F(n) = F(n-1) + F(n-2) + F(n-3),  F(1) = 1; F(2) = 1; F(3) = 2
 

* for N as large as 10^15, using dp will always be very slow, regardless of the time limit. 

*The basic idea behind matrix exponentiation,  is to use the base cases of the recurrence relationship in order to assemble a matrix which will allow us to compute values fast.

On our case we have:
F(1) = 1
F(2) = 1
F(3) = 2
And we now have a relationship that will go like this:
|f(4)|    = MATRIX * |f(1)|
|f(3)|               |f(2)|
|f(2)|               |f(3)|
Now all that is left is to assemble the matrix... and that is done based both on the rules of matrix multiplication and on the recursive relationship... Now, on our example we see that to obtain f(4), the 1st line of the matrix needs to be composed only of ones, as f(4) = 1 f(3) + 1 f(2) + 1* f(1).
Now, denoting the unknown elements as *, we have:
|f(4)|   = |1 1 1| * |f(1)|
|f(3)|     |* * *|   |f(2)|
|f(2)|     |* * *|   |f(3)|
For the second line, we want to obtain f(3), and the only possible way of doing it is by having:
  |f(4)|   = |1 1 1| * |f(1)|
  |f(3)|     |0 0 1|   |f(2)|
  |f(2)|     |* * *|   |f(3)|
To get the value of f(2), we can follow the same logic and get the final matrix:
|1 1 1|
|0 0 1|
|0 1 0|
To end it, we now need to generalize it, and, as we have 3 base cases, all we need to do to compute the Nth tribonacci number in O(logN) time, is to raise the matrix to the power N -3 to get:

|f(n)|   =   |1 1 1|^(N-3) * |f(1)|
|f(n-1)|     |0 0 1|         |f(2)|
|f(n-2)|     |0 1 0|         |f(3)|

The power of the matrix can now be computed in O(logN) time using repeated squaring applied to matrices and the method is complete... Below is the Python code that does this, computing the number modulo 10000000007.


def matrix_mult(A, B):
  C = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
  for i in range(3):
    for j in range(3):
      for k in range(3):
        C[i][k] = (C[i][k] + A[i][j] * B[j][k]) % 1000000007
  return C

def fast_exponentiation(A, n):
  if n == 1:
    return A
  else:
    if n % 2 == 0:
      A1 = fast_exponentiation(A, n/2)
      return matrix_mult(A1, A1)
    else:
      return matrix_mult(A, fast_exponentiation(A, n - 1))

def solve(n):
    A = [[1,1,1],[0,0,1],[0,1,0]]
    A_n = fast_exponentiation(A,n-3)
    return A_n[0][0] + A_n[0][1] + A_n[0][2]*2




OTHER LINKS:
http://discuss.codechef.com/questions/2263/crowd-editorial