package com.freetymekiyan.algorithms.level.medium;
import java.util.ArrayList;
import java.util.List;
/**
* Given two sparse matrices A and B, return the result of AB.
* <p>
* You may assume that A's column number is equal to B's row number.
* <p>
* Example:
* <p>
* | A = [
* | [ 1, 0, 0],
* | [-1, 0, 3]
* | ]
* <p>
* | B = [
* | [ 7, 0, 0 ],
* | [ 0, 0, 0 ],
* | [ 0, 0, 1 ]
* | ]
* <p>
* | | 1 0 0 | | 7 0 0 | | 7 0 0 |
* | AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 |
* | | 0 0 1 |
* Company Tags: LinkedIn, Facebook
* Tags: Hash Table
*/
public class SparseMatrixMultiplication {
/**
* Matrix multiplication is row in A multiply column in B.
* When implement, for A[i][j], multiply a row B[j][k], 0 <= k < nB.
* Then add result to res[i][k].
* Skip zeros since the matrix is sparse.
* <p>
* Loop through A from left to right, row by row.
* In each row, multiply the column value A[i][j] with each value in row j in B, B[j][k].
* Add it to res[i][k].
*/
public int[][] multiply(int[][] A, int[][] B) {
int mA = A.length, nA = A[0].length;
int nB = B[0].length;
int[][] res = new int[mA][nB];
for (int i = 0; i < mA; i++) {
for (int j = 0; j < nA; j++) {
if (A[i][j] == 0) {
continue; // Skip zeros in A.
}
for (int k = 0; k < nB; k++) {
if (B[j][k] == 0) {
continue; // Skip zeros in B.
}
res[i][k] += A[i][j] * B[j][k];
}
}
}
return res;
}
/**
* A sparse matrix can be represented as a sequence of rows.
* Each row is a sequence of (column-number, value) pairs of the nonzero values in the row.
* Create result matrix.
* Create a List array for the rows of matrix A.
* | For each value in A:
* | Create a list.
* | If A[i][j] != 0:
* | Add j and A[i][j] to the list.
* | Set this list in the array.
* For each list in the array:
* | Get the list first.
* | For each pair, get column and value.
* | Get value in B.
* | Multiply and update result.
* Return result.
*/
public int[][] multiplyB(int[][] A, int[][] B) {
int m = A.length, n = A[0].length, nB = B[0].length;
int[][] result = new int[m][nB];
// Build list of rows for A.
List[] indexA = new List[m];
for (int i = 0; i < m; i++) {
List<Integer> numsA = new ArrayList<>();
for (int j = 0; j < n; j++) {
if (A[i][j] != 0) {
numsA.add(j); // Add column.
numsA.add(A[i][j]); // Add actual value.
}
}
indexA[i] = numsA;
}
for (int i = 0; i < m; i++) {
List<Integer> numsA = indexA[i];
for (int p = 0; p < numsA.size() - 1; p += 2) {
int colA = numsA.get(p); // Get column.
int valA = numsA.get(p + 1); // Get actual value after.
for (int j = 0; j < nB; j++) {
int valB = B[colA][j];
result[i][j] += valA * valB;
}
}
}
return result;
}
}