/*
* File: ProbabilisticLatentSemanticAnalysisTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 02, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.text.topic;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.SparseVectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Random;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for class ProbabilisticLatentSemanticAnalysis.
*
* @author Justin Basilico
* @since 3.0
*/
public class ProbabilisticLatentSemanticAnalysisTest
{
/**
* Creates a new test.
*/
public ProbabilisticLatentSemanticAnalysisTest()
{
}
/**
* Test of constructors of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testConstructors()
{
}
/**
* Test of learn method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testLearn()
{
final Random random = new Random(211);
double[][] data = new double[][] {
{ 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0 },
{ 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0 },
{ 1, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0 },
{ 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1 }
};
// Note: See the comment at the end of the file for some MATLAB code
// for this data.
double EPSILON = 0.01;
NumberFormat format = new DecimalFormat("0.00");
ArrayList<Vector> documents = new ArrayList<Vector>(data.length);
for (double[] d : data)
{
documents.add(VectorFactory.getDefault().copyArray(d));
}
ProbabilisticLatentSemanticAnalysis instance = new ProbabilisticLatentSemanticAnalysis(random);
instance.addIterativeAlgorithmListener(new ProbabilisticLatentSemanticAnalysis.StatusPrinter());
int rank = 2;
instance.setRequestedRank(rank);
ProbabilisticLatentSemanticAnalysis.Result result =
instance.learn(documents);
assertSame(result, instance.getResult());
assertEquals(rank, result.latents.length);
System.out.println("Result:");
for (ProbabilisticLatentSemanticAnalysis.LatentData latent : result.latents)
{
System.out.println(" Latent " + latent.index);
System.out.println(" p(z) = " + format.format(latent.pLatent));
System.out.println(" p(t|z) = " + latent.pTermGivenLatent.toString(format));
System.out.println(" p(d|z) = " + latent.pDocumentGivenLatent.toString(format));
}
System.out.println("Transforms:");
for (Vector document : documents)
{
Vector transformed = result.evaluate(document);
assertEquals(1.0, transformed.sum(), 0.00001);
System.out.println(" p(z|q) = " + transformed.toString(format));
}
Vector zeros = documents.get(0).clone();
zeros.zero();
Vector transformed = result.evaluate(zeros);
assertEquals(0.0, transformed.sum(), 0.00001);
System.out.println(" p(z|0) = " + transformed.toString(format));
Vector ones = zeros.clone();
Vector tens = zeros.clone();
for (int i =0; i < ones.getDimensionality(); i++)
{
ones.setElement(i, 1.0);
tens.setElement(i, 10.0);
}
Vector transformed1s = result.evaluate(ones);
System.out.println(" p(z|1) = " + transformed1s.toString(format));
Vector transformed10s = result.evaluate(tens);
System.out.println(" p(z|10) = " + transformed10s.toString(format));
assertTrue(transformed1s.equals(transformed10s, 0.01));
}
/**
* Test of getResult method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetResult()
{
// Tested by testLearn.
}
/**
* Test of getRandom method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetRandom()
{
this.testSetRandom();
}
/**
* Test of setRandom method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testSetRandom()
{
ProbabilisticLatentSemanticAnalysis instance =
new ProbabilisticLatentSemanticAnalysis();
assertNotNull(instance.getRandom());
Random random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = null;
instance.setRandom(random);
assertSame(random, instance.getRandom());
}
/**
* Test of getVectorFactory method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetVectorFactory()
{
VectorFactory<? extends Vector> vectorFactory = VectorFactory.getDefault();
ProbabilisticLatentSemanticAnalysis instance =
new ProbabilisticLatentSemanticAnalysis();
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = SparseVectorFactory.getDefault();
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = null;
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
}
/**
* Test of getMatrixFactory method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetMatrixFactory()
{
MatrixFactory<? extends Matrix> matrixFactory = MatrixFactory.getDefault();
ProbabilisticLatentSemanticAnalysis instance =
new ProbabilisticLatentSemanticAnalysis();
assertSame(matrixFactory, instance.getMatrixFactory());
matrixFactory = new SparseMatrixFactoryMTJ();
instance.setMatrixFactory(matrixFactory);
assertSame(matrixFactory, instance.getMatrixFactory());
matrixFactory = null;
instance.setMatrixFactory(matrixFactory);
assertSame(matrixFactory, instance.getMatrixFactory());
}
/**
* Test of getRequestedRank method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetRequestedRank()
{
this.testSetRequestedRank();
}
/**
* Test of setRequestedRank method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testSetRequestedRank()
{
int requestedRank = ProbabilisticLatentSemanticAnalysis.DEFAULT_REQUESTED_RANK;
ProbabilisticLatentSemanticAnalysis instance =
new ProbabilisticLatentSemanticAnalysis();
assertEquals(requestedRank, instance.getRequestedRank());
requestedRank++;
instance.setRequestedRank(requestedRank);
assertEquals(requestedRank, instance.getRequestedRank());
boolean exceptionThrown = false;
try
{
instance.setRequestedRank(0);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(requestedRank, instance.getRequestedRank());
}
/**
* Test of getMinimumChange method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testGetMinimumChange()
{
this.testSetMinimumChange();
}
/**
* Test of setMinimumChange method, of class ProbabilisticLatentSemanticAnalysis.
*/
@Test
public void testSetMinimumChange()
{
double minimumChange = ProbabilisticLatentSemanticAnalysis.DEFAULT_MINIMUM_CHANGE;
ProbabilisticLatentSemanticAnalysis instance =
new ProbabilisticLatentSemanticAnalysis();
assertEquals(minimumChange, instance.getMinimumChange(), 0.0);
minimumChange = 0.0;
instance.setMinimumChange(minimumChange);
assertEquals(minimumChange, instance.getMinimumChange(), 0.0);
minimumChange = 2.4;
instance.setMinimumChange(minimumChange);
assertEquals(minimumChange, instance.getMinimumChange(), 0.0);
boolean exceptionThrown = false;
try
{
instance.setMinimumChange(-0.1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(minimumChange, instance.getMinimumChange(), 0.0);
}
}
/*
* Here is some helpful Octave/MATLAB code used for creating the unit test data:
D = [
1 0 0 1 0 0 0 0 0;
1 0 1 0 0 0 0 0 0;
1 1 0 0 0 0 0 0 0;
0 1 1 0 1 0 0 0 0;
0 1 1 2 0 0 0 0 0;
0 1 0 0 1 0 0 0 0;
0 1 0 0 1 0 0 0 0;
0 0 1 1 0 0 0 0 0;
0 1 0 0 0 0 0 0 1;
0 0 0 0 0 1 1 1 0;
0 0 0 0 0 0 1 1 1;
0 0 0 0 0 0 0 1 1];
*/