/* * File: LatentDirichletAllocationVectorGibbsSamplerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright October 22, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.text.topic; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.math.matrix.Vector; import static gov.sandia.cognition.math.ProbabilityUtil.*; import gov.sandia.cognition.util.DoubleReuseRandom; import java.io.IOException; import java.io.BufferedReader; import java.io.FileReader; import java.util.ArrayList; import java.util.Random; import org.junit.Test; import static org.junit.Assert.*; /** * Unit tests for class LatentDirichletAllocationVectorGibbsSampler. * * @author Justin Basilico * @since 3.1 */ public class LatentDirichletAllocationVectorGibbsSamplerTest { protected Random random = new Random(211); /** * Creates a new test. */ public LatentDirichletAllocationVectorGibbsSamplerTest() { } /** * Test of constructors of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testConstructors() { int topicCount = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_TOPIC_COUNT; double alpha = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_ALPHA; double beta = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_BETA; int maxIterations = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_MAX_ITERATIONS; int burnInIterations = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_BURN_IN_ITERATIONS; int iterationsPerSample = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_ITERATIONS_PER_SAMPLE; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(topicCount, instance.getTopicCount()); assertEquals(alpha, instance.getAlpha(), 0.0); assertEquals(beta, instance.getBeta(), 0.0); assertEquals(maxIterations, instance.getMaxIterations()); assertEquals(burnInIterations, instance.getBurnInIterations()); assertEquals(iterationsPerSample, instance.getIterationsPerSample()); assertNotNull(instance.getRandom()); topicCount = 1 + random.nextInt(100); alpha = random.nextDouble() * 10.0; beta = random.nextDouble() * 10.0; maxIterations = 1 + random.nextInt(100000); burnInIterations = random.nextInt(1000); iterationsPerSample = random.nextInt(100); instance = new LatentDirichletAllocationVectorGibbsSampler(topicCount, alpha, beta, maxIterations, burnInIterations, iterationsPerSample, random); assertEquals(topicCount, instance.getTopicCount()); assertEquals(alpha, instance.getAlpha(), 0.0); assertEquals(beta, instance.getBeta(), 0.0); assertEquals(maxIterations, instance.getMaxIterations()); assertEquals(burnInIterations, instance.getBurnInIterations()); assertEquals(iterationsPerSample, instance.getIterationsPerSample()); assertSame(random, instance.getRandom()); } /** * Test of learn method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testLearn() throws Exception { final VectorFactory<?> factory = VectorFactory.getSparseDefault(); final ArrayList<Vector> data = new ArrayList<Vector>(); data.add(factory.copyValues(0, 0, 4, 2, 5, 6, 0, 3, 0)); data.add(factory.copyValues(0, 0, 0, 8, 0, 3, 0, 0, 0)); data.add(factory.copyValues(4, 0, 6, 0, 0, 0, 3, 5, 0)); data.add(factory.copyValues(1, 0, 0, 3, 2, 0, 3, 8, 0)); data.add(factory.copyValues(3, 0, 5, 3, 0, 5, 6, 0, 0)); data.add(factory.copyValues(0, 0, 0, 1, 3, 3, 3, 2, 0)); int termCount = 9; int topicCount = 3; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler( topicCount, 2.0, 0.5, 50, 20, 10, random); assertNull(instance.learn(null)); assertNull(instance.learn(new ArrayList<Vector>())); LatentDirichletAllocationVectorGibbsSampler.Result result = instance.learn(data); assertEquals(topicCount, result.getTopicCount()); assertEquals(topicCount, result.topicTermProbabilities.length); assertEquals(data.size(), result.getDocumentCount()); assertEquals(data.size(), result.documentTopicProbabilities.length); assertEquals(termCount, result.getTermCount()); for (int i = 0; i < topicCount; i++) { assertEquals(termCount, result.topicTermProbabilities[i].length); for (int j = 0; j < termCount; j++) { assertIsProbability(result.topicTermProbabilities[i][j]); } } for (int i = 0; i < data.size(); i++) { assertEquals(topicCount, result.documentTopicProbabilities[i].length); for (int j = 0; j < topicCount; j++) { assertIsProbability(result.documentTopicProbabilities[i][j]); } } } /** * A test that first shows a bug in the code as it was, and then makes sure the issue doesn't come back. */ @Test public void testVectorWithDoubles() { double[] d1 = {0, 0.4, 1.1, 1.5, 2.0 }; double[] d2 = {1.2, 1.5, 2.3, 0, 1.2 }; double[] d3 = {0, 2.5, 3.3, 1.0, 1.2 }; final ArrayList<Vector> data = new ArrayList<Vector>(); data.add(VectorFactory.getDenseDefault().copyArray(d1)); data.add(VectorFactory.getDenseDefault().copyArray(d2)); data.add(VectorFactory.getDenseDefault().copyArray(d3)); LatentDirichletAllocationVectorGibbsSampler lda = new LatentDirichletAllocationVectorGibbsSampler(); // This used to throw an exception, but doesn't anymore // Refer to issue MachineLearning/FoundryLearning#9 for details lda.learn(data); } /** * Test of learn method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testDoubleReuse() throws Exception { final VectorFactory<?> factory = VectorFactory.getSparseDefault(); final ArrayList<Vector> data = new ArrayList<Vector>(); data.add(factory.copyValues(0, 0, 4, 2, 5, 6, 0, 3, 0)); data.add(factory.copyValues(0, 0, 0, 8, 0, 3, 0, 0, 0)); data.add(factory.copyValues(4, 0, 6, 0, 0, 0, 3, 5, 0)); data.add(factory.copyValues(1, 0, 0, 3, 2, 0, 3, 8, 0)); data.add(factory.copyValues(3, 0, 5, 3, 0, 5, 6, 0, 0)); data.add(factory.copyValues(0, 0, 0, 1, 3, 3, 3, 2, 0)); int termCount = 9; int topicCount = 3; Random randDoubleReuse = new DoubleReuseRandom(1000); LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler( topicCount, 2.0, 0.5, 50, 20, 10, randDoubleReuse); assertNull(instance.learn(null)); assertNull(instance.learn(new ArrayList<Vector>())); LatentDirichletAllocationVectorGibbsSampler.Result result = instance.learn(data); assertEquals(topicCount, result.getTopicCount()); assertEquals(topicCount, result.topicTermProbabilities.length); assertEquals(data.size(), result.getDocumentCount()); assertEquals(data.size(), result.documentTopicProbabilities.length); assertEquals(termCount, result.getTermCount()); for (int i = 0; i < topicCount; i++) { assertEquals(termCount, result.topicTermProbabilities[i].length); for (int j = 0; j < termCount; j++) { assertIsProbability(result.topicTermProbabilities[i][j]); } } for (int i = 0; i < data.size(); i++) { assertEquals(topicCount, result.documentTopicProbabilities[i].length); for (int j = 0; j < topicCount; j++) { assertIsProbability(result.documentTopicProbabilities[i][j]); } } } /** * Test of learn method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testLearnSmallTermCount() throws Exception { final VectorFactory<?> factory = VectorFactory.getSparseDefault(); final ArrayList<Vector> data = new ArrayList<Vector>(); data.add(factory.copyValues(1, 0)); data.add(factory.copyValues(1, 0)); data.add(factory.copyValues(0, 1)); data.add(factory.copyValues(0, 1)); data.add(factory.copyValues(1, 1)); data.add(factory.copyValues(1, 1)); int termCount = 2; int topicCount = 3; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler( topicCount, 2.0, 0.5, 50, 20, 10, random); assertNull(instance.learn(null)); assertNull(instance.learn(new ArrayList<Vector>())); LatentDirichletAllocationVectorGibbsSampler.Result result = instance.learn(data); assertEquals(topicCount, result.getTopicCount()); assertEquals(topicCount, result.topicTermProbabilities.length); assertEquals(data.size(), result.getDocumentCount()); assertEquals(data.size(), result.documentTopicProbabilities.length); assertEquals(termCount, result.getTermCount()); for (int i = 0; i < topicCount; i++) { assertEquals(termCount, result.topicTermProbabilities[i].length); for (int j = 0; j < termCount; j++) { assertIsProbability(result.topicTermProbabilities[i][j]); } } for (int i = 0; i < data.size(); i++) { assertEquals(topicCount, result.documentTopicProbabilities[i].length); for (int j = 0; j < topicCount; j++) { assertIsProbability(result.documentTopicProbabilities[i][j]); } } } public static ArrayList<String> readVocab( final String fileName) throws IOException { final ArrayList<String> result = new ArrayList<String>(); final BufferedReader in = new BufferedReader( new FileReader(fileName)); try { String line; while ((line = in.readLine()) != null) { result.add(line); } } finally { in.close(); } return result; } public static ArrayList<Vector> readDataAsVector( final String fileName, final int dimensionality) throws IOException { final ArrayList<Vector> result = new ArrayList<Vector>(); final BufferedReader in = new BufferedReader( new FileReader(fileName)); try { String line; while ((line = in.readLine()) != null) { final String[] parts = line.split("\\s"); final int elements = Integer.parseInt(parts[0]); final Vector v = VectorFactory.getSparseDefault().createVectorCapacity( dimensionality, elements); for (int i = 1; i < parts.length; i++) { final String part = parts[i]; final int split = part.indexOf(':'); final int index = Integer.parseInt(part.substring(0, split)); final int value = Integer.parseInt(part.substring(split + 1)); v.setElement(index, value); } result.add(v); } } finally { in.close(); } return result; } public static int[][] readDataAsArray( final String fileName) throws IOException { final ArrayList<int[]> result = new ArrayList<int[]>(); final BufferedReader in = new BufferedReader( new FileReader(fileName)); try { String line; while ((line = in.readLine()) != null) { final String[] parts = line.split("\\s"); final int elements = Integer.parseInt(parts[0]); final ArrayList<Integer> wordsList = new ArrayList<Integer>(elements); for (int i = 1; i < parts.length; i++) { final String part = parts[i]; final int split = part.indexOf(':'); final int index = Integer.parseInt(part.substring(0, split)); final int count = Integer.parseInt(part.substring(split + 1)); for (int j = 0; j < count; j++) { wordsList.add(index); } } int[] wordsArray = new int[wordsList.size()]; for (int i = 0; i < wordsList.size(); i++) { wordsArray[i] = wordsList.get(i); } result.add(wordsArray); } } finally { in.close(); } final int[][] resultArray = new int[result.size()][]; for (int i = 0; i < result.size(); i++) { resultArray[i] = result.get(i); } return resultArray; } /** * Test of getResult method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetResult() { LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertNull(instance.getResult()); } /** * Test of getRandom method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetRandom() { this.testSetRandom(); } /** * Test of setRandom method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetRandom() { LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertNotNull(instance.getRandom()); Random random = new Random(); instance.setRandom(random); assertSame(random, instance.getRandom()); random = new Random(); instance.setRandom(random); assertSame(random, instance.getRandom()); random = null; instance.setRandom(random); assertSame(random, instance.getRandom()); random = new Random(); instance.setRandom(random); assertSame(random, instance.getRandom()); } /** * Test of getTopicCount method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetTopicCount() { this.testSetTopicCount(); } /** * Test of setTopicCount method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetTopicCount() { int topicCount = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_TOPIC_COUNT; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(topicCount, instance.getTopicCount()); topicCount = 77; instance.setTopicCount(topicCount); assertEquals(topicCount, instance.getTopicCount()); boolean exceptionThrown = false; try { instance.setTopicCount(0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(topicCount, instance.getTopicCount()); exceptionThrown = false; try { instance.setTopicCount(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(topicCount, instance.getTopicCount()); } /** * Test of getAlpha method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetAlpha() { this.testSetAlpha(); } /** * Test of setAlpha method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetAlpha() { double alpha = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_ALPHA; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(alpha, instance.getAlpha(), 0.0); alpha = 1.1; instance.setAlpha(alpha); assertEquals(alpha, instance.getAlpha(), 0.0); boolean exceptionThrown = false; try { instance.setAlpha(0.0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(alpha, instance.getAlpha(), 0.0); exceptionThrown = false; try { instance.setAlpha(-0.1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(alpha, instance.getAlpha(), 0.0); } /** * Test of getBeta method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetBeta() { this.testSetBeta(); } /** * Test of setBeta method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetBeta() { double beta = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_BETA; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(beta, instance.getBeta(), 0.0); beta = 1.1; instance.setBeta(beta); assertEquals(beta, instance.getBeta(), 0.0); boolean exceptionThrown = false; try { instance.setBeta(0.0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(beta, instance.getBeta(), 0.0); exceptionThrown = false; try { instance.setBeta(-0.1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(beta, instance.getBeta(), 0.0); } /** * Test of getBurnInIterations method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetBurnInIterations() { this.testSetBurnInIterations(); } /** * Test of setBurnInIterations method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetBurnInIterations() { int burnInIterations = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_BURN_IN_ITERATIONS; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(burnInIterations, instance.getBurnInIterations()); burnInIterations = 0; instance.setBurnInIterations(burnInIterations); assertEquals(burnInIterations, instance.getBurnInIterations()); burnInIterations = 101; instance.setBurnInIterations(burnInIterations); assertEquals(burnInIterations, instance.getBurnInIterations()); boolean exceptionThrown = false; try { instance.setBurnInIterations(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(burnInIterations, instance.getBurnInIterations()); } /** * Test of getIterationsPerSample method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testGetIterationsPerSample() { this.testSetIterationsPerSample(); } /** * Test of setIterationsPerSample method, of class LatentDirichletAllocationVectorGibbsSampler. */ @Test public void testSetIterationsPerSample() { int iterationsPerSample = LatentDirichletAllocationVectorGibbsSampler.DEFAULT_ITERATIONS_PER_SAMPLE; LatentDirichletAllocationVectorGibbsSampler instance = new LatentDirichletAllocationVectorGibbsSampler(); assertEquals(iterationsPerSample, instance.getIterationsPerSample()); iterationsPerSample = 1; instance.setIterationsPerSample(iterationsPerSample); assertEquals(iterationsPerSample, instance.getIterationsPerSample()); iterationsPerSample = 12; instance.setIterationsPerSample(iterationsPerSample); assertEquals(iterationsPerSample, instance.getIterationsPerSample()); boolean exceptionThrown = false; try { instance.setIterationsPerSample(0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(iterationsPerSample, instance.getIterationsPerSample()); exceptionThrown = false; try { instance.setIterationsPerSample(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(iterationsPerSample, instance.getIterationsPerSample()); } }