package mikera.arrayz.impl; import java.nio.DoubleBuffer; import java.util.Arrays; import mikera.arrayz.Arrayz; import mikera.arrayz.INDArray; import mikera.vectorz.AVector; import mikera.vectorz.Tools; import mikera.vectorz.impl.ImmutableScalar; import mikera.vectorz.impl.ImmutableVector; import mikera.vectorz.util.DoubleArrays; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.IntArrays; /** * Immutable N-dimensional array class * * Uses strided array storage * * @author Mike * */ public class ImmutableArray extends BaseNDArray implements IDense { private static final long serialVersionUID = 2078025371733533775L; private ImmutableArray(int dims, int[] shape, int[] strides) { this(new double[(int)IntArrays.arrayProduct(shape)],shape.length,0,shape,strides); } private ImmutableArray(double[] data, int dimensions, int offset, int[] shape, int[] stride) { super(data,dimensions,offset,shape,stride); } private ImmutableArray(int[] shape, double[] data) { this(shape.length, shape, IntArrays.calcStrides(shape), data); } private ImmutableArray(int dims, int[] shape, double[] data) { this(dims, shape, IntArrays.calcStrides(shape), data); } public static INDArray wrap(double[] data, int[] shape) { long ec=IntArrays.arrayProduct(shape); if (data.length!=ec) throw new IllegalArgumentException("Data array does not have correct number of elements, expected: "+ec); return new ImmutableArray(shape.length,shape,data); } private ImmutableArray(int dims, int[] shape, int[] strides, double[] data) { this(data,dims,0,shape,strides); } @Override public int dimensionality() { return dimensions; } @Override public boolean isMutable() { return false; } @Override public boolean isFullyMutable() { return false; } @Override public boolean isElementConstrained() { return true; } @Override public int[] getShape() { return shape; } @Override public int[] getShapeClone() { return shape.clone(); } @Override public long[] getLongShape() { long[] lshape = new long[dimensions]; IntArrays.copyIntsToLongs(shape, lshape); return lshape; } @Override public int getIndex(int... indexes) { int ix = offset; for (int i = 0; i < dimensions; i++) { ix += indexes[i] * getStride(i); } return ix; } @Override public double get(int... indexes) { return data[getIndex(indexes)]; } @Override public void set(int[] indexes, double value) { throw new UnsupportedOperationException(ErrorMessages.immutable(this)); } @Override public INDArray slice(int majorSlice) { if (dimensions==0) { throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } else if (dimensions==1) { return ImmutableScalar.create(get(majorSlice)); } else if ((dimensions==2)&&(stride[1]==1)) { return ImmutableVector.wrap(data, offset+majorSlice*getStride(0), shape[1]); } else { return new ImmutableArray(data, dimensions-1, offset+majorSlice*getStride(0), Arrays.copyOfRange(shape, 1,dimensions), Arrays.copyOfRange(stride, 1,dimensions)); } } @Override public INDArray slice(int dimension, int index) { if ((dimension<0)||(dimension>=dimensions)) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension)); if (dimension==0) return slice(index); return new ImmutableArray(data, dimensions-1, offset+index*stride[dimension], IntArrays.removeIndex(shape,index), IntArrays.removeIndex(stride,index)); } @Override public int sliceCount() { return shape[0]; } @Override public ImmutableArray subArray(int[] offsets, int[] shape) { int n=dimensions; if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets)); if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets)); if (IntArrays.equals(shape, this.shape)) { if (IntArrays.isZero(offsets)) { return this; } else { throw new IllegalArgumentException("Invalid subArray offsets"); } } return new ImmutableArray(data, n, offset+IntArrays.dotProduct(offsets, stride), IntArrays.copyOf(shape), stride); } @Override public long elementCount() { return IntArrays.arrayProduct(shape); } @Override public boolean isView() { return false; } @Override public AVector asVector() { if (dimensions>0) return super.asVector(); return ImmutableVector.wrap(new double[] {data[offset]}); } @Override public INDArray exactClone() { return new ImmutableArray(data.clone(),dimensions,offset,shape.clone(),stride.clone()); } @Override public INDArray sparseClone() { return Arrayz.createSparse(this); } @Override public void toDoubleBuffer(DoubleBuffer dest) { if (dimensions>0) { super.toDoubleBuffer(dest); } else { dest.put(data[offset]); } } @Override public boolean equalsArray(double[] data, int offset) { return DoubleArrays.equals(this.data, this.offset, data, offset, Tools.toInt(elementCount())); } public static INDArray create(INDArray a) { int[] shape=a.getShape(); int n=(int)IntArrays.arrayProduct(shape); double[] newData = new double[n]; a.getElements(newData, 0); return ImmutableArray.wrap(newData, shape); } @Override public double[] getArray() { throw new UnsupportedOperationException("Array access not supported by ImmutableArray"); } }