/* * File: AbstractFactorizationMachineLearnerTest.java * Authors: Justin Basilico * Project: Cognitive Foundry * * Copyright 2013 Cognitive Foundry. All rights reserved. */ package gov.sandia.cognition.learning.algorithm.factor.machine; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.util.NamedValue; import java.util.Arrays; import java.util.Collections; import java.util.Random; import org.junit.Test; import static org.junit.Assert.*; /** * Unit tests for class {@link AbstractFactorizationMachineLearner}. * * @author Justin Basilico * @since 3.4.0 */ public class AbstractFactorizationMachineLearnerTest extends Object { /** * Creates a new test. */ public AbstractFactorizationMachineLearnerTest() { super(); } /** * Creates a new instance. * * @return * A new instance; */ protected AbstractFactorizationMachineLearner createInstance() { return new DummyFactorizationMachineLearner(); } /** * Test of constructors of class AbstractFactorizationMachineLearner. */ @Test public void testConstructors() { int factorCount = AbstractFactorizationMachineLearner.DEFAULT_FACTOR_COUNT; double biasRegularization = AbstractFactorizationMachineLearner.DEFAULT_BIAS_REGULARIZATION; double weightRegularization = AbstractFactorizationMachineLearner.DEFAULT_WEIGHT_REGULARIZATION; double factorRegularization = AbstractFactorizationMachineLearner.DEFAULT_FACTOR_REGULARIZATION; double seedScale = AbstractFactorizationMachineLearner.DEFAULT_SEED_SCALE; int maxIterations = AbstractFactorizationMachineLearner.DEFAULT_MAX_ITERATIONS; AbstractFactorizationMachineLearner instance = new DummyFactorizationMachineLearner(); assertEquals(factorCount, instance.getFactorCount()); assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0); assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0); assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0); assertEquals(seedScale, instance.getSeedScale(), 0.0); assertEquals(maxIterations, instance.getMaxIterations()); assertNotNull(instance.getRandom()); assertSame(instance.getRandom(), instance.getRandom()); factorCount = 22; biasRegularization = 3.33; weightRegularization = 44.44; factorRegularization = 555.55; seedScale = 0.6; maxIterations = 777; Random random = new Random(); instance = new DummyFactorizationMachineLearner(factorCount, biasRegularization, weightRegularization, factorRegularization, seedScale, maxIterations, random); assertEquals(factorCount, instance.getFactorCount()); assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0); assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0); assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0); assertEquals(seedScale, instance.getSeedScale(), 0.0); assertEquals(maxIterations, instance.getMaxIterations()); assertSame(random, instance.getRandom()); // No negative factor counts. boolean exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(-1, biasRegularization, weightRegularization, factorRegularization, seedScale, maxIterations, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } // No negative bias regularization. exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(factorCount, -1.0, weightRegularization, factorRegularization, seedScale, maxIterations, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } // No negative weight regularization. exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(factorCount, biasRegularization, -1.0, factorRegularization, seedScale, maxIterations, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } // No negative factor regularization. exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(factorCount, biasRegularization, weightRegularization, -1.0, seedScale, maxIterations, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } // No negative seed scale. exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(factorCount, biasRegularization, weightRegularization, factorRegularization, -1.0, maxIterations, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } // No negative max iterations. exceptionThrown = false; try { instance = new DummyFactorizationMachineLearner(factorCount, biasRegularization, weightRegularization, factorRegularization, seedScale, -1, random); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } /** * Test of getResult method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetResult() { AbstractFactorizationMachineLearner instance = new DummyFactorizationMachineLearner(); assertNull(instance.getResult()); // Doing a learn should initialize the result. int k = 7; int d = 40; instance.setFactorCount(k); instance.learn(Collections.singletonList( DefaultInputOutputPair.create( VectorFactory.getDefault().createVector(d), 1.0))); FactorizationMachine result = instance.getResult(); assertSame(result, instance.getResult()); assertEquals(0.0, result.getBias(), 0.0); assertEquals(d, result.getInputDimensionality()); assertEquals(0.0, result.getWeights().norm2Squared(), 0.0); assertEquals(k, result.getFactorCount()); } /** * Test of getFactorCount method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetFactorCount() { this.testSetFactorCount(); } /** * Test of setFactorCount method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetFactorCount() { int factorCount = FactorizationMachineStochasticGradient.DEFAULT_FACTOR_COUNT; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(factorCount, instance.getFactorCount()); factorCount = 2; instance.setFactorCount(factorCount); assertEquals(factorCount, instance.getFactorCount()); factorCount = 0; instance.setFactorCount(factorCount); assertEquals(factorCount, instance.getFactorCount()); boolean exceptionThrown = false; try { instance.setFactorCount(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } /** * Test of isBiasEnabled method, of class AbstractFactorizationMachineLearner. */ @Test public void testIsBiasEnabled() { this.testSetBiasEnabled(); } /** * Test of setBiasEnabled method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetBiasEnabled() { boolean biasEnabled = AbstractFactorizationMachineLearner.DEFAULT_BIAS_ENABLED; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(biasEnabled, instance.isBiasEnabled()); biasEnabled = !biasEnabled; instance.setBiasEnabled(biasEnabled); assertEquals(biasEnabled, instance.isBiasEnabled()); biasEnabled = !biasEnabled; instance.setBiasEnabled(biasEnabled); assertEquals(biasEnabled, instance.isBiasEnabled()); } /** * Test of isWeightsEnabled method, of class AbstractFactorizationMachineLearner. */ @Test public void testIsWeightsEnabled() { this.testSetWeightsEnabled(); } /** * Test of setWeightsEnabled method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetWeightsEnabled() { boolean weightsEnabled = AbstractFactorizationMachineLearner.DEFAULT_WEIGHTS_ENABLED; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(weightsEnabled, instance.isWeightsEnabled()); weightsEnabled = !weightsEnabled; instance.setWeightsEnabled(weightsEnabled); assertEquals(weightsEnabled, instance.isWeightsEnabled()); weightsEnabled = !weightsEnabled; instance.setWeightsEnabled(weightsEnabled); assertEquals(weightsEnabled, instance.isWeightsEnabled()); } /** * Test of isFactorsEnabled method, of class AbstractFactorizationMachineLearner. */ @Test public void testIsFactorsEnabled() { AbstractFactorizationMachineLearner instance = this.createInstance(); assertTrue(instance.isFactorsEnabled()); instance.setFactorCount(0); assertFalse(instance.isFactorsEnabled()); instance.setFactorCount(1); assertTrue(instance.isFactorsEnabled()); instance.setFactorCount(2); assertTrue(instance.isFactorsEnabled()); } /** * Test of getBiasRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetBiasRegularization() { this.testSetBiasRegularization(); } /** * Test of setBiasRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetBiasRegularization() { double biasRegularization = AbstractFactorizationMachineLearner.DEFAULT_BIAS_REGULARIZATION; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0); double[] goodValues = {0.1, 0.2, 1.0, 12.1, 0.0, 111 }; for (double value : goodValues) { biasRegularization = value; instance.setBiasRegularization(biasRegularization); assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0); } double[] badValues = { -0.1, -1, -10}; for (double value : badValues) { boolean exceptionThrown = false; try { instance.setBiasRegularization(value); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } } /** * Test of getWeightRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetWeightRegularization() { this.testSetWeightRegularization(); } /** * Test of setWeightRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetWeightRegularization() { double weightRegularization = AbstractFactorizationMachineLearner.DEFAULT_WEIGHT_REGULARIZATION; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0); double[] goodValues = {0.1, 0.2, 1.0, 12.1, 0.0, 111 }; for (double value : goodValues) { weightRegularization = value; instance.setWeightRegularization(weightRegularization); assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0); } double[] badValues = { -0.1, -1, -10}; for (double value : badValues) { boolean exceptionThrown = false; try { instance.setWeightRegularization(value); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } } /** * Test of getFactorRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetFactorRegularization() { this.testSetFactorRegularization(); } /** * Test of setFactorRegularization method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetFactorRegularization() { double factorRegularization = AbstractFactorizationMachineLearner.DEFAULT_FACTOR_REGULARIZATION; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0); double[] goodValues = {0.1, 0.2, 1.0, 12.1, 0.0, 111 }; for (double value : goodValues) { factorRegularization = value; instance.setFactorRegularization(factorRegularization); assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0); } double[] badValues = { -0.1, -1, -10}; for (double value : badValues) { boolean exceptionThrown = false; try { instance.setFactorRegularization(value); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } } /** * Test of getSeedScale method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetSeedScale() { this.testSetSeedScale(); } /** * Test of setSeedScale method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetSeedScale() { double seedScale = AbstractFactorizationMachineLearner.DEFAULT_SEED_SCALE; AbstractFactorizationMachineLearner instance = this.createInstance(); assertEquals(seedScale, instance.getSeedScale(), 0.0); double[] goodValues = {0.1, 0.2, 1.0, 12.1, 0.0, 111 }; for (double value : goodValues) { seedScale = value; instance.setSeedScale(seedScale); assertEquals(seedScale, instance.getSeedScale(), 0.0); } double[] badValues = { -0.1, -1, -10}; for (double value : badValues) { boolean exceptionThrown = false; try { instance.setSeedScale(value); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } } /** * Test of getRandom method, of class AbstractFactorizationMachineLearner. */ @Test public void testGetRandom() { this.testSetRandom(); } /** * Test of setRandom method, of class AbstractFactorizationMachineLearner. */ @Test public void testSetRandom() { AbstractFactorizationMachineLearner instance = this.createInstance(); assertNotNull(instance.getRandom()); assertSame(instance.getRandom(), instance.getRandom()); Random random = new Random(); instance.setRandom(random); assertSame(random, instance.getRandom()); random = null; instance.setRandom(random); assertSame(random, instance.getRandom()); random = new Random(); instance.setRandom(random); assertSame(random, instance.getRandom()); } public class DummyFactorizationMachineLearner extends AbstractFactorizationMachineLearner { public DummyFactorizationMachineLearner() { super(); } public DummyFactorizationMachineLearner( final int factorCount, final double biasRegularization, final double weightRegularization, final double factorRegularization, final double seedScale, final int maxIterations, final Random random) { super(factorCount, biasRegularization, weightRegularization, factorRegularization, seedScale, maxIterations, random); } @Override protected boolean step() { return false; } @Override protected void cleanupAlgorithm() { } @Override public NamedValue<? extends Number> getPerformance() { return null; } } }