/* * File: MeanLearnerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright April 21, 2008, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.baseline; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.ConstantEvaluator; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.util.CloneableSerializable; import java.util.ArrayList; import junit.framework.TestCase; /** * Tests of MeanLearner * * @author Justin Basilico * @since 2.1 */ public class MeanLearnerTest extends TestCase { /** * Creates a new test. * * @param testName The test name. */ public MeanLearnerTest( String testName) { super(testName); } /** * Tests the constructors of class MeanLearner. */ public void testConstructor() { MeanLearner instance = new MeanLearner(); assertNotNull(instance); } /** * Tests of clone */ public void testClone() { System.out.println( "Clone" ); MeanLearner instance = new MeanLearner(); CloneableSerializable clone = instance.clone(); assertNotNull( clone ); assertNotSame( instance, clone ); } /** * Test of learn method, of class MeanLearner. */ public void testLearn() { MeanLearner instance = new MeanLearner(); ArrayList<InputOutputPair<Vector, Double>> data = new ArrayList<InputOutputPair<Vector, Double>>(); ConstantEvaluator<Double> result = instance.learn(data); assertEquals(0.0, result.getValue()); double[][] values = new double[][]{ new double[]{0.00, -2.00}, new double[]{2.00, 2.00}, new double[]{3.00, 4.10}, new double[]{3.50, 5.00}, new double[]{4.00, 5.90}, new double[]{6.00, 10.10}, new double[]{8.00, 13.90}, new double[]{9.00, 16.00}}; VectorFactory<?> factory = VectorFactory.getDefault(); for (int i = 0; i < values.length; i++) { double input = values[i][0]; double output = values[i][1]; data.add(new DefaultInputOutputPair<Vector, Double>( factory.copyValues(input), output)); } result = instance.learn(data); assertEquals(6.875, result.getValue()); } }