package mikera.arrayz.impl; import mikera.arrayz.INDArray; import mikera.vectorz.AVector; import mikera.vectorz.util.ErrorMessages; /** * Array created by joining two arrays along a specific dimension * * @author Mike * */ public class JoinedArray extends BaseShapedArray { private static final long serialVersionUID = 4929988077055768422L; final INDArray left; final INDArray right; final int dimension; final int split; private JoinedArray(INDArray left, INDArray right, int dim) { super(left.getShapeClone()); this.left=left; this.right=right; dimension=dim; this.split=shape[dimension]; shape[dimension]+=right.getShape(dimension); } public static JoinedArray join(INDArray a, INDArray b, int dim) { int n=a.dimensionality(); if (b.dimensionality()!=n) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a, b)); for (int i=0; i<n; i++) { if ((i!=dim)&&(a.getShape(i)!=b.getShape(i))) { throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b)); } } return new JoinedArray(a,b,dim); } @Override public double get(int... indexes) { if (indexes.length!=dimensionality()) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, indexes)); int di=indexes[dimension]; if (di<split) { return left.get(indexes); } else { indexes=indexes.clone(); indexes[dimension]-=split; return right.get(indexes); } } @Override public void set(int[] indexes, double value) { if (indexes.length!=dimensionality()) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, indexes)); int di=indexes[dimension]; if (di<split) { left.set(indexes,value); } else { indexes=indexes.clone(); indexes[dimension]-=split; right.set(indexes,value); } } @Override public INDArray slice(int majorSlice) { if (dimension==0) { return (majorSlice<split)?left.slice(majorSlice):right.slice(majorSlice-split); } else { return new JoinedArray(left.slice(majorSlice),right.slice(majorSlice),dimension-1); } } @Override public int componentCount() { return 2; } @Override public INDArray getComponent(int k) { switch (k) { case 0: return left; case 1: return right; } throw new IndexOutOfBoundsException(ErrorMessages.invalidComponent(this,k)); } @Override public INDArray slice(int dimension, int index) { if (this.dimension==dimension) { return (index<split)?left.slice(dimension,index):right.slice(dimension,index-split); } else if (dimension==0) { return slice(index); } else { int nd= (dimension<this.dimension)?dimension:dimension-1; return left.slice(dimension,index).join(right.slice(dimension,index),nd); } } @Override public boolean isView() { return true; } @Override public INDArray exactClone() { return new JoinedArray(left.exactClone(),right.exactClone(),dimension); } @Override public void validate() { if (left.getShape(dimension)+right.getShape(dimension)!=shape[dimension]) throw new Error("Inconsistent shape along split dimension"); super.validate(); } @Override public double get() { throw new IllegalArgumentException("0d get not supported on "+getClass()); } @Override public double get(int x) { if ((x<0)||(x>=sliceCount())) { throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, x)); } if (x<split) { return left.get(x); } else { return right.get(x-split); } } @Override public double get(int x, int y) { if (dimension==0) { if ((x<0)||(x>=sliceCount())) { throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, x,y)); } if (x<split) { return left.get(x,y); } else { return right.get(x-split,y); } } else { if ((y<0)||(y>=sliceCount())) { throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, x,y)); } if (y<split) { return left.get(x,y); } else { return right.get(x,y-split); } } } @Override public boolean equalsArray(double[] data, int offset) { return left.equalsArray(data, offset)&&right.equalsArray(data,(int) (offset+left.elementCount())); } }