package mikera.matrixx.impl; import mikera.arrayz.INDArray; import mikera.arrayz.impl.IDenseArray; import mikera.matrixx.AMatrix; import mikera.vectorz.AVector; import mikera.vectorz.Tools; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.ADenseArrayVector; import mikera.vectorz.impl.AStridedVector; import mikera.vectorz.util.DoubleArrays; import mikera.vectorz.util.ErrorMessages; /** * Abstract base class for matrices wrapping a dense (rows*cols) subset of a double[] array * @author Mike * */ public abstract class ADenseArrayMatrix extends AStridedMatrix implements IFastRows, IDenseArray { private static final long serialVersionUID = -2144964424833585026L; protected ADenseArrayMatrix(double[] data, int rows, int cols) { super(data, rows, cols); } @Override public abstract int getArrayOffset(); @Override public boolean isPackedArray() { return (getArrayOffset()==0) && (data.length ==(rows*cols)); } @Override public boolean isZero() { return DoubleArrays.isZero(data, getArrayOffset(), rows*cols); } @Override public boolean isUpperTriangular() { // triangular test, taking into account cache layout to access via rows int rc=rowCount(); int cc=columnCount(); int offset=getArrayOffset(); for (int i=1; i<rc; i++) { if (!DoubleArrays.isZero(data, offset+i*cc, Math.min(cc, i))) return false; } return true; } @Override public boolean isLowerTriangular() { // triangular test, taking into account cache layout to access via rows int offset=getArrayOffset(); int cc=columnCount(); int testRows=Math.min(cc, rowCount()); for (int i=0; i<testRows; i++) { if (!DoubleArrays.isZero(data, offset+i+1, cc-i-1)) return false; offset+=cc; } return true; } @Override public int rowStride() { return cols; } @Override public int columnStride() { return 1; } @Override public double unsafeGet(int i, int j) { return data[index(i,j)]; } @Override public void set(AVector v) { int rc=rowCount(); if (v.length()!=cols) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, v)); double[] data=getArray(); int offset=getArrayOffset(); for (int i=0; i<rc; i++) { v.getElements(data, offset+i*cols); } } @Override public void set(AMatrix m) { if (!isSameShape(m)) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, m)); double[] data=getArray(); int offset=getArrayOffset(); m.getElements(data, offset); } @Override public void setElements(double[] values, int offset) { double[] data=getArray(); int di=getArrayOffset(); System.arraycopy(values, offset, data, di, Tools.toInt(elementCount())); } @Override public void setElements(int pos, double[] values, int offset, int length) { double[] data=getArray(); int di=getArrayOffset()+pos; System.arraycopy(values, offset, data, di, length); } @Override public ADenseArrayVector getRowView(int i) { return Vectorz.wrap(data, getArrayOffset()+i*cols, cols); } @Override public AStridedVector getColumnView(int i) { return Vectorz.wrapStrided(data, getArrayOffset()+i, rows, cols); } @Override public double elementSum() { return DoubleArrays.elementSum(data,getArrayOffset(), rows*cols); } @Override public double elementSquaredSum() { return DoubleArrays.elementSquaredSum(data,getArrayOffset(), rows*cols); } @Override public double elementMax(){ return DoubleArrays.elementMax(data,getArrayOffset(), rows*cols); } @Override public double elementMin(){ return DoubleArrays.elementMin(data,getArrayOffset(), rows*cols); } @Override public void copyRowTo(int row, double[] dest, int destOffset) { System.arraycopy(data, getArrayOffset()+row*cols, dest, destOffset, cols); } @Override public void unsafeSet(int i, int j,double value) { data[index(i,j)]=value; } protected int index(int row, int col) { return getArrayOffset()+(row*cols)+col; } @Override public void transform(AVector source, AVector dest) { if ((source instanceof Vector )&&(dest instanceof Vector)) { transform ((Vector)source, (Vector)dest); return; } if(rows!=dest.length()) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest)); if(cols!=source.length()) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source)); double[] data=getArray(); int offset=getArrayOffset(); for (int i=0; i<rows; i++) { dest.unsafeSet(i,source.dotProduct(data, offset+ i*cols)); } } @Override public void add(AVector v) { int rc=rowCount(); int cc=columnCount(); if(cc!=v.length()) throw new IllegalArgumentException(ErrorMessages.mismatch(this, v)); double[] data=getArray(); int offset=getArrayOffset(); for (int i=0; i<rc; i++) { v.addToArray(data, offset+i*cc); } } @Override public void add(AMatrix a) { if (!isSameShape(a)) { throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a)); } a.addToArray(getArray(), getArrayOffset()); } public void add(ADenseArrayMatrix a, ADenseArrayMatrix b) { checkSameShape(a); checkSameShape(b); DoubleArrays.add2(getArray(), getArrayOffset(), a.getArray(), a.getArrayOffset(), b.getArray(), b.getArrayOffset(), Tools.toInt(this.elementCount())); } @Override public void addToArray(double[] data, int offset) { DoubleArrays.add(getArray(), getArrayOffset(), data, offset, rows*cols); } @Override public boolean equals(AMatrix a) { if (!isSameShape(a)) return false; return a.equalsArray(getArray(), getArrayOffset()); } @Override public boolean equals(INDArray a) { if (!isSameShape(a)) return false; return a.equalsArray(getArray(), getArrayOffset()); } @Override public boolean equalsArray(double[] data, int offset) { return DoubleArrays.equals(getArray(), getArrayOffset(), data, offset, rows*cols); } @Override public boolean equals(ADenseArrayMatrix m) { if (!isSameShape(m)) return false; return DoubleArrays.equals(getArray(), getArrayOffset(), m.getArray(), m.getArrayOffset(), rows*cols); } }