package com.formulasearchengine.mathosphere.mlp.ml;
import com.formulasearchengine.mathosphere.mlp.cli.MachineLearningDefinienClassifierConfig;
import com.formulasearchengine.mathosphere.mlp.pojos.Relation;
import com.formulasearchengine.mathosphere.mlp.pojos.Sentence;
import com.formulasearchengine.mathosphere.mlp.pojos.WikiDocumentOutput;
import com.formulasearchengine.mlp.evaluation.pojo.IdentifierDefinition;
import edu.stanford.nlp.parser.nndep.DependencyParser;
import edu.stanford.nlp.trees.GrammaticalStructure;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.configuration.Configuration;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.StringToWordVector;
import java.io.*;
import java.util.*;
import static com.formulasearchengine.mathosphere.mlp.ml.WekaUtils.*;
/**
* Created by Leo on 23.12.2016.
* Classifies extracted relations with the provided machine learning model.
* Retains only the highest ranking, positive relations.
*/
public class WekaClassifier extends RichMapFunction<WikiDocumentOutput, WikiDocumentOutput> {
public final MachineLearningDefinienClassifierConfig config;
private FilteredClassifier svm;
private DependencyParser parser;
private StringToWordVector stringToWordVector;
public WekaClassifier(MachineLearningDefinienClassifierConfig config) throws IOException {
this.config = config;
}
@Override
public void open(Configuration parameters) throws Exception {
svm = (FilteredClassifier) weka.core.SerializationHelper.read(config.getSvmModel());
stringToWordVector = (StringToWordVector) weka.core.SerializationHelper.read(config.getStringToWordVectorFilter());
parser = DependencyParser.loadFromModelFile(config.dependencyParserModel());
}
@Override
public WikiDocumentOutput map(WikiDocumentOutput doc) throws Exception {
System.out.println("Classifying " + doc.getTitle());
Instances instances;
WekaUtils wekaUtils = new WekaUtils();
instances = wekaUtils.createInstances("AllRelations");
Map<Sentence, GrammaticalStructure> precomputedGraphStore = wekaUtils.getPrecomputedGraphStore();
Map<IdentifierDefinition, Relation> positiveClassifications = new HashMap<>();
for (int i = 0; i < doc.getRelations().size(); i++) {
Relation relation = doc.getRelations().get(i);
wekaUtils.addRelationToInstances(parser, precomputedGraphStore, doc.getTitle(), doc.getqId(), instances, doc.getMaxSentenceLength(), relation);
Instances toStringReplace = new Instances(instances, 1);
toStringReplace.add(instances.get(i));
Instances stringReplaced = Filter.useFilter(toStringReplace, stringToWordVector);
Instance instance = stringReplaced.get(0);
double[] distribution = svm.distributionForInstance(instance);
String predictedClass = instances.classAttribute().value((int) svm.classifyInstance(instance));
if (predictedClass.equals(MATCH)) {
relation.setScore(distribution[instances.classAttribute().indexOfValue(MATCH)]);
IdentifierDefinition extraction = new IdentifierDefinition(
instance.stringValue(instance.attribute(instances.attribute(IDENTIFIER).index())),
instance.stringValue(instance.attribute(instances.attribute(DEFINIEN).index())));
//put in hashmap to deal with duplicates and preserve highest score.
if (!positiveClassifications.containsKey(extraction)) {
positiveClassifications.put(extraction, relation);
} else {
if (positiveClassifications.get(extraction).getScore() < relation.getScore()) {
positiveClassifications.put(extraction, relation);
}
}
}
}
//replace relations with positive ones
doc.setRelations(new ArrayList<>(positiveClassifications.values()));
System.out.println("Classifying done " + doc.getTitle() + " considered " + instances.size() + " definiens");
return doc;
}
}