package mikera.vectorz.impl;
import java.util.Iterator;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.arrayz.impl.IStridedArray;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.AStridedMatrix;
import mikera.matrixx.impl.StridedMatrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.util.DoubleArrays;
import mikera.vectorz.util.ErrorMessages;
/**
* Abstract base class for vectors backed by a double[] array with a constant stride
*
* The double array can be directly accessed for performance purposes
*
* @author Mike
*/
public abstract class AStridedVector extends AArrayVector implements IStridedArray {
protected AStridedVector(int length, double[] data) {
super(length,data);
}
private static final long serialVersionUID = -7239429584755803950L;
public double[] getArray() {
return data;
}
public abstract int getArrayOffset();
public abstract int getStride();
@Override
public AStridedVector ensureMutable() {
return clone();
}
@Override
public boolean isRangeZero(int start, int length) {
int stride=getStride();
int offset=getArrayOffset()+start*stride;
return DoubleArrays.isZero(data, offset, length,stride);
}
@Override public double dotProduct(double[] data, int offset) {
double[] array=getArray();
int thisOffset=getArrayOffset();
int stride=getStride();
int length=length();
double result=0.0;
for (int i=0; i<length; i++) {
result+=array[i*stride+thisOffset]*data[i+offset];
}
return result;
}
@Override
public double elementSum() {
int len=length();
double[] array=getArray();
int offset=getArrayOffset();
int stride=getStride();
double result=0.0;
for (int i=0; i<len; i++) {
result+=array[offset+i*stride];
}
return result;
}
@Override
public double elementProduct() {
int len=length();
double[] array=getArray();
int offset=getArrayOffset();
int stride=getStride();
double result=1.0;
for (int i=0; i<len; i++) {
result*=array[offset+i*stride];
}
return result;
}
@Override
public double elementMax(){
int len=length();
double[] array=getArray();
int offset=getArrayOffset();
int stride=getStride();
double max = -Double.MAX_VALUE;
for (int i=0; i<len; i++) {
double d=array[offset+i*stride];
if (d>max) max=d;
}
return max;
}
@Override
public double elementMin(){
int len=length();
double[] array=getArray();
int offset=getArrayOffset();
int stride=getStride();
double min = Double.MAX_VALUE;
for (int i=0; i<len; i++) {
double d=array[offset+i*stride];
if (d<min) min=d;
}
return min;
}
@Override
public INDArray broadcast(int... shape) {
int dims=shape.length;
if (dims==0) {
throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, shape));
} else if (dims==1) {
if (shape[0]!=length()) throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, shape));
return this;
} else if (dims==2) {
int rc=shape[0];
int cc=shape[1];
if (cc!=length()) throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, shape));
return Matrixx.wrapStrided(getArray(), rc, cc, getArrayOffset(), 0, getStride());
}
if (shape[dims-1]!=length()) throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, shape));
int[] newStrides=new int[dims];
newStrides[dims-1]=getStride();
return Arrayz.wrapStrided(getArray(),getArrayOffset(),shape,newStrides);
}
@Override
public AMatrix broadcastLike(AMatrix target) {
if (length()==target.columnCount()) {
return StridedMatrix.wrap(getArray(), target.rowCount(), length(), getArrayOffset(), 0, getStride());
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, target));
}
}
@Override
public AVector selectView(int... inds) {
int n=inds.length;
int[] ix=new int[n];
int off=getArrayOffset();
int stride=getStride();
for (int i=0; i<n; i++) {
ix[i]=off+stride*inds[i];
}
return IndexedArrayVector.wrap(getArray(), ix);
}
@Override
public AStridedVector clone() {
return Vector.create(this);
}
@Override
public void set(AVector v) {
int length=checkSameLength(v);
int stride=getStride();
v.copyTo(0, getArray(), getArrayOffset(), length, stride);
}
@Override
public void setElements(double[] values, int offset) {
double[] data=getArray();
int stride=getStride();
int off=getArrayOffset();
for (int i=0; i<length; i++) {
data[off+i*stride]=values[offset+i];
}
}
@Override
public void setElements(int pos, double[] values, int offset, int length) {
double[] data=getArray();
int stride=getStride();
int off=getArrayOffset()+pos*stride;
for (int i=0; i<length; i++) {
data[off+i*stride]=values[offset+i];
}
}
@Override
public void add(Vector v) {
checkSameLength(v);
v.addToArray(getArray(), getArrayOffset(), getStride());
}
@Override
public void add(double[] data, int offset) {
int stride=getStride();
double[] tdata=getArray();
int toffset=getArrayOffset();
int length=length();
for (int i = 0; i < length; i++) {
tdata[toffset+i*stride]+=data[offset+i];
}
}
@Override
public void add(int offset, AVector a) {
int stride=getStride();
a.addToArray(getArray(), getArrayOffset()+offset*stride,stride);
}
@Override
public void add(int offset, AVector a, int aOffset, int length) {
double[] tdata=getArray();
int stride=getStride();
int toffset=getArrayOffset()+offset*stride;
a.subVector(aOffset, length).addToArray(tdata, toffset, stride);
}
@Override
public void addAt(int i, double v) {
int ix=checkIndex(i);
double[] data=getArray();
data[ix]+=v;
}
@Override
public void addToArray(int offset, double[] destData, int destOffset,int length) {
int stride=getStride();
double[] tdata=getArray();
int toffset=getArrayOffset()+offset*stride;
for (int i = 0; i < length; i++) {
destData[destOffset+i]+=tdata[toffset+i*stride];
}
}
@Override
public void addToArray(double[] dest, int destOffset, int destStride) {
int stride=getStride();
double[] tdata=getArray();
int toffset=getArrayOffset();
for (int i = 0; i < length; i++) {
dest[destOffset+i*destStride]+=tdata[toffset+i*stride];
}
}
@Override
public void copyTo(int offset, double[] dest, int destOffset, int length, int stride) {
int thisStride=getStride();
int thisOffset=this.getArrayOffset();
for (int i=offset; i<length; i++) {
dest[destOffset+i*stride]=data[thisOffset+i*thisStride];
}
}
@Override
public void clamp(double min, double max) {
int len=length();
int stride=getStride();
double[] data=getArray();
int offset=getArrayOffset();
for (int i = 0; i < len; i++) {
int ix=offset+i*stride;
double v=data[ix];
if (v<min) {
data[ix]=min;
} else if (v>max) {
data[ix]=max;
}
}
}
@Override
public double[] asDoubleArray() {
if (isPackedArray()) return getArray();
return null;
}
@Override
public AStridedMatrix asColumnMatrix() {
return Matrixx.wrapStrided(data, length, 1, getArrayOffset(), getStride(), 0);
}
@Override
public boolean isPackedArray() {
return (getStride()==1)&&(getArrayOffset()==0)&&(getArray().length==length());
}
@Override
public int[] getStrides() {
return new int[] {getStride()};
}
@Override
public Iterator<Double> iterator() {
return new StridedElementIterator(getArray(),getArrayOffset(),length(),getStride());
}
@Override
public int getStride(int dimension) {
switch (dimension) {
case 0: return getStride();
default: throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
}
}
@Override
public void fill(double value) {
int stride=getStride();
double[] array=getArray();
int di=getArrayOffset();
for (int i=0; i<length; i++) {
array[di]=value;
di+=stride;
}
}
@Override
public boolean equalsArray(double[] data, int offset) {
int stride=getStride();
double[] array=getArray();
int di=getArrayOffset();
for (int i=0; i<length; i++) {
if (data[offset+i]!=array[di]) return false;
di+=stride;
}
return true;
}
}