package edu.berkeley.nlp.util;
import java.io.Serializable;
import java.util.Arrays;
/**
* Class representing a sparse array. A <code>float[]</code> and
* <code>int[]</code> arrays store a sparse reprsentation of a float array.
* Arrays are grown behind the scene when necessary. Setting/getting a count
* takes time O(log n), where n is the number of non-zero elements.
*
* @author aria42
*
*/
public class SparseFloatArray implements Serializable {
private static final long serialVersionUID = 42L;
float[] data = new float[0];
int[] indices = new int[0];
int length = 0;
private void grow() {
int curSize = data.length;
int newSize = curSize + 10;
float[] newData = new float[newSize];
System.arraycopy(data, 0, newData, 0, curSize);
data = newData;
int[] newIndices = new int[newSize];
System.arraycopy(indices, 0, newIndices, 0, curSize);
for (int i = curSize; i < newIndices.length; ++i) {
newIndices[i] = Integer.MAX_VALUE;
newData[i] = Float.POSITIVE_INFINITY;
}
indices = newIndices;
}
public float getCount(int index) {
int res = Arrays.binarySearch(indices, index);
if (res >= 0 && res < length) {
return data[res];
}
return 0.0f;
}
public void incrementCount(int index0, float x0) {
float curCount = getCount(index0);
setCount(index0, curCount + x0);
}
public int size() {
return length;
}
public void setCount(int index0, float x) {
// float x = (float) x0;
// short index = (short) index0;
int res = Arrays.binarySearch(indices, index0);
// Greater than everything
if (res >= 0 && res < length) {
data[res] = x;
return;
}
if (length + 1 >= data.length) {
grow();
}
// In the middle
int insertionPoint = -(res + 1);
assert insertionPoint >= 0 && insertionPoint <= length : String.format(
"length: %d insertion: %d", length, insertionPoint);
// Shift The Stuff After
System.arraycopy(data, insertionPoint, data, insertionPoint + 1, length
- insertionPoint);
System.arraycopy(indices, insertionPoint, indices, insertionPoint + 1, length
- insertionPoint);
indices[insertionPoint] = index0;
data[insertionPoint] = x;
length++;
}
public int getActiveDimension(int i) {
assert i < length;
return indices[i];
}
public float getActiveCount(int i) {
assert i < length;
return data[i];
}
public void scale(float c) {
for (int i = 0; i < length; ++i) {
data[i] *= c;
}
}
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("{ ");
for (int i = 0; i < length; ++i) {
builder.append(String.format("%d : %.5f", indices[i], data[i]));
builder.append(" ");
}
builder.append(" }");
return builder.toString();
}
public String toString(Indexer<?> indexer) {
StringBuilder builder = new StringBuilder();
builder.append("{ ");
for (int i = 0; i < length; ++i) {
builder.append(String.format("%s : %.5f", indexer.getObject(indices[i]),
data[i]));
builder.append(" ");
}
builder.append(" }");
return builder.toString();
}
public float dotProduct(SparseFloatArray other) {
float sum = 0.0f;
for (int i = 0; i < length; ++i) {
int dim = indices[i];
sum += data[i] * other.getCount(dim);
}
return sum;
}
public static void main(String[] args) {
SparseFloatArray sv = new SparseFloatArray();
sv.setCount(0, 1.0f);
sv.setCount(1, 2.0f);
sv.incrementCount(1, 1.0f);
sv.incrementCount(-1, 10.0f);
System.out.println(sv);
}
}