/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.types;
import java.io.Serializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import cc.mallet.util.ArrayUtils;
// Generated package name
/**
* Implementation of Matrix that allows arbitrary
* number of dimensions. This implementation
* simply uses a flat array.
*
* Created: Tue Sep 16 14:52:37 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: SparseMatrixn.java,v 1.1 2007/10/22 21:37:39 mccallum Exp $
*/
public class SparseMatrixn implements Matrix, Cloneable, Serializable {
private SparseVector values;
private int numDimensions;
private int[] sizes;
private int singleSize;
/**
* Create a 1-d dense matrix with the given values.
*/
public SparseMatrixn(double[] vals) {
numDimensions = 1;
sizes = new int[1];
sizes [0] = vals.length;
values = new SparseVector (vals);
computeSingleSIze ();
}
/**
* Create a dense matrix with the given dimensions.
*
* @param szs An array containing the maximum for
* each dimension.
*/
public SparseMatrixn (int szs[]) {
numDimensions = szs.length;
sizes = (int[])szs.clone();
int total = 1;
for (int j = 0; j < numDimensions; j++) {
total *= sizes [j];
}
values = new SparseVector (new double [total]);
computeSingleSIze ();
}
public SparseMatrixn (int[] szs, double[] vals) {
numDimensions = szs.length;
sizes = (int[])szs.clone ();
values = new SparseVector (vals);
computeSingleSIze ();
}
/**
* Create a sparse matrix with the given dimensions and
* the given values.
*
* @param szs An array containing the maximum for
* each dimension.
* @param idxs An array containing the single index
* for each entry of the matrix. A single index is
* an integer computed from the indices of each dimension.
* as returned by {@link Matrixn#singleIndex}.
* @param vals A flat array of the entries of the
* matrix, in row-major order.
*/
public SparseMatrixn (int[] szs, int[] idxs, double[] vals) {
numDimensions = szs.length;
sizes = (int[])szs.clone();
values = new SparseVector (idxs, vals, true, true);
computeSingleSIze ();
}
private void computeSingleSIze ()
{
int product = 1;
for (int i = 0; i < sizes.length; i++) {
int size = sizes[i];
product *= size;
}
singleSize = product;
}
public int getNumDimensions () { return numDimensions; };
public int getDimensions (int [] szs) {
for ( int i = 0; i < numDimensions; i++ ) {
szs [i] = this.sizes [i];
}
return numDimensions;
}
public double value (int[] indices) {
return values.value (singleIndex (indices));
}
public void setValue (int[] indices, double value) {
values.setValue (singleIndex (indices), value);
}
/**
* Returns an array of all the present indices.
* Callers must not modify the return value.
*/
public int[] getIndices () {
return values.getIndices ();
}
public ConstantMatrix cloneMatrix () {
/* The Matrixn constructor will clone the arrays. */
return new SparseMatrixn (sizes, values.getIndices (), values.getValues ());
}
public Object clone () {
return cloneMatrix();
}
public int singleIndex (int[] indices)
{
return Matrixn.singleIndex (sizes, indices);
}
// This is public static so it will be useful as a general
// dereferencing utility for multidimensional arrays.
public static int singleIndex (int[] szs, int[] indices)
{
int idx = 0;
for ( int dim = 0; dim < indices.length; dim++ ) {
idx = (idx * szs[dim]) + indices [dim];
}
return idx;
}
public void singleToIndices (int single, int[] indices) {
Matrixn.singleToIndices (single, indices, sizes);
}
public boolean equals (Object o) {
if (o instanceof SparseMatrixn) {
/* This could be extended to work for all Matrixes. */
SparseMatrixn m2 = (SparseMatrixn) o;
return
(numDimensions == m2.numDimensions) &&
(sizes.equals (m2.sizes)) &&
(values.equals (m2.values));
} else {
return false;
}
}
/**
* Returns a one-dimensional array representation of the matrix.
* Caller must not modify the return value.
* @return An array of the values where index 0 is the major index, etc.
*/
public double[] toArray () {
return values.getValues ();
}
// Methods from Matrix
public double singleValue (int i)
{
return values.singleValue (i);
}
public int singleSize ()
{
return singleSize;
}
// Access by index into sparse array, efficient for sparse and dense matrices
public int numLocations ()
{
return values.numLocations ();
}
public int location (int index)
{
return values.location (index);
}
public double valueAtLocation (int location)
{
return values.valueAtLocation (location);
}
public void setValueAtLocation (int location, double value)
{
values.setValueAtLocation (location, value);
}
// Returns a "singleIndex"
public int indexAtLocation (int location)
{
return values.indexAtLocation (location);
}
public double dotProduct (ConstantMatrix m)
{
return values.dotProduct (m);
}
public double absNorm ()
{
return values.absNorm ();
}
public double oneNorm ()
{
return values.oneNorm ();
}
public double twoNorm ()
{
return values.twoNorm ();
}
public double infinityNorm ()
{
return values.infinityNorm ();
}
public void print()
{
values.print ();
}
public boolean isNaN()
{
return values.isNaN ();
}
public void setSingleValue (int i, double value)
{
values.setValue (i, value);
}
public void incrementSingleValue (int i, double delta)
{
double value = values.value (i);
values.setValue (i, value + delta);
}
public void setAll (double v)
{
values.setAll (v);
}
public void set (ConstantMatrix m)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void setWithAddend (ConstantMatrix m, double addend)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void setWithFactor (ConstantMatrix m, double factor)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void plusEquals (ConstantMatrix m)
{
plusEquals (m, 1.0);
}
// sucks, but so does the visitor pattern. not often used.
public void plusEquals (ConstantMatrix m, double factor)
{
if (m instanceof SparseVector) {
values.plusEqualsSparse ((SparseVector) m, factor);
} else if (m instanceof SparseMatrixn) {
SparseMatrixn smn = (SparseMatrixn) m;
if (Arrays.equals (sizes, smn.sizes)) {
values.plusEqualsSparse (smn.values, factor);
} else {
throw new UnsupportedOperationException ("sizes of " + m + " do not match " + this);
}
} else {
throw new UnsupportedOperationException ("Can't add " + m + " to " + this);
}
}
public void equalsPlus (double factor, ConstantMatrix m)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void timesEquals (double factor)
{
values.timesEquals (factor);
}
public void elementwiseTimesEquals (ConstantMatrix m)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void elementwiseTimesEquals (ConstantMatrix m, double factor)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void divideEquals (double factor)
{
values.timesEquals (1 / factor);
}
public void elementwiseDivideEquals (ConstantMatrix m)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public void elementwiseDivideEquals (ConstantMatrix m, double factor)
{
throw new UnsupportedOperationException ("Not yet implemented.");
}
public double oneNormalize ()
{
double norm = values.oneNorm();
values.timesEquals (1 / norm);
return norm;
}
public double twoNormalize ()
{
double norm = values.twoNorm();
values.timesEquals (1 / norm);
return norm;
}
public double absNormalize ()
{
double norm = values.absNorm();
values.timesEquals (1 / norm);
return norm;
}
public double infinityNormalize ()
{
double norm = values.infinityNorm();
values.timesEquals (1 / norm);
return norm;
}
// Serialization garbage
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CURRENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
int version = in.readInt ();
}
}