package edu.usc.cssl.tacit.topicmodel.zlda.services; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.Date; import java.util.Random; import org.eclipse.core.runtime.Platform; import edu.usc.cssl.tacit.common.ui.views.ConsoleView; public class ZlabelLDA { /* ToDo - when giving options for LDA - give option to use word indexes as in dictionary (This is what it will most likely be), to use regular indexes and to give stopwords */ private double[][] alpha, beta; /* ToDo - check to make sure all the lists in this array are of the same size */ private double eta; /* Confidence score */ private int[][] documents; private int[] fLabel = null; /* ToDo - figure out what to do for non-standard LDA */ private int[][][] topicSeeds; private int numsamp; private int numberOfDocuments; private int T; /* Number of Topics */ private int W; /* Number of Words in the vocabulary */ private int F; private double[] alphaSum, betaSum; private int[][] init; private Random random; private int[][] sample; private double[][] theta; private double[][] phi; private Counts counts; private Date date; private long prevTime; private long currentTime; private class Counts{ int[][] nw; int[][] nd; int[] nwColSum; public Counts(Integer w, int t, int d){ nw = new int[(int)w][t]; nd = new int[d][t]; nwColSum = new int[t]; /* Initialize the arrays with 0 values */ for(int i=0; i<w; i++){ for(int j=0; j<t; j++){ nw[i][j] = 0; } } for(int i=0; i<d; i++){ for(int j=0; j<t; j++){ nd[i][j] = 0; } } for(int i=0; i<t; i++){ nwColSum[i] = 0; } } } /* ToDo - also need phi, theta and sample - write those out */ public int[][] getSample() { return sample; } public double[][] getTheta() { return theta; } public double[][] getPhi() { return phi; } /* ToDo - Overload this to kingdom come also remember to do init thingy for the arguments */ public ZlabelLDA(int[][] docs, int[][][] zValues, double eta, double[][] alpha, double[][] beta, int numsamp, int[][] initSample) throws NullPointerException{ if(initSample == null){ this.init = null; } if(docs == null || zValues == null || alpha == null || beta == null){ throw new NullPointerException(); } documents = docs; topicSeeds = zValues; this.eta = eta; this.alpha = alpha; this.beta = beta; this.numsamp = numsamp; random = new Random(194582); sample = new int[docs.length][]; numberOfDocuments = documents.length; for(int i=0; i<documents.length; i++){ sample[i] = new int[documents[i].length]; } date = new Date(); } public ZlabelLDA(int[][] docs, int[][][] zValues, double eta, double[][] alpha, double[][] beta, int numsamp) throws NullPointerException{ this(docs, zValues, eta, alpha, beta, numsamp, null); } private double unif(){ return random.nextFloat(); } private boolean givenInit(){ if(init.length != documents.length){ ConsoleView.printlInConsoleln("Number of documents/number of init samples mismatch"); return false; } counts = new Counts(W, T, documents.length); int[] docInit, docSample; int[] doc; int zi; int word; for(int d=0; d<documents.length; d++){ docInit = init[d]; doc = documents[d]; docSample = sample[d]; if(docInit.length != documents[d].length){ ConsoleView.printlInConsoleln("Init sample/doc-length mismatch"); return false; } for(int i=0; i<doc.length; i++){ zi = docInit[i]; if(zi < 0 || zi >= T){ ConsoleView.printlInConsoleln("Non-numeric or out of range sample value"); return false; } word = doc[i]; docSample[i] = zi; counts.nw[(int)word][zi]++; counts.nd[d][zi]++; counts.nwColSum[zi]++; } } return true; } /** * Do an "online" init of Gibbs chain, adding one word * position at a time and then sampling for each new position */ private void onlineInit(){ /* Initialize variables for use in the loop */ counts = new Counts(W, T, documents.length); double[] numerator = new double[T]; int[] doc; int f; int word; double normSum, alphaJ, betaI, currBetaSum, denomL; int[][] docSeeds; int[] wordTopicSeeds, docSample; boolean foundTopic; int sampleValue; /* Iterate through the documents */ for(int d=0; d<documents.length; d++){ doc = documents[d]; f = fLabel[d]; docSeeds = topicSeeds[d]; docSample = sample[d]; for(int i=0; i<doc.length; i++){ word = doc[i]; normSum = 0; /* Calculate numerator for each topic */ for(int j=0; j<T; j++){ /* Initialize variables for this calculation */ alphaJ = alpha[f][j]; betaI = beta[j][word]; currBetaSum = betaSum[j]; denomL = counts.nwColSum[j] + currBetaSum; /** * * Calculate numerator for this topic * Note : alpha denom omitted because it is the same for all topics * **/ numerator[j] = (counts.nw[(int)word][j] + betaI)/denomL; numerator[j] = numerator[j]*(counts.nd[d][j] + alphaJ); /* Add a multiplicative penalty if applicable */ if(docSeeds[i] != null){ wordTopicSeeds = docSeeds[i]; foundTopic = false; /* Look for the current topic we're looking at in the topic seeds for this word */ for(int k=0; k<wordTopicSeeds.length; k++){ if(j == wordTopicSeeds[k]){ foundTopic = true; } } /* Penalize if the topics associated with this word isn't the current topic */ if(foundTopic == false){ numerator[j] = numerator[j]*(1 - eta); } } /* Add the computed numerator value to norm sum */ normSum += numerator[j]; } /* Draw sample and update the count/cache matrices and initial sample vector */ sampleValue = multSample(numerator, normSum); docSample[i] = sampleValue; counts.nw[(int)word][sampleValue]++; counts.nd[d][sampleValue]++; counts.nwColSum[sampleValue]++; } } } private void gibbsChain(){ double[] numerator = new double[T]; int[] doc; int f, zi; Integer word; double normSum, alphaJ, betaI, currBetaSum, denomL; int[][] docSeeds; int[] wordTopicSeeds, docSample; boolean foundTopic; int sampleValue; for(int d=0; d<documents.length; d++){ doc = documents[d]; f = fLabel[d]; docSeeds = topicSeeds[d]; docSample = sample[d]; for(int i=0; i<doc.length; i++){ zi = docSample[i]; word = doc[i]; counts.nw[(int)word][zi]--; counts.nd[d][zi]--; counts.nwColSum[zi]--; normSum = 0; for(int j=0; j<T; j++){ alphaJ = alpha[f][j]; betaI = beta[j][word]; currBetaSum = betaSum[j]; denomL = counts.nwColSum[j] + currBetaSum; /** * * Calculate numerator for this topic * Note : alpha denom omitted because it is the same for all topics * **/ numerator[j] = (counts.nw[(int)word][j] + betaI)/denomL; numerator[j] = numerator[j]*(counts.nd[d][j] + alphaJ); if(docSeeds[i] != null){ wordTopicSeeds = docSeeds[i]; foundTopic = false; /* Look for the current topic we're looking at in the topic seeds for this word */ for(int k=0; k<wordTopicSeeds.length; k++){ if(j == wordTopicSeeds[k]){ foundTopic = true; } } /* Penalize if the topics associated with this word isn't the current topic */ if(foundTopic == false){ numerator[j] = numerator[j]*(1 - eta); } } /* Add the computed numerator value to norm sum */ normSum += numerator[j]; } /* Draw sample and update the count/cache matrices and initial sample vector */ sampleValue = multSample(numerator, normSum); docSample[i] = sampleValue; counts.nw[(int)word][sampleValue]++; counts.nd[d][sampleValue]++; counts.nwColSum[sampleValue]++; } } } /** * Use final sample to estimate phi = P(w|z) */ private void estPhi(){ phi = new double[T][(int)W]; Integer colSum, nwct; double currBetaSum, betaW; for(int t=0; t<T; t++){ colSum = counts.nwColSum[t]; currBetaSum = betaSum[t]; for(int w=0; w<W; w++){ betaW = beta[t][w]; nwct = counts.nw[(int)w][t]; phi[t][w] = (betaW + nwct)/(currBetaSum + colSum); } } return; } /** * Use final sample to estimate theta = P(z|d) */ private void estTheta(){ theta = new double[documents.length][T]; double[] rowSums = new double[counts.nd.length]; double rowSum, currAlphaSum, alphaT; int f; Integer ndct; for(int i=0; i<documents.length; i++){ rowSum = 0; for(int j=0; j<counts.nd[i].length; j++){ rowSum = rowSum + counts.nd[i][j]; } rowSums[i] = rowSum; } for(int d=0; d<documents.length; d++){ rowSum = rowSums[d]; f = fLabel[d]; currAlphaSum = alphaSum[f]; for(int t=0; t<T; t++){ alphaT = alpha[f][t]; ndct = counts.nd[d][t]; theta[d][t] = (ndct + alphaT)/(rowSum + currAlphaSum); } } return; } private int multSample( double[] vals, double norm_sum) { double rand_sample = unif()*norm_sum; double tmp_sum = 0; int i = 0; while(tmp_sum < rand_sample || i == 0){ tmp_sum += vals[i]; i++; } return i-1; } private boolean validateInput(){ /* Check if fLabel array has size equal to number of documents */ /* If f-labels not provided, initialize to 0 */ Integer fmax = 0; if(fLabel == null){ fLabel = new int[documents.length]; for(int i=0; i<numberOfDocuments; i++){ fLabel[i] = 0; } } else{ /* If f-label is provided, check validity - non-negative values, etc. */ if(fLabel.length != numberOfDocuments){ ConsoleView.printlInConsoleln("f-label array has size less than the number of documents"); appendLog("f-label array has size less than the number of documents"); return false; } else{ for(int i=0; i<fLabel.length; i++){ if(fLabel[i] < 0){ ConsoleView.printlInConsoleln("Negative f-label - not valid input"); appendLog("Negative f-label - not valid input"); return false; } else if(fLabel[i] > fmax){ fmax = fLabel[i]; } } } } /* The number of maps in topicSeeds should be the same as the number of documents */ if(topicSeeds.length != documents.length){ ConsoleView.printlInConsoleln("Topic Seeds array/ no. of documents size mismatch"); appendLog("Topic Seeds array/ no. of documents size mismatch"); return false; } /* Get information from parameters and check dimensionality agreement */ if(alpha[0] == null || beta[0] == null){ ConsoleView.printlInConsoleln("Invalid alpha or beta value"); appendLog("Invalid alpha or beta value"); return false; } else { F = alpha.length; T = alpha[0].length; W = beta[0].length; } /* fmax needs to be the same as the dimensions of alpha */ if(F-1 != fmax){ ConsoleView.printlInConsoleln("Alpha/f dimensionality mismatch"); appendLog("Alpha/f dimensionality mismatch"); return false; } /** * Check all elements of alpha, beta etc. have same size between them **/ /* Beta must have the same number of rows as the number of topics we want */ if(T != beta.length){ ConsoleView.printlInConsoleln("Beta size/no. of topics mismatch"); appendLog("Beta size/no. of topics mismatch"); return false; } for(int i=1; i<alpha.length; i++){ /* The lists in alpha must have the same dimensions */ if(alpha[i].length != T){ ConsoleView.printlInConsoleln("Alpha arrays do not have the same dimensionality"); appendLog("Alpha arrays do not have the same dimensionality"); return false; } } for(int i=1; i<beta.length; i++){ /* The lists in beta must have the same dimensions */ if(beta[i].length != W){ ConsoleView.printlInConsoleln("Beta arrays do not have the same dimensionality"); appendLog("Beta arrays do not have the same dimensionality"); return false; } } /* all alpha and beta values must be +ve */ for(int i=0; i<alpha.length; i++){ for(int j=0; j<alpha[i].length; j++){ if(alpha[i][j] < 0){ ConsoleView.printlInConsoleln("Invalid value in the alpha array"); appendLog("Invalid value in the alpha array"); return false; } } } for(int i=0; i<beta.length; i++){ for(int j=0; j<beta[i].length; j++){ if(beta[i][j] < 0){ ConsoleView.printlInConsoleln("Invalid value in the beta array"); appendLog("Invalid value in the beta array"); return false; } } } /* Validate that the zlabels are all positive, and that none of the values is larger than T */ for(int i=0; i<topicSeeds.length; i++){ for(int j=0; j<topicSeeds[i].length; j++){ if(topicSeeds[i][j] != null){ for(int k=0; k<topicSeeds[i][j].length; k++){ if(topicSeeds[i][j][k] < 0 || topicSeeds[i][j][k] >= T ){ ConsoleView.printlInConsoleln("The topic seed value is invalid"); appendLog("The topic seed value is invalid"); return false; } } } } } /* Validate that the document entries all have postive values, and values within the size of vocabulary */ for(int i=0; i<documents.length; i++){ for(int j=0; j<documents[i].length; j++){ if(documents[i][j] < 0 || documents[i][j] >= W){ ConsoleView.printlInConsoleln("The word value in document is invalid"); appendLog("The word value in document is invalid"); return false; } } } /* Compute alphaSum and betaSum to prep the data-set */ double sum; /* All input is alright, okay to create new lists for alphaColSum and betaColSum */ alphaSum = new double[alpha.length]; betaSum = new double[beta.length]; for(int i=0; i<alpha.length; i++){ sum = 0; for(int j=0; j<alpha[i].length; j++){ sum = sum + alpha[i][j]; } alphaSum[i] = sum; } for(int i=0; i<beta.length; i++){ sum = 0; for(int j=0; j<beta[i].length; j++){ sum = sum + beta[i][j]; } betaSum[i] = sum; } return true; } public boolean zLDA(){ if(validateInput() != true){ ConsoleView.printlInConsoleln("Invalid Input"); return false; } prevTime = 0; currentTime = System.currentTimeMillis(); if(init == null){ onlineInit(); } else { if(givenInit() == false){ return false; } } prevTime = currentTime; currentTime = System.currentTimeMillis(); for(int si=1; si<=numsamp; si++){ gibbsChain(); } estPhi(); estTheta(); return true; } private StringBuilder readMe = new StringBuilder(); private void appendLog(String message){ ConsoleView.printlInConsoleln(message); readMe.append(message+"\n"); } public void writeReadMe(String location){ File readme = new File(location+"/README.txt"); try { BufferedWriter bw = new BufferedWriter(new FileWriter(readme)); String plugV = Platform.getBundle("edu.usc.cssl.tacit.plugins.zlda").getHeaders().get("Bundle-Version"); String appV = Platform.getBundle("edu.usc.cssl.tacit.application").getHeaders().get("Bundle-Version"); Date date = new Date(); bw.write("Zlabel LDA Output\n--------------------\n\nApplication Version: "+appV+"\nPlugin Version: "+plugV+"\nDate: "+date.toString()+"\n\n"); bw.write(readMe.toString()); bw.close(); } catch (IOException e) { e.printStackTrace(); } } }