package mikera.matrixx.impl; import mikera.arrayz.ISparse; import mikera.indexz.Index; import mikera.vectorz.AVector; import mikera.vectorz.Vector; import mikera.vectorz.impl.AxisVector; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * Matrix that transforms to a selection of specific elements of the source vector * i.e. has exactly one 1.0 in each row * * @author Mike */ public final class SubsetMatrix extends ABooleanMatrix implements ISparse, IFastRows { private static final long serialVersionUID = 4937375232646236833L; private int inputDims; private Index components; private SubsetMatrix(int inputDimensions, Index components) { inputDims=inputDimensions; this.components=components; } public static SubsetMatrix create(Index components, int inputDimensions) { SubsetMatrix sm=new SubsetMatrix(inputDimensions,components); if (!sm.components.allInRange(0,sm.inputDims)) { throw new IllegalArgumentException("SubsetMatrix with input dimensionality "+sm.inputDims+" not valid for component indexes: "+sm.components); } return sm; } @Override public void transform(AVector source, AVector dest) { dest.set(source, components); } @Override public double elementSum() { return rowCount(); } @Override public long nonZeroCount() { return rowCount(); } @Override public int rowCount() { return components.length(); } @Override public int columnCount() { return inputDims; } @Override public double density() { return 1.0/inputDims; } @Override public AxisVector getRowView(int i) { return AxisVector.create(components.get(i), inputDims); } @Override public double calculateElement(int i, AVector inputVector) { return inputVector.unsafeGet(components.get(i)); } @Override public double calculateElement(int i, Vector inputVector) { return inputVector.unsafeGet(components.get(i)); } @Override public double get(int row, int column) { if (column<0||column>=inputDims) { throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, row,column)); } return (column==components.get(row))?1.0:0.0; } @Override public double unsafeGet(int row, int column) { return (column==components.get(row))?1.0:0.0; } @Override public void set(int row, int column, double value) { throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column)); } @Override public SubsetMatrix exactClone() { return SubsetMatrix.create(components.clone(),inputDims); } @Override public void validate() { int rc=rowCount(); int cc=columnCount(); for (int i=0; i<rc; i++) { int s=components.get(i); if ((s<0)||(s>=cc)) throw new VectorzException("Component out of range at row "+i); } super.validate(); } }