package aima.test.core.unit.learning.framework;
import java.util.Hashtable;
import org.junit.Assert;
import org.junit.Test;
import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.DataSetFactory;
import aima.core.util.Util;
/**
* @author Ravi Mohan
*
*/
public class InformationAndGainTest {
@Test
public void testInformationCalculation() {
double[] fairCoinProbabilities = new double[] { 0.5, 0.5 };
double[] loadedCoinProbabilities = new double[] { 0.01, 0.99 };
Assert.assertEquals(1.0, Util.information(fairCoinProbabilities), 0.001);
Assert.assertEquals(0.08079313589591118,
Util.information(loadedCoinProbabilities), 0.000000000000000001);
}
@Test
public void testBasicDataSetInformationCalculation() throws Exception {
DataSet ds = DataSetFactory.getRestaurantDataSet();
double infoForTargetAttribute = ds.getInformationFor();// this should
// be the
// generic
// distribution
Assert.assertEquals(1.0, infoForTargetAttribute, 0.001);
}
@Test
public void testDataSetSplit() throws Exception {
DataSet ds = DataSetFactory.getRestaurantDataSet();
Hashtable<String, DataSet> hash = ds.splitByAttribute("patrons");// this
// should
// be
// the
// generic
// distribution
Assert.assertEquals(3, hash.keySet().size());
Assert.assertEquals(6, hash.get("Full").size());
Assert.assertEquals(2, hash.get("None").size());
Assert.assertEquals(4, hash.get("Some").size());
}
@Test
public void testGainCalculation() throws Exception {
DataSet ds = DataSetFactory.getRestaurantDataSet();
double gain = ds.calculateGainFor("patrons");
Assert.assertEquals(0.541, gain, 0.001);
gain = ds.calculateGainFor("type");
Assert.assertEquals(0.0, gain, 0.001);
}
}