/*
* Created on Mar 5, 2006
*/
package org.seqcode.math.probability;
import java.util.*;
import java.text.*;
/**
* @author tdanford
*/
public class FiniteDistribution {
private static Random rand;
private static NumberFormat nf;
static {
rand = new Random();
nf = DecimalFormat.getInstance();
nf.setMaximumFractionDigits(4);
}
private double[] vals;
public FiniteDistribution(int s) {
if(s <= 0) { throw new IllegalArgumentException("non-zero length req: " + s); }
vals = new double[s];
clear(1.0);
normalize();
}
public FiniteDistribution(Collection<Integer> w) {
vals = new double[w.size()];
int i = 0;
for(int v : w) {
vals[i++] = (double)v;
}
normalize();
}
public FiniteDistribution(double[] probs) {
vals = probs.clone();
normalize();
}
public FiniteDistribution(int s, int v) {
this(s);
clear(0.0);
vals[v] = 1.0;
}
public int size() { return vals.length; }
public double getProb(int i) { return vals[i]; }
public boolean equals(Object o) {
if(!(o instanceof FiniteDistribution)) {
return false;
}
FiniteDistribution fd = (FiniteDistribution)o;
if(vals.length != fd.vals.length) { return false; }
for(int i = 0; i < vals.length; i++) {
if(vals[i] != fd.vals[i]) { return false; }
}
return true;
}
public int hashCode() {
int code = 17;
for(int i = 0; i < vals.length; i++) {
long bits = Double.doubleToLongBits(vals[i]);
code += (int)(bits >> 32); code *= 37;
}
return code;
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for(int i = 0; i < vals.length; i++) {
sb.append(nf.format(vals[i]));
if(i < vals.length-1) { sb.append(" "); }
}
sb.append("]");
return sb.toString();
}
public int sampleValue() {
double p = rand.nextDouble();
for(int i = 0; i < vals.length; i++) {
p -= vals[i];
if(p <= 0.0) { return i; }
}
return vals.length-1;
}
public void clear() {
clear(0.0);
}
public void clear(double clearValue) {
for(int i = 0; i < vals.length; i++) {
vals[i] = clearValue;
}
}
public void addValue(Integer i) {
addValue(i, 1.0);
}
public void addValue(Integer i, Double w) {
vals[i] += w;
}
public void normalize() {
double sum = 0.0;
for(int i = 0; i < vals.length; i++) {
sum += vals[i];
}
for(int i = 0; i < vals.length; i++) {
if(sum > 0.0) {
vals[i] /= sum;
} else {
vals[i] = 1.0 / (double)vals.length;
}
}
}
}