/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.lda.cvb;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixUtils;
import org.apache.mahout.math.function.DoubleFunction;
import org.junit.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
public class TestCVBModelTrainer extends MahoutTestCase {
private static final double ETA = 0.1;
private static final double ALPHA = 0.1;
@Test
public void testInMemoryCVB0() throws Exception {
String[] terms = new String[26];
for(int i=0; i<terms.length; i++) {
terms[i] = String.valueOf((char) (i + 'a'));
}
int numGeneratingTopics = 3;
int numTerms = 26;
Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
@Override public double apply(double d) {
return 1.0 / Math.pow(d + 1.0, 2);
}
});
int numDocs = 100;
int numSamples = 20;
int numTopicsPerDoc = 1;
Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, RandomUtils.getRandom(),
numDocs, numSamples, numTopicsPerDoc);
List<Double> perplexities = Lists.newArrayList();
int numTrials = 1;
for (int numTestTopics = 1; numTestTopics < 2 * numGeneratingTopics; numTestTopics++) {
double[] perps = new double[numTrials];
for(int trial = 0; trial < numTrials; trial++) {
InMemoryCollapsedVariationalBayes0 cvb =
new InMemoryCollapsedVariationalBayes0(sampledCorpus, terms, numTestTopics, ALPHA, ETA,
2, 1, 0, (trial+1) * 123456L);
cvb.setVerbose(true);
perps[trial] = cvb.iterateUntilConvergence(0, 5, 0, 0.2);
System.out.println(perps[trial]);
}
Arrays.sort(perps);
System.out.println(Arrays.toString(perps));
perplexities.add(perps[0]);
}
System.out.println(Joiner.on(",").join(perplexities));
}
@Test
public void testRandomStructuredModelViaMR() throws Exception {
int numGeneratingTopics = 3;
int numTerms = 9;
Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
@Override
public double apply(double d) {
return 1.0 / Math.pow(d + 1.0, 3);
}
});
int numDocs = 500;
int numSamples = 10;
int numTopicsPerDoc = 1;
Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, RandomUtils.getRandom(1234),
numDocs, numSamples, numTopicsPerDoc);
Path sampleCorpusPath = getTestTempDirPath("corpus");
MatrixUtils.write(sampleCorpusPath, new Configuration(), sampledCorpus);
int numIterations = 5;
List<Double> perplexities = Lists.newArrayList();
int startTopic = numGeneratingTopics - 1;
int numTestTopics = startTopic;
while(numTestTopics < numGeneratingTopics + 2) {
Path topicModelStateTempPath = getTestTempDirPath("topicTemp" + numTestTopics);
Configuration conf = new Configuration();
CVB0Driver.run(conf, sampleCorpusPath, null, numTestTopics, numTerms,
ALPHA, ETA, numIterations, 1, 0, null, null, topicModelStateTempPath, 1234, 0.2f, 2,
1, 3, 1, false);
perplexities.add(lowestPerplexity(conf, topicModelStateTempPath));
numTestTopics++;
}
int bestTopic = -1;
double lowestPerplexity = Double.MAX_VALUE;
for(int t = 0; t < perplexities.size(); t++) {
if(perplexities.get(t) < lowestPerplexity) {
lowestPerplexity = perplexities.get(t);
bestTopic = t + startTopic;
}
}
assertEquals("The optimal number of topics is not that of the generating distribution",
bestTopic, numGeneratingTopics);
System.out.println("Perplexities: " + Joiner.on(", ").join(perplexities));
}
private static double lowestPerplexity(Configuration conf, Path topicModelTemp)
throws IOException {
double lowest = Double.MAX_VALUE;
double current;
int iteration = 2;
while(!Double.isNaN(current = CVB0Driver.readPerplexity(conf, topicModelTemp, iteration))) {
lowest = Math.min(current, lowest);
iteration++;
}
return lowest;
}
}