package edu.isistan.daclassifier;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.StringBufferInputStream;
import java.util.ArrayList;
import java.util.List;
import au.com.bytecode.opencsv.CSVWriter;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.meta.HMC;
import mulan.classifier.transformation.LabelPowerset;
import mulan.data.LabelsBuilder;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.*;
import mulan.evaluation.measure.*;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.functions.supportVector.RBFKernel;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveByName;
import weka.filters.unsupervised.attribute.StringToWordVector;
@SuppressWarnings("deprecation")
public class MachineLearner {
private static String textExpression = ArffGenerator.sP_DESC + "|" + ArffGenerator.sA0_DESC + "|" + ArffGenerator.sA1_DESC + "|" + ArffGenerator.sA2_DESC;
private static String srlExpression = ArffGenerator.sP + "|" + ArffGenerator.sA0 + "|" + ArffGenerator.sA1 + "|" + ArffGenerator.sA2;
private Instances fullInstances;
private Instances textInstances;
private Instances srlInstances;
private Instances fullInstancesFiltered;
private Instances textInstancesFiltered;
private Instances srlInstancesFiltered;
private MultiLabelInstances fullDataset;
private MultiLabelInstances textDataset;
private MultiLabelInstances srlDataset;
@SuppressWarnings("unused")
private LabelsMetaData labels;
public MachineLearner() {
}
public void load(String[] filenames, String xmlfilename) {
try {
Instances clone;
//
fullInstances = ArffGenerator.readFromCSV(filenames);
clone = new Instances(fullInstances);
textInstances = Filter.useFilter(clone, getAttributeFilter(clone, textExpression, false));
clone = new Instances(fullInstances);
srlInstances = Filter.useFilter(clone, getAttributeFilter(clone, srlExpression, false));
//
clone = new Instances(fullInstances);
fullInstancesFiltered = Filter.useFilter(clone, getWordFilter(clone));
clone = new Instances(textInstances);
textInstancesFiltered = Filter.useFilter(clone, getWordFilter(clone));
clone = new Instances(srlInstances);
srlInstancesFiltered = Filter.useFilter(clone, getWordFilter(clone));
//
fullDataset = new MultiLabelInstances(new StringBufferInputStream(fullInstancesFiltered.toString()), new BufferedInputStream(new FileInputStream(xmlfilename)));
textDataset = new MultiLabelInstances(new StringBufferInputStream(textInstancesFiltered.toString()), new BufferedInputStream(new FileInputStream(xmlfilename)));
srlDataset = new MultiLabelInstances(new StringBufferInputStream(srlInstancesFiltered.toString()), new BufferedInputStream(new FileInputStream(xmlfilename)));
//
labels = LabelsBuilder.createLabels(xmlfilename);
} catch (Exception e) {
e.printStackTrace();
}
}
public static RemoveByName getAttributeFilter(Instances input, String expression, boolean inverse) throws Exception {
RemoveByName filter = new RemoveByName();
filter.setExpression(expression);
filter.setInvertSelection(inverse);
filter.setInputFormat(input);
return filter;
}
public static StringToWordVector getWordFilter(Instances input) throws Exception {
StringToWordVector filter = new StringToWordVector();
filter.setInputFormat(input);
//filter.setIDFTransform(true);
filter.setUseStoplist(true);
SnowballStemmer stemmer = new SnowballStemmer();
filter.setStemmer(stemmer);
filter.setLowerCaseTokens(true);
return filter;
}
public void trainAndEvalFull_HMC_LP_SMO(String outputfilepath) throws Exception {
String learnerName = "HMC-LP-SMO-Full";
String outputfilename = outputfilepath + learnerName + ".csv";
String[] parametersNames = new String[] { "Dataset", "Kernel", "C", "Gamma" };
//
MultiLabelInstances[] datasetValues = new MultiLabelInstances[] { fullDataset, textDataset, srlDataset };
String[] datasetNames = new String[] { "Full", "Text", "SRL" };
double[] cValues = new double[] { -1, 1, 3, 5, 7, 9, 11 };
double[] gammaValues = new double[] { -9, -7, -5, -3, -1, 1, 3 };
Kernel[] kernelValues = new Kernel[] { new PolyKernel(), new RBFKernel() };
int numFolds = 10;
//
CSVWriter writer = new CSVWriter(new FileWriter(outputfilename), ';');
storeHeads(writer, parametersNames);
//
MultiLabelLearner learner;
SMO smo;
//
for(int datasetIndex = 0; datasetIndex < datasetValues.length; datasetIndex++) {
String datasetName = datasetNames[datasetIndex];
MultiLabelInstances dataset = datasetValues[datasetIndex];
smo = new SMO();
for(Kernel kernelValue : kernelValues) {
smo.setKernel(kernelValue);
for(double cValue : cValues) {
double c = Math.pow(2, cValue);
smo.setC(c);
if(kernelValue instanceof RBFKernel) {
for(double gammaValue : gammaValues) {
double gamma = Math.pow(2, gammaValue);
((RBFKernel) kernelValue).setGamma(gamma);
learner = new HMC(new LabelPowerset(smo));
String[] parametersValues = new String[] { datasetName, kernelValue.getClass().getSimpleName(), String.valueOf(cValue), String.valueOf(gammaValue) };
runFull_HMC_LP_SMO(learner, dataset, numFolds, writer, learnerName, parametersValues);
}
}
else {
learner = new HMC(new LabelPowerset(smo));
String[] parametersValues = new String[] { datasetName, kernelValue.getClass().getSimpleName(), String.valueOf(cValue), String.valueOf(0) };
runFull_HMC_LP_SMO(learner, dataset, numFolds, writer, learnerName, parametersValues);
}
//
}
}
}
//
writer.close();
}
private void runFull_HMC_LP_SMO(MultiLabelLearner learner, MultiLabelInstances dataset, int numFolds, CSVWriter writer, String learnerName, String[] parametersValues) throws Exception {
CustomEvaluator evaluator;
MultipleEvaluation trainResults;
MultipleEvaluation testResults;
//
evaluator = new CustomEvaluator();
evaluator.setStrict(false);
evaluator.crossValidate(learner, dataset, numFolds);
trainResults = evaluator.getTrainingMultipleEvaluation();
testResults = evaluator.getTestingMultipleEvaluation();
//
storeExampleBasedAccuracyValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedPrecisionValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedRecallValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedFMeasureValues(writer, learnerName, true, trainResults, parametersValues);
//
storeExampleBasedAccuracyValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedPrecisionValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedRecallValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedFMeasureValues(writer, learnerName, false, testResults, parametersValues);
}
public void trainAndEvalPercentage_HMC_LP_SMO(String outputfilepath) throws Exception {
String learnerName = "HMC-LP-SMO-Percentage";
String outputfilename = outputfilepath + learnerName + ".csv";
String[] parametersNames = new String[] { "Dataset", "Percentage", "Increment", "Kernel", "C", "Gamma" };
//
MultiLabelInstances[] datasetValues = new MultiLabelInstances[] { fullDataset, textDataset, srlDataset };
String[] datasetNames = new String[] { "Full", "Text", "SRL" };
int[] percentageValues = new int[] { 20, 30, 40, 50 , 60 };
double[] cValues = new double[] { -1, 1, 3, 5, 7, 9, 11 };
double[] gammaValues = new double[] { -9, -7, -5, -3, -1, 1, 3 };
Kernel[] kernelValues = new Kernel[] { new PolyKernel(), new RBFKernel() };
int incrementValue = 10;
//
CSVWriter writer = new CSVWriter(new FileWriter(outputfilename), ';');
storeHeads(writer, parametersNames);
//
MultiLabelLearner learner;
SMO smo;
//
for(int datasetIndex = 0; datasetIndex < datasetValues.length; datasetIndex++) {
String datasetName = datasetNames[datasetIndex];
MultiLabelInstances dataset = datasetValues[datasetIndex];
for(int percentageIndex = 0; percentageIndex < percentageValues.length; percentageIndex++) {
int percentageValue = percentageValues[percentageIndex];
smo = new SMO();
for(Kernel kernelValue : kernelValues) {
smo.setKernel(kernelValue);
for(double cValue : cValues) {
double c = Math.pow(2, cValue);
smo.setC(c);
if(kernelValue instanceof RBFKernel) {
for(double gammaValue : gammaValues) {
double gamma = Math.pow(2, gammaValue);
((RBFKernel) kernelValue).setGamma(gamma);
learner = new HMC(new LabelPowerset(smo));
String[] parametersValues = new String[] { datasetName, String.valueOf(percentageValue), String.valueOf(incrementValue), kernelValue.getClass().getSimpleName(), String.valueOf(cValue), String.valueOf(gammaValue) };
runPercentage_HMC_LP_SMO(learner, dataset, percentageValue, incrementValue, writer, learnerName, parametersValues);
}
}
else {
learner = new HMC(new LabelPowerset(smo));
String[] parametersValues = new String[] { datasetName, String.valueOf(percentageValue), String.valueOf(incrementValue), kernelValue.getClass().getSimpleName(), String.valueOf(cValue), String.valueOf(0) };
runPercentage_HMC_LP_SMO(learner, dataset, percentageValue, incrementValue, writer, learnerName, parametersValues);
}
}
}
}
}
//
writer.close();
}
private void runPercentage_HMC_LP_SMO(MultiLabelLearner learner, MultiLabelInstances dataset, int percentage, int increment, CSVWriter writer, String learnerName, String[] parametersValues) throws Exception {
CustomEvaluator evaluator;
MultipleEvaluation trainResults;
MultipleEvaluation testResults;
//
evaluator = new CustomEvaluator();
evaluator.setStrict(false);
evaluator.randomPercentageValidate(learner, dataset, percentage, increment);
trainResults = evaluator.getTrainingMultipleEvaluation();
testResults = evaluator.getTestingMultipleEvaluation();
//
storeExampleBasedAccuracyValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedPrecisionValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedRecallValues(writer, learnerName, true, trainResults, parametersValues);
storeExampleBasedFMeasureValues(writer, learnerName, true, trainResults, parametersValues);
//
storeExampleBasedAccuracyValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedPrecisionValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedRecallValues(writer, learnerName, false, testResults, parametersValues);
storeExampleBasedFMeasureValues(writer, learnerName, false, testResults, parametersValues);
}
private void storeHeads(CSVWriter writer, String[] parametersNames) {
String learner = "Machine Learner";
String type = "Subset Type";
String measure = "Measure Name";
String value = "Measure Value";
String idealValue = "Measure Ideal Value";
List<String> headsList = new ArrayList<String>();
headsList.add(learner);
headsList.add(type);
headsList.add(measure);
headsList.add(value);
headsList.add(idealValue);
for(String parameterName : parametersNames)
headsList.add(parameterName);
String[] heads = headsList.toArray(new String[] { });
writer.writeNext(heads);
}
private void storeExampleBasedAccuracyValues(CSVWriter writer, String learnerName, boolean isTrain, MultipleEvaluation results, String[] parametersValues) {
storeValues(writer, learnerName, isTrain, ExampleBasedAccuracy.class.getSimpleName(), String.valueOf(results.getMean("Example-Based Accuracy")), String.valueOf(new ExampleBasedAccuracy(true).getIdealValue()), parametersValues);
}
private void storeExampleBasedPrecisionValues(CSVWriter writer, String learnerName, boolean isTrain, MultipleEvaluation results, String[] parametersValues) {
storeValues(writer, learnerName, isTrain, ExampleBasedPrecision.class.getSimpleName(), String.valueOf(results.getMean("Example-Based Precision")), String.valueOf(new ExampleBasedPrecision(true).getIdealValue()), parametersValues);
}
private void storeExampleBasedRecallValues(CSVWriter writer, String learnerName, boolean isTrain, MultipleEvaluation results, String[] parametersValues) {
storeValues(writer, learnerName, isTrain, ExampleBasedRecall.class.getSimpleName(), String.valueOf(results.getMean("Example-Based Recall")), String.valueOf(new ExampleBasedRecall(true).getIdealValue()), parametersValues);
}
private void storeExampleBasedFMeasureValues(CSVWriter writer, String learnerName, boolean isTrain, MultipleEvaluation results, String[] parametersValues) {
storeValues(writer, learnerName, isTrain, ExampleBasedFMeasure.class.getSimpleName(), String.valueOf(results.getMean("Example-Based F Measure")), String.valueOf(new ExampleBasedFMeasure(true).getIdealValue()), parametersValues);
}
private void storeValues(CSVWriter writer, String learnerName, boolean isTrain, String measureName, String measureValue, String measureIdealValue, String[] parametersValues) {
List<String> valuesList = new ArrayList<String>();
valuesList.add(learnerName);
valuesList.add(isTrain ? "Train" : "Test");
valuesList.add(measureName);
valuesList.add(measureValue);
valuesList.add(measureIdealValue);
for(String parameterValue : parametersValues)
valuesList.add(parameterValue);
String[] values = valuesList.toArray(new String[] { });
writer.writeNext(values);
}
public void firstEvaluation(String outputfilepath) {
try {
this.trainAndEvalFull_HMC_LP_SMO(outputfilepath);
}
catch (Exception e) {
e.printStackTrace();
}
}
public void secondEvaluation(String outputfilepath) {
try {
this.trainAndEvalPercentage_HMC_LP_SMO(outputfilepath);
}
catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
String[] filenames = Utils.getCSVFilenames();
String xmlfilename = Utils.getLabelsFilename();
String outputfilepath = Utils.getResultsFilepath();
MachineLearner learner = new MachineLearner();
learner.load(filenames, xmlfilename);
// Select best parameters in 10 fold cross validation
//learner.firstEvaluation(outputfilepath);
// Select best subset in random selection & different percentages
learner.secondEvaluation(outputfilepath);
}
}