package mikera.matrixx.impl; import java.util.Arrays; import java.util.List; import java.util.ArrayList; // import java.util.HashMap; // import java.util.HashSet; // import java.util.Map; // import java.util.Map.Entry; import mikera.arrayz.INDArray; import mikera.arrayz.ISparse; import mikera.indexz.Index; import mikera.matrixx.AMatrix; import mikera.matrixx.Matrixx; import mikera.vectorz.AVector; import mikera.vectorz.Op; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.RepeatedElementVector; import mikera.vectorz.impl.SingleElementVector; import mikera.vectorz.impl.SparseIndexedVector; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * Matrix stored as a collection of normally sparse row vectors. * * This format is especially efficient for: * - innerProduct() with another matrix, especially one with efficient * column access like SparseColumnMatrix * - access via getRow() operation * - transpose into SparseColumnMatrix * * @author Mike * */ public class SparseRowMatrix extends ASparseRCMatrix implements ISparse, IFastRows { private static final long serialVersionUID = 8646257152425415773L; private static final long SPARSE_ELEMENT_THRESHOLD = 1000L; private final AVector emptyRow; protected SparseRowMatrix(int rowCount, int columnCount) { this(new AVector[rowCount],rowCount,columnCount); } protected SparseRowMatrix(AVector[] data, int rowCount, int columnCount) { super(rowCount,columnCount,data); if (data.length != rowCount) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(rowCount, data.length)); emptyRow=Vectorz.createZeroVector(columnCount); } protected SparseRowMatrix(AVector... vectors) { this(vectors, vectors.length, vectors[0].length()); } protected SparseRowMatrix(List<AVector> data, int rowCount, int columnCount) { this(data.toArray(new AVector[0]),rowCount,columnCount); } protected SparseRowMatrix(List<AVector> data) { this(data.toArray(new AVector[0])); } // protected SparseRowMatrix(HashMap<Integer,AVector> data, int rowCount, int columnCount) { // super(rowCount,columnCount,data); // emptyColumn=Vectorz.createZeroVector(rowCount); // } public static SparseRowMatrix create(int rows, int cols) { return new SparseRowMatrix(rows, cols); } public static SparseRowMatrix create(AVector[] data, int rows, int cols) { return new SparseRowMatrix(data, rows, cols); } public static SparseRowMatrix create(AVector... vecs) { return new SparseRowMatrix(vecs); // don't validate; user can call validate() if they want it. } public static SparseRowMatrix create(List<AVector> rows) { return create(rows.toArray(new AVector[rows.size()])); } public static INDArray create(ArrayList<INDArray> slices, int rows, int cols) { AVector[] vecs=new AVector[rows]; for (int i=0; i<rows; i++) { INDArray a=slices.get(i); if ((a.dimensionality()!=1)||(a.sliceCount()!=cols)) throw new IllegalArgumentException(ErrorMessages.incompatibleShape(a)); vecs[i]=a.asVector(); } return wrap(vecs,rows,cols); } public static SparseRowMatrix wrap(AVector[] vecs, int rows, int cols) { return create(vecs, rows, cols); } public static SparseRowMatrix wrap(AVector... vecs) { return create(vecs); } public static SparseRowMatrix create(AMatrix source) { if (source instanceof SparseColumnMatrix) return ((SparseColumnMatrix)source).toSparseRowMatrix(); int rc = source.rowCount(); int cc = source.columnCount(); AVector[] data = new AVector[rc]; List<AVector> rows=source.getRows(); for (int i = 0; i < rc; i++) { AVector row = rows.get(i); if (!row.isZero()) { data[i] = Vectorz.createSparse(row); } } return SparseRowMatrix.wrap(data,rc,cc); } public static SparseRowMatrix wrap(List<AVector> vecs) { return create(vecs); } // public static SparseRowMatrix wrap(HashMap<Integer, AVector> data, int rows, int cols) { // return new SparseRowMatrix(data, rows, cols); // } @Override public int componentCount() { return rows; } @Override public AVector getComponent(int k) { AVector v=data[(int)k]; if (v==null) return emptyRow; return v; } @Override protected int lineLength() { return cols; } @Override public double get(int i, int j) { return getRow(i).get(j); } @Override public void set(int i, int j, double value) { checkIndex(i,j); unsafeSet(i,j,value); } @Override public double unsafeGet(int row, int column) { return getRow(row).unsafeGet(column); } @Override public void unsafeSet(int i, int j, double value) { AVector v = unsafeGetVector(i); if (v == null) { if (value == 0.0) return; v = SingleElementVector.create(value, j, cols); } else if (v.isFullyMutable()) { v.set(j, value); return; } else { v = v.sparseClone(); v.unsafeSet(j, value); } unsafeSetVec(i, v); } @Override public void set(AMatrix a) { checkSameShape(a); List<AVector> srows=a.getRows(); for (int i=0; i<rows; i++) { setRow(i,srows.get(i)); } } @Override public void setRow(int i, AVector v) { data[i]=v.copy(); } @Override public void addAt(int i, int j, double d) { if (d==0.0) return; AVector v=unsafeGetVector(i); if (v.isFullyMutable()) { v.addAt(j, d); } else { v=v.mutable(); v.addAt(j, d); replaceRow(i,v); } } @Override public void addToArray(double[] targetData, int offset) { for (int i = 0; i < rows; ++i) { AVector v = unsafeGetVector(i); if (v != null) v.addToArray(targetData, offset+cols*i); } } private AVector ensureMutableRow(int i) { AVector v = unsafeGetVector(i); if (v == null) { AVector nv=SparseIndexedVector.createLength(cols); unsafeSetVec(i, nv); return nv; } if (v.isFullyMutable()) return v; AVector mv=v.mutable(); unsafeSetVec(i, mv); return mv; } @Override public List<AVector> getColumns() { return toSparseColumnMatrix().getColumns(); } public SparseColumnMatrix toSparseColumnMatrix() { SparseColumnMatrix cm=SparseColumnMatrix.create(rows,cols); for (int i = 0; i < rows; i++) { AVector rowVec = unsafeGetVector(i); if (null != rowVec) { Index nonSparseCols = rowVec.nonSparseIndex(); int n=nonSparseCols.length(); for (int k = 0; k < n; k++) { int j = nonSparseCols.unsafeGet(k); double v=rowVec.unsafeGet(j); if (v!=0.0) { cm.unsafeSet(i,j, v); } } } } return cm; } @Override public AVector getRow(int i) { AVector v = unsafeGetVector(i); if (v == null) return emptyRow; return v; } @Override public List<AVector> getRows() { ArrayList<AVector> rowList = new ArrayList<AVector>(rows); for (int i = 0; i < rows; i++) { AVector v = unsafeGetVector(i); if (v == null) v=emptyRow; rowList.add(v); } return rowList; } @Override public AVector getRowView(int i) { return ensureMutableRow(i); } @Override public boolean isUpperTriangular() { int rc=rowCount(); for (int i=1; i<rc; i++) { if (!getRow(i).isRangeZero(0, i)) return false; } return true; } @Override public void swapRows(int i, int j) { if (i == j) return; AVector a = unsafeGetVector(i); AVector b = unsafeGetVector(j); unsafeSetVec(i, b); unsafeSetVec(j, a); } @Override public void replaceRow(int i, AVector vec) { if (vec.length() != cols) throw new IllegalArgumentException(ErrorMessages.incompatibleShape(vec)); unsafeSetVec(i, vec); } @Override public void add(AMatrix a) { checkSameShape(a); int rc=rowCount(); for (int i=0; i<rc; i++) { AVector myVec=unsafeGetVector(i); AVector aVec=a.getRow(i); if (myVec==null) { if (!aVec.isZero()) { unsafeSetVec(i,aVec.copy()); } } else if (myVec.isFullyMutable()) { myVec.add(aVec); } else { unsafeSetVec(i,myVec.addCopy(aVec)); } } } @Override public void copyRowTo(int i, double[] data, int offset) { AVector v=this.unsafeGetVector(i); if (v==null) { Arrays.fill(data, offset, offset+cols, 0.0); } else { v.getElements(data, offset); } } @Override public void copyColumnTo(int col, double[] targetData, int offset) { for (int i = 0; i < rows; ++i) { AVector v = unsafeGetVector(i); targetData[offset+i] = (v==null) ? 0.0 : v.unsafeGet(col); } } @Override public SparseColumnMatrix getTransposeView() { return SparseColumnMatrix.wrap(data, cols, rows); } @Override public AMatrix multiplyCopy(double a) { long n=componentCount(); AVector[] ndata=new AVector[(int)n]; for (int i = 0; i < n; ++i) { AVector v = unsafeGetVector(i); if (v != null) ndata[i] = v.multiplyCopy(a); } return wrap(ndata,rows,cols); } @Override public void multiplyRow(int i, double value) { if (value==0.0) { unsafeSetVec(i,null); return; } AVector v = unsafeGetVector(i); if (v==null) return; v=v.multiplyCopy(value); unsafeSetVec(i,v); } @Override public AVector innerProduct(AVector a) { return transform(a); } @Override public AVector transform(AVector a) { AVector r=Vector.createLength(rows); for (int i=0; i<rows; i++) { r.set(i,getRow(i).dotProduct(a)); } return r; } @Override public double[] toDoubleArray() { double[] ds=new double[rows*cols]; // we use adding to array, since rows themselves are likely to be sparse for (int i = 0; i < rows; ++i) { AVector v = unsafeGetVector(i); if (v != null) v.addToArray(ds, i*cols); } return ds; } @Override public AMatrix innerProduct(AMatrix a) { if (a instanceof SparseColumnMatrix) { return innerProduct((SparseColumnMatrix) a); } SparseRowMatrix r = Matrixx.createSparse(rows, a.columnCount()); for (int i = 0; i < rows; ++i) { AVector row = unsafeGetVector(i); if (! ((row == null) || (row.isZero()))) { r.replaceRow(i,row.innerProduct(a)); } } return r; } /** * Specialised inner product for sparse row matrix multiplied by sparse column matrix. * * @param a * @return */ public SparseRowMatrix innerProduct(SparseColumnMatrix a) { return innerProduct(SparseRowMatrix.create(a)); } public SparseRowMatrix innerProduct(SparseRowMatrix a) { SparseRowMatrix r = Matrixx.createSparse(rows, a.columnCount()); for (int i = 0; i < rows; ++i) { AVector row = unsafeGetVector(i); if (row != null) { r.replaceRow(i,row.innerProduct(a)); } } return r; } @Override public SparseRowMatrix exactClone() { SparseRowMatrix result = new SparseRowMatrix(rows, cols); for (int i = 0; i < rows; ++i) { AVector row = unsafeGetVector(i); if (row != null) result.replaceRow(i, row.exactClone()); } return result; } @Override public AMatrix clone() { if (this.elementCount() < SPARSE_ELEMENT_THRESHOLD) return super.clone(); return exactClone(); } @Override public AMatrix sparse() { return this; } @Override public void validate() { super.validate(); for (int i=0; i<rows; i++) { if (getRow(i).length()!=cols) throw new VectorzException("Invalid column count at row: "+i); } } @Override public boolean equals(AMatrix m) { if (m==this) return true; if (!isSameShape(m)) return false; for (int i=0; i<rows; i++) { AVector v=unsafeGetVector(i); AVector ov = m.getRow(i); if (v==null) { if (!ov.isZero()) return false; } else { if (!v.equals(ov)) return false; } } return true; } public static AVector innerProduct(AMatrix a, AVector b) { // TODO: consider reducing working set? return create(a).innerProduct(b); } public static AMatrix innerProduct(AMatrix a, AMatrix b) { // TODO: consider reducing working set? return create(a).innerProduct(b); } }