package edu.nd.nina.hdtm;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Queue;
import java.util.Set;
import java.util.TreeMap;
import edu.nd.nina.graph.DefaultEdge;
import edu.nd.nina.graph.DirectedFeatureGraph;
import edu.nd.nina.graph.load.LoadFromFeatureGraph;
import edu.nd.nina.graph.load.LoadFromHBase;
import edu.nd.nina.io.FeatureGraph;
import edu.nd.nina.math.Randoms;
import edu.nd.nina.types.Alphabet;
import edu.nd.nina.types.FeatureSequence;
import edu.nd.nina.types.Instance;
import edu.nd.nina.util.ValueSorter;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TObjectDoubleHashMap;
/**
*
* Hierarchical Document Topic Models
*
* Generates a hierarchy from a directed feature graph. See Weninger, Bisk, Han.
* "Hierarchical Document Topic Models" in CIKM 2012 for details.
*
* @author Tim Weninger 4/9/2013
*
*/
public class HierachicalDocTopicModel {
/** Writer to output hierarchy with best log likelihood */
private PrintWriter bestHierarchyWriter;
/** Writer for log likelihood traces */
private PrintWriter logLikelihoodTraceWriter;
/** Graph of the dataset. Vertices are Instances, Edges are empty */
private DirectedFeatureGraph<Instance, DefaultEdge> graph;
/**
* Root vertex in the directedFeatureGraph. Input to the
* HierarchicalDocTopicModel must be a rooted graph. Currently, the first
* node in the input file is designated to be the root
*/
private Instance root;
/** */
private RWRNode rootNode;
/** Number of "instances" (documents/vertices probably) in the graph. */
private int numInstances;
/** Number of "features" (words probably) in the graph data */
private int numTypes;
/** LDA-style smoothing on topic distributions */
private final double alpha;
/** LDA-style smoothing on word distributions */
private final double eta;
/** LDA-style sum of the eta values */
private double etaSum;
/** Restart probability for random walk with restart */
private final double gamma;
/** Level in the induced hierarchy indexed with <doc, token> */
private int[][] levels;
/**
* Array of nodes in the hierarchy. Parents constitute the currently
* selected path through the hierarchy
*/
private RWRNode[] hierarchyNodes;
/** Random number utility class */
private Randoms random;
/** Shows algorithm progress if true */
private boolean showProgress = true;
/** Interval between topic outputs */
private int displayTopicsInterval = 50;
/** Number of words to print during output */
private int numWordsToDisplay = 10;
/**
* Constructor creates empty DirectedFeatureGraph and sets parameters to
* default values
*/
public HierachicalDocTopicModel(double alpha, double gamma, double eta) {
graph = new DirectedFeatureGraph<Instance, DefaultEdge>(
DefaultEdge.class, LoadFromFeatureGraph.createPipe());
this.alpha = alpha;
this.gamma = gamma;
this.eta = eta;
}
/**
* Mutator for log likelihood trace writer
*
* @param logLikelihoodTraceWriter
* Writer to print log likelihood trace
*/
public void setLLTrace(PrintWriter logLikelihoodTraceWriter) {
this.logLikelihoodTraceWriter = logLikelihoodTraceWriter;
}
/**
* Mutator for bestgraph trace writer
*
* @param bestGraphWriter
* Writer to print best graph
*/
public void setBestGraph(PrintWriter bestGraphWriter) {
this.bestHierarchyWriter = bestGraphWriter;
}
/**
* Mutator for progress display parameters
*
* @param displayTopicsInterval
* Interval between topic outputs
* @param numWordsToDisplay
* K-top words to display when showing topics
*/
public void setTopicDisplay(int displayTopicsInterval, int numWordsToDisplay) {
this.displayTopicsInterval = displayTopicsInterval;
this.numWordsToDisplay = numWordsToDisplay;
}
/**
* This parameter determines whether the sampler outputs shows progress by
* outputting a character after every iteration.
*
* @param showProgress
* Show algorithm progress if true
*/
public void setProgressDisplay(boolean showProgress) {
this.showProgress = showProgress;
}
/**
*
* @param featureGraphFile
* @param random
*/
public void initialize(File featureGraphFile, Randoms random,
String rootName) {
this.random = random;
this.root = FeatureGraph.loadFeatureGraphFromFile(featureGraphFile,
graph, rootName);
rootNode = new RWRNode(numTypes, root);
// The initial hierarchy is a breadth first iteration of the original
// graph starting from the predefined root node.
Map<RWRNode, RWRNode> c = new HashMap<RWRNode, RWRNode>();
Map<Instance, RWRNode> m = new HashMap<Instance, RWRNode>();
Set<Instance> mark = new HashSet<Instance>();
Queue<RWRNode> Q = new LinkedList<RWRNode>();
c.put(rootNode, null);
m.put(rootNode.ins, rootNode);
mark.add(rootNode.ins);
Q.add(rootNode);
while (!Q.isEmpty()) {
RWRNode t = Q.poll();
for (DefaultEdge te : graph.outgoingEdgesOf(t.ins)) {
Instance o = graph.getEdgeTarget(te);
if (!mark.contains(o)) {
RWRNode x = t.addChild(o);
Q.add(x);
mark.add(o);
m.put(o, x);
c.put(x, t);
}
}
}
ArrayList<Instance> removes = new ArrayList<Instance>();
// remove non linked
for (Instance ins : graph.vertexSet()) {
RWRNode rwrp = m.get(ins);
if (rwrp == null) {
removes.add(ins);
continue;
}
}
for (Instance ins : removes) {
graph.removeVertex(ins);
}
graph.resolveInstances();
if (!(graph.getInstances().get(0).getData() instanceof FeatureSequence)) {
throw new IllegalArgumentException(
"Input must be a FeatureSequence");
}
numInstances = graph.getInstances().size();
numTypes = graph.getInstances().getDataAlphabet().size();
// initialize the typecount arrays
for (RWRNode rn : m.values()) {
rn.typeCounts = new int[numTypes];
}
// renormalize the sourcecounts
int i = 0;
for (Instance ins : graph.vertexSet()) {
ins.unLock();
ins.setSource(i++);
ins.lock();
}
etaSum = eta * numTypes;
levels = new int[numInstances][];
hierarchyNodes = new RWRNode[numInstances];
// Initialize and fill the topic pointer arrays for every document. Set
// everything to the breadth first hierarchy that we added earlier.
for (Instance ins : graph.vertexSet()) {
FeatureSequence fs = (FeatureSequence) ins.getData();
int seqLen = fs.getLength();
LinkedList<RWRNode> path = new LinkedList<RWRNode>();
RWRNode rwrp = m.get(ins);
if (rwrp == null) {
continue;
}
do {
rwrp.customers++;
path.addFirst(rwrp);
rwrp = c.get(rwrp);
} while (rwrp != null);
RWRNode node = path.getLast();
assert (levels[(Integer) ins.getSource()] != null);
levels[(Integer) ins.getSource()] = new int[seqLen];
hierarchyNodes[(Integer) ins.getSource()] = node;
for (int token = 0; token < seqLen; token++) {
int type = fs.getIndexAtPosition(token);
levels[(Integer) ins.getSource()][token] = random.next(path
.size());// numLevels);
node = path.get(levels[(Integer) ins.getSource()][token]);
node.totalTokens++;
node.typeCounts[type]++;
}
path.clear();
}
}
/**
* This is the Gibbs sampling control method. For numIterations
* Gibbs-iterations, first sample a path from the root to each
* instance/vertex and then redistribute the topic words for each
* instance/vertex w.r.t its parents
*
* @param numIterations
* Number of Gibbs-iterations
* @param burnin
* Period at the start for which no samples are taken
* @param sample
* After burnin period, this is the number of iterations between
* recording sampels
*/
public void estimate(int numIterations, int burnin, int sample) {
double best = Double.NEGATIVE_INFINITY;
for (int iteration = 1; iteration <= numIterations; iteration++) {
// If we have passed the burnin period, and we are on a sampleable
// iteration, then we should add the current Gibbs-sample to the
// list of samples and possibly trace the likelihood values
if (iteration >= burnin && iteration % sample == 0) {
double ll = calcLogLikelihood(getTypeCounts(), 0, rootNode, 0d);
if (ll > best) {
printNodes(iteration, ll);
best = ll;
}
double[] ret = getStats();
logLikelihoodTraceWriter.println("\n" + iteration + "\t" + ll
+ "\t" + ret[0] + "\t" + ret[1] + "\t" + ret[2]);
System.out.println("\n" + iteration + "\t" + ll + "\t" + ret[0]
+ "\t" + ret[1] + "\t" + ret[2]);
logLikelihoodTraceWriter.flush();
for (int doc = 0; doc < numInstances; doc++) {
recordParent(hierarchyNodes[doc]);
}
}
// First draw a path through the directedFeatureGraph for each
// vertex
for (Instance ins : graph.vertexSet()) {
samplePath(ins);
}
// Second redistribute the topic-words for each vertex and its
// parents
for (Instance ins : graph.vertexSet()) {
sampleTopics(ins);
}
// Print the algorithms progress
if (showProgress) {
System.out.print(".");
if (iteration % displayTopicsInterval == 0) {
System.out.println(" " + iteration);
}
}
}
}
/**
* Same a path from the root to the current instance/vertex by sampling from
* RWR probabilities
*
* @param ins
* Current instance/vertex to which a path is drawn
*/
private void samplePath(Instance ins) {
LinkedList<RWRNode> path = new LinkedList<RWRNode>();
int doc = (Integer) ins.getSource();
RWRNode node = hierarchyNodes[doc];
assert (node != null);
// root doesn't need sampled.
if (node.parent == null)
return;
int depth = node.level + 1;
hierarchyNodes[doc].dropPath();
TObjectDoubleHashMap<RWRNode> nodeWeights = new TObjectDoubleHashMap<RWRNode>();
// Calculate p(c_m | c_{-m})
calculateRWR(nodeWeights, rootNode, hierarchyNodes[doc], 0.0);
// Add weights for p(w_m | c, w_{-m}, z)
// The path may have no further customers and therefore be unavailable,
// but it should still exist since we haven't reset hierarchyNodes[doc]
// yet...
Map<Instance, Map<Integer, TIntIntHashMap>> descTypeCounts = new HashMap<Instance, Map<Integer, TIntIntHashMap>>();
// Save the counts of every word at each level, and remove counts from
// the current path
Set<Instance> desc = hierarchyNodes[doc].descendents;
for (Instance desIns : desc) {
int desDoc = (Integer) desIns.getSource();
node = hierarchyNodes[desDoc];
while (node != null) {
path.addFirst(node);
node = node.parent;
}
Map<Integer, TIntIntHashMap> typeCounts = new HashMap<Integer, TIntIntHashMap>();
descTypeCounts.put(desIns, typeCounts);
int[] docLevels = levels[desDoc];
FeatureSequence fs = (FeatureSequence) desIns.getData();
for (int token = 0; token < docLevels.length; token++) {
int level = docLevels[token];
int type = fs.getIndexAtPosition(token);
if (!typeCounts.containsKey(level)) {
typeCounts.put(level, new TIntIntHashMap());
}
if (!typeCounts.get(level).containsKey(type)) {
typeCounts.get(level).put(type, 1);
} else {
typeCounts.get(level).increment(type);
}
path.get(level).typeCounts[type]--;
assert (path.get(level).typeCounts[type] >= 0);
assert (path.get(level).totalTokens >= 0);
}
path.clear();
}
// Calculate the weight for a new path at a given level.
double[] newTopicWeights = new double[depth];
int[] levelTotalTokens = new int[depth];
for (Map<Integer, TIntIntHashMap> typeCounts : descTypeCounts.values()) {
// Skip the root...
for (int level = 1; level < typeCounts.size() && level < depth; level++) {
if (!typeCounts.containsKey(level)) {
continue;
}
int[] types = typeCounts.get(level).keys();
for (int t : types) {
for (int i = 0; i < typeCounts.get(level).get(t); i++) {
newTopicWeights[level] += Math.log((eta + i)
/ (etaSum + levelTotalTokens[level]));
levelTotalTokens[level]++;
}
}
}
}
// Reevaluate the nodeWeights based on the current topic/word
// distribution
calculateWordLikelihood(nodeWeights, rootNode, hierarchyNodes[doc].ins,
0.0, descTypeCounts, newTopicWeights, 0);
RWRNode[] nodes = nodeWeights.keys(new RWRNode[] {});
double[] parenthoodProbabilities = new double[nodes.length];
double sum = 0.0;
double max = Double.NEGATIVE_INFINITY;
// To avoid underflow, we're using log weights and normalizing the node
// weights so that the largest weight is always 1.
for (int i = 0; i < nodes.length; i++) {
if (nodeWeights.get(nodes[i]) > max) {
max = nodeWeights.get(nodes[i]);
}
}
for (int i = 0; i < nodes.length; i++) {
parenthoodProbabilities[i] = Math.exp(nodeWeights.get(nodes[i])
- max);
sum += parenthoodProbabilities[i];
}
assert (parenthoodProbabilities.length > 0);
// Draw a parent instance/vertex from the probability-set
RWRNode newParent = nodes[random.GetDiscrete(parenthoodProbabilities,
sum)];
// add the picked parent to the hierarchy
int oldLevel = hierarchyNodes[doc].level;
hierarchyNodes[doc].parent = newParent;
newParent.children.add(hierarchyNodes[doc]);
hierarchyNodes[doc].addPath();
// Reassign levels to descendants
propagateLevelsToDesc(hierarchyNodes[doc], newParent.level + 1);
int newLevel = hierarchyNodes[doc].level;
RWRNode x = hierarchyNodes[doc];
RWRNode[] newpath = new RWRNode[x.level + 1];
for (int i = x.level; i >= 0; i--) {
newpath[i] = x;
x = x.parent;
}
x = hierarchyNodes[doc];
for (Instance descIns : descTypeCounts.keySet()) {
int descDoc = (Integer) descIns.getSource();
Map<Integer, TIntIntHashMap> typeCounts = descTypeCounts
.get(descIns);
for (int level = oldLevel; level > newLevel; level--) {
// new path is shorter than old path... we add counts to the
// node
Set<Integer> a = typeCounts.keySet();
Integer[] b = new Integer[a.size()];
a.toArray(b);
for (int i : b) {
if (i > newLevel) {
int[] types = typeCounts.get(i).keys();
for (int t : types) {
if (!typeCounts.containsKey(newLevel)) {
typeCounts.put(newLevel, new TIntIntHashMap());
}
if (typeCounts.get(newLevel).containsKey(t)) {
typeCounts.get(newLevel).adjustValue(t,
typeCounts.get(i).get(t));
} else {
typeCounts.get(newLevel).put(t,
typeCounts.get(i).get(t));
}
}
typeCounts.remove(i);
}
}
for (int i = 0; i < levels[descDoc].length; i++) {
if (levels[descDoc][i] > newLevel) {
levels[descDoc][i] = newLevel;
}
}
}
x = hierarchyNodes[descDoc];
for (int level = x.level; level >= 0; level--) {
if (newLevel > oldLevel
+ (hierarchyNodes[descDoc].level - level)) {
// new path is longer than old path... we add counts to head
// of path
x = x.parent;
continue;
}
if (!typeCounts.containsKey(level)) {
x = x.parent;
continue;
}
int[] types = typeCounts.get(level).keys();
for (int i : typeCounts.keySet()) {
if (i > hierarchyNodes[descDoc].level) {
System.out.println();
}
}
for (int t : types) {
x.typeCounts[t] += typeCounts.get(level).get(t);
x.totalTokens += typeCounts.get(level).get(t);
}
x = x.parent;
}
}
desc = hierarchyNodes[doc].descendents;
for (Instance descIns : desc) {
int desDoc = (Integer) descIns.getSource();
path.clear();
node = hierarchyNodes[desDoc];
while (node != null) {
path.addFirst(node);
node = node.parent;
}
assert (path.getLast().level == path.size() - 1);
}
}
/**
*
* @param ins
*/
private void sampleTopics(Instance ins) {
FeatureSequence fs = (FeatureSequence) ins.getData();
int seqLen = fs.getLength();
int doc = (Integer) ins.getSource();
int[] docLevels = levels[doc];
LinkedList<RWRNode> path = new LinkedList<RWRNode>();
RWRNode node;
Map<Integer, Integer> levelCounts = new TreeMap<Integer, Integer>();
int type, token, level;
double sum;
// calculate the depth
node = hierarchyNodes[doc];
int depth = 0;
while (node != null) {
path.addFirst(node);
node = node.parent;
depth++;
}
// Get the node
node = hierarchyNodes[doc];
assert (node != null);
// Initialize levelCounts to 0
for (int i = 0; i < path.size(); i++) {
levelCounts.put(i, 0);
}
// Populate levelCounts
for (token = 0; token < seqLen; token++) {
int lev = docLevels[token];
levelCounts.put(lev, levelCounts.get(lev) + 1);
}
double[] levelProbabilities = new double[depth];
// Calculate probabilities for each words appearing at each level
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
int lev = docLevels[token];
levelCounts.put(lev, levelCounts.get(lev) - 1);
node = path.get(lev);
node.typeCounts[type]--;
node.totalTokens--;
sum = 0.0;
for (level = 0; level < depth; level++) {
levelProbabilities[level] = (alpha + levelCounts.get(level))
* (eta + path.get(level).typeCounts[type])
/ (etaSum + path.get(level).totalTokens);
sum += levelProbabilities[level];
}
// Sample a level from the probability set
level = random.GetDiscrete(levelProbabilities, sum);
docLevels[token] = level;
levelCounts.put(docLevels[token],
levelCounts.get(docLevels[token]) + 1);
node = path.get(level);
node.typeCounts[type]++;
node.totalTokens++;
}
}
/**
* Adds the current sample's parent to the list of sampled parents.
*
* @param currentVertex
* Vertex to record
*/
private void recordParent(RWRNode currentVertex) {
if (currentVertex != null && currentVertex.parent != null) {
currentVertex.addParent(currentVertex.parent.ins);
}
}
/**
* Convenience method to calculate statistics when displaying algorithm
* progress
*
* @return array of size 3 containing descriptive statistics: [0] - maximum
* depth of the current hierarchy; [1] - average depth of the
* current hierarchy; [2] - average degree of the current hierarchy
*/
private double[] getStats() {
double[] ret = new double[3];
double maxDepth = 0;
double sumDepth = 0;
double sumDegree = 0;
for (Instance ins : graph.vertexSet()) {
int doc = (Integer) ins.getSource();
if (hierarchyNodes[doc] == null)
continue;
maxDepth = Math.max(maxDepth, hierarchyNodes[doc].level);
sumDepth += hierarchyNodes[doc].level;
sumDegree += hierarchyNodes[doc].children.size();
}
ret[0] = maxDepth;
ret[1] = sumDepth / (double) numInstances;
ret[2] = sumDegree / (double) numInstances;
return ret;
}
/**
* For each document, get the count of the "type" (word/token/feature) for
* each level in the hierarchy
*
* @return Map of instance/vertex id to type counts <vertex, <level, type>>
*/
private Map<Integer, TIntIntHashMap> getTypeCounts() {
Map<Integer, TIntIntHashMap> typeCounts = new TreeMap<Integer, TIntIntHashMap>();
int[] docLevels;
for (Instance ins : graph.vertexSet()) {
int doc = (Integer) ins.getSource();
docLevels = levels[doc];
if (docLevels == null)
continue;
FeatureSequence fs = (FeatureSequence) ins.getData();
// Save the counts of every word at each level
for (int token = 0; token < docLevels.length; token++) {
int level = docLevels[token];
int type = fs.getIndexAtPosition(token);
if (!typeCounts.containsKey(level))
typeCounts.put(level, new TIntIntHashMap());
if (!typeCounts.get(level).containsKey(type)) {
typeCounts.get(level).put(type, 1);
} else {
typeCounts.get(level).increment(type);
}
}
}
return typeCounts;
}
/**
* Recursively calculates the log likelihood of the current hierarchy
* (represented by typeCounts).
*
* @param typeCounts
* Data structure which stores the topical hierarchy
* @param level
* Level currently under consideration (initially 0 i.e., root
* level)
* @param node
* The current node (initially root)
* @param weight
* The current log probability (initially 0)
* @return Log likelihood (goodness of fit). Higher is better.
*/
private double calcLogLikelihood(Map<Integer, TIntIntHashMap> typeCounts,
int level, RWRNode node, double weight) {
// First calculate the likelihood of the words at this level, given this
// topic/level.
double nodeWeight = 0.0, ll = 0.0;
// recursive base case
if (typeCounts.get(level) == null)
return ll;
int[] types = typeCounts.get(level).keys();
int totalTokens = 0;
for (int type : types) {
for (int i = 0; i < typeCounts.get(level).get(type); i++) {
nodeWeight += Math.log((eta + node.typeCounts[type] + i)
/ (etaSum + node.totalTokens + totalTokens));
totalTokens++;
}
}
// Propagate that weight to the child nodes
for (RWRNode child : node.children) {
nodeWeight += calcLogLikelihood(typeCounts, level + 1, child,
weight + nodeWeight);
}
return nodeWeight;
}
/**
* Updates the levels of the children after reassigning its parent
*
* @param node
* Descendant Node to reassign levels
* @param level
* New level
*/
private void propagateLevelsToDesc(RWRNode node, int level) {
node.level = level;
for (RWRNode n : node.children) {
propagateLevelsToDesc(n, level + 1);
}
}
/**
* Recursively calculates the probability of selecting a parent based on the
* random walk probability of reaching the target w.r.t the gamma restart
* probability. Probability is 0 if there is no edge between currentParent
* -> target in the original directedFeatureGraph
*
* @param parentProbabilities
* Store the probabilities of each node being the target's parent
* @param parentCandidate
* The current parent node to have its probability calculated
* @param target
* The node for which paths are being sampled
* @param parentProbability
* Calculated RWR probability for currentParent to be the
* target's parent
*/
private void calculateRWR(
TObjectDoubleHashMap<RWRNode> parentProbabilities,
RWRNode parentCandidate, final RWRNode target,
double parentProbability) {
for (RWRNode child : parentCandidate.children) {
if (child.ins != target.ins) {
double w = parentProbability
+ Math.log((1 - gamma)
/ (double) parentCandidate.children.size());
calculateRWR(parentProbabilities, child, target, w);
}
}
// Probability is 0 if there is no edge between parentCandidate ->
// target in the original directedFeatureGraph. In fact, its not even
// stored in the result map
if (graph.containsEdge(parentCandidate.ins, target.ins)) {
parentProbabilities.put(parentCandidate,
parentProbability + Math.log(gamma));
}
}
/**
* Calculates the probabilities for each word appearing in each
* topic/document. Given a set of parent probabilities calculated by RWR
* only, we reevaluate the parenthood probability w.r.t the words and
* topics.
*
* @param parentProbabilities
* Store the probabilities of each node being the target's parent
* @param parentCandidate
* The current parent node to have its probability calculated
* @param target
* The node for which paths are being sampled
* @param parentProbability
* Calculated RWR probability for currentParent to be the
* target's parent
* @param descTypeCounts
* The counts of the words at each level for each
* instance/vertex; this is the topic distribution along the
* parent path.
* @param newTopicWeights
* Weights of the topics along the path from root to target
* @param level
* Current level in the hierarchy, i.e,
* specificity/generalizability of the topics
*/
private void calculateWordLikelihood(
TObjectDoubleHashMap<RWRNode> parentProbabilities,
RWRNode parentCandidate, Instance target, double parentProbability,
Map<Instance, Map<Integer, TIntIntHashMap>> descTypeCounts,
double[] newTopicWeights, int level) {
// First calculate the likelihood of the words at this level, given this
// topic.
double nodeWeight = 0.0;
for (Map<Integer, TIntIntHashMap> typeCounts : descTypeCounts.values()) {
if (!typeCounts.containsKey(level))
continue;
int[] types = typeCounts.get(level).keys();
int totalTokens = 0;
for (int type : types) {
for (int i = 0; i < typeCounts.get(level).get(type); i++) {
nodeWeight += Math
.log((eta + parentCandidate.typeCounts[type] + i)
/ (etaSum + parentCandidate.totalTokens + totalTokens));
totalTokens++;
}
}
}
// Propagate that weight to the child nodes
for (RWRNode child : parentCandidate.children) {
if (child.descendents.contains(target) && child.ins != target) {
calculateWordLikelihood(parentProbabilities, child, target,
parentProbability + nodeWeight, descTypeCounts,
newTopicWeights, level + 1);
}
}
// Finally, add the weight of a new path
level++;
while (level < newTopicWeights.length) {
nodeWeight += newTopicWeights[level];
level++;
}
if (graph.containsEdge(parentCandidate.ins, target)) {
assert (parentProbabilities.contains(parentCandidate));
parentProbabilities.adjustValue(parentCandidate, nodeWeight);
}
}
/**
* Write a text file describing the current sampling state.
*
* @param writer
* PrintWriter to which output is written
* @throws IOException
* Thrown if writer error occurs
*/
public void printState(PrintWriter writer) throws IOException {
Alphabet alphabet = graph.getInstances().getDataAlphabet();
int count = 0;
double sum = 0;
for (Instance ins : graph.getInstances()) {
int doc = (Integer) ins.getSource();
FeatureSequence fs = (FeatureSequence) ins.getData();
int seqLen = fs.getLength();
int[] docLevels = levels[doc];
RWRNode node;
int type, token, level;
StringBuffer path = new StringBuffer();
// Start with the leaf, and build a string describing the path for
// this doc
node = hierarchyNodes[doc];
if (node == null)
continue;
int depth = node.level;
for (level = depth - 1; level >= 0; level--) {
path.append(node.ins.getSource() + " ");
node = node.parent;
}
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
level = docLevels[token];
count++;
sum += level;
// The "" just tells java we're not trying to add a string and
// an int
writer.println(path + "" + type + " "
+ alphabet.lookupObject(type) + " " + level + " ");
}
writer.println(doc + " " + (double) sum / (double) count + " "
+ ins.getName());
}
writer.flush();
}
/**
* Prints the current hierarchy to the bestGraphWriter
*
* @param iteration
* Current iteration
* @param logLikelihood
* Likelihood of the graph
*/
public void printNodes(int iteration, double logLikelihood) {
bestHierarchyWriter.println("Iteration: " + iteration);
bestHierarchyWriter.println("LL: " + logLikelihood);
printNode(rootNode, 0);
bestHierarchyWriter.println("*****************************");
bestHierarchyWriter.println();
bestHierarchyWriter.println();
bestHierarchyWriter.flush();
}
/**
* Recursive convenience method to help print hierarchy.
*
* @param node
* Current node to be printed (Initially root)
* @param indent
* Number of spaced to indent for pretty printing
*/
private void printNode(RWRNode node, int indent) {
StringBuffer out = new StringBuffer();
for (int i = 0; i < indent; i++) {
out.append(" ");
}
out.append(node.totalTokens + "/" + node.customers + " ");
out.append(node.ins.getName() + " ");
out.append(node.getTopTypes());
bestHierarchyWriter.println(out);
for (RWRNode child : node.children) {
printNode(child, indent + 1);
}
}
/**
* Prints the final results of the algorithm. The final hierarchy is defined
* by selecting the most frequently sampled parent for each vertex
*
* @param writer
* PrintWriter to which output is written
*/
public void printResults(PrintWriter writer) {
for (RWRNode n : hierarchyNodes) {
int total = 0;
if (n == null)
continue;
Instance best = n.ins;
for (Entry<Instance, Integer> e : n.parentList.entrySet()) {
total += e.getValue();
}
for (Entry<Instance, Integer> e : n.parentList.entrySet()) {
float weight = ((float) (e.getValue() * 100)) / (float) total;
writer.println(e.getKey().getName().toString() + " -> "
+ n.ins.getName().toString() + "[weight=\""
+ (int) weight + "\"]");
}
}
writer.println();
writer.close();
}
/**
*
* Private internal class which stores the frequently modified hierarchy
*
* @author Tim Weninger 4/9/2013
*
*/
private class RWRNode {
/** The vertex in the corresponding DirectedFeatureGraph */
Instance ins;
/** Children nodes in the hierarchy */
List<RWRNode> children;
/** Parent node in the hierarchy */
RWRNode parent;
/** Current level in the hierarchy, i.e., depth */
int level;
/** Set of descendants */
Set<Instance> descendents;
/** Number of descendants */
int customers;
/**
* Number of tokens/words in the topic corresponding to the current
* vertex/document
*/
int totalTokens;
/**
* Counts of terms appearing in the topic corresponding to the current
* vertex/document
*/
int[] typeCounts;
/**
* Count of sampled parents. Used to generate the final output hierarchy
*/
Map<Instance, Integer> parentList;
/**
* Constructor
*
* @param parent
* Parent RWRNode in hierarchy or null if root
* @param dimensions
* Number of types/words
* @param level
* Level/depth if the new RWRNode
* @param ins
* The corresponding instance/vertex in the
* DirectedFeatureGraph
*/
private RWRNode(RWRNode parent, int dimensions, int level, Instance ins) {
this.ins = ins;
this.customers = 0;
this.parent = parent;
this.children = new ArrayList<RWRNode>();
this.level = level;
this.descendents = new HashSet<Instance>();
this.totalTokens = 0;
this.typeCounts = new int[dimensions];
this.parentList = new HashMap<Instance, Integer>();
}
/**
* Constructor for root
*
* @param dimensions
* Number of types/words
* @param ins
* The corresponding instance/vertex in the
* DirectedFeatureGraph
*/
private RWRNode(int dimensions, Instance ins) {
this(null, dimensions, 0, ins);
}
@Override
public String toString() {
return ins.toString();
}
/**
* Add sampled parent to the list of sampled parents
*
* @param parent
* Newly sampled parent
*/
private void addParent(Instance parent) {
if (parentList.containsKey(parent)) {
parentList.put(parent, parentList.get(parent) + 1);
} else {
parentList.put(parent, 1);
}
}
/**
* Create a new RWRNode corresponding to the provided instance/vertex
*
* @param ins
* Instance/vertex to create RWRNode around
* @return new RWRNode surrounding the provided instance
*/
private RWRNode addChild(Instance ins) {
RWRNode node = new RWRNode(this, typeCounts.length, level + 1, ins);
children.add(node);
RWRNode p = node.parent;
while (p != null) {
p.descendents.add(ins);
p = p.parent;
}
return node;
}
/**
* Removes the provided node from the list of children
*
* @param node
* child node to be removed
*/
private void removeChild(RWRNode node) {
children.remove(node);
}
/**
* Remove the current RWRNode from the hierarchy and update the
* descendants
*/
private void dropPath() {
RWRNode node = this;
int descendents = node.customers;
node.parent.removeChild(node);
Set<Instance> desc = node.descendents;
desc.add(ins);
while (node.parent != null) {
node = node.parent;
node.descendents.removeAll(desc);
node.customers -= descendents;
if (node.customers == 0) {
node.parent.removeChild(node);
}
}
}
/**
* Add a new path to the current RWRNode through the hierarchy and
* update its descendants
*/
private void addPath() {
RWRNode node = this;
int descendents = node.customers;
Set<Instance> desc = node.descendents;
desc.add(ins);
while (node.parent != null) {
node = node.parent;
node.descendents.addAll(desc);
node.customers += descendents;
}
}
/**
* Get the top K most frequent types/words
*
* @return String with K most frequent types/words
*/
private String getTopTypes() {
ValueSorter[] sortedTypes = new ValueSorter[typeCounts.length];
for (int type = 0; type < typeCounts.length; type++) {
sortedTypes[type] = new ValueSorter(type, typeCounts[type]);
}
Arrays.sort(sortedTypes);
StringBuffer out = new StringBuffer();
for (int i = 0; i < numWordsToDisplay; i++) {
out.append(graph.getInstances().getAlphabet()
.lookupObject(sortedTypes[i].getID())
+ " ");
}
return out.toString();
}
}
public static void main(String[] args) {
// Gamma parameter: CRP smoothing parameter; number of imaginary
// customers at next, as yet unused table
Double[] gammas = {.25, .50, .75, .99};
for(Double gamma : gammas){
// The filename in which to write the Gibbs sampling state after at the
// end of the iterations
File dataFile = new File("./data/hdtm/Category_Agriculture.txt");
File outputFile = new File("./data/hdtm/Category_Agriculture_output_"+gamma+".txt");
File resultsFile = new File("./data/hdtm/Category_Agriculture_results_"+gamma+".txt");
File lltraceFile = new File("./data/hdtm/Category_Agriculture_lltrace_"+gamma+".txt");
File bestgraphFile = new File("./data/hdtm/Category_Agriculture_bestgraph_"+gamma+".txt");
// The random seed for the Gibbs sampler. Default is 0, which will use
// the clock.
Integer randomSeed = 1;
// The number of iterations of Gibbs sampling
Integer numIterations = 1000;
// If true, print a character to standard output after every sampling
// iteration.
Boolean showProgress = true;
// The number of iterations between printing a brief summary of the
// topics so far
Integer showTopicsInterval = 50;
// The number of most probable words to print for each topic after model
// estimation
Integer topWords = 5;
// Alpha parameter: smoothing over level distributions.
Double alpha = 10.0;
// Eta parameter: smoothing over topic-word distributions
Double eta = 0.1;
HierachicalDocTopicModel hlda = new HierachicalDocTopicModel(alpha,
gamma, eta);
try {
hlda.setLLTrace(new PrintWriter(lltraceFile));
hlda.setBestGraph(new PrintWriter(bestgraphFile));
} catch (IOException e) {
e.printStackTrace();
}
// Display preferences
hlda.setTopicDisplay(showTopicsInterval, topWords);
hlda.setProgressDisplay(showProgress);
// Initialize random number generator
Randoms random = null;
if (randomSeed == 1) {
random = new Randoms();
} else {
random = new Randoms(randomSeed);
}
// Initialize and start the sampler
hlda.initialize(
dataFile,
random,
"Agriculture");
hlda.estimate(numIterations, 500, 10);
// Output results
if (outputFile != null) {
try {
hlda.printState(new PrintWriter(outputFile));
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
try {
hlda.printResults(new PrintWriter(resultsFile));
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
}
}