DEV Community

Nayan Pahuja
Nayan Pahuja

Posted on • Edited on

DAY 26 - Matrix Chain Multiplication

Hey Guys! Nayan here and this is the 25th day of the 100-DAY challenge. If you are reading this for the first time , I learn and solve some questions throughout the day and post them here.

In the coming articles, we will discuss problems related to a new pattern called “Partition DP”. Well let's get started with the problem.

DAY 26 Problem Matrix Chain Multiplication:

Given a sequence of matrices, find the most efficient way to multiply these matrices together. The efficient way is the one that involves the least number of multiplications.

The dimensions of the matrices are given in an array arr[] of size N (such that N = number of matrices + 1) where the ith matrix has the dimensions (arr[i-1] x arr[i]).

Example:

Input: N = 5
arr = {40, 20, 30, 10, 30}
Output: 26000
Explanation: There are 4 matrices of dimension 40x20, 20x30, 30x10, 10x30. Say the matrices are named as A, B, C, D. Out of all possible combinations, the most efficient way is (A*(B*C))*D. The number of operations are - 20*30*10 + 40*20*10 + 40*10*30 = 26000.

Intuition:
This is a new kind of DP problem from what we have been doing before. Let's start by recognizing the pattern, our approach to this pattern when we encounter these questions and how to optimize the solution further.

Let's start by noting down the key prerequisites of this question first.
We need to know what is normal matrix multiplication:
Say there are two matrices A and B. A has the size of N x M and B has the size of n x m(both are different).
If we are to multiply two matrices, it requires it such that the column of first matrix(M in this case) and row of the second matrix(n in this case) must be equal.
Then the resultant matrix that we get is of size N x m(Row of first matrix * Column of second matrix).

Now back to the question:
Suppose we have got 4 matrices named A,B,C and D.
What are the ways we can multiply these 4 matrices together provided that each matrix can be multipied only to its previous.
Hence AB is possible, but AC is not possible.
BC is possible, CD is possible but BD is not possible.

One case could be (A) * (BCD) or (AB) * (CD) or (ABC)*(D). Hence we are going to do exactly that. We are going to divide our matrices into parts of multiplication and solve for each of them.
It's clearly observable that we are going to encounter multiple sub problems repeating here and hence the reason we optimize it using DP.

Approach:

  • Start with the entire block/array and mark it with i,j. We need to find the value of f(i,j).
  • Try all cases
  • Run the best possible answer of the all the cases we can make.
  • As we can see 3 partitions are possible, to try all three partitions, we can write a for loop starting from i and ending at j-1, (1 to 3). The two partitions will be f(i,k) and f(k+1,j).
  • Establish the base case.
  • If our i == j(means matrix needs to be multiplied by nothing). We return 0 as no operation is needed.
  • Establish the recursive relation: Run a loop of int K from i to j-1.
  • The size at each step would be arr[i-1]* arr[k] * arr[j] + solve(i,k) + solve(k+1,j).
  • Apply memoization to it to optimize it. Return the minimum of all steps.

Example

Code:

#include <bits/stdc++.h> 
int solve(int i, int j, vector<int> &arr,vector<vector<int>> &dp){
    if(i == j){
        return 0;
    }
    if(dp[i][j] != -1){
        return dp[i][j];
    }
    int mini = INT_MAX;
    for(int k = i; k < j; k++){
        int steps  = arr[i-1] * arr[k] * arr[j] + solve(i,k,arr,dp) + solve(k+1,j,arr,dp);
        if(steps < mini){
            mini = steps;
        }

    }
    return dp[i][j] = mini;
}
int matrixMultiplication(vector<int> &arr, int N)
{
    vector<vector<int>> dp(N,vector<int>(N,-1));
    return solve(1,N-1,arr,dp);
}
Enter fullscreen mode Exit fullscreen mode

Complexity Analysis:
Time: O(N^3)
Space: O(N^2) + O(N). Extra recursion stack space.

Tabulation Code:

int matrixMultiplication(vector<int> &arr, int N)
{
    vector<vector<int>> dp(N,vector<int>(N,0));
    //tabulation

    //base case(not needed if initialized as 0)
    // for(int i = 0; i < N; i++){
    //     dp[i][i] = 0;
    // }

    for(int i = N-1; i >= 1; i--){
        for(int j = i+1; j < N; j++){
            int mini = 1e9;
            for(int k = i; k < j; k++){
        int steps  = arr[i-1] * arr[k] * arr[j] + dp[i][k] + dp[k+1][j];
        if(steps < mini){
            mini = steps;
        }

    }
    dp[i][j] = mini;

        }
    }
    return dp[1][N-1];
}
Enter fullscreen mode Exit fullscreen mode

Complexity Analysis:
Time: O(N^3)
Space: O(N^2).

Thanks for reading. Feedback is highly appreciated.

Top comments (0)