package mikera.arrayz; import java.nio.DoubleBuffer; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import mikera.arrayz.impl.BaseShapedArray; import mikera.arrayz.impl.IDenseArray; import mikera.arrayz.impl.IStridedArray; import mikera.arrayz.impl.ImmutableArray; import mikera.indexz.Index; import mikera.matrixx.Matrix; import mikera.vectorz.AVector; import mikera.vectorz.IOperator; import mikera.vectorz.Op; import mikera.vectorz.Scalar; import mikera.vectorz.Tools; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.ArrayIndexScalar; import mikera.vectorz.impl.StridedElementIterator; import mikera.vectorz.util.DoubleArrays; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.IntArrays; import mikera.vectorz.util.VectorzException; /** * General purpose mutable dense N-dimensional array * * This is the general multi-dimensional equivalent of Matrix and Vector, and as such is the * most efficient storage type for dense 3D+ arrays * * @author Mike * */ public final class Array extends BaseShapedArray implements IStridedArray, IDenseArray { private static final long serialVersionUID = -8636720562647069034L; private final int dimensions; private final int[] strides; private final double[] data; private Array(int dims, int[] shape, int[] strides) { super(shape); this.dimensions = dims; this.strides = strides; int n = (int) IntArrays.arrayProduct(shape); this.data = new double[n]; } private Array(int[] shape, double[] data) { this(shape.length, shape, IntArrays.calcStrides(shape), data); } private Array(int dims, int[] shape, double[] data) { this(dims, shape, IntArrays.calcStrides(shape), data); } public static INDArray wrap(double[] data, int... shape) { long ec=IntArrays.arrayProduct(shape); if (data.length!=ec) throw new IllegalArgumentException("Data array does not have correct number of elements, expected: "+ec); return new Array(shape.length,shape,data); } private Array(int dims, int[] shape, int[] strides, double[] data) { super(shape); this.dimensions = dims; this.strides = strides; this.data = data; } public static Array wrap(Vector v) { return new Array(v.getShape(),v.getArray()); } public static Array wrap(Matrix m) { return new Array(m.getShape(),m.getArray()); } public static Array newArray(int... shape) { return new Array(shape.length, shape, createStorage(shape)); } public static Array create(INDArray a) { int[] shape=a.getShape(); return new Array(a.dimensionality(), shape, a.toDoubleArray()); } public static double[] createStorage(int... shape) { long ec=1; for (int i=0; i<shape.length; i++) { int si=shape[i]; if ((ec*si)!=(((int)ec)*si)) throw new IllegalArgumentException(ErrorMessages.tooManyElements(shape)); ec*=shape[i]; } int n=(int)ec; if (ec!=n) throw new IllegalArgumentException(ErrorMessages.tooManyElements(shape)); return new double[n]; } @Override public int dimensionality() { return dimensions; } @Override public long[] getLongShape() { long[] lshape = new long[dimensions]; IntArrays.copyIntsToLongs(shape, lshape); return lshape; } public int getStride(int dim) { return strides[dim]; } public int getIndex(int... indexes) { int ix = 0; for (int i = 0; i < dimensions; i++) { ix += indexes[i] * getStride(i); } return ix; } @Override public double get(int... indexes) { return data[getIndex(indexes)]; } @Override public void set(int[] indexes, double value) { data[getIndex(indexes)] = value; } @Override public Vector asVector() { return Vector.wrap(data); } @Override public Vector toVector() { return Vector.create(data); } @Override public INDArray slice(int majorSlice) { return slice(0, majorSlice); } @Override public INDArray slice(int dimension, int index) { if ((dimension < 0) || (dimension >= dimensions)) throw new IndexOutOfBoundsException(ErrorMessages.invalidDimension(this,dimension)); if (dimensions == 1) return ArrayIndexScalar.wrap(data, index); if (dimensions == 2) { if (dimension == 0) { return Vectorz.wrap(data, index * shape[1], shape[1]); } else { return Vectorz.wrapStrided(data, index, shape[0], strides[0]); } } int offset = index * getStride(dimension); return new NDArray( data, offset, IntArrays.removeIndex(shape, dimension), IntArrays.removeIndex(strides, dimension)); } @Override public INDArray getTranspose() { return getTransposeView(); } @Override public INDArray getTransposeView() { return NDArray.wrapStrided(data, 0, IntArrays.reverse(shape), IntArrays.reverse(strides)); } @Override public INDArray subArray(int[] offsets, int[] shape) { int n=dimensions; if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets)); if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets)); if (IntArrays.equals(shape, this.shape)) { if (IntArrays.isZero(offsets)) { return this; } else { throw new IllegalArgumentException("Invalid subArray offsets"); } } int[] strides=IntArrays.calcStrides(this.shape); return new NDArray(data, IntArrays.dotProduct(offsets, strides), IntArrays.copyOf(shape), strides); } @Override public long elementCount() { return data.length; } @Override public double elementSum() { return DoubleArrays.elementSum(data); } @Override public double elementMax(){ return DoubleArrays.elementMax(data); } @Override public double elementMin(){ return DoubleArrays.elementMin(data); } @Override public double elementSquaredSum() { return DoubleArrays.elementSquaredSum(data); } @Override public void abs() { DoubleArrays.abs(data); } @Override public void signum() { DoubleArrays.signum(data); } @Override public void square() { DoubleArrays.square(data); } @Override public void exp() { DoubleArrays.exp(data); } @Override public void log() { DoubleArrays.log(data); } @Override public boolean isMutable() { return true; } @Override public boolean isFullyMutable() { return true; } @Override public boolean isElementConstrained() { return false; } @Override public boolean isView() { return false; } @Override public void applyOp(Op op) { op.applyTo(data); } @Override public void applyOp(IOperator op) { if (op instanceof Op) { ((Op) op).applyTo(data); } else { for (int i = 0; i < data.length; i++) { data[i] = op.apply(data[i]); } } } @Override public boolean equals(INDArray a) { if (a instanceof Array) return equals((Array) a); if (!isSameShape(a)) return false; return a.equalsArray(data, 0); } public boolean equals(Array a) { if (a.dimensions != dimensions) return false; if (!IntArrays.equals(shape, a.shape)) return false; return DoubleArrays.equals(data, a.data); } @Override public Array exactClone() { return new Array(dimensions, shape, strides, data.clone()); } @Override public void setElements(int pos, double[] values, int offset, int length) { System.arraycopy(values, offset, data, pos, length); } @Override public void getElements(double[] values, int offset) { System.arraycopy(data, 0, values, offset, data.length); } @Override public Iterator<Double> elementIterator() { return new StridedElementIterator(data,0,(int)elementCount(),1); } @Override public void multiply(double factor) { DoubleArrays.multiply(data, 0, data.length, factor); } @Override public List<?> getSlices() { if (dimensions==1) { int n=sliceCount(); ArrayList<Double> al=new ArrayList<Double>(n); for (int i=0; i<n; i++) { al.add(get(i)); } return al; } else { return super.getSliceViews(); } } @Override public void toDoubleBuffer(DoubleBuffer dest) { dest.put(data); } @Override public double[] toDoubleArray() { return DoubleArrays.copyOf(data); } @Override public double[] asDoubleArray() { return data; } @Override public INDArray clone() { // always return the efficient type for each dimensionality switch (dimensions) { case 0: return Scalar.create(data[0]); case 1: return Vector.create(data); case 2: return Matrix.wrap(shape[0], shape[1], DoubleArrays.copyOf(data)); default: return Array.wrap(DoubleArrays.copyOf(data),shape); } } @Override public void validate() { super.validate(); if (dimensions != shape.length) throw new VectorzException("Inconsistent dimensionality"); if ((dimensions > 0) && (strides[dimensions - 1] != 1)) throw new VectorzException("Last stride should be 1"); if (data.length != IntArrays.arrayProduct(shape)) throw new VectorzException("Inconsistent shape"); if (!IntArrays.equals(strides, IntArrays.calcStrides(shape))) throw new VectorzException("Inconsistent strides"); } /** * Creates a new matrix using the elements in the specified vector. * Truncates or zero-pads the data as required to fill the new matrix * @param data * @param rows * @param columns * @return */ public static Array createFromVector(AVector a, int... shape) { Array m = Array.newArray(shape); int n=(int)Math.min(m.elementCount(), a.length()); a.copyTo(0, m.data, 0, n); return m; } @Override public double[] getArray() { return data; } @Override public int getArrayOffset() { return 0; } @Override public int[] getStrides() { return strides; } @Override public boolean isPackedArray() { return true; } @Override public boolean isZero() { return DoubleArrays.isZero(data); } @Override public INDArray immutable() { return ImmutableArray.wrap(DoubleArrays.copyOf(data), this.shape); } @Override public double get() { if (dimensions==0) { return data[0]; } else { throw new IllegalArgumentException("O-d get not supported on Array of shape: "+Index.of(this.getShape()).toString()); } } @Override public double get(int x) { if (dimensions==1) { return data[x]; } else { throw new IllegalArgumentException("1-d get not supported on Array of shape: "+Index.of(this.getShape()).toString()); } } @Override public double get(int x, int y) { if (dimensions==2) { return data[x*strides[0]+y]; } else { throw new IllegalArgumentException("2-d get not supported on Array of shape: "+Index.of(this.getShape()).toString()); } } @Override public boolean equalsArray(double[] data, int offset) { return DoubleArrays.equals(this.data, 0, data, offset, Tools.toInt(elementCount())); } }