/*
* File: MapBasedPointMassDistributionTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Dec 3, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.distribution;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import junit.framework.TestCase;
import java.util.Random;
/**
* Unit tests for MapBasedPointMassDistributionTest.
*
* @author krdixon
*/
public class MapBasedPointMassDistributionTest
extends TestCase
{
/**
* Random number generator to use for a fixed random seed.
*/
public final Random RANDOM = new Random( 1 );
/**
* Default tolerance of the regression tests, {@value}.
*/
public final double TOLERANCE = 1e-5;
/**
* Tests for class MapBasedPointMassDistributionTest.
* @param testName Name of the test.
*/
public MapBasedPointMassDistributionTest(
String testName)
{
super(testName);
}
/**
* Creates an instance
* @return
* Instance.
*/
DefaultDataDistribution<String> createInstance()
{
DefaultDataDistribution<String> f =
new DefaultDataDistribution<String>();
f.increment( "a", 0.5 );
f.increment( "b", 2.0 );
f.increment( "c", 2.5 );
return f;
}
/**
* Tests the constructors of class MapBasedPointMassDistributionTest.
*/
public void testConstructors()
{
System.out.println( "Constructors" );
DefaultDataDistribution<String> instance =
new DefaultDataDistribution<String>();
assertEquals( 0, instance.getDomain().size() );
assertEquals( 0.0, instance.getTotal() );
instance = new DefaultDataDistribution<String>(2);
assertEquals( 0, instance.getDomain().size() );
assertEquals( 0.0, instance.getTotal() );
}
/**
* Test of clone method, of class DefaultDataDistribution.
*/
public void testClone()
{
System.out.println("clone");
DefaultDataDistribution<String> f = this.createInstance();
@SuppressWarnings("unchecked")
DefaultDataDistribution<String> clone = f.clone();
assertNotSame( clone, f );
assertNotNull( clone );
assertEquals( f.getTotal(), clone.getTotal() );
assertNotSame( f.asMap(), clone.asMap() );
assertEquals( f.getDomain().size(), clone.asMap().size() );
for( String value : f.getDomain() )
{
assertEquals( f.get(value), clone.get(value) );
}
}
/**
* Test of sample method, of class DefaultDataDistribution.
*/
public void testSample()
{
System.out.println("sample");
int numSamples = 10;
DefaultDataDistribution<String> f = this.createInstance();
Collection<String> samples = f.sample( RANDOM, numSamples );
assertEquals( numSamples, samples.size() );
}
/**
* Test of add method, of class DefaultDataDistribution.
*/
public void testAdd()
{
System.out.println("add");
DefaultDataDistribution<String> f = this.createInstance();
double tm = 5.0;
for( String value : f.getDomain() )
{
double w1 = f.get(value);
f.increment( value );
assertEquals( 1.0+w1, f.get(value) );
tm += 1.0;
assertEquals( tm, f.getTotal() );
}
int n0 = f.getDomain().size();
String z = "z";
assertEquals( 0.0, f.get(z) );
assertEquals( tm, f.getTotal() );
f.increment( z );
assertEquals( 1.0, f.get(z) );
assertEquals( tm+1.0, f.getTotal() );
assertEquals( n0+1, f.getDomain().size() );
}
/**
* Test of add method, of class DefaultDataDistribution.
*/
public void testAdd_GenericType_double()
{
System.out.println("add");
DefaultDataDistribution<String> f = this.createInstance();
double tm = 5.0;
for( String value : f.getDomain() )
{
double w1 = f.get(value);
double mass = RANDOM.nextDouble();
tm += mass;
f.increment( value, mass );
assertEquals( w1+mass, f.get(value) );
assertEquals( tm, f.getTotal() );
}
int n0 = f.getDomain().size();
String z = "z";
assertEquals( 0.0, f.get(z) );
assertEquals( tm, f.getTotal() );
f.increment( z, 0.0 );
assertEquals( 0.0, f.get(z) );
assertEquals( n0, f.getDomain().size() );
double mz = RANDOM.nextDouble();
tm += mz;
f.increment( z, mz );
assertEquals( mz, f.get(z) );
assertEquals( tm, f.getTotal() );
assertEquals( n0+1, f.getDomain().size() );
f.increment(z);
double before = f.get(z);
f.increment( z, -1.0 );
assertEquals( before-1.0, f.get(z) );
assertEquals( tm, f.getTotal() );
}
/**
* Test of remove method, of class DefaultDataDistribution.
*/
public void testRemove()
{
System.out.println("remove");
DefaultDataDistribution<String> f = this.createInstance();
double tm = 5.0;
String z = "z";
int nz = f.getDomain().size();
f.increment( z );
assertEquals( nz+1, f.getDomain().size() );
assertEquals( 1.0, f.get(z) );
f.decrement(z);
assertEquals( tm, f.getTotal() );
assertEquals( nz+1, f.getDomain().size() );
assertEquals( 0.0, f.get(z) );
f.compact();
assertEquals( nz, f.getDomain().size() );
assertEquals( 0.0, f.get(z) );
f.increment( z, 1.0 + RANDOM.nextDouble() );
assertEquals( nz+1, f.getDomain().size() );
f.decrement(z);
assertEquals( nz+1, f.getDomain().size() );
f.compact();
assertEquals( nz+1, f.getDomain().size() );
}
/**
* Test of remove method, of class DefaultDataDistribution.
*/
public void testRemove_GenericType_double()
{
System.out.println("remove");
DefaultDataDistribution<String> f = this.createInstance();
double tm = 5.0;
for( String value : f.getDomain() )
{
double w1 = f.get(value);
double rm = RANDOM.nextDouble();
if( w1 > rm )
{
tm -= rm;
f.decrement( value, rm );
double em = w1-rm;
assertEquals( em, f.get(value) );
assertEquals( tm, f.getTotal() );
}
}
String z = "z";
int nz = f.getDomain().size();
double value = RANDOM.nextDouble();
f.increment( z, value );
assertEquals( nz+1, f.getDomain().size() );
assertEquals( value, f.get(z) );
f.decrement( z, 0.0 );
assertEquals( nz+1, f.getDomain().size() );
assertEquals( value, f.get(z) );
f.decrement(z, value*2.0);
assertEquals( nz+1, f.getDomain().size() );
assertEquals( 0.0, f.get(z) );
f.compact();
assertEquals( nz, f.getDomain().size() );
assertEquals( 0.0, f.get(z) );
assertEquals( tm, f.getTotal() );
f.increment( z, value );
tm += value;
assertEquals( tm, f.getTotal() );
assertEquals( nz+1, f.getDomain().size() );
tm -= value/2.0;
f.decrement(z, value/2.0);
assertEquals( nz+1, f.getDomain().size() );
assertEquals( value/2.0, f.get(z) );
assertEquals( tm, f.getTotal() );
f.increment(z,1.0);
double before = f.get(z);
f.decrement( z, -1.0 );
assertEquals( before+1.0, f.get(z) );
double expected = f.getTotal() - f.get(z);
f.decrement(z,1000.0);
assertEquals( 0.0, f.get(z) );
assertEquals( expected, f.getTotal() );
}
/**
* Test of setMass method, of class DefaultDataDistribution.
*/
public void testSetMass()
{
System.out.println("add");
DefaultDataDistribution<String> f = this.createInstance();
double tm = 5.0;
for( String value : f.getDomain() )
{
double w1 = f.get(value);
double mass = RANDOM.nextDouble();
tm += mass - w1;
f.set( value, mass );
assertEquals( mass, f.get(value) );
assertEquals( tm, f.getTotal() );
}
int n0 = f.getDomain().size();
String z = "z";
assertEquals( 0.0, f.get(z) );
assertEquals( tm, f.getTotal() );
f.set( z, 0.0 );
assertEquals( 0.0, f.get(z) );
assertEquals( n0, f.getDomain().size() );
double mz = RANDOM.nextDouble();
tm += mz;
f.set( z, mz );
assertEquals( mz, f.get(z) );
assertEquals( tm, f.getTotal() );
assertEquals( n0+1, f.getDomain().size() );
tm -= mz;
f.set( z, -0.1 );
assertEquals( 0.0, f.get(z) );
assertEquals( tm, f.getTotal() );
}
public void testSetBelowZeroBug()
{
DefaultDataDistribution<String> f = new DefaultDataDistribution<>();
f.set("a", 3);
f.set("b", 4);
assertEquals(7, f.getTotal(), 0.0);
f.set("b", 2);
assertEquals(5, f.getTotal(), 0.0);
f.set("b", 0);
assertEquals(3, f.getTotal(), 0.0);
f.set("b", 5);
assertEquals(8, f.getTotal(), 0.0);
f.set("b", -4);
assertEquals(3, f.getTotal(), 0.0);
}
/**
* Test of getMass method, of class DefaultDataDistribution.
*/
public void testget()
{
System.out.println("getMass");
DefaultDataDistribution<String> f = this.createInstance();
assertEquals( 0.5, f.get("a") );
assertEquals( 2.0, f.get("b") );
assertEquals( 2.5, f.get("c") );
assertEquals( 0.0, f.get("z") );
}
/**
* Test of getFraction method, of class DefaultDataDistribution.
*/
public void testGetFraction()
{
DefaultDataDistribution<String> instance =
new DefaultDataDistribution<String>();
assertEquals(0.0, instance.getFraction("a"));
assertEquals(0.0, instance.getFraction("b"));
assertEquals(0.0, instance.getFraction("c"));
assertEquals(0.0, instance.getFraction("d"));
double epsilon = 0.000000001;
instance.increment("a");
assertEquals(1 / 1.0, instance.getFraction("a"), epsilon);
instance.increment("a");
assertEquals(2 / 2.0, instance.getFraction("a"), epsilon);
instance.increment("b");
assertEquals(2 / 3.0, instance.getFraction("a"), epsilon);
assertEquals(1 / 3.0, instance.getFraction("b"), epsilon);
instance.increment("c", 4.7);
assertEquals(2 / 7.7, instance.getFraction("a"), epsilon);
assertEquals(1 / 7.7, instance.getFraction("b"), epsilon);
assertEquals(4.7 / 7.7, instance.getFraction("c"), epsilon);
instance.increment("a", 2);
assertEquals(4 / 9.7, instance.getFraction("a"), epsilon);
assertEquals(1 / 9.7, instance.getFraction("b"), epsilon);
assertEquals(4.7 / 9.7, instance.getFraction("c"), epsilon);
instance.decrement("a", 1.0);
assertEquals(3 / 8.7, instance.getFraction("a"), epsilon);
assertEquals(1 / 8.7, instance.getFraction("b"), epsilon);
assertEquals(4.7 / 8.7, instance.getFraction("c"), epsilon);
instance.decrement("c", 3);
assertEquals(3 / 5.7, instance.getFraction("a"), epsilon);
assertEquals(1 / 5.7, instance.getFraction("b"), epsilon);
assertEquals(1.7 / 5.7, instance.getFraction("c"), epsilon);
instance.decrement("b", 1);
assertEquals(3 / 4.7, instance.getFraction("a"), epsilon);
assertEquals(0 / 4.7, instance.getFraction("b"), epsilon);
assertEquals(1.7 / 4.7, instance.getFraction("c"), epsilon);
instance.increment("d");
assertEquals(3 / 5.7, instance.getFraction("a"), epsilon);
assertEquals(0 / 5.7, instance.getFraction("b"), epsilon);
assertEquals(1.7 / 5.7, instance.getFraction("c"), epsilon);
assertEquals(1 / 5.7, instance.getFraction("d"), epsilon);
}
/**
* Test of getMaxValue method, of class DefaultDataDistribution.
*/
public void testGetMaxValue()
{
DefaultDataDistribution<String> instance =
new DefaultDataDistribution<String>();
assertEquals(0.0, instance.getMaxValue());
instance.increment("a");
assertEquals(1.0, instance.getMaxValue());
instance.increment("b");
assertEquals(1.0, instance.getMaxValue());
instance.increment("b");
assertEquals(2.0, instance.getMaxValue());
instance.increment("c", 7.4);
assertEquals(7.4, instance.getMaxValue());
}
/**
* Test of getMaxValueKey method, of class DefaultDataDistribution.
*/
public void testgetMaxValueKey()
{
DefaultDataDistribution<String> instance =
new DefaultDataDistribution<String>();
assertNull(instance.getMaxValueKey());
instance.increment("a");
assertEquals("a", instance.getMaxValueKey());
instance.increment("b");
assertTrue("a".equals(instance.getMaxValueKey())); // a should be the first value encountered.
instance.increment("b");
assertEquals("b", instance.getMaxValueKey());
instance.increment("c", 7.4);
assertEquals("c", instance.getMaxValueKey());
}
/**
* Test of getMaxValueKey method, of class DefaultDataDistribution.
*/
public void testgetMaxValueKeys()
{
DefaultDataDistribution<String> instance =
new DefaultDataDistribution<String>();
assertTrue(instance.getMaxValueKeys().isEmpty());
instance.increment("a");
assertEquals(1, instance.getMaxValueKeys().size());
assertTrue(instance.getMaxValueKeys().contains("a"));
instance.increment("b");
assertEquals(2, instance.getMaxValueKeys().size());
assertTrue(instance.getMaxValueKeys().contains("a"));
assertTrue(instance.getMaxValueKeys().contains("b"));
instance.increment("b");
assertEquals(1, instance.getMaxValueKeys().size());
assertTrue(instance.getMaxValueKeys().contains("b"));
instance.increment("c", 7.4);
assertEquals(1, instance.getMaxValueKeys().size());
assertTrue(instance.getMaxValueKeys().contains("c"));
}
/**
* Test of getDomain method, of class DefaultDataDistribution.
*/
public void testGetDomain()
{
System.out.println("getDomain");
DefaultDataDistribution<String> instance = this.createInstance();
assertEquals( 3, instance.getDomain().size() );
}
/**
* Test of getDistributionFunction method, of class DefaultDataDistribution.
*/
public void testGetDistributionFunction()
{
System.out.println("getDistributionFunction");
DefaultDataDistribution<String> instance = this.createInstance();
DataDistribution.PMF<String> pmf = instance.getProbabilityFunction();
assertNotNull( pmf );
assertNotSame( instance, pmf );
}
/**
* PMF.getDistributionFunction
*/
public void testPMFGetDistributionFunction()
{
System.out.println("PMF.getDistributionFunction");
DefaultDataDistribution.PMF<String> instance =
(DefaultDataDistribution.PMF<String>) this.createInstance().getProbabilityFunction();
assertSame( instance, instance.getProbabilityFunction() );
}
/**
* Test of getEntropy method, of class DefaultDataDistribution.
*/
public void testPMFGetEntropy()
{
System.out.println("getEntropy");
DataDistribution.PMF<String> instance =
this.createInstance().getProbabilityFunction();
assertEquals( ProbabilityMassFunctionUtil.getEntropy( instance ), instance.getEntropy() );
}
/**
* Test of evaluate method, of class DefaultDataDistribution.
*/
public void testPMFEvaluate()
{
System.out.println("evaluate");
DataDistribution.PMF<String> instance =
this.createInstance().getProbabilityFunction();
for( String value : instance.getDomain() )
{
assertEquals( instance.get(value)/instance.getTotal(), instance.evaluate(value) );
}
assertEquals( 0.0, instance.get("z") );
instance = new DefaultDataDistribution.PMF<String>();
assertEquals( 0.0, instance.get("a") );
instance = new DefaultDataDistribution.PMF<String>();
assertEquals( 0.0, instance.get("z") );
assertEquals( 0.0, instance.get("z") );
}
/**
* Test of asMap method, of class DefaultDataDistribution.
*/
public void testasMap()
{
System.out.println("asMap");
DefaultDataDistribution<String> instance = this.createInstance();
assertNotNull( instance.asMap() );
}
/**
* Test of getTotal method, of class DefaultDataDistribution.
*/
public void testgetTotal()
{
System.out.println("getTotal");
DefaultDataDistribution<String> instance = this.createInstance();
assertEquals( 5.0, instance.getTotal() );
}
/**
* clear()
*/
public void testClear()
{
System.out.println( "clear" );
DataDistribution.PMF<String> instance =
this.createInstance().getProbabilityFunction();
assertEquals( 5.0, instance.getTotal() );
assertEquals( 3, instance.getDomain().size() );
instance.clear();
assertEquals( 0.0, instance.getTotal() );
assertEquals( 0, instance.getDomain().size() );
instance.clear();
assertEquals( 0.0, instance.getTotal() );
assertEquals( 0, instance.getDomain().size() );
assertEquals( 0.0, instance.evaluate("z") );
}
/**
* toString()
*/
public void testToString()
{
System.out.println( "toString" );
DefaultDataDistribution<String> instance = this.createInstance();
String s = instance.toString();
System.out.println( "Distribution:\n" + s );
assertNotNull( s );
}
/**
* Learner
*/
public void testLearner()
{
System.out.println( "Learner" );
DefaultDataDistribution<String> instance = this.createInstance();
ArrayList<WeightedValue<String>> values =
new ArrayList<WeightedValue<String>>( instance.getDomain().size() );
for( String s : instance.getDomain() )
{
values.add( new DefaultWeightedValue<String>( s, instance.get(s) ) );
}
DefaultDataDistribution.WeightedEstimator<String> learner =
new DefaultDataDistribution.WeightedEstimator<String>();
DefaultDataDistribution<String> pmf = learner.learn(values);
assertEquals( instance.getDomain().size(), pmf.getDomain().size() );
for( String s : instance.getDomain() )
{
assertEquals( instance.get(s), pmf.get(s) );
}
}
}