Dynamic Programming: Matrix Chain Multiplication
Matrix Chain Multiplication
Welcome to the third article on Dynamic Programming. You can refer to the first article (introduction) here. Let us get started.
Introduction
Before we formally define our problem let us first refresh our memory about matrix multiplication. In order to multiply two matrices (A) and (B) the number of columns of matrix (A) must match the number of rows of matrix (B). For example if (A) is a (2 x 3) matrix and (B) is a (3 x 2) matrix then the product (C) of the two matrices is going to be a (2 x 2) matrix. Take a look at the following demonstration.
1 2 3 4 5 6 7 8 9 |
A = | a11 a12 a13 | | a21 a22 a23 | B = | b11 b12 | | b21 b22 | | b31 b32 | C = | (a11*b11 + a12*b21 + a13*b31) (a11*b12 + a12*b22 + a13*b32) | | (a21*b11 + a22*b21 + a23*b31) (a21*b12 + a22*b22 + a23*b32) | |
As you can see in the example above the number of scalar multiplications needed to compute the product (C) is (12). Generally speaking if (A) is an (m x k) matrix and (B) is a (k x n) matrix then the product matrix is an (m x n) matrix and the number of scalar multiplications needed is (m x k x n). In our case the number of scalar multiplications is going to be (2 x 3 x 2 = 12). When multiplying two matrices there is always one way to do that however if we have a chain (3 or more) of matrices the problem becomes more interesting. For example if we need to multiply three matrices (D = A x B x C) then there are different ways to do that. Here are the possible ways:
1. Multiply (A) and (B) then multiply the result with (C) to get (D)
1 |
D = (A x B) x C |
2. Multiply (B) and (C) then multiply the result with (A) to get (D)
1 |
D = A x (B x C) |
Both ways yield the same result but the question is which method is more efficient in terms of computational complexity. The one with less scalar multiplications should be more efficient. Does the number of scalar multiplications depend on the order in which we multiply matrices. The answer to this question is YES. Let us take an example and assume:
1 2 3 4 5 6 |
A: (2 x 3) B: (3 x 2) C: (2 x 3) D1 = (A x B) x C D2 = A x (B x C) |
To calculate the number of scalar multiplications just apply the rule we discussed earlier (m x k x n)
1 2 3 4 5 6 7 8 9 |
Number of scalar multiplications to get (D1) = Number of scalar multiplications to get (Z = A x B = 2 x 2 matrix) + Number of scalar multiplications to get (Z x C) = (2 x 3 x 2) + (2 x 2 x 3) = 24 Number of scalar multiplications to get (D2) = Number of scalar multiplications to get (Z = B x C = 3 x 3 matrix) + Number of scalar multiplications to get (A x Z) = (3 x 2 x 3) + (2 x 3 x 3) = 36 |
As you can see (D1) is more efficient to compute as it only involves (24) scalar multiplications compared to (36) in the case of (D2). To conclude the number of scalar multiplications involved in computing the product of a chain of matrices depends on how we choose the parenthesis.
Problem definition
We are now ready to define our problem. Given a chain of matrices:
1 |
(A1 x A2 x A3 ... Ai-1 x Ai x A(i+1) ... A(n-1) x An) |
Any matrix in the chain (call it Ai) is a (p(i-1) x pi) matrix which means number of rows = p(i-1) and number of columns = (pi). Note that the (i) and (i-1) here are just indexes to denote the position of the matrix within the chain. For example p(i-1) does NOT mean (p) multiplied by (i-1) so please keep that in mind. Any two matrices in a row (for example A(i-1) and Ai) must satisfy the following condition: the number of columns (p(i-1)) of the first matrix (A(i-1)) must match the number of rows of the second matrix (Ai). We did not define our problem yet due to the long introduction so let us do that now. Formulate an algorithm to parenthesize the matrix multiplication chain so that the total number of scalar multiplication is minimal.
Solution
We can use brute force to get all possible ways to parenthesize the matrix chain then choose the one with minimum number of scalar multiplications. This solution is suicidal because the running time is exponential. If you do not believe me check this out
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
n = 2 there is only 1 way to parenthesize the chain (A1 x A2) n = 3 there are 2 ways to parenthesize the chain ((A1 x A2) x A3) (A1 x (A2 x A3)) n = 3 there are 5 ways to parenthesize the chain ((A1 x A2) x (A3 x A4)) (((A1 x A2) x A3) x A4) ((A1 x (A2 x A3)) x A4) (A1 x ((A2 x A3) x A4)) (A1 x (A2 x (A3 x A4))) |
The general formula to get the number of ways in terms of (n) is as follows. This sequence grows exponentially.
1 2 |
P(1) = 1 P(n) = Sum (P(k)P(n-k)) and (k) goes from (1) to (n-1) |
So we are convinced now that brute force is not a good approach. Since we are talking about Dynamic Programming we will demonstrate that DP is in fact a good approach for the matrix chain multiplication problem. Again we will follow the steps discussed in the first Dynamic Programming article.
Characterize the structure of an optimal solution
Ask the question “what are we intending to optimize?” The answer is number of scalar multiplications. In other words we want to minimize the number of scalar multiplications needed to compute the matrix chain product. In next step we will define this number in terms of smaller size sub problems using recursion.
Recursively define the value of an optimal solution
Here is below a schematic diagram for the matrix chain that we have.
Our matrix chain goes from matrix (1) and ends at matrix (n). A sub problem is another chain that goes from matrix (i) and end at matrix (j) where (i) goes from (1) to (n-1) and (j) goes from (i+1) to (n). Do not be confused about these indexes. We are just saying it is a sub chain that can be anywhere between (1) and (n) that is the whole meaning in plain English.
Let the minimum number of scalar multiplications needed to calculate the sub chain of matrices between (i) and (j) be M[i, j]. As you can see M[i, j] is a two dimensional matrix (or table) where (i) goes from (1) to (n-1) and (j) goes from (i+1) to (n). It is obvious that the minimum number of scalar multiplications need to calculate the whole chain (NOT the sub chain) is M[1, n].
Note that whenever (i == j) the sub chain is a single matrix which means the number of scalar multiplications is ZERO so
1 |
M[i, j] = 0 if i = j |
if we split the sub chain (i to j) at some position (k) as demonstrated in the schematic diagram above then the number of scalar multiplications needed to compute the sub chain product (i to j) is equal to the number of scalar multiplications needed to compute the sub chain product between (i and k) which is M[i, k] plus the number of scalar multiplications needed to compute the sub chain product between (k + 1 and j) which is M[k + 1, j] plus the number of scalar multiplications needed to compute the product of the two parts of the split (i to k and k + 1 to j) which is P(i-1)P(k)P(j). We can write this formally as
1 |
M[i, j] = M[i, k] + M[k + 1, j] + P(i-1)P(k)P(j) |
Let me quickly comment on the term P(i-1)P(k)P(j) as it might be confusing how we got that. If you go back to the introduction section we mentioned that if you want to compute the number of scalar multiplications needed to compute the product of two matrices then you multiply (m x k x n) where (m) is the number of rows of the first matrix, (k) is the number of columns of the first matrix and (n) is the number of columns of the second matrix. This is exactly what we did here where P(i-1) is the number of rows of the left side matrix chain split, P(k) is the number of columns of the left matrix chain split and P(j) is the number of columns of the right side matrix chain split (recall that the split was done at some arbitrary position (k)).
We need to find the position (k) between (i) and (j) at which we split the sub chain such that the number of scalar multiplications is minimized. In other words only one (k) value is chosen. We can formulate that as
1 |
M[i, j] = Min (M[i, k] + M[k + 1, j] + P(i-1)P(k)P(j)) where i <= k < j |
Let us write the complete recursive formula for the optimal solution
1 2 |
M[i, j] = 0 if i = j M[i, j] = Min (M[i, k] + M[k + 1, j] + P(i-1)P(k)P(j)) where i <= k < j |
Do not forget that M[i, j] is not only defined recursively in terms of smaller size sub problems but also these sub problems are overlapping otherwise Dynamic Programming wont be of a great value. The sub problems M[i, k] and M[k + 1, j] do overlap. Again the same concept we learned before is applied we only calculate new values but already computed values are just looked up in the table M[i, j].
We are ready now to populate the table M[i, j] and if we substitute (i = 1) and (j = n) we will be able to compute the minimum number of scalar multiplications needed to compute the matrix chain product however that is not what we eventually need. We need to know exactly where to put the parenthesis so that we achieve such an optimal solution. Again this can be taken care of by using back pointers (please refer to the previous articles for more information). Let us use another table called S[i, j] in which we save the values (k) that gives us the minimum number of scalar multiplications needed to compute the sub matrix chain product between (i) and (j). Populating M[i, j] and S[i, j] is done in the next step.
Compute the value of an optimal solution in a bottom up fashion
The following code section demonstrates building M[i, j] in a bottom up fashion. It directly follows the explanation and formula definitions provided earlier so the comments below are more than enough.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
//Whenever i == j put M = 0 (Explanation provided above) for (int i = 1; i < n; i++) { M[i, i] = 0; } // 1. Possible values for the left index (i) of the sub // matrix chain go from (1) to (n - 1) // 2. Possible values for the right index (j) of the sub // matrix chain go from the next position after (i) // to the last position (n) // 3. Sub matrix chain between (i) and (j) is split // at position (k) which goes from (i) to (j-1) for (int i = 1; i < n; i++) for (int j = i + 1; j <= n; j++) for (int k = i; k < j; k++) { //Choose the value of (k) that gives the minimum number //of scalar multiplications so we keep updating the //number until we get the smallest one if (M[i, k] + M[k + 1, j] + P(i-1)P(k)P(j) < M[i, j]) { //Once the (k) loop is finished then M[i, j] //holds the minimum number of scalar multiplications M[i, j] = M[i, k] + M[k + 1, j] + P(i-1)P(k)P(j); //Save the (k) value which is a mark that points to //the position where a parenthesis is going to be placed S[i, j] = k; } } |
Construct an optimal solution from the computed information
The table S[i, j] stores the needed back pointers to construct the final solution (refer to the introduction article for more information about back pointers). For each sub matrix chain product between (i) and (j) there is a stored value (k) in S[i, j]. We use this position to insert a parenthesis. The code below uses a recursive approach to print the overall matrix chain product parenthesized such that minimum number of scalar multiplications is involved. The function receives the back pointers table (S) and two indexes (i) and (j). In order to print the overall solution you need to call the function with (i = 1) and (j = n).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
PrintParenthesis(S, i, j) { //When i == j this means a sub chain of only one //matrix so no need to print parenthesis instead //you just print the matrix name if (i == j) { Print Ai; } else { //Otherwise get the position at which the sub chain is split //then enclose the two smaller sub chains with parenthesis int k = S[i, j]; Print "("; //Recursively apply the function on the left //side of the split around (k) PrintParenthesis(S, i, k); //Recursively apply the function on the right //side of the split around (k) PrintParenthesis(S, k + 1, j); Print ")"; } } |
If you take a look at the code that populates the tables. It is not hard to tell that the code runs in polynomial time. We have three nested loop and the worst we can get is O(n^3) but it is actually better than that because the inner loops do not run exactly (n) times. The recursive code used to construct the solution also runs in polynomial time. This proves that Dynamic Programming is a powerful technique if the problem at hand can be solved using DP.
We are done with part 3. In the next part I will explain the third Dynamic Programming problem “Longest Common Subsequence” discussed in “Introduction to Algorithms” by Thomas H. Cormen. Any comments or feedback is appreciated. Thanks for reading.
Never seen such an easy and fantastic explanation for dynamic programming and its examples. Thanks alot. Keep up your good work.
Thank you very much. I am glad you liked it. in my personal opinion abstraction is always hard to understand and I am a believer in examples because they are easier to imagine and digest.
Wonderful. Very simple to understand. I feel i lucky i bumped into this site. Thanks. Keep up the good work.
awesome stuff
For the sequence of matrices, given below, compute the order of the product, A1.A2.A3.A4, in such a way that minimizes the total number of scalar multiplications, using Dynamic Programming.
Order of A1 = 40 x 50
Order of A2 = 50 x 100
Order of A3 = 100 x 20
Order of A4 = 20 x 80
plz some 1 help me how to solve it
if u got the answer than sand me also plz
For the sequence of matrices, given below, compute the order of the product, A1.A2.A3.A4, in such a way that minimizes the total number of scalar multiplications, using Dynamic Programming.
Order of A1 = 40 x 50
Order of A2 = 50 x 100
Order of A3 = 100 x 20
Order of A4 = 20 x 80
plz some 1 help me how to solve it
This code does not work for me.. correct me what is wrong..
package com.test.dp;
public class MatrixChainMulitplication {
//p is an array of matrices..
public static void matrixChainOrder(int[] p ,int n) {
int m[][] = new int[n][n];
int sol[][] = new int[n][n];
//if i == j m[i][j] = 0;
for(int i = 0; i