package edu.cmu.graphchi.util; import java.util.Random; import org.apache.commons.math.linear.ArrayRealVector; import org.apache.commons.math.linear.RealVector; /** * A huge dense matrix, which internally splits to many sub-blocks * Row-directed storage, so scanning row by row is efficient. This is useful * in keeping all vertex-values in memory efficiently. * @author akyrola */ public class HugeDoubleMatrix implements Cloneable { private int BLOCKSIZE = 1024 * 1024 * 16; // 16M * 4 = 64 megabytes private long nrows, ncols; private double[][] data; public HugeDoubleMatrix(long nrows, long ncols, double initialValue) { this.nrows = (long)nrows; this.ncols = (long)ncols; while(BLOCKSIZE % ncols != 0) BLOCKSIZE++; long elements = nrows * ncols; int nblocks = (int) (elements / (long)BLOCKSIZE + (elements % BLOCKSIZE == 0 ? 0 : 1)); data = new double[nblocks][]; System.out.println("Creating " + nblocks + " blocks"); for(int i=0; i<nblocks; i++) { data[i] = new double[BLOCKSIZE]; if (initialValue != 0.0f) { double[] mat = data[i]; for(int j=0; j<BLOCKSIZE; j++) { mat[j] = initialValue; } } } } public HugeDoubleMatrix(long nrows, long ncols) { this(nrows, ncols, 0.0); } public long size() { return nrows * ncols; } public long getNumRows() { return nrows; } public double getValue(int row, int col) { long idx = (long)row * ncols + (long)col; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); return data[block][blockidx]; } public void setValue(int row, int col, double val) { long idx = (long)row * ncols + (long)col; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); data[block][blockidx] = val; } public void add(int row, int col, double delta) { long idx = (long)row * ncols + (long)col; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); data[block][blockidx] += delta; } // Premature optimization public double[] getRowBlock(int row) { long idx = (long)row * ncols; int block = (int) (idx / BLOCKSIZE); return data[block]; } public int getBlockIdx(int row) { long idx = (long)row * ncols; int blockidx = (int) (idx % BLOCKSIZE); return blockidx; } public double[] getEmptyRow() { double[] arr = new double[(int)ncols]; return arr; } public void multiplyRow(int row, float mul) { for(int i=0; i<ncols; i++) { setValue(row, i, getValue(row, i) * mul); // TODO make faster } } public void getRow(int row, double[] arr) { long idx = (long)row * ncols; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); System.arraycopy(data[block], blockidx, arr, 0, (int)ncols); } public RealVector getRowAsVector(int row) { double [] arr = new double[(int) ncols]; long idx = (long)row * ncols; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); System.arraycopy(data[block], blockidx, arr, 0, (int)ncols); return new ArrayRealVector(arr); } public void setRow(int row, double[] arr) { long idx = (long)row * ncols; int block = (int) (idx / BLOCKSIZE); int blockidx = (int) (idx % BLOCKSIZE); System.arraycopy(arr, 0, data[block], blockidx, (int)ncols); } /** Divides each item by the sum of squares */ public void normalizeSquared(int col) { double sqr = 0.0f; for(int j=0; j < nrows; j++) { double x = (double) getValue(j, col); sqr += x * x; } System.out.println("Normalize-squared: " + col + " sqr: " + sqr); float div = (float) Math.sqrt(sqr); System.out.println("Div : " + div); if (Float.isInfinite(div) || Float.isNaN(div)) throw new RuntimeException("Illegal normalizer: " + div); if (sqr == 0.0f) throw new IllegalArgumentException("Column was all-zeros!"); for(int j=0; j < nrows; j++) { double x = getValue(j, col); setValue(j, col, x / div); } } public void setColumn(int col, double val) { for(int j=0; j < nrows; j++) { this.setValue(j, col, val); } } /** * Sets every value less than a cutoff value to zero. * @param cutOff */ public void zeroLessThan(float cutOff) { for(int i=0; i < data.length; i++) { double[] block = data[i]; for(int j=0; j < block.length; j++) { if (block[j] != 0.0f && block[j] < cutOff) block[j] = 0.0f; } } } /** * Sets all values less than cutOff to zero, everyone else to value * @param cutOff * @param value */ public void binaryFilter(double cutOff, double value) { for(int i=0; i < data.length; i++) { double[] block = data[i]; for(int j=0; j < block.length; j++) { block[j] = (block[j] >= cutOff ? 1.0 : 0.0) * value; } } } /** * Randomize the content with numbers between from and to * @param from min value * @param to max value */ public void randomize(double from, double to) { Random r = new Random(); for(int i=0; i < data.length; i++) { double[] block = data[i]; for(int j=0; j < block.length; j++) { block[j] = from + (to - from) * r.nextDouble(); } } } }