package weka.classifiers.pmml.consumer;
import weka.core.Instances;
import weka.core.FastVector;
import weka.core.Attribute;
import weka.core.pmml.PMMLFactory;
import weka.core.pmml.PMMLModel;
import weka.test.Regression;
import weka.classifiers.evaluation.EvaluationUtils;
import java.io.*;
import junit.framework.TestCase;
import junit.framework.Test;
import junit.framework.TestSuite;
public abstract class AbstractPMMLClassifierTest extends TestCase {
protected FastVector m_modelNames = new FastVector();
protected FastVector m_dataSetNames = new FastVector();
public AbstractPMMLClassifierTest(String name) {
super(name);
}
public Instances getData(String name) {
Instances elnino = null;
try {
elnino =
new Instances(new BufferedReader(new InputStreamReader(
ClassLoader.getSystemResourceAsStream("weka/classifiers/pmml/data/" + name))));
} catch (Exception ex) {
ex.printStackTrace();
}
return elnino;
}
public PMMLClassifier getClassifier(String name) {
PMMLClassifier regression = null;
try {
PMMLModel model =
PMMLFactory.getPMMLModel(new BufferedInputStream(ClassLoader.getSystemResourceAsStream(
"weka/classifiers/pmml/data/" + name)));
regression = (PMMLClassifier)model;
} catch (Exception ex) {
ex.printStackTrace();
}
return regression;
}
public void testRegression() throws Exception {
PMMLClassifier classifier = null;
Instances testData = null;
EvaluationUtils evalUtils = null;
weka.test.Regression reg = new weka.test.Regression(this.getClass());
FastVector predictions = null;
boolean success = false;
for (int i = 0; i < m_modelNames.size(); i++) {
classifier = getClassifier((String)m_modelNames.elementAt(i));
testData = getData((String)m_dataSetNames.elementAt(i));
evalUtils = new EvaluationUtils();
try {
String className = classifier.getMiningSchema().getFieldsAsInstances().classAttribute().name();
Attribute classAtt = testData.attribute(className);
testData.setClass(classAtt);
predictions = evalUtils.getTestPredictions(classifier, testData);
success = true;
String predsString = weka.classifiers.AbstractClassifierTest.predictionsToString(predictions);
reg.println(predsString);
} catch (Exception ex) {
ex.printStackTrace();
String msg = ex.getMessage().toLowerCase();
if (msg.indexOf("not in classpath") > -1) {
return;
}
}
}
if (!success) {
fail("Problem during regression testing: no successful predictions generated");
}
try {
String diff = reg.diff();
if (diff == null) {
System.err.println("Warning: No reference available, creating.");
} else if (!diff.equals("")) {
fail("Regression test failed. Difference:\n" + diff);
}
} catch (java.io.IOException ex) {
fail("Problem during regression testing.\n" + ex);
}
}
}