package dist;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import util.ABAGAILArrays;
import util.graph.Edge;
import util.graph.Node;
import util.graph.Tree;
/**
* A root node in a discrete dependency tree
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class DiscreteDependencyTreeRootNode extends Node {
/**
* The unconditional probabilities
*/
private double[] probabilities;
/**
* Build a dependency tree root out of the given node
* @param node the node to build the root out of
* @param ranges the ranges of the values
* @param data the data itself
*/
public DiscreteDependencyTreeRootNode(DataSet dataSet, Node node, double m, Tree t) {
DataSetDescription dsd = dataSet.getDescription();
probabilities = new double[dsd.getDiscreteRange(node.getLabel())];
double weightSum = 0;
for (int i = 0; i < dataSet.size(); i++) {
probabilities[dataSet.get(i).getDiscrete(node.getLabel())]
+= dataSet.get(i).getWeight();
weightSum += dataSet.get(i).getWeight();
}
for (int i = 0; i < probabilities.length; i++) {
probabilities[i] = (probabilities[i] + m / probabilities.length)
/ (weightSum + m);
}
t.addNode(this);
setLabel(node.getLabel());
for (int i = 0; i < node.getEdgeCount(); i++) {
DiscreteDependencyTreeNode dtn = new DiscreteDependencyTreeNode(dataSet,
node.getEdge(i).getOther(node), node.getLabel(), m, t);
connectDirected(dtn, new Edge());
}
}
/**
* Calculate the probability
* @param instance the instance
* @return the probability
*/
public double probabilityOf(Instance instance) {
DiscreteDistribution dd = new DiscreteDistribution(probabilities);
double p = dd.p(new Instance(instance.getDiscrete(getLabel())));
for (int i = 0; i < getEdgeCount(); i++) {
DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this);
p *= dtn.probabilityOf(instance);
}
return p;
}
/**
* Sample from the root of the tree
* @param node the root of the tree
*/
public void generateRandom(Instance instance) {
DiscreteDistribution dd = new DiscreteDistribution(probabilities);
instance.getData().set(getLabel(), dd.sample(null).getDiscrete());
for (int i = 0; i < getEdgeCount(); i++) {
DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this);
dtn.generateRandom(instance);
}
}
/**
* Sample from the root of the tree
* @param node the root of the tree
*/
public void generateMostLikely(Instance instance) {
DiscreteDistribution dd = new DiscreteDistribution(probabilities);
instance.getData().set(getLabel(), dd.mode(null).getDiscrete());
for (int i = 0; i < getEdgeCount(); i++) {
DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this);
dtn.generateMostLikely(instance);
}
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
return super.toString() + "\n" + ABAGAILArrays.toString(probabilities);
}
}