package mikera.matrixx.impl;
import java.util.Iterator;
import mikera.arrayz.impl.IStridedArray;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrixx;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.AStridedVector;
import mikera.vectorz.util.ErrorMessages;
/**
* Abstract base class for arbitrary strided matrices
*
* @author Mike
*/
public abstract class AStridedMatrix extends AArrayMatrix implements IStridedArray {
private static final long serialVersionUID = -8908577438753599161L;
protected AStridedMatrix(double[] data, int rows, int cols) {
super(data, rows, cols);
}
public abstract int getArrayOffset();
public abstract int rowStride();
public abstract int columnStride();
@Override
public AStridedMatrix subMatrix(int rowStart, int rowCount, int colStart, int colCount) {
if ((rowStart<0)||(rowStart>=this.rows)||(colStart<0)||(colStart>=this.cols)) throw new IndexOutOfBoundsException(ErrorMessages.position(rowStart,colStart));
if ((rowStart+rowCount>this.rows)||(colStart+colCount>this.cols)) throw new IndexOutOfBoundsException(ErrorMessages.position(rowStart+rowCount,colStart+colCount));
int rowStride=rowStride();
int colStride=columnStride();
int offset=getArrayOffset();
return StridedMatrix.wrap(data, rowCount, colCount, offset+rowStart*rowStride+colStart*colStride, rowStride, colStride);
}
@Override
public AStridedVector getRowView(int i) {
return Vectorz.wrapStrided(data, getArrayOffset()+i*rowStride(), cols, columnStride());
}
@Override
public double diagonalProduct() {
int n=Math.min(rowCount(), columnCount());
int offset=getArrayOffset();
int st=rowStride()+columnStride();
double[] data=getArray();
double result=1.0;
for (int i=0; i<n; i++) {
result*=data[offset];
offset+=st;
}
return result;
}
@Override
public double trace() {
int n=Math.min(rowCount(), columnCount());
int offset=getArrayOffset();
int st=rowStride()+columnStride();
double[] data=getArray();
double result=0.0;
for (int i=0; i<n; i++) {
result+=data[offset];
offset+=st;
}
return result;
}
@Override
public AStridedVector getColumnView(int i) {
return Vectorz.wrapStrided(data, getArrayOffset()+i*columnStride(), rows, rowStride());
}
@Override
public AStridedVector getBand(int i) {
int cs=columnStride();
int rs=rowStride();
if ((i>cols)||(i<-rows)) throw new IndexOutOfBoundsException(ErrorMessages.invalidBand(this, i));
return Vectorz.wrapStrided(data, getArrayOffset()+bandStartColumn(i)*cs+bandStartRow(i)*rs, bandLength(i), rs+cs);
}
@Override
public void add(AVector v) {
checkColumnCount(v.length());
int offset=getArrayOffset();
int colStride=columnStride();
int rowStride=rowStride();
for (int i=0; i<rows; i++) {
v.addToArray(data, offset+i*rowStride, colStride);
}
}
@Override
public void addToArray(double[] dest, int destOffset) {
int offset=getArrayOffset();
int colStride=columnStride();
int rowStride=rowStride();
for (int i=0; i<rows; i++) {
int ro=offset+i*rowStride;
for (int j=0; j<cols; j++) {
dest[destOffset++]+=data[ro+j*colStride];
}
}
}
@Override
public void applyOp(Op op) {
int offset=getArrayOffset();
int colStride=columnStride();
int rowStride=rowStride();
for (int i=0; i<rows; i++) {
int ro=offset+i*rowStride;
for (int j=0; j<cols; j++) {
int ix=ro+j*colStride;
data[ix]=op.apply(data[ix]);
}
}
}
@Override
public void add(AMatrix m) {
checkSameShape(m);
int offset=getArrayOffset();
int colStride=columnStride();
int rowStride=rowStride();
for (int i=0; i<rows; i++) {
m.getRow(i).addToArray(data, offset+i*rowStride, colStride);
}
}
@Override
public abstract void copyRowTo(int row, double[] dest, int destOffset);
@Override
public abstract void copyColumnTo(int col, double[] dest, int destOffset);
@Override
public int[] getStrides() {
return new int[] {rowStride(), columnStride()};
}
@Override
public int getStride(int dimension) {
switch (dimension) {
case 0: return rowStride();
case 1: return columnStride();
default: throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
}
}
@Override
public Iterator<Double> elementIterator() {
return new StridedElementIterator(this);
}
@Override
public AMatrix getTranspose() {
return getTransposeView();
}
@Override
public AMatrix getTransposeView() {
return Matrixx.wrapStrided(getArray(),columnCount(),rowCount(),getArrayOffset(),columnStride(),rowStride());
}
@Override
public boolean isPackedArray() {
return (getArrayOffset()==0)&&(columnStride()==1)&&(rowStride()==columnCount())&&(getArray().length==elementCount());
}
@Override
public double[] asDoubleArray() {
if (isPackedArray()) return getArray();
return null;
}
}