package dist;
import shared.Copyable;
import shared.DataSet;
import shared.Instance;
import util.ABAGAILArrays;
/**
* A class implementing a look up table for
* conditional probability, that is
* representing P(O = o | I = i) for some single
* discrete random variables O the output and
* I the input
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class DiscreteDistributionTable extends AbstractConditionalDistribution implements Copyable {
/**
* The probability table
*/
private DiscreteDistribution[] discreteDistributions;
/**
* Make a new look up table conditional probability
* @param probabilities the initial probabilities
* @param m the m estimate parameter
*/
public DiscreteDistributionTable(double[][] probabilities) {
discreteDistributions = new DiscreteDistribution[probabilities.length];
for (int i = 0; i < discreteDistributions.length; i++) {
discreteDistributions[i] = new DiscreteDistribution(probabilities[i]);
}
}
/**
* Make a new table
* @param table the table to use
* @param m the m estimate parameter
*/
public DiscreteDistributionTable(DiscreteDistribution[] table) {
this.discreteDistributions = table;
}
/**
* @see dist.ConditionalDistribution#distributionFor(shared.Instance)
*/
public Distribution distributionFor(Instance i) {
return discreteDistributions[i.getDiscrete()];
}
/**
* @see shared.Distribution#reestimate(shared.DataSet)
*/
public void estimate(DataSet observations) {
double[] sums = new double[discreteDistributions.length];
double[][] probabilities = getProbabilityMatrix();
for (int i = 0; i < probabilities.length; i++) {
for (int j = 0; j < probabilities[i].length; j++) {
probabilities[i][j] = 0;
}
}
for (int i = 0; i < observations.size(); i++) {
Instance cur = observations.get(i);
sums[cur.getDiscrete()] += cur.getWeight();
probabilities[cur.getDiscrete()][cur.getLabel().getDiscrete()] += cur.getWeight();
}
for (int i = 0; i < probabilities.length; i++) {
double[] prior = discreteDistributions[i].getPrior();
double m = discreteDistributions[i].getM();
for (int j = 0; j < probabilities[i].length; j++) {
probabilities[i][j] = (probabilities[i][j] + prior[i] * m) / (sums[i] + m);
}
}
}
/**
* Get the probability matrix
* value [i][j] i the matrix is the
* probability of observing j given i
* @return the matrix
*/
public double[][] getProbabilityMatrix() {
double[][] matrix = new double[getInputRange()][];
for (int i = 0; i < getInputRange(); i++) {
matrix[i] = discreteDistributions[i].getProbabilities();
}
return matrix;
}
/**
* Set the probability matrix
* @param matrix the matrix
*/
public void setProbabilityMatrix(double[][] matrix) {
for (int i = 0; i < getInputRange(); i++) {
discreteDistributions[i].setProbabilities(matrix[i]);
}
}
/**
* Get the discrete distributions
* @return the distributions
*/
public DiscreteDistribution[] getDistributions() {
return discreteDistributions;
}
/**
* Set the distributions
* @param distributions the distributions
*/
public void setDistributions(DiscreteDistribution[] distributions) {
this.discreteDistributions = distributions;
}
/**
* Get the input range
* @return the range
*/
public int getInputRange() {
return discreteDistributions.length;
}
/**
* Get the output range
* @return the range
*/
public int getOutputRange() {
return discreteDistributions[0].getRange();
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
return ABAGAILArrays.toString(getProbabilityMatrix());
}
/**
* Make a uniform table
* @param inputRange the input range
* @param outputRange the output range
* @param m the m value
* @return the table
*/
public static DiscreteDistributionTable uniform(int inputRange, int outputRange) {
DiscreteDistribution[] table = new DiscreteDistribution[inputRange];
for (int i = 0; i < table.length; i++) {
table[i] = DiscreteDistribution.uniform(outputRange);
}
return new DiscreteDistributionTable(table);
}
/**
* Make a random table
* @param inputRange the input range
* @param outputRange the output range
* @param m the m value
* @return the table
*/
public static DiscreteDistributionTable random(int inputRange, int outputRange) {
DiscreteDistribution[] table = new DiscreteDistribution[inputRange];
for (int i = 0; i < table.length; i++) {
table[i] = DiscreteDistribution.random(outputRange);
}
return new DiscreteDistributionTable(table);
}
/**
* @see shared.Copyable#copy()
*/
public Copyable copy() {
DiscreteDistribution[] copies = new DiscreteDistribution[discreteDistributions.length];
for (int i = 0; i < copies.length; i++) {
copies[i] = (DiscreteDistribution) discreteDistributions[i].copy();
}
DiscreteDistributionTable copy = new DiscreteDistributionTable(copies);
return copy;
}
}