/*
* Copyright (C) 2016 RankSys http://ranksys.org
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
package org.ranksys.lda;
import cc.mallet.pipe.Noop;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import java.io.IOException;
import java.util.Iterator;
import static java.util.stream.IntStream.range;
/**
* LDA model estimator. See ParallelTopicModel in Mallet (http://mallet.cs.umass.edu/) for more details.
*
* @author Saúl Vargas (Saul.Vargas@glasgow.ac.uk)
*/
public class LDAModelEstimator {
/**
* Estimate a topic model for collaborative filtering data.
*
* @param <U> user type
* @param <I> item type
* @param preferences preference data
* @param k number of topics
* @param alpha alpha in model
* @param beta beta in model
* @param numIterations number of iterations
* @param burninPeriod burnin period
* @return a topic model
* @throws IOException when internal IO error occurs
*/
public static <U, I> ParallelTopicModel estimate(FastPreferenceData<U, I> preferences, int k, double alpha, double beta, int numIterations, int burninPeriod) throws IOException {
ParallelTopicModel topicModel = new ParallelTopicModel(k, alpha * k, beta);
topicModel.addInstances(new LDAInstanceList<>(preferences));
topicModel.setTopicDisplay(numIterations + 1, 0);
topicModel.setNumIterations(numIterations);
topicModel.setBurninPeriod(burninPeriod);
topicModel.setNumThreads(Runtime.getRuntime().availableProcessors());
topicModel.estimate();
return topicModel;
}
private static class LDAAlphabet extends Alphabet {
private final int numItems;
public LDAAlphabet(int numItems) {
this.numItems = numItems;
}
@Override
public int size() {
return numItems;
}
}
private static class LDAInstanceList<U, I> extends InstanceList {
private final FastPreferenceData<U, I> preferences;
private final Alphabet alphabet;
public LDAInstanceList(FastPreferenceData<U, I> preferences) {
super(new Noop());
this.preferences = preferences;
this.alphabet = new LDAAlphabet(preferences.numItems());
}
@Override
public Iterator<Instance> iterator() {
return preferences.getAllUidx()
.mapToObj(preferences::getUidxPreferences)
.map(userPreferences -> {
FeatureSequence sequence = new FeatureSequence(alphabet);
userPreferences
.forEach(pref -> range(0, (int) pref.v2)
.forEach(i -> sequence.add(pref.v1)));
return new Instance(sequence, null, null, null);
})
.iterator();
}
@Override
public Alphabet getDataAlphabet() {
return alphabet;
}
}
}