package com.jwetherell.algorithms.data_structures; import java.math.BigDecimal; import java.math.BigInteger; import java.util.Comparator; /** * Matrx. This Matrix implementation is designed to be more efficient * in cache. A matrix is a rectangular array of numbers, symbols, or expressions. * * http://en.wikipedia.org/wiki/Matrix_(mathematics) * * @author Justin Wetherell <phishman3579@gmail.com> */ @SuppressWarnings("unchecked") public class Matrix<T extends Number> { private int rows = 0; private int cols = 0; private T[] matrix = null; private final Comparator<T> comparator = new Comparator<T>() { /** * {@inheritDoc} */ @Override public int compare(T o1, T o2) { /* TODO: What if Java adds new numeric type? */ int result = 0; if (o1 instanceof BigDecimal || o2 instanceof BigDecimal) { BigDecimal c1 = (BigDecimal)o1; BigDecimal c2 = (BigDecimal)o2; result = c1.compareTo(c2); } else if (o1 instanceof BigInteger || o2 instanceof BigInteger) { BigInteger c1 = (BigInteger)o1; BigInteger c2 = (BigInteger)o2; result = c1.compareTo(c2); } else if (o1 instanceof Long || o2 instanceof Long) { Long c1 = o1.longValue(); Long c2 = o2.longValue(); result = c1.compareTo(c2); } else if (o1 instanceof Double || o2 instanceof Double) { Double c1 = o1.doubleValue(); Double c2 = o2.doubleValue(); result = c1.compareTo(c2); } else if (o1 instanceof Float || o2 instanceof Float) { Float c1 = o1.floatValue(); Float c2 = o2.floatValue(); result = c1.compareTo(c2); } else { Integer c1 = o1.intValue(); Integer c2 = o2.intValue(); result = c1.compareTo(c2); } return result; } }; /** * Matrix with 'rows' number of rows and 'cols' number of columns. * * @param rows Number of rows in Matrix. * @param cols Number of columns in Matrix. */ public Matrix(int rows, int cols) { this.rows = rows; this.cols = cols; this.matrix = (T[]) new Number[rows * cols]; } /** * Matrix with 'rows' number of rows and 'cols' number of columns, populates * the double index matrix. * * @param rows Number of rows in Matrix. * @param cols Number of columns in Matrix. * @param matrix 2D matrix used to populate Matrix. */ public Matrix(int rows, int cols, T[][] matrix) { this.rows = rows; this.cols = cols; this.matrix = (T[]) new Number[rows * cols]; for (int r=0; r<rows; r++) for (int c=0; c<cols; c++) this.matrix[getIndex(r,c)] = matrix[r][c]; } private int getIndex(int row, int col) { if (row == 0) return col; return ((row * cols) + col); } public T get(int row, int col) { return matrix[getIndex(row, col)]; } public T[] getRow(int row) { T[] result = (T[]) new Number[cols]; for (int c = 0; c < cols; c++) { result[c] = this.get(row, c); } return result; } public T[] getColumn(int col) { T[] result = (T[]) new Number[rows]; for (int r = 0; r < rows; r++) { result[r] = this.get(r, col); } return result; } public void set(int row, int col, T value) { matrix[getIndex(row, col)] = value; } public Matrix<T> identity() throws Exception{ if(this.rows != this.cols) throw new Exception("Matrix should be a square"); final T element = this.get(0, 0); final T zero; final T one; if (element instanceof BigDecimal) { zero = (T)BigDecimal.ZERO; one = (T)BigDecimal.ONE; } else if(element instanceof BigInteger){ zero = (T)BigInteger.ZERO; one = (T)BigInteger.ONE; } else if(element instanceof Long){ zero = (T)new Long(0); one = (T)new Long(1); } else if(element instanceof Double){ zero = (T)new Double(0); one = (T)new Double(1); } else if(element instanceof Float){ zero = (T)new Float(0); one = (T)new Float(1); } else { zero = (T)new Integer(0); one = (T)new Integer(1); } final T array[][] = (T[][])new Number[this.rows][this.cols]; for(int i = 0; i < this.rows; ++i) { for(int j = 0 ; j < this.cols; ++j){ array[i][j] = zero; } } final Matrix<T> identityMatrix = new Matrix<T>(this.rows, this.cols, array); for(int i = 0; i < this.rows;++i){ identityMatrix.set(i, i, one); } return identityMatrix; } public Matrix<T> add(Matrix<T> input) { Matrix<T> output = new Matrix<T>(this.rows, this.cols); if ((this.cols != input.cols) || (this.rows != input.rows)) return output; for (int r = 0; r < output.rows; r++) { for (int c = 0; c < output.cols; c++) { for (int i = 0; i < cols; i++) { T m1 = this.get(r, c); T m2 = input.get(r, c); T result; /* TODO: This is ugly and how to handle number overflow? */ if (m1 instanceof BigDecimal || m2 instanceof BigDecimal) { BigDecimal result2 = ((BigDecimal)m1).add((BigDecimal)m2); result = (T)result2; } else if (m1 instanceof BigInteger || m2 instanceof BigInteger) { BigInteger result2 = ((BigInteger)m1).add((BigInteger)m2); result = (T)result2; } else if (m1 instanceof Long || m2 instanceof Long) { Long result2 = (m1.longValue() + m2.longValue()); result = (T)result2; } else if (m1 instanceof Double || m2 instanceof Double) { Double result2 = (m1.doubleValue() + m2.doubleValue()); result = (T)result2; } else if (m1 instanceof Float || m2 instanceof Float) { Float result2 = (m1.floatValue() + m2.floatValue()); result = (T)result2; } else { // Integer Integer result2 = (m1.intValue() + m2.intValue()); result = (T)result2; } output.set(r, c, result); } } } return output; } public Matrix<T> subtract(Matrix<T> input) { Matrix<T> output = new Matrix<T>(this.rows, this.cols); if ((this.cols != input.cols) || (this.rows != input.rows)) return output; for (int r = 0; r < output.rows; r++) { for (int c = 0; c < output.cols; c++) { for (int i = 0; i < cols; i++) { T m1 = this.get(r, c); T m2 = input.get(r, c); T result; /* TODO: This is ugly and how to handle number overflow? */ if (m1 instanceof BigDecimal || m2 instanceof BigDecimal) { BigDecimal result2 = ((BigDecimal)m1).subtract((BigDecimal)m2); result = (T)result2; } else if (m1 instanceof BigInteger || m2 instanceof BigInteger) { BigInteger result2 = ((BigInteger)m1).subtract((BigInteger)m2); result = (T)result2; } else if (m1 instanceof Long || m2 instanceof Long) { Long result2 = (m1.longValue() - m2.longValue()); result = (T)result2; } else if (m1 instanceof Double || m2 instanceof Double) { Double result2 = (m1.doubleValue() - m2.doubleValue()); result = (T)result2; } else if (m1 instanceof Float || m2 instanceof Float) { Float result2 = (m1.floatValue() - m2.floatValue()); result = (T)result2; } else { // Integer Integer result2 = (m1.intValue() - m2.intValue()); result = (T)result2; } output.set(r, c, result); } } } return output; } public Matrix<T> multiply(Matrix<T> input) { Matrix<T> output = new Matrix<T>(this.rows, input.cols); if (this.cols != input.rows) return output; for (int r = 0; r < output.rows; r++) { for (int c = 0; c < output.cols; c++) { T[] row = getRow(r); T[] column = input.getColumn(c); T test = row[0]; /* TODO: This is ugly and how to handle number overflow? */ if (test instanceof BigDecimal) { BigDecimal result = BigDecimal.ZERO; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; BigDecimal result2 = ((BigDecimal)m1).multiply(((BigDecimal)m2)); result.add(result2); } output.set(r, c, (T)result); } else if (test instanceof BigInteger) { BigInteger result = BigInteger.ZERO; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; BigInteger result2 = ((BigInteger)m1).multiply(((BigInteger)m2)); result.add(result2); } output.set(r, c, (T)result); } else if (test instanceof Long) { Long result = 0l; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; Long result2 = m1.longValue() * m2.longValue(); result = result+result2; } output.set(r, c, (T)result); } else if (test instanceof Double) { Double result = 0d; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; Double result2 = m1.doubleValue() * m2.doubleValue(); result = result+result2; } output.set(r, c, (T)result); } else if (test instanceof Float) { Float result = 0f; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; Float result2 = m1.floatValue() * m2.floatValue(); result = result+result2; } output.set(r, c, (T)result); } else { // Integer Integer result = 0; for (int i = 0; i < cols; i++) { T m1 = row[i]; T m2 = column[i]; Integer result2 = m1.intValue() * m2.intValue(); result = result+result2; } output.set(r, c, (T)result); } } } return output; } public void copy(Matrix<T> m) { for (int r = 0; r < m.rows; r++) { for (int c = 0; c < m.cols; c++) { set(r, c, m.get(r, c)); } } } /** * {@inheritDoc} */ @Override public int hashCode() { int hash = this.rows + this.cols; for (T t : matrix) hash += t.intValue(); return 31 * hash; } /** * {@inheritDoc} */ @Override public boolean equals(Object obj) { if (obj == null) return false; if (!(obj instanceof Matrix)) return false; Matrix<T> m = (Matrix<T>) obj; if (this.rows != m.rows) return false; if (this.cols != m.cols) return false; for (int i=0; i<matrix.length; i++) { T t1 = matrix[i]; T t2 = m.matrix[i]; int result = comparator.compare(t1, t2); if (result!=0) return false; } return true; } /** * {@inheritDoc} */ @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append("Matrix:\n"); for (int r = 0; r < rows; r++) { builder.append("row=[").append(r).append("] "); for (int c = 0; c < cols; c++) { builder.append(this.get(r, c)).append("\t"); } builder.append("\n"); } return builder.toString(); } }