/*
* File: ExtendedKalmanFilterTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Apr 13, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.evaluator.AbstractStatefulEvaluator;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant;
import gov.sandia.cognition.learning.function.vector.GeneralizedLinearModel;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.signals.LinearDynamicalSystem;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import java.util.ArrayList;
/**
* Unit tests for ExtendedKalmanFilterTest.
*
* @author krdixon
*/
public class ExtendedKalmanFilterTest
extends RecursiveBayesianEstimatorTestHarness<Vector,Vector,MultivariateGaussian>
{
/**
* Tests for class ExtendedKalmanFilterTest.
* @param testName Name of the test.
*/
public ExtendedKalmanFilterTest(
String testName)
{
super(testName);
}
/**
* Tests the constructors of class ExtendedKalmanFilterTest.
*/
@Override
public void testConstructors()
{
System.out.println( "Constructors" );
ExtendedKalmanFilter ekf = new ExtendedKalmanFilter();
assertNull( ekf.getMotionModel() );
assertNull( ekf.getObservationModel() );
int stateDim = 2;
int outputDim = 1;
GeneralizedLinearModel observationModel =
new GeneralizedLinearModel(
stateDim, outputDim, new AtanFunction() );
StateSummer motionModel = new StateSummer(stateDim);
Vector input = VectorFactory.getDefault().createVector(stateDim);
Matrix modelCovariance =
MatrixFactory.getDefault().createIdentity(stateDim,stateDim);
Matrix measurementCovariance =
MatrixFactory.getDefault().createIdentity(outputDim, outputDim);
ekf = new ExtendedKalmanFilter(
motionModel, observationModel, input,
modelCovariance, measurementCovariance);
assertSame( motionModel, ekf.getMotionModel() );
assertSame( observationModel, ekf.getObservationModel() );
assertSame( modelCovariance, ekf.getModelCovariance() );
assertSame( measurementCovariance, ekf.getMeasurementCovariance() );
assertSame( input, ekf.getCurrentInput() );
}
/**
* Test of clone method, of class ExtendedKalmanFilter.
*/
public void testClone()
{
System.out.println("clone");
ExtendedKalmanFilter instance = this.createInstance();
ExtendedKalmanFilter clone = instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
assertNotNull( clone.getMotionModel() );
assertNotSame( instance.getMotionModel(), clone.getMotionModel() );
assertNotNull( clone.getObservationModel() );
assertNotSame( instance.getMotionModel(), clone.getMotionModel() );
}
/**
* StateSumer
*/
public static class StateSummer
extends AbstractStatefulEvaluator<Vector,Vector,Vector>
{
/**
* dim
*/
int dim;
/**
* Constructor
* @param dim
* dim
*/
public StateSummer(
int dim)
{
this.dim = dim;
}
@Override
public Vector createDefaultState()
{
return VectorFactory.getDefault().createVector(dim,1.0);
}
@Override
public Vector evaluate(
Vector input)
{
Vector state = this.getState();
state.scaleEquals(0.9);
state.plusEquals(input);
this.setState(state);
return state;
}
}
@Override
public ExtendedKalmanFilter createInstance()
{
int stateDim = 2;
int outputDim = 1;
GeneralizedLinearModel observationModel =
new GeneralizedLinearModel(
stateDim, outputDim, new AtanFunction() );
StateSummer motionModel = new StateSummer(stateDim);
Vector input = VectorFactory.getDefault().createVector(stateDim);
Matrix modelCovariance =
MatrixFactory.getDefault().createIdentity(stateDim,stateDim);
Matrix measurementCovariance =
MatrixFactory.getDefault().createIdentity(outputDim, outputDim);
return new ExtendedKalmanFilter(
motionModel, observationModel, input,
modelCovariance, measurementCovariance);
}
@Override
public MultivariateGaussian createConditionalDistribution()
{
Vector mean = VectorFactory.getDefault().copyValues( RANDOM.nextGaussian() );
Matrix covariance = MatrixFactory.getDefault().createMatrix(1,1);
covariance.setElement(0, 0, RANDOM.nextDouble());
return new MultivariateGaussian(mean, covariance);
}
@Override
public void testKnownValues()
{
System.out.println( "Known Values" );
// EKF and KF should be approximately equal for a LDS
final int dim = 2;
Matrix A = MatrixFactory.getDefault().createIdentity(dim, dim);
Matrix B = MatrixFactory.getDefault().createIdentity(dim, dim);
Matrix C = MatrixFactory.getDefault().createIdentity(dim, dim);
LinearDynamicalSystem model = new LinearDynamicalSystem( A, B, C );
MultivariateDiscriminant outputModel =
new MultivariateDiscriminant( C );
Vector input = VectorFactory.getDefault().createVector(dim,0.1);
Matrix modelCovariance = MatrixFactory.getDefault().createIdentity(dim,dim);
Matrix outputCovariance = MatrixFactory.getDefault().createIdentity(dim,dim);
ExtendedKalmanFilter ekf = new ExtendedKalmanFilter(
model.clone(), outputModel, input, modelCovariance, outputCovariance );
KalmanFilter kalman = new KalmanFilter(
model.clone(), modelCovariance, outputCovariance );
MultivariateGaussian noiseMaker = new MultivariateGaussian(
VectorFactory.getDefault().createVector(dim), outputCovariance );
ArrayList<Vector> noise = noiseMaker.sample(RANDOM, 100);
ArrayList<Vector> ks = new ArrayList<Vector>( noise.size() );
for( int n = 0; n < noise.size(); n++ )
{
ks.add( model.evaluate(input).plus( noise.get(n) ) );
}
MultivariateGaussian gekf = ekf.learn(ks);
MultivariateGaussian gk = kalman.learn(ks);
System.out.println( "EKF:\n" + gekf );
System.out.println( "Kalman:\n" + gk );
final double EPS = 1e-1;
Vector m1 = gk.getMean();
Vector m2 = gekf.getMean();
if( !m1.equals(m2,EPS) )
{
assertEquals( m1, m2 );
}
Matrix C1 = gk.getCovariance();
Matrix C2 = gekf.getCovariance();
if( !C1.equals(C2,EPS) )
{
assertEquals( C1, C2 );
}
}
}