Problem Statement

You are given a row-wise sorted 2-D matrix of size n*m, where each row is sorted in non-decreasing order. The task is to find the median of the matrix.

Constraints:

  1. The number of elements in the matrix (n*m) is odd. This guarantees a unique median.
  2. The matrix is not necessarily column-wise sorted.

Median Definition:

  • If the matrix is flattened into a sorted 1D array, the median is the middle element.
  • Median is the ((n * m) / 2) th element in the sorted order of all matrix element.

Examples

Example 1:

Input: matrix = [
                 [1, 3, 5],
                 [2, 6, 9],
                 [3, 6, 8]
                ]
Output: 5

Explanation:
Flattened sorted matrix: [1, 2, 3, 3, 5, 6, 6, 8, 9]

Median => Flattened sorted matrix[n/2] = 9/2 = 4th index
       = 5
Example 2:

Input: matrix = [
                 [1, 10, 20],
                 [2, 15, 30],
                 [5, 25, 35]
                ]
Output: 15

Explanation:
Flattened sorted matrix: [1, 2, 5, 10, 15, 20, 25, 30, 35]
Median: matrix[9/2] = matrix[4] = 15

Different Approaches

1️⃣ Brute Force Approach

Intuition:

The extremely naive approach is to use a linear array or list to store the elements of the given matrix and sort them in increasing manner. The middle element will be the median of matrix.

Approach:

  1. Traverse the rows of the matrix using a for loop.
  2. Next, traverse each column of the current row using a nested for loop.
  3. Inside the loops, store each element of the current cell to a linear array or list.
  4. Finally, return the middle element of that linear array.

Code:

#include <bits/stdc++.h>
using namespace std;

class Solution {
public:
    //Function to find the median of the matrix.
    int findMedian(vector<vector<int>>& matrix) {
        vector<int> lst;
        int n = matrix.size();
        int m = matrix[0].size(); 
        
        /* Traverse the matrix and 
        copy the elements to list*/
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                lst.push_back(matrix[i][j]);
            }
        }
        
        // Sort the list
        sort(lst.begin(), lst.end());
        
        // Return the median element
        return lst[(n * m) / 2];
    }
};

int main() {
    vector<vector<int>> matrix = {
        {1, 2, 3, 4, 5},
        {8, 9, 11, 12, 13},
        {21, 23, 25, 27, 29}
    };
    
    //Create an instance of Solution class
    Solution sol;
    
    int ans = sol.findMedian(matrix);
    
    //Print the answer
    cout << "The median element is: " << ans << endl;
    
    return 0;
}

Complexity Analysis:

  • Time Complexity: O(M*N) + O(M*N(log(M*N))), where N is the number of rows in the matrix, M is the number of columns in each row. At first, the matrix is traversed to copy the elements, which takes O(MxN) time complexity. Then, the linear array of size (MxN) is sorted, which takes O(MxN log(MxN)) time complexity.
  • Space Complexity: O(M*N), As a temporary list is used to store the elements of the matrix.

2️⃣ Binary Search

Intuition:

Here, the idea is to use the binary search algorithm to find the median of the matrix. To apply binary search, the search range should be sorted. In this case, the search range is [min, max], where min is the minimum element of the matrix and max is the maximum element of the matrix.

How to check if the current element(mid) is the median:

If current element is the median, the number of elements smaller or equal to current element will surely be greater than (M*N) // 2 (integer division).

How to check how many numbers are smaller or equal to an element ‘mid’:
  • One way is to traverse the entire matrix and count the numbers, but in that case, the time complexity would be high. Therefore, other approaches need to be considered. It is noted that the matrix is sorted row-wise, so the concept of upper bound can be applied.
  • For every particular row, the upper bound of the current element 'mid' will be found. The returned index will indicate the number of smaller or equal elements in that row. This process will be repeated for each row, and the results will be aggregated to obtain the total number for the entire matrix.

Approach: 

Working of findMedian(matrix):
  1. Calculate the minimum and maximum values of the matrix. Since the given matrix is sorted row-wise, the minimum element will be found in the first column, and the maximum element will be found in the last column.
  2. The pointers low and high are placed as follows: Initially, the pointer low points to the minimum element of the matrix, and the pointer high points to the maximum element of the matrix.
  3. Initialize a while loop which will run till low is less than or equal to high. Calculate mid as (low+high) // 2 ( ‘//’ refers to integer division).
  4. The calculateSmallEqual() function is used to obtain the number of elements less than or equal to mid. Within the function, the upper bound formula mentioned above is applied to each row to calculate the total number of elements less than or equal to mid. The total number is then returned from the calculateSmallEqual() function.
  5. If smallEqual is less than or equal to (M*N) // 2, it can be concluded that the median must be a larger number. Therefore, the left half (smaller values) is eliminated by setting low = mid + 1.
  6. If smallEqual is greater than (M*N) // 2, it can be concluded that the element mid might be the median. However, to determine the actual median, we need to move towards smaller numbers. Therefore, in this case, the right half (larger values) is eliminated by setting high = mid - 1.
Working of upperBound(arr, x,m):
  1. This helper functions finds out the upper bound of x in the the array.
  2. Start with low = 0 and high = m - 1, where m is the size of the array arr. Initialize ans to m, which serves as the default answer assuming all elements might be less than or equal to x.
  3. Use a while loop to perform binary search until low exceeds high.Calculate the 'mid' index as (low + high) / 2.
  4. If arr[mid] is greater than x. This means it's a potential candidate for being the upper bound (first element greater than x). Update ans to mid, indicating that we found a potential upper bound. Then adjust high to (mid - 1) to search for a potentially smaller upper bound on the left side of mid. Otherwise, move low to (mid + 1) to search for the upper bound on the right side of mid
  5. Once the while loop terminates, return the 'ans' variable.
Working of countSmallEqual(matrix, n, m, x):
  1. Start with cnt = 0, which will store the count of elements in the matrix that are less than or equal to x.
  2. Iterate through each row of the matrix and call upperBound(),
  3. The upperBound function returns the index of the first element in the current row that is greater than x. Add the returned index to cnt. This addition effectively counts all elements in the current row matrix[i] that are less than or equal to x.
  4. After iterating through all rows and accumulating counts, return 'cnt', which represents the total number of elements in the entire matrix matrix that are less than or equal to x.

Code:

#include<bits/stdc++.h>
using namespace std;

class Solution {
private:
    //Function to find the upper bound of an element
    int upperBound(vector<int>& arr, int x, int m) {
        int low = 0, high = m - 1;
        int ans = m;
        
        //Apply binary search
        while (low <= high) {
            int mid = (low + high) / 2;
            
            /* If arr[mid] > x, it 
            can be a possible answer*/
            if (arr[mid] > x) {
                ans = mid;
                //Look for smaller index on the left
                high = mid - 1;
            }
            else {
                low = mid + 1; 
            }
        }
        //Return the answer
        return ans;
    }
    /*Function to find the count of 
    element smaller than x or equal to x*/
    int countSmallEqual(vector<vector<int>>& matrix, int n, int m, int x) {
        int cnt = 0;
        for (int i = 0; i < n; i++) {
            cnt += upperBound(matrix[i], x, m);
        }
        //Returnt the count
        return cnt;
    }
public:
    //Function to find the median in a matrix
    int findMedian(vector<vector<int>>& matrix) {
        int n = matrix.size();
        int m = matrix[0].size();
        
        //Intialize low and high
        int low = INT_MAX, high = INT_MIN;

        //Point low and high to right elements
        for (int i = 0; i < n; i++) {
            low = min(low, matrix[i][0]);
            high = max(high, matrix[i][m - 1]);
        }

        int req = (n * m) / 2;
        
        //Perform binary search
        while (low <= high) {
            int mid = low + (high - low) / 2;
            
            /*Store the count of elements
            lesser than or equal to mid*/
            int smallEqual = countSmallEqual(matrix, n, m, mid);
            if (smallEqual <= req) low = mid + 1;
            else high = mid - 1;
        }
        //Return low as answer
        return low;
    }
};
int main() {
    vector<vector<int>> matrix = {{1, 2, 3, 4, 5},
                                  {8, 9, 11, 12, 13},
                                  {21, 23, 25, 27, 29}};
    
    //Create an instance of Solution class
    Solution sol;
    
    int ans = sol.findMedian(matrix);
    
    //Print the output
    cout << "The median element is: " << ans << endl;
    return 0;
}

Complexity Analysis:

  • Time Complexity: O(log(max)) * O(M(logN)), where N is the number of rows in the matrix, M is the number of columns in each row. Our search space ranges from [min, max], where min is the minimum and max is the maximum element of the matrix. Binary search is applied within this search space, which operates with a time complexity of O(log(max)). Then, the countSmallEqual() function is called for each 'mid', which has a time complexity of O(M*log(N)).
  • Space Complexity: As no additional space is used, so the Space Complexity is O(1).