package edu.cmu.minorthird.classify.algorithms.svm;
import java.util.Iterator;
import java.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import libsvm.svm_node;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.FeatureFactory;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
/**
*
* This class is responsible for testing nodeToFeature and nodeArrayToInstance
* functions from SVMUtils class.
*
* @author chiachi
*/
public class SVMUtilsTest extends TestCase{
static Logger logger=Logger.getLogger(SVMUtilsTest.class);
FeatureFactory featureFactory;
/**
* Standard test class constructior for SVMUtilsTest
* @param name Name of the test
*/
public SVMUtilsTest(String name){
super(name);
Dataset testDataSet=createTestDataset();
featureFactory=testDataSet.getFeatureFactory();
}
/**
* Convinence constructior for SVMUtilsTest
*/
public SVMUtilsTest(){
super("SVMUtilsTest");
}
/**
* setUp to run before each test
*/
protected void setUp(){
Logger.getRootLogger().removeAllAppenders();
org.apache.log4j.BasicConfigurator.configure();
//TODO add initializations if needed
}
/**
* clean up to run after each test
*/
protected void tearDown(){
//TODO clean up resources if needed
}
public void testNodeToFeature(){
System.out.println("Testing nodeToFeature()...");
System.out.println("FeatureFactory:");
System.out.println(featureFactory);
// test case 1
libsvm.svm_node svmNodeTemp=new libsvm.svm_node();
svmNodeTemp.index=1;
svmNodeTemp.value=10.7;
Feature testFeature=featureFactory.getFeature(svmNodeTemp.index-1);
String featureStrName=testFeature.toString();
Feature returnedFeature=SVMUtils.nodeToFeature(svmNodeTemp,featureFactory);
System.out.println("Feature Index "+svmNodeTemp.index+": "+returnedFeature);
assertNotNull(returnedFeature);
assertEquals(featureStrName,returnedFeature.toString());
// test case 2
libsvm.svm_node svmNodeTemp2=new libsvm.svm_node();
svmNodeTemp2.index=100;
svmNodeTemp2.value=10.7;
Feature returnedFeature2=SVMUtils.nodeToFeature(svmNodeTemp2,featureFactory);
System.out.println("Feature Index "+svmNodeTemp2.index+": "+returnedFeature2);
assertNull(returnedFeature2);
System.out.println("Done.");
}
public void testNodeArrayToInstance(){
System.out.println("Testing nodeArrayToInstance()...");
String[] featureNames=new String[3];
svm_node[] nodes=new svm_node[3];
for(int i=0;i<nodes.length;i++){
nodes[i]=new svm_node();
nodes[i].index=(i+1);
nodes[i].value=3.1+(double)i;
featureNames[i]=featureFactory.getFeature(i).toString();
}
// calling method and check returned object
Instance instance=SVMUtils.nodeArrayToInstance(nodes,featureFactory);
assertNotNull(instance);
checkInstance(instance,featureNames,nodes);
// call method with incorrect id
nodes[2].index=100;
instance=SVMUtils.nodeArrayToInstance(nodes,featureFactory);
assertNull(instance);
System.out.println("Done.");
}
private static void checkInstance(Instance instance,String[] featureNames,svm_node[] nodes){
for(Iterator<Feature> it=instance.numericFeatureIterator();it.hasNext();){
Feature feature=it.next();
boolean found=false;
for(int i=0;i<nodes.length;i++){
if(featureNames[i].equals(feature.toString())){
found=true;
assertEquals(nodes[i].value,instance.getWeight(feature));
}
}
assertTrue(found);
}
}
private static Dataset createTestDataset(){
int numInstances=10;
int numMaxFeatures=10;
Random random=new Random();
String[][] features=new String[][]{
{"bad","slow","mistake","complain","angry","stress"},
{"good","excellent","potentical","new","many","conquer","trial"}
};
String[] labels=new String[]{
"critisize",
"appreciate"
};
Dataset dataset=new BasicDataset();
for(int i=0;i<numInstances;i++){
MutableInstance instance=new MutableInstance();
int numFeatures=random.nextInt(numMaxFeatures)+1;
int labelIndex=random.nextInt(labels.length);
for(int j=0;j<numFeatures;j++){
int featureIndex=random.nextInt(features[labelIndex].length);
instance.addBinary(new Feature(new String[]{"testdata",features[labelIndex][featureIndex]}));
}
dataset.add(new Example(instance,new ClassLabel(labels[labelIndex])));
}
return dataset;
}
/**
* Creates a TestSuite from all testXXX methods
* @return TestSuite
*/
public static Test suite(){
return new TestSuite(SVMUtilsTest.class);
}
/**
* Run the full suite of tests with text output
* @param args - unused
*/
public static void main(String args[]){
junit.textui.TestRunner.run(suite());
}
}