/*
TagRecommender:
A framework to implement and evaluate algorithms for the recommendation
of tags.
Copyright (C) 2013 Dominik Kowald
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package processing;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.Timer;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import com.google.common.base.Stopwatch;
import com.google.common.primitives.Ints;
import common.DoubleMapComparator;
import common.Bookmark;
import common.MemoryThread;
import common.PerformanceMeasurement;
import common.Utilities;
import cc.mallet.pipe.StringList2FeatureSequence;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import file.BookmarkWriter;
import file.PredictionFileWriter;
import file.BookmarkReader;
import file.BookmarkSplitter;
public class MalletCalculator {
private final static int MAX_RECOMMENDATIONS = 10;
private final static int MAX_TERMS = 100;
//private final static int NUM_THREADS = 10;
private final static int NUM_ITERATIONS = 2000;
private final static double ALPHA = 0.01;
private final static double BETA = 0.01;
private static double TOPIC_THRESHOLD = 0.001;
private int numTopics;
private List<Map<Integer, Integer>> maps;
private InstanceList instances;
private List<Map<Integer, Double>> docList;
private List<Map<Integer, Double>> topicList;
private Map<Integer, Double> mostPopularTopics;
public MalletCalculator(List<Map<Integer, Integer>> maps, int numTopics) {
this.numTopics = numTopics;
this.maps = maps;
this.mostPopularTopics = new LinkedHashMap<Integer, Double>();
initializeDataStructures();
}
private void initializeDataStructures() {
this.instances = new InstanceList(new StringList2FeatureSequence());
for (Map<Integer, Integer> map : this.maps) {
List<String> tags = new ArrayList<String>();
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
for (int i = 0; i < entry.getValue(); i++) {
tags.add(entry.getKey().toString());
}
}
Instance inst = new Instance(tags, null, null, null);
inst.setData(tags);
this.instances.addThruPipe(inst);
}
}
private List<Map<Integer, Double>> getMaxTopicsByDocs(ParallelTopicModel LDA, int maxTopicsPerDoc) {
List<Map<Integer, Double>> docList = new ArrayList<Map<Integer, Double>>();
Map<Integer, Double> unsortedMostPopularTopics = new LinkedHashMap<Integer, Double>();
int numDocs = this.instances.size();
for (int doc = 0; doc < numDocs; ++doc) {
Map<Integer, Double> topicList = new LinkedHashMap<Integer, Double>();
double[] topicProbs = LDA.getTopicProbabilities(doc);
//double probSum = 0.0;
for (int topic = 0; topic < topicProbs.length && topic < maxTopicsPerDoc; topic++) {
if (topicProbs[topic] > TOPIC_THRESHOLD) { // TODO
double newTopicProb = topicProbs[topic];
topicList.put(topic, newTopicProb);
Double oldTopicProb = unsortedMostPopularTopics.get(topic);
unsortedMostPopularTopics.put(topic, oldTopicProb == null ? newTopicProb : oldTopicProb.doubleValue() + newTopicProb);
//probSum += topicProbs[topic];
}
}
//System.out.println("Topic Sum: " + probSum);
Map<Integer, Double> sortedTopicList = new TreeMap<Integer, Double>(new DoubleMapComparator(topicList));
sortedTopicList.putAll(topicList);
docList.add(sortedTopicList);
}
Map<Integer, Double> sortedMostPopularTopics = new TreeMap<Integer, Double>(new DoubleMapComparator(unsortedMostPopularTopics));
sortedMostPopularTopics.putAll(unsortedMostPopularTopics);
for (Map.Entry<Integer, Double> entry : sortedMostPopularTopics.entrySet()) {
if (this.mostPopularTopics.size() < MAX_RECOMMENDATIONS) {
this.mostPopularTopics.put(entry.getKey(), entry.getValue());
}
}
return docList;
}
private List<Map<Integer, Double>> getMaxTermsByTopics(ParallelTopicModel LDA, int limit) {
Alphabet alphabet = LDA.getAlphabet();
List<Map<Integer, Double>> topicList = new ArrayList<Map<Integer, Double>>();
int numTopics = LDA.getNumTopics();
List<TreeSet<IDSorter>> sortedWords = LDA.getSortedWords();
for (int topic = 0; topic < numTopics; ++topic) {
Map<Integer, Double> termList = new LinkedHashMap<Integer, Double>();
TreeSet<IDSorter> topicWords = sortedWords.get(topic);
//int i = 0;
double weightSum = 0.0;
for (IDSorter entry : topicWords) {
if (entry.getWeight() > 0.0) {
//if (i++ < limit) {
int tag = Integer.parseInt(alphabet.lookupObject(entry.getID()).toString());
termList.put(tag, entry.getWeight());
weightSum += entry.getWeight();
//} else {
// break;
//}
}
}
// relative values
//double relSum = 0.0;
for (Map.Entry<Integer, Double> entry : termList.entrySet()) {
//relSum += (entry.getValue() / weightSum);
entry.setValue(entry.getValue() / weightSum);
}
//System.out.println("RelSum: " + relSum);
topicList.add(termList);
}
return topicList;
}
public void predictValuesProbs(boolean topicCreation) {
ParallelTopicModel LDA = new ParallelTopicModel(this.numTopics, ALPHA * this.numTopics, BETA); // TODO
LDA.addInstances(this.instances);
LDA.setNumThreads(1);
LDA.setNumIterations(NUM_ITERATIONS);
LDA.setRandomSeed(43);
try {
LDA.estimate();
} catch (Exception e) {
e.printStackTrace();
}
this.docList = getMaxTopicsByDocs(LDA, this.numTopics);
System.out.println("Fetched Doc-List");
this.topicList = !topicCreation ? getMaxTermsByTopics(LDA, MAX_TERMS) : null;
System.out.println("Fetched Topic-List");
}
public Map<Integer, Double> getValueProbsForID(int id, boolean topicCreation) {
Map<Integer, Double> terms = null;
if (id < this.docList.size()) {
Map<Integer, Double> docVals = this.docList.get(id);
if (this.topicList == null) {
return docVals;
}
terms = new LinkedHashMap<Integer, Double>();
for (Map.Entry<Integer, Double> topic : docVals.entrySet()) { // look at each assigned topic
Set<Entry<Integer, Double>> entrySet = this.topicList.get(topic.getKey()).entrySet();
double topicProb = topic.getValue();
for (Map.Entry<Integer, Double> entry : entrySet) { // and its terms
if (topicCreation) {
// DEPRECATED
if (topicProb > TOPIC_THRESHOLD) {
terms.put(entry.getKey(), topicProb);
break; // only use first tag as topic-name with the topic probability
}
} else {
double wordProb = entry.getValue();
Double val = terms.get(entry.getKey());
terms.put(entry.getKey(), val == null ? wordProb * topicProb : val + wordProb * topicProb);
}
}
}
}
return terms;
}
public Map<Integer, Double> getMostPopularTopics() {
return this.mostPopularTopics;
}
// Statics -------------------------------------------------------------------------------------------------------------------------
private static String timeString;
private static Map<Integer, Double> getRankedTagList(BookmarkReader reader, Map<Integer, Double> userMap, Map<Integer, Double> resMap, boolean sorting) {
Map<Integer, Double> resultMap = new LinkedHashMap<Integer, Double>();
if (userMap != null) {
for (Map.Entry<Integer, Double> entry : userMap.entrySet()) {
resultMap.put(entry.getKey(), entry.getValue().doubleValue());
}
}
if (resMap != null) {
for (Map.Entry<Integer, Double> entry : resMap.entrySet()) {
double resVal = entry.getValue().doubleValue();
Double val = resultMap.get(entry.getKey());
resultMap.put(entry.getKey(), val == null ? resVal : val.doubleValue() + resVal);
}
}
if (sorting) {
Map<Integer, Double> sortedResultMap = new TreeMap<Integer, Double>(new DoubleMapComparator(resultMap));
sortedResultMap.putAll(resultMap);
Map<Integer, Double> returnMap = new LinkedHashMap<Integer, Double>(MAX_RECOMMENDATIONS);
int i = 0;
for (Map.Entry<Integer, Double> entry : sortedResultMap.entrySet()) {
if (i++ < MAX_RECOMMENDATIONS) {
returnMap.put(entry.getKey(), entry.getValue());
} else {
break;
}
}
return returnMap;
}
return resultMap;
}
private static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader, int sampleSize, boolean sorting, int numTopics, boolean userBased, boolean resBased, boolean topicCreation) {
int size = reader.getBookmarks().size();
int trainSize = size - sampleSize;
//int oldTrainSize = trainSize;
Stopwatch timer = new Stopwatch();
timer.start();
MalletCalculator userCalc = null;
List<Map<Integer, Integer>> userMaps = null;
if (userBased) {
userMaps = Utilities.getUserMaps(reader.getBookmarks().subList(0, trainSize));
userCalc = new MalletCalculator(userMaps, numTopics);
userCalc.predictValuesProbs(topicCreation);
System.out.println("User-Training finished");
}
MalletCalculator resCalc = null;
List<Map<Integer, Integer>> resMaps = null;
if (resBased) {
resMaps = Utilities.getResMaps(reader.getBookmarks().subList(0, trainSize));
resCalc = new MalletCalculator(resMaps, numTopics);
resCalc.predictValuesProbs(topicCreation);
System.out.println("Res-Training finished");
}
List<Map<Integer, Double>> results = new ArrayList<Map<Integer, Double>>();
if (topicCreation) {
trainSize = 0;
}
timer.stop();
long trainingTime = timer.elapsed(TimeUnit.MILLISECONDS);
timer.reset();
timer.start();
int mpCount = 0;
for (int i = trainSize; i < size; i++) { // the test set
Bookmark data = reader.getBookmarks().get(i);
int userID = data.getUserID();
int resID = data.getResourceID();
Map<Integer, Double> userPredMap = null;
if (userCalc != null) {
userPredMap = userCalc.getValueProbsForID(userID, topicCreation);
}
Map<Integer, Double> resPredMap = null;
if (resCalc != null) {
//if (i > oldTrainSize) {
// System.out.println("Test-Set");
//}
resPredMap = resCalc.getValueProbsForID(resID, topicCreation);
if (topicCreation && resPredMap == null) {
resPredMap = resCalc.getMostPopularTopics();
mpCount++;
}
}
Map<Integer, Double> map = getRankedTagList(reader, userPredMap, resPredMap, sorting);
results.add(map);
}
timer.stop();
long testTime = timer.elapsed(TimeUnit.MILLISECONDS);
timeString = PerformanceMeasurement.addTimeMeasurement(timeString, true, trainingTime, testTime, (topicCreation ? size : sampleSize));
System.out.println("MpCount: " + mpCount);
return results;
}
// public statics ---------------------------------------------------------------------------------------------------------------------------------
public static BookmarkReader predictSample(String filename, int trainSize, int sampleSize, int numTopics, boolean userBased, boolean resBased) {
Timer timerThread = new Timer();
MemoryThread memoryThread = new MemoryThread();
timerThread.schedule(memoryThread, 0, MemoryThread.TIME_SPAN);
BookmarkReader reader = new BookmarkReader(trainSize, false);
reader.readFile(filename);
List<Map<Integer, Double>> ldaValues = startLdaCreation(reader, sampleSize, true, numTopics, userBased, resBased, false);
List<int[]> predictionValues = new ArrayList<int[]>();
for (int i = 0; i < ldaValues.size(); i++) {
Map<Integer, Double> ldaVal = ldaValues.get(i);
predictionValues.add(Ints.toArray(ldaVal.keySet()));
}
reader.setTestLines(reader.getBookmarks().subList(trainSize, reader.getBookmarks().size()));
PredictionFileWriter writer = new PredictionFileWriter(reader, predictionValues);
writer.writeFile(filename + "_lda_" + numTopics);
timeString = PerformanceMeasurement.addMemoryMeasurement(timeString, false, memoryThread.getMaxMemory());
timerThread.cancel();
Utilities.writeStringToFile("./data/metrics/" + filename + "_lda_" + numTopics + "_TIME.txt", timeString);
return reader;
}
public static void createSample(String filename, short numTopics, boolean tagRec, int trainSize, boolean personalizedTopicCreation) {
Timer timerThread = new Timer();
MemoryThread memoryThread = new MemoryThread();
timerThread.schedule(memoryThread, 0, MemoryThread.TIME_SPAN);
String outputFile = new String(filename) + "_lda_" + numTopics;
if (tagRec) {
TOPIC_THRESHOLD = 0.001;
} else {
TOPIC_THRESHOLD = 0.01;
}
Integer creationTrainSize = (personalizedTopicCreation ? trainSize : null);
BookmarkReader reader = new BookmarkReader(creationTrainSize == null ? 0 : creationTrainSize.intValue(), false);
reader.readFile(filename);
int size = reader.getBookmarks().size();
List<Map<Integer, Double>> ldaValues = startLdaCreation(reader, creationTrainSize == null ? 0 : size - creationTrainSize.intValue(), true, numTopics, false, true, true);
List<int[]> predictionValues = new ArrayList<int[]>();
//List<double[]> probValues = new ArrayList<double[]>();
for (int i = 0; i < ldaValues.size(); i++) {
Map<Integer, Double> ldaVal = ldaValues.get(i);
predictionValues.add(Ints.toArray(ldaVal.keySet()));
//probValues.add(Doubles.toArray(ldaVal.values()));
}
List<Bookmark> userSample = reader.getBookmarks().subList(0, size);
BookmarkWriter.writeSample(reader, userSample, outputFile, predictionValues, false);
//if (creationTrainSize != null) {
List<Bookmark> trainUserSample = reader.getBookmarks().subList(0, trainSize);
List<int[]> trainPredictionValues = predictionValues.subList(0, trainSize);
List<Bookmark> testUserSample = reader.getBookmarks().subList(trainSize, size);
List<int[]> testPredictionValues = predictionValues.subList(trainSize, size);
BookmarkWriter.writeSample(reader, trainUserSample, outputFile + "_train", trainPredictionValues, false);
BookmarkWriter.writeSample(reader, testUserSample, outputFile + "_test", testPredictionValues, false);
//}
timeString = PerformanceMeasurement.addMemoryMeasurement(timeString, false, memoryThread.getMaxMemory());
timerThread.cancel();
Utilities.writeStringToFile("./data/metrics/" + filename + "_lda_creation_" + numTopics + "_TIME.txt", timeString);
}
}