/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math; import com.google.common.collect.AbstractIterator; import org.apache.mahout.math.list.IntArrayList; import org.apache.mahout.math.map.OpenIntObjectHashMap; import java.util.Iterator; import java.util.Map; /** Doubly sparse matrix. Implemented as a Map of RandomAccessSparseVector rows */ public class SparseMatrix extends AbstractMatrix { private OpenIntObjectHashMap<Vector> rowVectors; /** * Construct a matrix of the given cardinality with the given row map * * @param rows * a Map<Integer, RandomAccessSparseVector> of rows * @param columns * @param rowVectors */ public SparseMatrix(int rows, int columns, Map<Integer, RandomAccessSparseVector> rowVectors) { super(rows, columns); this.rowVectors = new OpenIntObjectHashMap<Vector>(); for (Map.Entry<Integer, RandomAccessSparseVector> entry : rowVectors.entrySet()) { this.rowVectors.put(entry.getKey(), entry.getValue().clone()); } } /** * Construct a matrix with specified number of rows and columns. */ public SparseMatrix(int rows, int columns) { super(rows, columns); this.rowVectors = new OpenIntObjectHashMap<Vector>(); } @Override public Matrix clone() { SparseMatrix clone = (SparseMatrix) super.clone(); clone.rowVectors = rowVectors.clone(); return clone; } @Override public Iterator<MatrixSlice> iterator() { final IntArrayList keys = new IntArrayList(rowVectors.size()); rowVectors.keys(keys); return new AbstractIterator<MatrixSlice>() { private int slice; @Override protected MatrixSlice computeNext() { if (slice >= rowVectors.size()) { return endOfData(); } int i = keys.get(slice); Vector row = rowVectors.get(i); slice++; return new MatrixSlice(row, i); } }; } @Override public double getQuick(int row, int column) { Vector r = rowVectors.get(row); return r == null ? 0.0 : r.getQuick(column); } @Override public Matrix like() { return new SparseMatrix(rowSize(), columnSize()); } @Override public Matrix like(int rows, int columns) { return new SparseMatrix(rows, columns); } @Override public void setQuick(int row, int column, double value) { Vector r = rowVectors.get(row); if (r == null) { r = new RandomAccessSparseVector(columnSize()); rowVectors.put(row, r); } r.setQuick(column, value); } @Override public int[] getNumNondefaultElements() { int[] result = new int[2]; result[ROW] = rowVectors.size(); for (Vector vectorEntry : rowVectors.values()) { result[COL] = Math.max(result[COL], vectorEntry .getNumNondefaultElements()); } return result; } @Override public Matrix viewPart(int[] offset, int[] size) { if (offset[ROW] < 0) { throw new IndexException(offset[ROW], rowSize()); } if (offset[ROW] + size[ROW] > rowSize()) { throw new IndexException(offset[ROW] + size[ROW], rowSize()); } if (offset[COL] < 0) { throw new IndexException(offset[COL], columnSize()); } if (offset[COL] + size[COL] > columnSize()) { throw new IndexException(offset[COL] + size[COL], columnSize()); } return new MatrixView(this, offset, size); } @Override public Matrix assignColumn(int column, Vector other) { if (rowSize() != other.size()) { throw new CardinalityException(rowSize(), other.size()); } if (column < 0 || column >= columnSize()) { throw new IndexException(column, columnSize()); } for (int row = 0; row < rowSize(); row++) { double val = other.getQuick(row); if (val != 0.0) { Vector r = rowVectors.get(row); if (r == null) { r = new RandomAccessSparseVector(columnSize()); rowVectors.put(row, r); } r.setQuick(column, val); } } return this; } @Override public Matrix assignRow(int row, Vector other) { if (columnSize() != other.size()) { throw new CardinalityException(columnSize(), other.size()); } if (row < 0 || row >= rowSize()) { throw new IndexException(row, rowSize()); } rowVectors.put(row, other); return this; } @Override public Vector viewRow(int row) { if (row < 0 || row >= rowSize()) { throw new IndexException(row, rowSize()); } Vector res = rowVectors.get(row); if (res == null) { res = new RandomAccessSparseVector(columnSize()); } return res; } }