package dist; import util.linalg.DenseVector; import util.graph.DFSTree; import util.graph.Graph; import util.graph.KruskalsMST; import util.graph.Node; import util.graph.Tree; import util.graph.WeightedEdge; import shared.DataSet; import shared.DataSetDescription; import shared.Instance; /** * A discrete dependency distribution * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class DiscreteDependencyTree extends AbstractDistribution { /** * The dependency tree root */ private DiscreteDependencyTreeRootNode root; /** * The tree */ private Tree dt; /** * The m value */ private double m; /** * Description the data set */ private DataSetDescription description; /** * Make a new discrete dependency tree distribution * @param m the small positive value to add when making the tree */ public DiscreteDependencyTree(double m) { this.m = m; } /** * Make a new discrete dependency tree distribution * @param m the small positive value to add when making the tree */ public DiscreteDependencyTree(double m, int[] ranges) { this.m = m; description = new DataSetDescription(); description.setMinVector(new DenseVector(ranges.length)); DenseVector max = new DenseVector(ranges.length); for (int i = 0; i < max.size(); i++) { max.set(i, ranges[i] - 1); } description.setMaxVector(max); } /** * @see dist.Distribution#probabilityOf(shared.Instance) */ public double p(Instance i) { return root.probabilityOf(i); } /** * @see dist.Distribution#generateRandom(shared.Instance) */ public Instance sample(Instance ignored) { Instance i = new Instance(new DenseVector(dt.getNodeCount())); root.generateRandom(i); return i; } /** * @see dist.Distribution#generateMostLikely(shared.Instance) */ public Instance mode(Instance ignored) { Instance i = new Instance(new DenseVector(dt.getNodeCount())); root.generateMostLikely(i); return i; } /** * @see dist.Distribution#estimate(shared.DataSet) */ public void estimate(DataSet observations) { if (description != null) { observations.setDescription(description); } else if (observations.getDescription() == null) { observations.setDescription(new DataSetDescription(observations)); } double[][] mutualI = calculateMutualInformation(observations); // construct the graph Tree rg = buildDirectedMST(observations, mutualI); // make the dependency tree dt = new Tree(); root = new DiscreteDependencyTreeRootNode(observations, rg.getRoot(), m, dt); dt.setRoot(root); } /** * Build the directed mst from the mutual information * and ranges * @param ranges the ranges * @param mutualI the mutual information values * @return the directed mst */ private Tree buildDirectedMST(DataSet observations, double[][] mutualI) { Graph g = new Graph(); for (int i = 0; i < observations.get(0).size(); i++) { Node n = new Node(i); g.addNode(n); } for (int i = 0; i < observations.get(0).size(); i++) { for (int j = 0; j < i; j++) { Node a = g.getNode(i); Node b = g.getNode(j); a.connect(b, new WeightedEdge(-mutualI[i][j])); } } // find the mst g = new KruskalsMST().transform(g); // direct it Tree rg = (Tree) new DFSTree().transform(g); return rg; } /** * Calculate the mutual information from the data * @param ranges the ranges of the data * @param data the data itself * @return the mutual informations */ private double[][] calculateMutualInformation(DataSet observations) { DataSetDescription dsd = observations.getDescription(); // probs[i][j] is the probability that x_i = j double[][] probs = new double[observations.get(0).size()][]; for (int i = 0; i < probs.length; i++) { probs[i] = new double[dsd.getDiscreteRange(i)]; } double weightSum = 0; // fill in probs for (int i = 0; i < observations.size(); i++) { for (int j = 0; j < observations.get(i).size(); j++) { probs[j][observations.get(i).getDiscrete(j)] += observations.get(i).getWeight(); } weightSum += observations.get(i).getWeight(); } // normalize for (int i = 0; i < probs.length; i++) { for (int j = 0; j < probs[i].length; j++) { probs[i][j] /= weightSum; } } // calculate the entropies of the different variables double[] entropies = new double[observations.get(0).size()]; for (int i = 0; i < observations.get(0).size(); i++) { for (int j = 0; j < dsd.getDiscreteRange(i); j++) { if (probs[i][j] != 0) { entropies[i] -= probs[i][j] * Math.log(probs[i][j]); } } } // calculate the mutual information between all variables double[][] mutualI = new double[observations.get(0).size()][]; for (int i = 0; i < mutualI.length; i++) { mutualI[i] = new double[i]; for (int j = 0; j < i; j++) { // the joint probabilities // joints[a][b] is the probability that x_i = a && x_j = b double[][] joints = new double[dsd.getDiscreteRange(i)][dsd.getDiscreteRange(j)]; // fill in the joints for (int k = 0; k < observations.size(); k++) { Instance instance = observations.get(k); joints[instance.getDiscrete(i)][instance.getDiscrete(j)] ++; } // normalize for (int k = 0; k < joints.length; k++) { for (int l = 0; l < joints[k].length; l++) { joints[k][l] /= weightSum; } } // calculate the mutual information I(x_i; x_j) // add the entropy of x_i mutualI[i][j] += entropies[i]; // and the entropy of x_j mutualI[i][j] += entropies[j]; // subtract the joint entropy for (int k = 0; k < joints.length; k++) { for (int l = 0; l < joints[k].length; l++) { if (joints[k][l] != 0) { mutualI[i][j] += joints[k][l] * Math.log(joints[k][l]); } } } } } return mutualI; } /** * @see java.lang.Object#toString() */ public String toString() { return dt.toString(); } }