package edu.cmu.minorthird.classify.algorithms.svm;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.FeatureFactory;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
/**
*
* This class is responsible for testing VisibleSVM class.
*
* @author chiachi
*/
public class VisibleSVMTest extends TestCase{
Logger log=Logger.getLogger(this.getClass());
svm_model m_toy1Model;
svm_model m_toy2Model;
FeatureFactory m_toy1FeatureFactory;
FeatureFactory m_toy2FeatureFactory;
ExampleSchema m_toy2ExampleSchema;
MutableInstance[] m_toy1Instances;
MutableInstance[] m_toy2Instances;
/**
* Standard test class constructior for VisibleSVMTest
* @param name Name of the test
*/
public VisibleSVMTest(String name){
super(name);
createTestSettings();
}
/**
* Convinence constructior for VisibleSVMTest
*/
public VisibleSVMTest(){
super("VisibleSVMTest");
}
/**
* setUp to run before each test
*/
protected void setUp(){
Logger.getRootLogger().removeAllAppenders();
org.apache.log4j.BasicConfigurator.configure();
//TODO add initializations if needed
}
/**
* clean up to run after each test
*/
protected void tearDown(){
//TODO clean up resources if needed
}
public void testConstructorWithTwoParams(){
System.out.println("Testing Constructor for SVMLearner...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
assertNotNull(testedVSVM);
}
public void testConstructorWithThreeParams(){
System.out.println("Testing Constructor for MultiClassSVMLearner...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
assertNotNull(testedVSVM);
}
public void testGetExamples(){
System.out.println("Testing getExamples()...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
Example[] returnedExamples=testedVSVM.getExamples();
assertTrue(isReturnedDataMatched(returnedExamples,false));
}
public void testGetExamplesMultiClass(){
System.out.println("Testing getExamples() for MultiClassSVM...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
Example[] returnedExamples=testedVSVM.getExamples();
assertEquals(returnedExamples.length,m_toy2Instances.length);
// need to deal with this crap
System.out.println(">>> "+isReturnedDataMatched(returnedExamples,true));
assertTrue(isReturnedDataMatched(returnedExamples,true));
//assertTrue(isToy3MatchedReturnedData(returnedExamples));
}
public void testGetExampleWeightLabels(){
System.out.println("Testing getExampleWeightLabels()...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
String[][] returnedLabels=testedVSVM.getExampleWeightLabels();
libsvm.m3gateway gate=new libsvm.m3gateway(m_toy1Model);
double[][] weights=gate.getCoefficientsForSVsInDecisionFunctions();
double[] rlTemp=new double[returnedLabels.length];
double[] rw=new double[weights[0].length];
DecimalFormat df=new DecimalFormat("0.0000");
for(int index=0;index<weights[0].length;++index){
rlTemp[index]=Double.parseDouble(returnedLabels[index][0]);
rw[index]=Double.parseDouble(df.format(weights[0][index]));
assertEquals(rw[index],rlTemp[index]);
}
}
public void testGetExampleWeightLabelsMultiClass(){
System.out.println("Testing getExampleWeightLabels() for MultiClassSVM...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory);
String[][] returnedLabels=testedVSVM.getExampleWeightLabels();
libsvm.m3gateway gate=new libsvm.m3gateway(m_toy2Model);
double[][] weights=gate.getCoefficientsForSVsInDecisionFunctions();
double[] rlTemp=new double[returnedLabels.length];
double[] rw=new double[weights[0].length];
for(int k=0;k<weights.length;++k){
DecimalFormat df=new DecimalFormat("0.0000");
for(int index=0;index<weights[0].length;++index){
if((k==1&&index==4)||returnedLabels[index][k]=="null"){
rlTemp[index]=0.0;
rw[index]=0.0;
}else{
rlTemp[index]=Double.parseDouble(returnedLabels[index][k]);
rw[index]=Double.parseDouble(df.format(weights[k][index]));
}
assertEquals(rw[index],rlTemp[index]);
}
}
}
public void testGetHyperplane(){
System.out.println("Testing getHyperplane()...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
Hyperplane hp=testedVSVM.getHyperplane(0);
assertNotNull(hp);
}
public void testGetHyperplaneMultiClass(){
System.out.println("Testing getHyperplane() for MultiClassSVM...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
Hyperplane hp1=testedVSVM.getHyperplane(0);
Hyperplane hp2=testedVSVM.getHyperplane(1);
assertNotNull(hp1);
assertNotNull(hp2);
}
public void testGetHyperplaneLabel(){
System.out.println("Testing getHyperplaneLabel()...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
String returnedHPLabels=testedVSVM.getHyperplaneLabel(0);
assertTrue(returnedHPLabels.equals(""));
}
public void testGetHyperplaneLabelMultiClass(){
System.out.println("Testing getHyperplaneLabel() for MultiClassSVM...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
String label1=testedVSVM.getHyperplaneLabel(0);
String label2=testedVSVM.getHyperplaneLabel(1);
assertEquals("marge vs. homer",label1);
assertEquals("marge vs. bart",label2);
}
public void testToGUI(){
System.out.println("Testing toGUI()...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
assertNotNull(testedVSVM.toGUI());
}
public void testToGUIMultiClass(){
System.out.println("Testing toGUI() for MultiClassSVM...");
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
assertNotNull(testedVSVM.toGUI());
}
public void testGetHyperplaneOutOfBounds(){
System.out.println("Testing GetHyperplane() out of bounds...");
try{
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
testedVSVM.getHyperplane(3);
fail("Hyperplane retrieved using out of bounds index!");
}catch(IllegalArgumentException success){
assertNotNull(success.getMessage());
}
}
public void testGetHyperplaneMultiClassOutOfBounds(){
System.out.println("Testing GetHyperplane() out of bounds for MultiClassSVM...");
try{
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
testedVSVM.getHyperplane(4);
fail("Hyperplane retrieved using out of bounds index!");
}catch(IllegalArgumentException success){
assertNotNull(success.getMessage());
}
}
public void testGetHyperplaneLabelOutOfBounds(){
System.out.println("Testing GetHyperplaneLabel() out of bounds...");
try{
VisibleSVM testedVSVM=new VisibleSVM(m_toy1Model,m_toy1FeatureFactory);
testedVSVM.getHyperplaneLabel(3);
fail("HPLabel retrieved using out of bounds index!");
}catch(IllegalArgumentException success){
assertNotNull(success.getMessage());
}
}
public void testGetHyperplaneLabelMultiClassOutOfBounds(){
System.out.println("Testing GetHyperplaneLabel() out of bounds for MultiClassSVM...");
try{
VisibleSVM testedVSVM=new VisibleSVM(m_toy2Model,m_toy2FeatureFactory,m_toy2ExampleSchema);
testedVSVM.getHyperplaneLabel(4);
fail("HPLabel retrieved using out of bounds index!");
}catch(IllegalArgumentException success){
assertNotNull(success.getMessage());
}
}
/**
* Creates a TestSuite from all testXXX methods
* @return TestSuite
*/
public static Test suite(){
return new TestSuite(VisibleSVMTest.class);
}
private Dataset makeToy1Dataset(){
Dataset result=SampleDatasets.sampleData("toy",false);
m_toy1Instances=new MutableInstance[result.size()];
Iterator<Example> it=result.iterator();
for(int i=0;it.hasNext();i++){
Example example1=it.next();
Iterator<Feature> fLoop=example1.featureIterator();
m_toy1Instances[i]=new MutableInstance();
for(int j=0;fLoop.hasNext();j++){
Feature f=fLoop.next();
m_toy1Instances[i].addNumeric(f,example1.getWeight(f));
}
}
return result;
}
private Dataset makeToy2Dataset(){
Dataset result=SampleDatasets.makeToy3ClassData(new Random(100),5);
m_toy2Instances=new MutableInstance[result.size()];
Iterator<Example> it=result.iterator();
for(int i=0;it.hasNext();i++){
Example example1=it.next();
Iterator<Feature> fLoop=example1.featureIterator();
m_toy2Instances[i]=new MutableInstance();
for(int j=0;fLoop.hasNext();j++){
Feature f=fLoop.next();
m_toy2Instances[i].addNumeric(f,example1.getWeight(f));
}
}
return result;
}
private void createTestSettings(){
//init parameters, exactly the same as initparams in SVMLearner, and MultiClassSVMLearner
svm_parameter parameters=new svm_parameter();
parameters.svm_type=svm_parameter.C_SVC;
parameters.kernel_type=svm_parameter.LINEAR;
parameters.degree=3;
parameters.gamma=0; // 1/k
parameters.coef0=0;
parameters.nu=0.5;
parameters.cache_size=40;
parameters.C=1;
parameters.eps=1e-3;
parameters.p=0.1;
parameters.shrinking=1;
parameters.nr_weight=0;
parameters.weight_label=new int[0];
parameters.weight=new double[0];
parameters.probability=0;
Dataset dataset1=makeToy1Dataset();
m_toy1FeatureFactory=dataset1.getFeatureFactory();
svm_problem problem1=SVMUtils.convertToSVMProblem(dataset1);
m_toy1Model=svm.svm_train(problem1,parameters);
Dataset dataset2=makeToy2Dataset();
m_toy2FeatureFactory=dataset2.getFeatureFactory();
m_toy2ExampleSchema=dataset2.getSchema();
svm_problem problem2=SVMUtils.convertToSVMProblem(dataset2);
m_toy2Model=svm.svm_train(problem2,parameters);
}
private boolean isReturnedDataMatched(Example[] returnedExamples,boolean isMultiClassSVM){
MutableInstance[] instances;
if(isMultiClassSVM){
instances=m_toy2Instances;
}else{
instances=m_toy1Instances;
}
int featureCount=10;
String[][] originalNames=new String[instances.length][featureCount];
int[][] originalNumNames=new int[instances.length][featureCount];
for(int index=0;index<instances.length;++index){
int subidx=0;
String temphold="";
for(Iterator<Feature> flidx1=instances[index].featureIterator();flidx1
.hasNext();){
Feature ftemp1=flidx1.next();
originalNames[index][subidx]=ftemp1.toString();
originalNumNames[index][subidx]=ftemp1.numericName();
temphold+=originalNames[index][subidx]+" ";
++subidx;
}
}
boolean[] exampleChecked=new boolean[returnedExamples.length];
String[][] returnedNames=new String[returnedExamples.length][featureCount];
int[][] returnedNumNames=new int[returnedExamples.length][featureCount];
for(int index=0;index<returnedExamples.length;++index){
exampleChecked[index]=false;
int subidx=0;
String temphold="";
for(Iterator<Feature> flidx1=returnedExamples[index].featureIterator();flidx1
.hasNext();){
Feature ftemp1=flidx1.next();
returnedNames[index][subidx]=ftemp1.toString();
returnedNumNames[index][subidx]=ftemp1.numericName();
temphold+=returnedNames[index][subidx]+" ";
++subidx;
}
}
int[] matchedMap=new int[instances.length];
for(int index=0;index<exampleChecked.length;++index){
for(int idx=0;idx<originalNames.length;++idx){
if(originalNames[idx].length==returnedNames[index].length){
Arrays.sort(originalNumNames[idx]);
Arrays.sort(returnedNumNames[index]);
if(Arrays.equals(originalNumNames[idx],returnedNumNames[index])){
matchedMap[idx]=index+1;
exampleChecked[index]=true;
idx=originalNames.length;
}
}
}
}
for(int index=0;index<exampleChecked.length;++index){
if(!exampleChecked[index]){
return false;
}
}
return true;
}
/**
* Run the full suite of tests with text output
* @param args - unused
*/
public static void main(String args[]){
junit.textui.TestRunner.run(suite());
}
}