package mikera.matrixx.impl; import mikera.matrixx.AMatrix; import mikera.matrixx.Matrix; import mikera.matrixx.Matrixx; import mikera.vectorz.AVector; import mikera.vectorz.Op; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.AStridedVector; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * A general purpose strided matrix implementation * * @author Mike */ public final class StridedMatrix extends AStridedMatrix { private static final long serialVersionUID = -7928115802247422177L; private final int rowStride; private final int colStride; private final int offset; private StridedMatrix(double[] data, int rowCount, int columnCount, int offset, int rowStride, int columnStride) { super(data,rowCount,columnCount); this.offset = offset; this.rowStride = rowStride; this.colStride = columnStride; } public static StridedMatrix create(int rowCount, int columnCount) { double[] data = new double[rowCount * columnCount]; return new StridedMatrix(data, rowCount, columnCount, 0, columnCount, 1); } @Override public boolean isFullyMutable() { return true; } @Override public boolean isMutable() { return true; } @Override public AStridedVector getRowView(int i) { return Vectorz.wrapStrided(data, offset+i*rowStride, cols, colStride); } @Override public AStridedVector getColumnView(int i) { return Vectorz.wrapStrided(data, offset+i*colStride, rows, rowStride); } @Override public void copyRowTo(int row, double[] dest, int destOffset) { int rowOffset=offset+row*rowStride; for (int i=0;i<cols; i++) { dest[destOffset+i]=data[rowOffset+i*colStride]; } } @Override public void copyColumnTo(int col, double[] dest, int destOffset) { int colOffset=offset+col*colStride; for (int i=0;i<rows; i++) { dest[destOffset+i]=data[colOffset+i*rowStride]; } } @Override public int rowStride() { return rowStride; } @Override public int columnStride() { return colStride; } @Override public int getArrayOffset() { return offset; } @Override public boolean isPackedArray() { return (offset == 0) && (colStride == 1) && (rowStride == cols) && (data.length == rows * cols); } @Override public AStridedMatrix subMatrix(int rowStart, int rowCount, int colStart, int colCount) { if ((rowStart<0)||(rowStart>=this.rows)||(colStart<0)||(colStart>=this.cols)) throw new IndexOutOfBoundsException(ErrorMessages.position(rowStart,colStart)); if ((rowStart+rowCount>this.rows)||(colStart+colCount>this.cols)) throw new IndexOutOfBoundsException(ErrorMessages.position(rowStart+rowCount,colStart+colCount)); return new StridedMatrix(data, rowCount, colCount, offset+rowStart*rowStride+colStart*colStride, rowStride, colStride); } @Override public void applyOp(Op op) { int rc = rowCount(); int cc = columnCount(); int o=offset; for (int row = 0; row < rc; row++) { int ro=o+row*rowStride(); for (int col = 0; col < cc; col++) { int index = ro+col*colStride; double v = data[index]; data[index] = op.apply(v); } } } @Override public void getElements(double[] dest, int destOffset) { int rc = rowCount(); int cc = columnCount(); for (int row = 0; row < rc; row++) { copyRowTo(row, dest, destOffset+row*cc); } } @Override public AMatrix getTranspose() { return Matrixx.wrapStrided(data, cols, rows, offset, colStride, rowStride); } @Override public AMatrix getTransposeView() { return Matrixx.wrapStrided(data, cols, rows, offset, colStride, rowStride); } @Override public double get(int i, int j) { checkIndex(i,j); return data[index(i,j)]; } @Override public double unsafeGet(int i, int j) { return data[index(i,j)]; } @Override public AVector asVector() { if (isPackedArray()) { return Vector.wrap(data); } else if (cols==1) { return Vectorz.wrapStrided(data, offset, rows, rowStride); } else if (rows ==1){ return Vectorz.wrapStrided(data, offset, cols, colStride); } return super.asVector(); } @Override public void set(int i, int j, double value) { checkIndex(i,j); data[index(i,j)] = value; } @Override public void unsafeSet(int i, int j, double value) { data[index(i,j)] = value; } @Override public AMatrix exactClone() { return new StridedMatrix(data.clone(), rows, cols, offset, rowStride, colStride); } public static StridedMatrix create(AMatrix m) { StridedMatrix sm = StridedMatrix.create(m.rowCount(), m.columnCount()); sm.set(m); return sm; } public static StridedMatrix wrap(Matrix m) { return new StridedMatrix(m.data, m.rowCount(), m.columnCount(), 0, m.columnCount(), 1); } public static StridedMatrix wrap(double[] data, int rows, int columns, int offset, int rowStride, int columnStride) { return new StridedMatrix(data, rows, columns, offset, rowStride, columnStride); } @Override public void validate() { super.validate(); if (!equals(this.exactClone())) throw new VectorzException("Thing not equal to itself"); if (offset<0) throw new VectorzException("Negative offset! ["+offset+"]"); if (index(rows-1,cols-1)>=data.length) throw new VectorzException("Negative offset! ["+offset+"]"); } @Override protected final int index(int row, int col) { return offset+(row*rowStride)+(col*colStride); } @Override public Matrix clone() { return Matrix.create(this); } @Override public boolean equals(AMatrix a) { if (a==this) return true; if (a instanceof ADenseArrayMatrix) return equals((ADenseArrayMatrix)a); if (!isSameShape(a)) return false; for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { if (data[index(i, j)] != a.unsafeGet(i, j)) return false; } } return true; } @Override public boolean equalsArray(double[] data, int offset) { for (int i = 0; i < rows; i++) { int si=this.offset+i*rowStride; for (int j = 0; j < cols; j++) { if (this.data[si] != data[offset++]) return false; si+=colStride; } } return true; } }