package dist; import util.ABAGAILArrays; import util.graph.Edge; import util.graph.Node; import util.graph.Tree; import shared.DataSet; import shared.DataSetDescription; import shared.Instance; /** * A node in a discrete dependency tree * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class DiscreteDependencyTreeNode extends Node { /** * The conditional probabilities */ private double[][] probabilities; /** * The parent */ private int parent; /** * Make a dependency tree node * @param ranges the ranges of the data * @param data the data itself * @param node the node * @param parent the parent node index * @param m the bayesian estimate parameter * @param t the tree */ public DiscreteDependencyTreeNode(DataSet dataSet, Node node, int parent, double m, Tree t) { DataSetDescription dsd = dataSet.getDescription(); double[][] probabilities = new double[dsd.getDiscreteRange(parent)][dsd.getDiscreteRange(node.getLabel())]; double[] sums = new double[dsd.getDiscreteRange(parent)]; for (int i = 0; i < dataSet.size(); i++) { probabilities[dataSet.get(i).getDiscrete(parent)] [dataSet.get(i).getDiscrete(node.getLabel())] += dataSet.get(i).getWeight(); sums[dataSet.get(i).getDiscrete(parent)] += dataSet.get(i).getWeight(); } for (int i = 0; i < probabilities.length; i++) { for (int j = 0; j < probabilities[i].length; j++) { probabilities[i][j] = (probabilities[i][j] + m / probabilities[i].length) / (sums[i] + m); } } this.probabilities = probabilities; this.parent = parent; t.addNode(this); setLabel(node.getLabel()); for (int i = 0; i < node.getEdgeCount(); i++) { DiscreteDependencyTreeNode dtc = new DiscreteDependencyTreeNode( dataSet, node.getEdge(i).getOther(node), node.getLabel(), m, t); connectDirected(dtc, new Edge()); } } /** * Calculate the probability * @param instance the instance * @return the probability */ public double probabilityOf(Instance sample) { DiscreteDistribution dd = new DiscreteDistribution( probabilities[sample.getDiscrete(parent)]); double p = dd.p(new Instance(sample.getDiscrete(getLabel()))); for (int i = 0; i < getEdgeCount(); i++) { DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this); p *= dtn.probabilityOf(sample); } return p; } /** * Sample from the node * @param sample the sample so far */ public void generateRandom(Instance sample) { DiscreteDistribution dd = new DiscreteDistribution( probabilities[sample.getDiscrete(parent)]); sample.getData().set(getLabel(), dd.sample(null).getDiscrete()); for (int i = 0; i < getEdgeCount(); i++) { DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this); dtn.generateRandom(sample); } } /** * Sample from the node * @param sample the sample so far */ public void generateMostLikely(Instance sample) { DiscreteDistribution dd = new DiscreteDistribution( probabilities[sample.getDiscrete(parent)]); sample.getData().set(getLabel(), dd.mode(null).getDiscrete()); for (int i = 0; i < getEdgeCount(); i++) { DiscreteDependencyTreeNode dtn = (DiscreteDependencyTreeNode) getEdge(i).getOther(this); dtn.generateRandom(sample); } } /** * @see java.lang.Object#toString() */ public String toString() { return super.toString() + " Parent = " + parent + "\n" + ABAGAILArrays.toString(probabilities); } }