package context.core.task.topicmodeling; import cc.mallet.types.*; import cc.mallet.pipe.*; import cc.mallet.topics.*; import java.util.*; import java.util.regex.*; import java.io.*; import context.core.entity.FileData; import context.core.util.JavaIO; /** * * @author Aale */ public class MalletTopicModeling { // public static void main(String[] args) throws Exception { public int numTopics; public int numWordsPerTopic; public int numIterations; public int numOptInterval; public double sumAlpha; public List<FileData> CorpusFiles; public String stopListPath; public Boolean isLowercase; public ParallelTopicModel model; public IDSorter[] sortedTopics; public Object[][] TopWords; public double[] Weight; public double[][] TopicDist; public Alphabet dataAlphabet; public int topicMask; public int topicBits; public int numTypes; public double beta; public int[][] typeTopicCounts; public int totalTokens; public MalletTopicModeling(int numTopics, int numWordsPerTopic, int numIterations, int numOptInterval, double sumAlpha, List<FileData> CorpusFiles, String stopListPath, Boolean isLowercase){ this.numTopics = numTopics; this.numWordsPerTopic = numWordsPerTopic; this.numIterations = numIterations; this.numOptInterval=numOptInterval; this.sumAlpha=sumAlpha; this.CorpusFiles = CorpusFiles; this.stopListPath = stopListPath; this.isLowercase = isLowercase; this.model = new ParallelTopicModel(this.numTopics, sumAlpha, 0.01); this.topicModeling(); } public String[] topicModellingOutput(){ String docProbs = ""; String topicWords = ""; String wordWeights=""; double llToken=0; String llTokenStr=""; String[] allOuts = new String[4]; // Output for table 1 for (int order = 0; order < numTopics; order++) { String tempString = ""; String tempWordWeight=""; tempString = tempString.concat("Topic" + Integer.toString(order + 1) + ","); // get topic in the order of sortedTopics (which is sorted by // weight) int topic = sortedTopics[order].getID(); tempString = tempString .concat(Double.toString(Weight[topic]) + ","); for (int word = 0; word < numWordsPerTopic; word++) { tempString = tempString.concat((String) TopWords[topic][word] + " - "); } tempString = tempString.concat("\n"); //Output 3 for(int type=0;type<numTypes;type++){ int[] topicCounts = typeTopicCounts[type]; double wordWeight = beta; int index = 0; while (index < topicCounts.length && topicCounts[index] > 0) { int currentTopic = topicCounts[index] & topicMask; if (currentTopic == topic) { wordWeight += topicCounts[index] >> topicBits; break; } index++; } for (int word = 0; word < numWordsPerTopic; word++){ if(dataAlphabet.lookupObject(type)==TopWords[topic][word]){ tempWordWeight = tempWordWeight.concat("Topic" + Integer.toString(order + 1) + ","); tempWordWeight=tempWordWeight .concat(TopWords[topic][word]+","+Double.toString(wordWeight)); tempWordWeight = tempWordWeight.concat("\n"); } } } System.out.print(tempString); topicWords = topicWords.concat(tempString); wordWeights=wordWeights.concat(tempWordWeight); } System.out.print(wordWeights); // Output for table 2 for (int doc = 0; doc < TopicDist.length; doc++) { String tempString = ""; tempString = tempString.concat(CorpusFiles.get(doc).getFile() .getName()); for (int order = 0; order < TopicDist[doc].length; order++) { // get topic in the order of sortedTopics (which is sorted by // weight) int topic = sortedTopics[order].getID(); tempString = tempString.concat("," + Double.toString(TopicDist[doc][topic])); } tempString = tempString.concat("\n"); System.out.print(tempString); docProbs = docProbs.concat(tempString); } //Output for table 4 String tempLLToken=""; double likelihood = model.modelLogLikelihood(); llToken=likelihood/(double) totalTokens; tempLLToken=Double.toString(llToken); llTokenStr=llTokenStr.concat(tempLLToken+"\n"); //System.out.println("modelLogLikelihood: "+model.modelLogLikelihood()+"Token: "+totalTokens+"LL Token: "+llToken); System.out.println("LL Token: "+tempLLToken); allOuts[0] = topicWords; allOuts[1] = docProbs; allOuts[2]=wordWeights; allOuts[3]=llTokenStr; return allOuts; } /** * * @param numTopics * @param numWordsPerTopic * @param numIterations * @param CorpusFiles * @param stopListPath * @return */ public void topicModeling() { // Begin by importing documents from text to feature sequences ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); // Pipes: lowercase, tokenize, remove stopwords, map to features // pipeList.add( new CharSequenceLowercase() ); pipeList.add(new CharSequence2TokenSequence(Pattern .compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}"))); // pipeList.add( new TokenSequenceRemoveStopwords(new // File("stoplists/en.txt"), "UTF-8", false, false, false) ); File stopList = new File(stopListPath); if (stopList.exists()) { pipeList.add(new TokenSequenceRemoveStopwords(stopList, "UTF-8", false, false, false)); } pipeList.add(new TokenSequence2FeatureSequence()); InstanceList instances = new InstanceList(new SerialPipes(pipeList)); File[] fileList = new File[CorpusFiles.size()]; int indx = 0; for (FileData file : CorpusFiles) { File filename = null; try { filename = file.getFile(); } catch (Exception e) { e.printStackTrace(); } fileList[indx] = filename; indx++; try { if (isLowercase) { final String filecontent = JavaIO.readFile(filename); String fileContentLowerCase = filecontent.toLowerCase(); instances.addThruPipe(new Instance(fileContentLowerCase, Integer.toString(indx), Integer.toString(indx), Integer.toString(indx))); } else { instances.addThruPipe(new Instance(JavaIO .readFile(filename), Integer.toString(indx), Integer.toString(indx), Integer.toString(indx))); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } model.addInstances(instances); // Use two parallel samplers, which each look at one half the corpus and // combine // statistics after every iteration. model.setNumThreads(2); // Run the model for 50 iterations and stop (this is for testing only, // for real applications, use 1000 to 2000 iterations) model.setNumIterations(numIterations); // 2016.01.26 Added by Julian // Set optimize interval, SymmetricAlpha, burnin period to optimize // alpha // (use the default value in PrallelTopicMOdel.java) // The default burnin period is 200. // The iteration must be larger than burnin period, // or model.estimate() won't optimize alpha (i.e. weight) // Default alpha = alphaSum/numTopics model.setOptimizeInterval(numOptInterval); model.setSymmetricAlpha(false); model.setBurninPeriod(200); model.setRandomSeed(1337); // End of adding default value // model.setSaveState(model.numIterations, "./data/StateSave.txt"); // model.setSaveSerializedModel(model.numIterations, // "./data/ModelSave.txt"); try { model.estimate(); } catch (Exception e) { e.printStackTrace(); } // The data alphabet maps word IDs to strings dataAlphabet = instances.getDataAlphabet(); // Get an array of sorted sets of word ID/count pairs TopicDist = new double[model.data.size()][model .getNumTopics()];// an array of topic distributions for each // document for (int indx1 = 0; indx1 < TopicDist.length; indx1++) { TopicDist[indx1] = model.getTopicProbabilities(indx1); } // Use weight to replace average topic fit // See ParallelTopicModel.displayTopWords(), // it uses alpha as the second value in each topic for output (i.e. // weight) Weight = model.alpha; TopWords = model.getTopWords(numWordsPerTopic); numTypes=model.numTypes; topicBits=model.topicBits; topicMask=model.topicMask; beta=model.beta; typeTopicCounts=model.typeTopicCounts; totalTokens=model.totalTokens; // Sort output topics by weight // Use IDSorter.java in Mallet to pull out initial topic ID after // sorting sortedTopics = new IDSorter[numTopics]; for (int topic = 0; topic < numTopics; topic++) { // Initialize the sorters with dummy values sortedTopics[topic] = new IDSorter(topic, Weight[topic]); } Arrays.sort(sortedTopics); } }