package org.wikibrain.matrix; import gnu.trove.map.TIntIntMap; import gnu.trove.map.hash.TIntIntHashMap; import java.io.File; import java.io.IOException; import java.util.Iterator; /** * A sparse matrix implementation that can be fit in memory. * * @author Shilad Sen */ public class InMemorySparseMatrix implements Matrix<InMemorySparseMatrix.InMemorySparseMatrixRow> { public final int rowIds[]; public final int colIds[][]; public final double values[][]; public TIntIntMap rowMap; public TIntIntMap idMap; // raw ids -> compressed ids public InMemorySparseMatrix(int rowIds[], int colIds[][], double [][] values) { this.rowIds = rowIds; this.colIds = colIds; this.values = values; rowMap = new TIntIntHashMap(rowIds.length * 2); for (int i = 0; i < rowIds.length; i++) { rowMap.put(rowIds[i], i); } } public InMemorySparseMatrix(File file) throws IOException { this(new SparseMatrix(file)); // FIXME: close the matrix! } public InMemorySparseMatrix(Matrix<? extends MatrixRow> diskMatrix) { rowIds = diskMatrix.getRowIds(); colIds = new int[rowIds.length][]; values = new double[rowIds.length][]; int i = 0; for (MatrixRow row : diskMatrix) { if (rowIds[i] != row.getRowIndex()) throw new IllegalStateException(); colIds[i] = new int[row.getNumCols()]; values[i] = new double[row.getNumCols()]; for (int j = 0; j < row.getNumCols(); j++) { colIds[i][j] = row.getColIndex(j); values[i][j] = row.getColValue(j); } i++; } rowMap = new TIntIntHashMap(rowIds.length * 2); for (int j = 0; j < rowIds.length; j++) { rowMap.put(rowIds[j], j); } } /** * Replaces raw ids with a contiguous set of ids starting at 0. */ public void compressIds() { idMap = new TIntIntHashMap(); for (int i = 0; i < rowIds.length; i++) { if (idMap.size() != i) throw new IllegalStateException(); if (idMap.containsKey(rowIds[i])) throw new IllegalStateException("duplicate row: " + rowIds[i]); idMap.put(rowIds[i], i); rowIds[i] = i; } for (int i = 0; i < colIds.length; i++) { for (int j = 0; j < colIds[i].length; j++) { if (!idMap.containsKey(colIds[i][j])) { idMap.put(colIds[i][j], idMap.size()); } colIds[i][j] = idMap.get(colIds[i][j]); } quickSort(colIds[i], values[i], 0, colIds[i].length - 1); } } public TIntIntMap getIdMap() { return idMap; } public void decompressIds() { throw new UnsupportedOperationException(); } @Override public InMemorySparseMatrixRow getRow(int rowId) throws IOException { if (rowMap.containsKey(rowId)) { return getMatrixRowInternal(rowMap.get(rowId)); } else { return null; } } @Override public int[] getRowIds() { return rowIds; } @Override public int getNumRows() { return rowIds.length; } private InMemorySparseMatrixRow getMatrixRowInternal(int i) { return new InMemorySparseMatrixRow(rowIds[i], colIds[i], values[i]); } @Override public Iterator<InMemorySparseMatrixRow> iterator() { return new Iterator<InMemorySparseMatrixRow>() { private int i; @Override public boolean hasNext() { return i < rowIds.length; } @Override public InMemorySparseMatrixRow next() { if (i >= rowIds.length) return null; return getMatrixRowInternal(i++); } @Override public void remove() { throw new UnsupportedOperationException(); } }; } @Override public File getPath() { throw new UnsupportedOperationException(); } @Override public void close() throws IOException { } public static class InMemorySparseMatrixRow extends BaseMatrixRow implements MatrixRow { int rowId; int colIds[]; double colVals[]; public InMemorySparseMatrixRow(int rowId, int[] colId, double[] value) { this.rowId = rowId; this.colIds = colId; this.colVals = value; } @Override public int getColIndex(int i) { return colIds[i]; } @Override public float getColValue(int i) { return (float)colVals[i]; } @Override public int getRowIndex() { return rowId; } @Override public int getNumCols() { return colIds.length; } } // Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/ private void quickSort(int colIds[], double colVals[], int low, int high) { if (colIds.length == 0 || low >= high) return; // pick the pivot int middle = (low + high) / 2; int pivot = colIds[middle]; // partition around the pivot int i = low, j = high; while (i <= j) { while (colIds[i] < pivot) { i++; } while (colIds[j] > pivot) { j--; } if (i <= j) { int temp = colIds[i]; double tempV = colVals[i]; colIds[i] = colIds[j]; colVals[i] = colVals[j]; colIds[j] = temp; colVals[j] = tempV; i++; j--; } } //recursively sort two sub parts quickSort(colIds, colVals, low, j); quickSort(colIds, colVals, i, high); } }