package edu.stanford.nlp.stats;
import java.util.Random;
public class DirichletProcess<E> implements ProbabilityDistribution<E> {
/**
*
*/
private static final long serialVersionUID = -8653536087199951278L;
private final ProbabilityDistribution<E> baseMeasure;
private final double alpha;
private final ClassicCounter<E> sampled;
public DirichletProcess(ProbabilityDistribution<E> baseMeasure, double alpha) {
this.baseMeasure = baseMeasure;
this.alpha = alpha;
this.sampled = new ClassicCounter<>();
sampled.incrementCount(null, alpha);
}
public E drawSample(Random random) {
E drawn = Counters.sample(sampled);
if (drawn == null) {
drawn = baseMeasure.drawSample(random);
}
sampled.incrementCount(drawn);
return drawn;
}
public double numOccurances(E object) {
if (object == null) {
throw new RuntimeException("You cannot ask for the number of occurances of null.");
}
return sampled.getCount(object);
}
public double probabilityOf(E object) {
if (object == null) {
throw new RuntimeException("You cannot ask for the probability of null.");
}
if (sampled.keySet().contains(object)) {
return sampled.getCount(object) / sampled.totalCount();
} else {
return 0.0;
}
}
public double logProbabilityOf(E object) {
return Math.log(probabilityOf(object));
}
public double probabilityOfNewObject() {
return alpha / sampled.totalCount();
}
}