package ch.akuhn.matrix;
import static ch.akuhn.foreach.For.range;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
public class SparseVector extends Vector {
/* default */int[] keys;
/* default */int size, used;
/* default */double[] values;
protected SparseVector(double[] values) {
this(values.length);
for (int n: range(values.length)) {
if (values[n] != 0) put(n, values[n]);
}
}
protected SparseVector(int size) {
this(size, 10);
}
public SparseVector(int size, int capacity) {
assert size >= 0;
assert capacity >= 0;
this.size = size;
this.keys = new int[capacity];
this.values = new double[capacity];
}
@Override
public double add(int key, double value) {
if (key < 0 || key >= size) throw new IndexOutOfBoundsException(Integer.toString(key));
int spot = Arrays.binarySearch(keys, 0, used, key);
if (spot >= 0) return values[spot] += value;
return update(-1 - spot, key, value);
}
@Override
public Iterable<Entry> entries() {
return new Iterable<Entry>() {
@Override
public Iterator<Entry> iterator() {
return new Iterator<Entry>() {
private int spot = 0;
@Override
public boolean hasNext() {
return spot < used;
}
@Override
public Entry next() {
if (!hasNext()) throw new NoSuchElementException();
return new Entry(keys[spot], values[spot++]);
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
};
}
@Override
public boolean equals(Object obj) {
return obj instanceof SparseVector && this.equals((SparseVector) obj);
}
public boolean equals(SparseVector v) {
return size == v.size && used == v.used && Arrays.equals(keys, v.keys) && Arrays.equals(values, values);
}
@Override
public double get(int key) {
if (key < 0 || key >= size) throw new IndexOutOfBoundsException(Integer.toString(key));
int spot = Arrays.binarySearch(keys, 0, used, key);
return spot < 0 ? 0 : values[spot];
}
@Override
public int hashCode() {
return size ^ Arrays.hashCode(keys) ^ Arrays.hashCode(values);
}
public boolean isUsed(int key) {
return 0 <= Arrays.binarySearch(keys, 0, used, key);
}
@Override
public double put(int key, double value) {
if (key < 0 || key >= size) throw new IndexOutOfBoundsException(Integer.toString(key));
int spot = Arrays.binarySearch(keys, 0, used, key);
if (spot >= 0)
return values[spot] = (float) value;
else
return update(-1 - spot, key, value);
}
public void resizeTo(int newSize) {
if (newSize < this.size) throw new UnsupportedOperationException();
this.size = newSize;
}
@Override
public int size() {
return size;
}
private double update(int spot, int key, double value) {
// grow if reaching end of capacity
if (used == keys.length) {
int capacity = (keys.length * 3) / 2 + 1;
keys = Arrays.copyOf(keys, capacity);
values = Arrays.copyOf(values, capacity);
}
// shift values if not appending
if (spot < used) {
System.arraycopy(keys, spot, keys, spot + 1, used - spot);
System.arraycopy(values, spot, values, spot + 1, used - spot);
}
used++;
keys[spot] = key;
return values[spot] = (float) value;
}
@Override
public int used() {
return used;
}
public void trim() {
keys = Arrays.copyOf(keys, used);
values = Arrays.copyOf(values, used);
}
@Override
public double dot(Vector x) {
double product = 0;
for (int k = 0; k < used; k++)
product += x.get(keys[k]) * values[k];
return product;
}
@Override
public void scaleAndAddTo(double a, Vector y) {
for (int k = 0; k < used; k++)
y.add(keys[k], a * values[k]);
}
@Override
public boolean equals(Vector v, double epsilon) {
throw new Error("not yet implemented");
}
@Override
public Vector times(double scalar) {
SparseVector y = new SparseVector(size);
y.keys = Arrays.copyOf(keys, size);
y.values = Arrays.copyOf(values, size);
for (int i = 0; i < values.length; i++)
values[i] *= scalar;
return y;
}
@Override
public void apply(Function f) {
assert f.apply(0) == 0; // assume zero is fixpoint
for (int n = 0; n < used; n++)
values[n] = f.apply(values[n]);
}
}