package com.cse10.classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LibSVM;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SelectedTag;
import weka.core.SerializationHelper;
import java.util.Random;
import org.apache.log4j.Logger;
/**
* Wrapper class for LibSVM
* Created by Chamath on 12/20/2014
*/
public class SVMClassifierHandler extends ClassifierHandler {
protected LibSVMExtended svm;
private Logger log;
public SVMClassifierHandler() {
log = Logger.getLogger(this.getClass());
svm = new LibSVMExtended();
int kernelTypeIndex = 2;
SelectedTag st;
st = new SelectedTag(kernelTypeIndex, LibSVM.TAGS_KERNELTYPE);
svm.setKernelType(st);
}
/**
* configure svm
*
* @param cost
* @param gamma
* @param weights
* @param isNormalizeData check whether data normalization is required
*/
public void configure(double cost, double gamma, String weights, boolean isNormalizeData) {
svm.setCost(cost);
svm.setGamma(gamma);
svm.setWeights(weights);
svm.setNormalize(isNormalizeData);
}
/**
* build classifier with given training data and save it
*
* @param filteredTrainingData
* @param isSaving check whether model need to be save into file
* @return
*/
public void buildSVM(Instances filteredTrainingData, boolean isSaving) {
try {
svm.buildClassifier(filteredTrainingData);
//save classifier
if (isSaving) {
SerializationHelper.write("Classifier\\src\\main\\resources\\models\\svm.model", svm);
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* access svm model
*
* @return
*/
public LibSVMExtended getSvm() {
return svm;
}
/**
* return Evaluation object for testing purposes
*
* @param filteredTrainingData
* @param numOfFolds
* @return
*/
public Evaluation crossValidateClassifier(Instances filteredTrainingData, int numOfFolds) {
//perform cross validation
Evaluation evaluation = null;
try {
evaluation = new Evaluation(filteredTrainingData);
evaluation.crossValidateModel(svm, filteredTrainingData, numOfFolds, new Random(1));
log.info(evaluation.toSummaryString());
log.info(evaluation.weightedAreaUnderROC() + "\n");
double[][] confusionMatrix = evaluation.confusionMatrix();
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
log.info(confusionMatrix[i][j] + " \n");
}
log.info("\n");
}
log.info("accuracy for crime class= " + (confusionMatrix[0][0] / (confusionMatrix[0][1] + confusionMatrix[0][0])) * 100 + "% \n");
log.info("accuracy for other class= " + (confusionMatrix[1][1] / (confusionMatrix[1][1] + confusionMatrix[1][0])) * 100 + "% \n");
} catch (Exception e) {
e.printStackTrace();
}
return evaluation;
}
/**
* classify the article
*
* @param filteredTestInstance
* @return
*/
public double classifyInstance(Instance filteredTestInstance) {
double result = -1.0;
try {
result = svm.classifyInstance(filteredTestInstance);
} catch (Exception e) {
e.printStackTrace();
}
return result;
}
}