package de.fub.agg2graph.gpseval; import de.fub.agg2graph.gpseval.data.AggregatedData; import de.fub.agg2graph.gpseval.features.Feature; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.functions.MultilayerPerceptron; import weka.classifiers.trees.J48; import weka.classifiers.trees.RandomForest; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; /** * WekaEval is used to run Weka with training- and test-sets to test the * different classification algorithms with different parameters. Instead of * using training/test-sets you can run a cross validation, too. */ public class WekaEval { private Map<String, List<AggregatedData>> mGpsData; private List<Feature> mFeatures; private Map<String, Attribute> mFeatureAttrMapping = new HashMap<>(); private Attribute mClassAttribute; private FastVector mAttrs; private double mTrainingSetSize = 0.6; private int mCrossValidationFolds = 10; private List<WekaResult> mCrossValidationResults = new LinkedList<>(); private List<WekaResult> mTrainingTestResults = new LinkedList<>(); // TODO implement setter private Classifier[] mClassifiers = new Classifier[]{new J48(), new RandomForest(), new NaiveBayes(), // TODO new BayesNet(), --> produces warning new MultilayerPerceptron()}; /** * * @param gpsData The GPS-data to use * @param features The feature-set */ @SuppressWarnings("unchecked") public WekaEval(Map<String, List<AggregatedData>> gpsData, List<Feature> features) { mGpsData = gpsData; mFeatures = features; /* * Create attributes based on feature-list */ mAttrs = new FastVector(); for (Feature feature : mFeatures) { Attribute attr = new Attribute(feature.getIdentifier(), mFeatureAttrMapping.keySet().size()); mFeatureAttrMapping.put(feature.getIdentifier(), attr); mAttrs.addElement(attr); } // the last attribute is the class attributes which accepts // the class name as string. These names must be added to a FastVector. int classCount = gpsData.keySet().size(); FastVector classAttributeValue = new FastVector(classCount); for (String className : gpsData.keySet()) { classAttributeValue.addElement(className); } mClassAttribute = new Attribute("class", classAttributeValue); mAttrs.addElement(mClassAttribute); } /** * Get the Attribute-instance specified by the feature-name. * * @param feature * @return */ public Attribute getAttributeByFeature(String feature) { return mFeatureAttrMapping.get(feature); } /** * Runs Weka (1) using training- and test-set and (2) using cross * validation. * * @throws Exception */ public void run() throws Exception { runTrainTestSet(); runCrossValidation(); } /** * Runs Weka using training- and test-set. * * @throws Exception */ @SuppressWarnings("unchecked") public void runTrainTestSet() throws Exception { // Training set Instances trainingSet = new Instances("Classes", mAttrs, 0); trainingSet.setClassIndex(trainingSet.numAttributes() - 1); // Test set Instances testingSet = new Instances("Classes", mAttrs, 0); testingSet.setClassIndex(testingSet.numAttributes() - 1); // fill training and test set for (String className : mGpsData.keySet()) { List<AggregatedData> gpsData = mGpsData.get(className); int curTrainingSetSize = (int) Math.ceil(gpsData.size() * mTrainingSetSize); for (int i = 0; i < gpsData.size(); i++) { Instance instance = getInstance(className, gpsData.get(i)); if (i < curTrainingSetSize) { trainingSet.add(instance); } else { testingSet.add(instance); } } } if (trainingSet.numInstances() < 1) { throw new Exception("Empty training set!"); } if (testingSet.numInstances() < 1) { throw new Exception("Empty test set!"); } // start evaluating for each algorithm for (Classifier cls : mClassifiers) { evaluate(cls, trainingSet, testingSet); } } /** * Runs Weka using cross-validation. */ @SuppressWarnings("unchecked") public void runCrossValidation() throws Exception { // Training set Instances trainingSet = new Instances("Classes", mAttrs, 0); trainingSet.setClassIndex(trainingSet.numAttributes() - 1); // fill training set for (String className : mGpsData.keySet()) { List<AggregatedData> gpsData = mGpsData.get(className); for (AggregatedData data : gpsData) { Instance instance = getInstance(className, data); trainingSet.add(instance); } } // start evaluating for each algorithm for (Classifier cls : mClassifiers) { evaluate(cls, trainingSet, mCrossValidationFolds); } } /** * Get a Weka-Instance-object for the given aggregated data and class name. * * @param className * @param data * @return */ public Instance getInstance(String className, AggregatedData data) { int capacity = mAttrs.size(); Instance instance = new DenseInstance(capacity); for (Feature feature : mFeatures) { String featureId = feature.getIdentifier(); Attribute attr = getAttributeByFeature(featureId); double value = data.getData(featureId); instance.setValue(attr, value); } instance.setValue(mClassAttribute, className); return instance; } /** * Start Weka-evaluation using training- and test-set. * * @param cls * @param trainingSet * @param testingSet * @throws Exception */ public void evaluate(Classifier cls, Instances trainingSet, Instances testingSet) throws Exception { cls.buildClassifier(trainingSet); Evaluation eval = new Evaluation(trainingSet); eval.evaluateModel(cls, testingSet); mTrainingTestResults.add(new WekaResult(cls.getClass().getSimpleName(), eval, trainingSet)); } /** * Start Weka-evaluation using cross-validation. * * @param cls * @param trainingSet * @param numFolds * @throws Exception */ public void evaluate(Classifier cls, Instances trainingSet, int numFolds) throws Exception { Evaluation eval = new Evaluation(trainingSet); eval.crossValidateModel(cls, trainingSet, numFolds, new Random()); mCrossValidationResults.add(new WekaResult(cls.getClass() .getSimpleName(), eval, trainingSet)); } /** * Sets the training set size. * * @param size Set size between 0 and 1. */ public void setTrainingSetSize(double size) { if (size > 1.0 || size <= 0.0) { throw new IllegalArgumentException( "Training size must be a value between 0 (exclusive) and 1."); } mTrainingSetSize = size; } /** * Set the number of folds used for cross-validation. * * @param numFolds */ public void setCrossValidationFolds(int numFolds) { mCrossValidationFolds = numFolds; } /** * Get the results for the cross-validation evaluation. * * @return */ public List<WekaResult> getCrossValidationResults() { return mCrossValidationResults; } /** * Get the results for the evaluation based on training/test set. * * @return */ public List<WekaResult> getTrainingTestResults() { return mTrainingTestResults; } }