package com.cse10.classifier;
import com.cse10.article.TrainingArticle;
import com.cse10.database.DatabaseHandler;
import libsvm.svm_model;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffSaver;
import weka.filters.Filter;
import weka.filters.supervised.instance.SMOTE;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import org.apache.log4j.Logger;
/**
* load training data using sampling technique
* Created by Chamath on 12/20/2014.
*/
public class DataHandlerWithSampling extends DataHandler {
private Logger log;
public DataHandlerWithSampling() {
isFeatureVectorTransformerRequired = false;
fileName = "dataWithSampling";
log = Logger.getLogger(this.getClass());
}
@Override
protected String printDescription() {
String description = "This data handler will load training data and use sampling method to generate training data.";
log.info(description);
return description;
}
/**
* fetch training data no need to use FeatureVectorTransformer again
*
* @param featureVectorTransformer
* @return Instances
* @throws Exception
*/
public Instances loadTrainingData(FeatureVectorTransformer featureVectorTransformer) {
printDescription();
FastVector attributeList = new FastVector(2);
Attribute content = new Attribute("text", (FastVector) null);
FastVector classVal = new FastVector();
classVal.addElement("crime");
classVal.addElement("other");
Attribute classValue = new Attribute("@@class@@", classVal);
//add class attribute and news text
attributeList.addElement(content);
attributeList.addElement(classValue);
Instances trainingData = new Instances("TrainingNews", attributeList, 0);
if (trainingData.classIndex() == -1) {
trainingData.setClassIndex(trainingData.numAttributes() - 1);
}
//load training data using database handler
List<TrainingArticle> trainingArticles = DatabaseHandler.fetchTrainingArticles();
for (TrainingArticle trainingArticle : trainingArticles) {
Instance inst = new Instance(trainingData.numAttributes());
inst.setValue(content, trainingArticle.getContent());
inst.setValue(classValue, trainingArticle.getLabel());
inst.setDataset(trainingData);
trainingData.add(inst);
}
trainingData.setClassIndex(trainingData.numAttributes() - 1);
featureVectorTransformer.configure(1, 1, false);
featureVectorTransformer.setInputFormat(trainingData);
Instances filteredData = featureVectorTransformer.getTransformedArticles(trainingData, "hybrid1");
SVMClassifierHandler svm = new SVMClassifierHandler();
svm.configure(8.0, 0.001953125, "10 1", true);
svm.buildSVM(filteredData, false);
svm_model svmModel = svm.getSvm().getSVMModel();
int n[] = svmModel.sv_indices;
log.info("Number of support vectors=" + n.length);
double otherCount = 0;
double crimeCount = 0;
Instances otherClassSupportVectors = new Instances(filteredData); //other class support vectors
otherClassSupportVectors.delete();
Instances crimeClassSupportVectors = new Instances(filteredData); //other class support vectors
crimeClassSupportVectors.delete();
for (int k = 0; k < n.length; k++) {
Instance i = filteredData.instance(n[k] - 1);
log.info(n[k] - 1 + " " + i.classValue());
if (i.classValue() == 0.0) {
crimeCount++;
crimeClassSupportVectors.add(i);
} else if (i.classValue() == 1.0) {
otherCount++;
otherClassSupportVectors.add(i);
}
}
log.info("Crime Count " + crimeCount);
log.info("Other Count " + otherCount);
ArffSaver saver = new ArffSaver();
saver.setInstances(otherClassSupportVectors);
try {
saver.setFile(new File("Classifier\\src\\main\\resources\\arffData\\otherClassSupportVectors.arff"));
saver.writeBatch();
} catch (IOException e) {
e.printStackTrace();
}
saver.setInstances(crimeClassSupportVectors);
try {
saver.setFile(new File("Classifier\\src\\main\\resources\\arffData\\crimeClassSupportVectors.arff"));
saver.writeBatch();
} catch (IOException e) {
e.printStackTrace();
}
for (int j = 0; j < crimeClassSupportVectors.numInstances(); j++) {
otherClassSupportVectors.add(crimeClassSupportVectors.instance(j));
}
filteredData = otherClassSupportVectors;
Random r = new Random();
filteredData.randomize(r);
saver.setInstances(filteredData);
try {
saver.setFile(new File("Classifier\\src\\main\\resources\\arffData\\balancedTrainingDataHybrid.arff"));
saver.writeBatch();
} catch (IOException e) {
e.printStackTrace();
}
SMOTE s = new SMOTE();
try {
s.setInputFormat(filteredData);
} catch (Exception e) {
e.printStackTrace();
}
// Specifies percentage of SMOTE instances to create.
double percentage = ((otherCount / crimeCount) - 1) * 100;
s.setPercentage(Math.round(percentage));
Instances dataBalanced = null;
try {
dataBalanced = Filter.useFilter(filteredData, s);
} catch (Exception e) {
e.printStackTrace();
}
dataBalanced.randomize(r);
saver.setInstances(dataBalanced);
try {
saver.setFile(new File("Classifier\\src\\main\\resources\\arffData\\balancedTrainingDataHybridRandomized.arff"));
saver.writeBatch();
} catch (IOException e) {
e.printStackTrace();
}
return dataBalanced;
}
}