package mikera.matrixx.impl; import java.util.List; import java.util.ArrayList; // import java.util.HashMap; // import java.util.HashSet; // import java.util.Map; // import java.util.Map.Entry; import mikera.arrayz.ISparse; import mikera.indexz.Index; import mikera.matrixx.AMatrix; import mikera.matrixx.Matrix; import mikera.vectorz.AVector; import mikera.vectorz.Op; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.SingleElementVector; import mikera.vectorz.impl.SparseIndexedVector; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * Matrix stored as a collection of normally sparse column vectors * * This format is especially efficient for: * - transposeInnerProduct() with another matrix * - access via getColumn() operation * - transpose into SparseRowMatrix * * @author Mike * */ public class SparseColumnMatrix extends ASparseRCMatrix implements ISparse, IFastColumns { private static final long serialVersionUID = -5994473197711276621L; private static final long SPARSE_ELEMENT_THRESHOLD = 1000L; private final AVector emptyColumn; protected SparseColumnMatrix(int rowCount, int columnCount) { this(new AVector[columnCount],rowCount,columnCount); } protected SparseColumnMatrix(AVector[] data, int rowCount, int columnCount) { super(rowCount,columnCount,data); if (data.length != columnCount) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(columnCount, data.length)); emptyColumn=Vectorz.createZeroVector(rowCount); } protected SparseColumnMatrix(AVector... vectors) { this(vectors, vectors[0].length(), vectors.length); } protected SparseColumnMatrix(List<AVector> data, int rowCount, int columnCount) { this(data.toArray(new AVector[0]),rowCount,columnCount); } protected SparseColumnMatrix(List<AVector> data) { this(data.toArray(new AVector[0])); } // protected SparseColumnMatrix(HashMap<Integer,AVector> data, int rowCount, int columnCount) { // super(rowCount,columnCount,data); // emptyColumn=Vectorz.createZeroVector(rowCount); // } public static SparseColumnMatrix create(int rows, int cols) { return new SparseColumnMatrix(rows, cols); } public static SparseColumnMatrix create(AVector[] data, int rows, int cols) { return new SparseColumnMatrix(data, rows, cols); } public static SparseColumnMatrix create(AVector... vecs) { return new SparseColumnMatrix(vecs); // don't validate; user can call validate() if they want it. } public static SparseColumnMatrix create(List<AVector> columns) { return create(columns.toArray(new AVector[columns.size()])); } public static SparseColumnMatrix wrap(AVector[] vecs, int rows, int cols) { return create(vecs, rows, cols); } public static SparseColumnMatrix wrap(AVector... vecs) { return create(vecs); } public static SparseColumnMatrix create(AMatrix source) { if (source instanceof SparseRowMatrix) return ((SparseRowMatrix)source).toSparseColumnMatrix(); int cc=source.columnCount(); int rc=source.rowCount(); AVector[] data = new AVector[cc]; for (int i=0; i<cc; i++) { AVector col = source.getColumn(i); if (!col.isZero()) data[i] = Vectorz.createSparse(col); } return new SparseColumnMatrix(data,rc,cc); } public static SparseColumnMatrix wrap(List<AVector> vecs) { return create(vecs); } // public static SparseColumnMatrix wrap(HashMap<Integer,AVector> cols, int rowCount, int columnCount) { // return new SparseColumnMatrix(cols,rowCount,columnCount); // } @Override public int componentCount() { return cols; } @Override public AVector getComponent(int k) { AVector v=data[(int)k]; if (v==null) return emptyColumn; return v; } @Override protected int lineLength() { return rows; } @Override public double get(int i, int j) { return getColumn(j).get(i); } @Override public void set(int i, int j, double value) { checkIndex(i,j); unsafeSet(i,j,value); } @Override public double unsafeGet(int row, int column) { return getColumn(column).unsafeGet(row); } @Override public void unsafeSet(int i, int j, double value) { AVector v = unsafeGetVector(j); if (v==null) { if (value == 0.0) return; v = SingleElementVector.create(value, i, rows); } else if (v.isFullyMutable()) { v.unsafeSet(i,value); return; } else { v = v.sparseClone(); v.unsafeSet(i, value); } unsafeSetVec(j, v); } @Override public void set(AMatrix a) { checkSameShape(a); List<AVector> scols=a.getColumns(); for (int i=0; i<cols; i++) { setColumn(i,scols.get(i)); } } @Override public void addAt(int i, int j, double d) { AVector v=getColumn(j); if (v.isFullyMutable()) { v.addAt(i, d); } else { v=v.mutable(); v.addAt(i, d); replaceColumn(j,v); } } @Override public void addToArray(double[] targetData, int offset) { for (int i = 0; i < cols; ++i) { AVector v = unsafeGetVector(i); if (v != null) v.addToArray(targetData, offset+i, cols); } } @Override public List<AVector> getRows() { return toSparseRowMatrix().getRows(); } public SparseRowMatrix toSparseRowMatrix() { SparseRowMatrix rm=SparseRowMatrix.create(rows, cols); for (int j = 0; j < cols; j++) { AVector colVec = unsafeGetVector(j); if (colVec!=null) { Index nonSparseRows = colVec.nonSparseIndex(); int n=nonSparseRows.length(); for (int k = 0; k < n; k++) { int i = nonSparseRows.unsafeGet(k); double v=colVec.unsafeGet(i); if (v!=0.0) { rm.unsafeSet(i,j, v); } } } } return rm; } private AVector ensureMutableColumn(int i) { AVector v = unsafeGetVector(i); if (v == null) { AVector nv=SparseIndexedVector.createLength(rows); unsafeSetVec(i, nv); return nv; } if (v.isFullyMutable()) return v; AVector mv=v.mutable(); unsafeSetVec(i, mv); return mv; } @Override public AVector getColumn(int j) { AVector v = unsafeGetVector(j); if (v==null) return emptyColumn; return v; } @Override public AVector getColumnView(int j) { return ensureMutableColumn(j); } @Override public boolean isLowerTriangular() { int cc=columnCount(); for (int i=1; i<cc; i++) { if (!getColumn(i).isRangeZero(0, i)) return false; } return true; } @Override public void swapColumns(int i, int j) { if (i == j) return; AVector a = unsafeGetVector(i); AVector b = unsafeGetVector(j); unsafeSetVec(i, b); unsafeSetVec(j, a); } @Override public void replaceColumn(int i, AVector vec) { checkColumn(i); if (vec.length()!=rows) throw new IllegalArgumentException(ErrorMessages.incompatibleShape(vec)); unsafeSetVec(i, vec); } @Override public void add(AMatrix a) { int count=columnCount(); for (int i=0; i<count; i++) { AVector myVec=unsafeGetVector(i); AVector aVec=a.getColumn(i); if (myVec==null) { if (!aVec.isZero()) { unsafeSetVec(i,aVec.copy()); } } else if (myVec.isMutable()) { myVec.add(aVec); } else { unsafeSetVec(i,myVec.addCopy(aVec)); } } } @Override public void copyColumnTo(int i, double[] targetData, int offset) { getColumn(i).copyTo(targetData, offset); } @Override public void copyRowTo(int row, double[] targetData, int offset) { for (int i = 0; i < cols; ++i) { AVector e = unsafeGetVector(i); targetData[offset+i] = (e==null)? 0.0 : e.unsafeGet(row); } } @Override public SparseRowMatrix getTransposeView() { return SparseRowMatrix.wrap(data, cols, rows); } @Override public AMatrix multiplyCopy(double a) { long n=componentCount(); AVector[] ndata=new AVector[(int)n]; for (int i = 0; i < n; ++i) { AVector v = unsafeGetVector(i); if (v != null) { ndata[i] = v.multiplyCopy(a); } } return wrap(ndata,rows,cols); } @Override public AVector innerProduct(AVector a) { return transform(a); } @Override public AVector transform(AVector a) { Vector r=Vector.createLength(rows); for (int i=0; i<cols; i++) { getColumn(i).addMultipleToArray(a.get(i), 0, r.getArray(), 0, rows); } return r; } @Override public Matrix toMatrixTranspose() { Matrix m=Matrix.create(cols, rows); for (int i = 0; i < cols; ++i) { getColumn(i).copyTo(m.data, rows*i); } return m; } @Override public double[] toDoubleArray() { Matrix m=Matrix.create(rows, cols); for (int i=0; i<cols; i++) { AVector v = unsafeGetVector(i); if (v != null) { m.getColumn(i).set(v); } } return m.getArray(); } @Override public AMatrix transposeInnerProduct(AMatrix a) { return getTranspose().innerProduct(a); } @Override public SparseColumnMatrix exactClone() { SparseColumnMatrix result= new SparseColumnMatrix(rows,cols); for (int i = 0; i < cols; ++i) { AVector col = unsafeGetVector(i); if (col != null) { result.replaceColumn(i, col.exactClone()); } } return result; } @Override public AMatrix clone() { if (this.elementCount() < SPARSE_ELEMENT_THRESHOLD) return super.clone(); return exactClone(); } @Override public AMatrix sparse() { return this; } @Override public void validate() { super.validate(); for (int i=0; i<cols; i++) { if (getColumn(i).length()!=rows) throw new VectorzException("Invalid row count at column: "+i); } } @Override public boolean equals(AMatrix m) { if (m==this) return true; if (!isSameShape(m)) return false; for (int i=0; i<cols; i++) { AVector v=unsafeGetVector(i); AVector ov = m.getColumn(i); if (v==null) { if (!ov.isZero()) return false; } else { if (!v.equals(ov)) return false; } } return true; } }