/* * File: DiscreteSamplingUtilTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright June 16, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. */ package gov.sandia.cognition.statistics; import gov.sandia.cognition.math.matrix.VectorFactory; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.Random; import junit.framework.TestCase; /** * Unit tests for class {@code DiscreteSamplingUtil}. * * @author Justin Basilico * @since 3.1 */ public class DiscreteSamplingUtilTest extends TestCase { protected Random random = new Random(211); /** * Creates a new test. * * @param testName The test name. */ public DiscreteSamplingUtilTest( String testName) { super(testName); } /** * Test of sampleIndexFromProbabilities method, of class DiscreteSamplingUtil. */ public void testSampleIndexFromProbabilities() { double[] probabilities = {0.1, 0.7, 0.2}; int[] counts = new int[3]; for (int i = 0; i < 100; i++) { int index = DiscreteSamplingUtil.sampleIndexFromProbabilities(random, probabilities); assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); counts = new int[3]; for (int i = 0; i < 100; i++) { int index = DiscreteSamplingUtil.sampleIndexFromProbabilities(random, VectorFactory.getDefault().copyArray(probabilities)); assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); } /** * Test of sampleIndexFromProportions method, of class DiscreteSamplingUtil. */ public void testSampleIndexFromProportions() { double[] proportions = {0.5, 3.6, 1.1}; int[] counts = new int[3]; for (int i = 0; i < 100; i++) { int index = DiscreteSamplingUtil.sampleIndexFromProportions(random, proportions); assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); counts = new int[3]; double sum = 0.5 + 3.6 + 1.1; for (int i = 0; i < 100; i++) { int index = DiscreteSamplingUtil.sampleIndexFromProportions(random, proportions, sum); assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); } /** * Test of sampleIndicesFromProportions method, of class DiscreteSamplingUtil. */ public void testSampleIndicesFromProportions() { double[] proportions = {0.5, 3.6, 1.1}; int[] counts = new int[3]; int[] samples = DiscreteSamplingUtil.sampleIndicesFromProportions(random, proportions, 100); for (int i = 0; i < samples.length; i++) { int index = samples[i]; assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); } /** * Test of sampleIndexFromCumulativeProportions method, of class DiscreteSamplingUtil. */ public void testSampleIndexFromCumulativeProportions() { double[] cumulativeProportions = {0.5, 4.1, 5.2}; int[] counts = new int[3]; for (int i = 0; i < 100; i++) { int index = DiscreteSamplingUtil.sampleIndexFromCumulativeProportions(random, cumulativeProportions); assertTrue(index >= 0 && index < 3); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] > 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); } /** * Test of sampleIndicesFromCumulativeProportions method, of class DiscreteSamplingUtil. */ public void testSampleIndicesFromCumulativeProportions() { double[] cumulativeProportions = {0.5, 4.1, 4.1, 4.1, 5.2, 5.2}; int[] counts = new int[cumulativeProportions.length]; int[] samples = DiscreteSamplingUtil.sampleIndicesFromCumulativeProportions( random, cumulativeProportions, 100); for (int i = 0; i < samples.length; i++) { int index = samples[i]; assertTrue(index >= 0 && index < cumulativeProportions.length); counts[index]++; } assertTrue(counts[0] > 0); assertTrue(counts[1] > 0); assertTrue(counts[2] == 0); assertTrue(counts[3] == 0); assertTrue(counts[4] > 0); assertTrue(counts[5] == 0); assertTrue(counts[0] < counts[1]); assertTrue(counts[1] > counts[2]); assertTrue(counts[0] < counts[1]); } /** * Test of sampleWithReplacement method, of class DiscreteSamplingUtil. */ public void testSampleWithReplacement() { List<String> data = new ArrayList<String>(); data.add("a"); data.add("b"); data.add("c"); List<String> result = DiscreteSamplingUtil.sampleWithReplacement(random, data, 1); assertEquals(1, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result = DiscreteSamplingUtil.sampleWithReplacement(random, data, 2); assertEquals(2, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result = DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 3); assertEquals(3, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result = DiscreteSamplingUtil.sampleWithReplacement(random, data, 4); assertEquals(4, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result = DiscreteSamplingUtil.sampleWithReplacement(random, data, 100); assertEquals(100, result.size()); assertTrue(result.contains("a") && result.contains("b") && result.contains("c")); result = DiscreteSamplingUtil.sampleWithReplacement(random, data, 0); assertEquals(0, result.size()); } /** * Test of sampleWithReplacementInto method, of class DiscreteSamplingUtil. */ public void testSampleWithReplacementInto() { List<String> data = new ArrayList<String>(); data.add("a"); data.add("b"); data.add("c"); Collection<Object> result = new LinkedList<Object>(); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 1, result); assertEquals(1, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 2, result); assertEquals(2, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 3, result); assertEquals(3, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 4, result); assertEquals(4, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 100, result); assertEquals(100, result.size()); assertTrue(result.contains("a") && result.contains("b") && result.contains("c")); result.clear(); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 0, result); assertEquals(0, result.size()); // This tests that it doesn't remove stuff. result.clear(); result.add("d"); result.add("e"); DiscreteSamplingUtil.sampleWithReplacementInto(random, data, 1, result); assertEquals(3, result.size()); assertTrue(result.contains("d")); assertTrue(result.contains("e")); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); } /** * Test of sampleWithoutReplacement method, of class DiscreteSamplingUtil. */ public void testSampleWithoutReplacement() { List<String> data = new ArrayList<String>(); data.add("a"); data.add("b"); data.add("c"); List<String> result = DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 1); assertEquals(1, result.size()); assertTrue(result.contains("a") || result.contains("b") || result.contains("c")); result = DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 2); assertEquals(2, result.size()); int unique = 0; if (result.contains("a")) unique++; if (result.contains("b")) unique++; if (result.contains("c")) unique++; assertEquals(2, unique); result = DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 3); assertEquals(3, result.size()); assertTrue(result.contains("a") && result.contains("b") && result.contains("c")); boolean exceptionThrown = false; try { DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 4); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } exceptionThrown = false; try { DiscreteSamplingUtil.sampleWithoutReplacement(random, data, 0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } exceptionThrown = false; try { DiscreteSamplingUtil.sampleWithoutReplacement(random, data, -1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } }