package com.formulasearchengine.mathosphere.mlp.ml;
import com.formulasearchengine.mathosphere.mlp.cli.MachineLearningDefinienExtractionConfig;
import com.formulasearchengine.mathosphere.mlp.pojos.EvaluationResult;
import com.formulasearchengine.mathosphere.mlp.pojos.WikiDocumentOutput;
import com.formulasearchengine.mlp.evaluation.Evaluator;
import com.formulasearchengine.mlp.evaluation.pojo.GoldEntry;
import edu.stanford.nlp.parser.nndep.DependencyParser;
import org.apache.commons.io.FileUtils;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.util.Collector;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.functions.LibSVM;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.converters.ArffLoader;
import weka.core.converters.ArffSaver;
import weka.core.tokenizers.NGramTokenizer;
import weka.filters.Filter;
import weka.filters.supervised.instance.Resample;
import weka.filters.supervised.instance.SMOTE;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.StringToWordVector;
import java.io.*;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Stream;
import static com.formulasearchengine.mathosphere.mlp.ml.WekaUtils.*;
import static java.util.stream.Collectors.toList;
import static weka.core.Range.indicesToRangeList;
/**
* Created by Leo on 23.12.2016.
*/
public class WekaLearner implements GroupReduceFunction<WikiDocumentOutput, EvaluationResult> {
private static final int folds = 10;
private static final int totalQids = 100;
private static final Integer[] rand = new Integer[]{
99, 53, 70, 23, 86, 19, 84, 76, 34, 82,
32, 50, 47, 63, 54, 20, 74, 94, 46, 18,
22, 42, 100, 88, 96, 24, 31, 7, 12, 17,
26, 33, 29, 25, 79, 90, 56, 81, 4, 72,
27, 80, 60, 97, 5, 11, 67, 37, 2, 78,
6, 21, 51, 28, 91, 35, 8, 71, 14, 41,
52, 55, 62, 38, 30, 66, 58, 75, 93, 40,
48, 49, 89, 15, 69, 61, 65, 85, 9, 77,
98, 1, 43, 3, 83, 16, 36, 95, 44, 92,
59, 68, 45, 13, 39, 73, 87, 10, 57, 64};
/**
* Cost.
*/
public static final Double[] C_coarse = {Math.pow(2, -7), Math.pow(2, -5), Math.pow(2, -3), Math.pow(2, -1), Math.pow(2, 1), Math.pow(2, 3), Math.pow(2, 5), Math.pow(2, 7), Math.pow(2, 9), Math.pow(2, 11), Math.pow(2, 13), Math.pow(2, 15)};
/**
* Gamma.
*/
public static final Double[] Y_coarse = {Math.pow(2, -15), Math.pow(2, -13), Math.pow(2, -11), Math.pow(2, -9), Math.pow(2, -7), Math.pow(2, -5), Math.pow(2, -3), Math.pow(2, -1), Math.pow(2, 1), Math.pow(2, 3)};
/**
* Cost of the interesting region.
*/
public static final Double[] C_fine = {
Math.pow(2, -9)
, Math.pow(2, -8.75), Math.pow(2, -8.5), Math.pow(2, -8.25), Math.pow(2, -8)
, Math.pow(2, -7.75), Math.pow(2, -7.5), Math.pow(2, -7.25), Math.pow(2, -7)
, Math.pow(2, -6.75), Math.pow(2, -6.5), Math.pow(2, -6.25), Math.pow(2, -6)
, Math.pow(2, -5.75), Math.pow(2, -5.5), Math.pow(2, -5.25), Math.pow(2, -5)
, Math.pow(2, -4.75), Math.pow(2, -4.5), Math.pow(2, -4.25), Math.pow(2, -4)
, Math.pow(2, -3.75), Math.pow(2, -3.5), Math.pow(2, -3.25), Math.pow(2, -3)
, Math.pow(2, -2.75), Math.pow(2, -2.5), Math.pow(2, -2.25), Math.pow(2, -2)
, Math.pow(2, -1.75), Math.pow(2, -1.5), Math.pow(2, -1.25), Math.pow(2, -1)
, Math.pow(2, -0.75), Math.pow(2, -0.5), Math.pow(2, -0.25), Math.pow(2, 0)
};
/**
* Gamma of the interesting region.
*/
public static final Double[] Y_fine = {
Math.pow(2, -7)
, Math.pow(2, -6.75), Math.pow(2, -6.5), Math.pow(2, -6.25), Math.pow(2, -6)
, Math.pow(2, -5.75), Math.pow(2, -5.5), Math.pow(2, -5.25), Math.pow(2, -5)
, Math.pow(2, -4.75), Math.pow(2, -4.5), Math.pow(2, -4.25), Math.pow(2, -4)
, Math.pow(2, -3.75), Math.pow(2, -3.5), Math.pow(2, -3.25), Math.pow(2, -3)
, Math.pow(2, -2.75), Math.pow(2, -2.5), Math.pow(2, -2.25), Math.pow(2, -2)
};
/**
* Cost for highest accuracy in the SVM.
* c = 0.074325445
*/
public static final Double[] C_best_accuracy = {Math.pow(2, -3.75)};
/**
* Gamma for highest accuracy in the SVM.
* γ = 0.026278013
*/
public static final Double[] Y_best_accuracy = {Math.pow(2, -5.25)};
/**
* Cost for highest recall in the evaluation. (Post SVM evaluation metric.)
* tp: 77 fn: 233 fp: 110
* c = 0.074325445
*/
public static final Double[] C_best_recall = {Math.pow(2, -3.75)};
/**
* Gamma for highest recall in the evaluation. (Post SVM evaluation metric.)
* tp: 77 fn: 233 fp: 110
* γ = 0.011048544
*/
public static final Double[] Y_best_recall = {Math.pow(2, -6.5)};
/**
* Cost for highest F1 in the evaluation. (Post SVM evaluation metric.)
* tp: 70 fn: 240 fp: 44
* c = 0.4204482076
*/
public static final Double[] C_best_F1 = {Math.pow(2, -1.25)};
/**
* Gamma for highest F1 in the evaluation. (Post SVM evaluation metric.)
* tp: 70 fn: 240 fp: 44
* γ = 0.018581361
*/
public static final Double[] Y_best_F1 = {Math.pow(2, -5.75)};
public static final String INSTANCES_ARFF_FILE_NAME = "/instances.arff";
private final ArrayList<GoldEntry> gold;
public WekaLearner(MachineLearningDefinienExtractionConfig config) throws IOException {
this.config = config;
this.gold = (new Evaluator()).readGoldEntries(new File(config.getGoldFile()));
}
public final MachineLearningDefinienExtractionConfig config;
@Override
public void reduce(Iterable<WikiDocumentOutput> values, Collector<EvaluationResult> out) throws Exception {
Instances instances;
DependencyParser parser = DependencyParser.loadFromModelFile(config.dependencyParserModel());
WekaUtils wekaUtils = new WekaUtils();
instances = wekaUtils.createInstances("AllRelations");
for (WikiDocumentOutput value : values) {
wekaUtils.addRelationsToInstances(parser, value.getRelations(), value.getTitle(), value.getqId(), instances, value.getMaxSentenceLength());
}
if (config.isWriteInstances()) {
File instancesFile = new File(config.getOutputDir() + INSTANCES_ARFF_FILE_NAME);
ArffSaver arffSaver = new ArffSaver();
arffSaver.setFile(instancesFile);
arffSaver.setInstances(instances);
arffSaver.writeBatch();
}
//do model once with all data
if (config.getWriteSvmModel()) {
generateAndWriteFullModel(instances);
}
process(out, instances);
}
/**
* Generate the model with all data and write it with the appropriate filters.
*
* @param instances as returned from {@link WekaUtils#createInstances(String)}
* @throws Exception
*/
private void generateAndWriteFullModel(Instances instances) throws Exception {
StringToWordVector stringToWordVector = getStringToWordVectorFilter(instances);
Instances stringsReplacedData = Filter.useFilter(instances, stringToWordVector);
Instances resampled = dumbResample(stringsReplacedData);
Remove removeFilter = getRemoveFilter(stringsReplacedData);
LibSVM svmForOut = new LibSVM();
svmForOut.setCost(config.getSvmCost().get(0));
svmForOut.setGamma(config.getSvmGamma().get(0));
FilteredClassifier filteredClassifierForOut = new FilteredClassifier();
filteredClassifierForOut.setClassifier(svmForOut);
filteredClassifierForOut.setFilter(removeFilter);
filteredClassifierForOut.buildClassifier(resampled);
weka.core.SerializationHelper.write(config.getOutputDir() + "/svm_model_c_" + config.getSvmCost().get(0) + "_gamma_" + config.getSvmGamma().get(0) + ".model", filteredClassifierForOut);
weka.core.SerializationHelper.write(config.getOutputDir() + "/string_filter_c_" + config.getSvmCost().get(0) + "_gamma_" + config.getSvmGamma().get(0) + ".model", stringToWordVector);
}
public List<EvaluationResult> processFromInstances() throws Exception {
BufferedReader reader =
new BufferedReader(new FileReader(config.getInstancesFile()));
ArffLoader.ArffReader arff = new ArffLoader.ArffReader(reader);
Instances instances;
instances = arff.getData();
instances.setClassIndex(instances.numAttributes() - 1);
ArrayList<EvaluationResult> evaluationResults = new ArrayList<>();
//wrap
Collector<EvaluationResult> c = new ListCollector<>(evaluationResults);
process(c, instances);
return evaluationResults;
}
/**
* Do the pre-processing, training and testing.
*
* @param out the result of the testing
* @param instances the instances to use for the testing and training.
* @throws Exception
*/
private void process(Collector<EvaluationResult> out, Instances instances) throws Exception {
if (config.isCoarseSearch()) {
config.setSvmCost(Arrays.asList(WekaLearner.C_coarse));
config.setSvmGamma(Arrays.asList(WekaLearner.Y_coarse));
} else if (config.isFineSearch()) {
config.setSvmCost(Arrays.asList(WekaLearner.C_fine));
config.setSvmGamma(Arrays.asList(WekaLearner.Y_fine));
}
List<Double> percentages = config.getPercent();
List<Double> C_used = config.getSvmCost();
List<Double> Y_used = config.getSvmGamma();
File output = new File(config.getOutputDir() + "/svm_cross_eval_statistics.csv");
File outputDetails = new File(config.getOutputDir() + "/svm_cross_eval_detailed_statistics.txt");
File extractedDefiniens = new File(config.getOutputDir() + "/classifications.csv");
StringToWordVector stringToWordVector = getStringToWordVectorFilter(instances);
Instances stringsReplacedData = Filter.useFilter(instances, stringToWordVector);
Remove removeFilter = getRemoveFilter(stringsReplacedData);
removeFilter.setInputFormat(stringsReplacedData);
FileUtils.deleteQuietly(output);
FileUtils.deleteQuietly(outputDetails);
FileUtils.deleteQuietly(extractedDefiniens);
Double[] oversample = new Double[]{0d};//, 10d, 20d, 50d, 70d, 100d, 120d, 150d};
List<Double[]> parameters = new ArrayList<>();
for (double p : percentages) {
for (double c : C_used) {
for (double y : Y_used) {
for (double o : oversample)
parameters.add(new Double[]{p, c, y, o});
}
}
}
ForkJoinPool forkJoinPool = new ForkJoinPool(config.getParallelism());
Stream<EvaluationResult> a = parameters.parallelStream().map(
parameter -> crossEvaluate(stringsReplacedData, removeFilter, parameter[0], parameter[1], parameter[2], parameter[3]));
Callable<List<EvaluationResult>> task = () -> a.collect(toList());
List<EvaluationResult> evaluationResults = forkJoinPool.submit(task).get();
for (EvaluationResult evaluationResult : evaluationResults) {
FileUtils.write(outputDetails, "Cost; " + Utils.doubleToString(evaluationResult.cost, 10) + "; gamma; " + Utils.doubleToString(evaluationResult.gamma, 10) + "\n" + Arrays.toString(evaluationResult.text) + "\n", true);
//remove duplicates from extraction
StringBuilder e = new StringBuilder();
Set<String> set = new HashSet<>();
set.addAll(evaluationResult.extractions);
List<String> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
for (String extraction : list) {
e.append(extraction).append("\n");
}
Evaluator evaluator = new Evaluator();
StringReader reader = new StringReader(e.toString());
evaluationResult.setScoreSummary(evaluator.evaluate(evaluator.readExtractions(reader, gold, false), gold));
//Output files
FileUtils.write(extractedDefiniens, "Cost; "
+ Utils.doubleToString(evaluationResult.cost, 10)
+ "; gamma; " + Utils.doubleToString(evaluationResult.gamma, 10)
+ "; percentage_of_data_used; " + evaluationResult.percent
+ "\n", true);
FileUtils.write(extractedDefiniens, e.toString(), true);
FileUtils.write(output, evaluationResult.toString() + "\n", true);
out.collect(evaluationResult);
}
}
private StringToWordVector getStringToWordVectorFilter(Instances instances) throws Exception {
StringToWordVector stringToWordVector = new StringToWordVector();
stringToWordVector.setAttributeIndices(indicesToRangeList(new int[]{
instances.attribute(SURFACE_TEXT_AND_POS_TAG_OF_TWO_PRECEDING_AND_FOLLOWING_TOKENS_AROUND_THE_DESC_CANDIDATE).index(),
instances.attribute(SURFACE_TEXT_AND_POS_TAG_OF_THREE_PRECEDING_AND_FOLLOWING_TOKENS_AROUND_THE_PAIRED_MATH_EXPR).index(),
instances.attribute(SURFACE_TEXT_OF_THE_FIRST_VERB_THAT_APPEARS_BETWEEN_THE_DESC_CANDIDATE_AND_THE_TARGET_MATH_EXPR).index(),
instances.attribute(SURFACE_TEXT_AND_POS_TAG_OF_DEPENDENCY_WITH_LENGTH_3_FROM_IDENTIFIER).index(),
instances.attribute(SURFACE_TEXT_AND_POS_TAG_OF_DEPENDENCY_WITH_LENGTH_3_FROM_DEFINIEN).index()}));
stringToWordVector.setWordsToKeep(1000);
NGramTokenizer nGramTokenizer = new NGramTokenizer();
nGramTokenizer.setNGramMaxSize(3);
nGramTokenizer.setNGramMinSize(1);
nGramTokenizer.setDelimiters(nGramTokenizer.getDelimiters().replaceAll(":", ""));
stringToWordVector.setTokenizer(nGramTokenizer);
stringToWordVector.setInputFormat(instances);
return stringToWordVector;
}
private Remove getRemoveFilter(Instances instances) throws Exception {
Remove removeFilter = new Remove();
removeFilter.setAttributeIndices(indicesToRangeList(new int[]{
instances.attribute(TITLE).index(),
instances.attribute(IDENTIFIER).index(),
instances.attribute(DEFINIEN).index(),
instances.attribute(Q_ID).index(),
}));
removeFilter.setInputFormat(instances);
return removeFilter;
}
private EvaluationResult crossEvaluate(Instances stringsReplacedData, Remove removeFilter, double percent, double cost, double gamma, double oversample) {
try {
System.out.println("Cost; " + Utils.doubleToString(cost, 10)
+ "; gamma; " + Utils.doubleToString(gamma, 10));
EvaluationResult result = new EvaluationResult(config.isLeaveOneOutEvaluation() ? totalQids : folds, percent, cost, gamma);
result.prefix = "oversample; " + oversample;
Instances reduced;
if (percent != 100) {
//draw random sample, careful, this actually has an effect, even for setSampleSizePercent(100) and setBiasToUniformClass(0)
reduced = downsample(stringsReplacedData, percent);
} else {
reduced = stringsReplacedData;
}
Instances resampled = resample(oversample, reduced);
int counter = 0;
while (10 * counter < totalQids) {
for (int n = 0; n < folds; n++) {
trainAndTest(10 * counter + n, removeFilter, cost, gamma, stringsReplacedData, resampled, result);
}
if (!config.isLeaveOneOutEvaluation()) {
break;
}
counter++;
}
return result;
} catch (Exception e) {
e.printStackTrace();
return new EvaluationResult(folds, percent, cost, gamma);
}
}
private Instances resample(double oversample, Instances reduced) throws Exception {
Instances resampled;
//oversampling to deal with the ratio of the classes
if (true) {
resampled = dumbResample(reduced);
} else {
resampled = smote(reduced, oversample);
}
return resampled;
}
private Instances downsample(Instances stringsReplacedData, double percent) throws Exception {
Instances reduced;
Resample sampler = new Resample();
sampler.setRandomSeed(1);
//do not change distribution
sampler.setBiasToUniformClass(0);
sampler.setSampleSizePercent(percent);
sampler.setInputFormat(stringsReplacedData);
reduced = Filter.useFilter(stringsReplacedData, sampler);
return reduced;
}
private Instances dumbResample(Instances reduced) throws Exception {
Resample resampleFilter = new Resample();
resampleFilter.setRandomSeed(1);
resampleFilter.setBiasToUniformClass(1);
resampleFilter.setInputFormat(reduced);
return Filter.useFilter(reduced, resampleFilter);
}
private Instances smote(Instances stringsReplacedData, double oversample) throws Exception {
Instances resampled;
SMOTE smote = getSmoteFilter(stringsReplacedData, oversample);
resampled = Filter.useFilter(stringsReplacedData, smote);
return resampled;
}
private SMOTE getSmoteFilter(Instances stringsReplacedData, double oversample) throws Exception {
SMOTE smote = new SMOTE();
smote.setRandomSeed(1);
smote.setPercentage(oversample);
smote.setNearestNeighbors(5);
smote.setInputFormat(stringsReplacedData);
return smote;
}
/**
* @param n fold.
* @param removeFilter the filter that removes the string attributes title, qid, identifier and definiens.
* @param cost cost for the svm.
* @param gamma gamma for the svm.
* @param beforeResampling plain data for test set generation, strings replaced.
* @param resampled resampled training data.
* @param result for returning the results.
* @throws Exception weka may throw.
*/
private void trainAndTest(int n, Filter removeFilter, double cost, double gamma, Instances beforeResampling, Instances resampled,
EvaluationResult result) throws Exception {
LibSVM svm = new LibSVM();
svm.setCost(cost);
svm.setGamma(gamma);
FilteredClassifier filteredClassifier = new FilteredClassifier();
filteredClassifier.setClassifier(svm);
filteredClassifier.setFilter(removeFilter);
List<Integer> testIds;
if (config.isLeaveOneOutEvaluation()) {
testIds = new ArrayList<>();
testIds.add(n);
} else {
testIds = Arrays.asList(Arrays.copyOfRange(rand, folds * n, folds * (n + 1)));
}
Instances train = new Instances(resampled, 1);
Instances test = new Instances(beforeResampling, 1);
//build test and training set independently
for (int i = 0; i < resampled.numInstances(); i++) {
Instance a = resampled.instance(i);
if (!testIds.contains(Integer.parseInt(a.stringValue(a.attribute(resampled.attribute(Q_ID).index()))))) {
train.add(a);
}
}
for (int i = 0; i < beforeResampling.numInstances(); i++) {
Instance a = beforeResampling.instance(i);
if (testIds.contains(Integer.parseInt(a.stringValue(a.attribute(beforeResampling.attribute(Q_ID).index()))))) {
//from unresampled data for accurate accuracy predictions
test.add(a);
}
}
Classifier clsCopy = FilteredClassifier.makeCopy(filteredClassifier);
clsCopy.buildClassifier(train);
//extract matches
for (int i = 0; i < test.size(); i++) {
Instance instance = test.get(i);
String match = train.classAttribute().value(0);
String predictedClass = train.classAttribute().value((int) clsCopy.classifyInstance(instance));
if (match.equals(predictedClass)) {
String extraction =
instance.stringValue(instance.attribute(train.attribute(Q_ID).index())) + ","
+ "\"" + instance.stringValue(instance.attribute(train.attribute(TITLE).index())).replaceAll("\\s", "_") + "\","
+ "\"" + instance.stringValue(instance.attribute(train.attribute(IDENTIFIER).index())) + "\","
+ "\"" + instance.stringValue(instance.attribute(train.attribute(DEFINIEN).index())).toLowerCase() + "\"";
result.extractions.add(extraction);
}
}
Evaluation eval = new Evaluation(resampled);
eval.setPriors(train);
eval.evaluateModel(clsCopy, test);
result.averagePrecision[n] = eval.precision(0);
result.averageRecall[n] = eval.recall(0);
result.accuracy[n] = eval.pctCorrect() / 100d;
StringBuilder b = new StringBuilder();
b.append(", fold, ").append(n).append("\n").append(eval.toClassDetailsString()).append("\n").append(eval.toSummaryString(true));
result.text[n] = b.toString();
}
}