/*******************************************************************************
* Copyright 2013
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package de.tudarmstadt.ukp.csniper.webapp.evaluation;
import static java.util.Collections.singleton;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createPrimitive;
import static org.apache.uima.fit.factory.TypeSystemDescriptionFactory.createTypeSystemDescription;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EmptyStackException;
import java.util.List;
import java.util.Set;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.SystemUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.uima.UIMAException;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.CasCreationUtils;
import org.cleartk.classifier.CleartkProcessingException;
import org.cleartk.classifier.DataWriter;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.Train;
import com.google.common.io.Files;
import de.tudarmstadt.ukp.csniper.ml.DummySentenceSplitter;
import de.tudarmstadt.ukp.csniper.ml.GoldFromMetadataAnnotator;
import de.tudarmstadt.ukp.csniper.ml.TKSVMlightFeatureExtractor;
import de.tudarmstadt.ukp.csniper.ml.tksvm.DefaultTKSVMlightDataWriterFactory;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightDataWriter;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightSequenceClassifier;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightSequenceClassifierBuilder;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TreeFeatureVector;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.CachedParse;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.EvaluationItem;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.EvaluationResult;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.Mark;
import de.tudarmstadt.ukp.csniper.webapp.project.model.AnnotationType;
import de.tudarmstadt.ukp.csniper.webapp.search.tgrep.PennTreeUtils;
import de.tudarmstadt.ukp.csniper.webapp.statistics.SortableAggregatedEvaluationResultDataProvider.ResultFilter;
import de.tudarmstadt.ukp.csniper.webapp.statistics.model.AggregatedEvaluationResult;
import de.tudarmstadt.ukp.csniper.webapp.support.task.Task;
import de.tudarmstadt.ukp.csniper.webapp.support.uima.AnalysisEngineFactory;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.PennTree;
public class MlPipeline
{
private static Log LOG = LogFactory.getLog(MlPipeline.class);
// private static final String LANGUAGE = "en";
private static final Double THRESHOLD = 0.0;
private String language;
private AnalysisEngine gold;
private AnalysisEngine sent;
private AnalysisEngine tok;
private AnalysisEngine parser;
private EvaluationRepository repository;
private Task task;
public MlPipeline(String aLanguage)
throws ResourceInitializationException
{
language = aLanguage;
gold = createPrimitive(GoldFromMetadataAnnotator.class);
sent = createPrimitive(DummySentenceSplitter.class);
tok = AnalysisEngineFactory.createAnalysisEngine(
AnalysisEngineFactory.SEGMENTER, "language", aLanguage, "createSentences",
false);
parser = AnalysisEngineFactory.createAnalysisEngine(
AnalysisEngineFactory.PARSER, "language", aLanguage);
}
public void setRepostitory(EvaluationRepository aRepostitory)
{
repository = aRepostitory;
}
public void setTask(Task aTask)
{
task = aTask;
}
public String parse(EvaluationResult result, CAS cas)
throws UIMAException
{
// get parse from db, or parse now
String pennTree = "";
CachedParse cp = repository.getCachedParse(result.getItem());
if (cp != null && !cp.getPennTree().isEmpty()) {
if ("ERROR".equals(cp.getPennTree())) {
System.out.println("Unable to parse: [" + result.getItem().getCoveredText()
+ "] (cached)");
return "";
}
// write existing parse to cas for extraction
pennTree = cp.getPennTree();
addPennTree(cas, cp.getPennTree());
}
else {
parser.process(cas);
try {
pennTree = StringUtils.normalizeSpace(JCasUtil.selectSingle(cas.getJCas(),
PennTree.class).getPennTree());
repository.writeCachedParse(new CachedParse(result.getItem(), pennTree));
}
catch (IllegalArgumentException e) {
System.out.println("Unable to parse: [" + result.getItem().getCoveredText() + "]");
repository.writeCachedParse(new CachedParse(result.getItem(), "ERROR"));
}
}
return pennTree;
}
public void createTrainingData(File aModelDir, List<EvaluationResult> aTrainingList)
throws UIMAException, IOException
{
AnalysisEngine extract = createPrimitive(TKSVMlightFeatureExtractor.class,
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, aModelDir.getAbsolutePath(),
TKSVMlightFeatureExtractor.PARAM_DATA_WRITER_FACTORY_CLASS_NAME,
DefaultTKSVMlightDataWriterFactory.class.getName());
ProgressMeter progress = new ProgressMeter(aTrainingList.size());
// extract features
CAS cas = CasCreationUtils.createCas(createTypeSystemDescription(), null, null);
for (EvaluationResult result : aTrainingList) {
// add gold annotation
DocumentMetaData.create(cas).setDocumentTitle(result.getResult());
// set doc text
cas.setDocumentText(result.getItem().getCoveredText());
// set language
cas.setDocumentLanguage(language);
// convert gold annotations
gold.process(cas);
// preprocessing
sent.process(cas);
tok.process(cas);
// get parse from db, or parse now
parse(result, cas);
// extract features
extract.process(cas);
cas.reset();
progress.next();
LOG.info(progress);
if (task != null) {
task.increment();
task.checkCanceled();
}
}
extract.collectionProcessComplete();
}
public void classify(File aModelDir, List<EvaluationResult> aToPredictList)
throws IOException, UIMAException
{
TKSVMlightSequenceClassifierBuilder builder = new TKSVMlightSequenceClassifierBuilder();
TKSVMlightSequenceClassifier classifier = builder
.loadClassifierFromTrainingDirectory(aModelDir);
File cFile = File.createTempFile("tkclassify", ".txt");
BufferedWriter bw = null;
try {
bw = new BufferedWriter(new FileWriter(cFile));
// predict unclassified
CAS cas = CasCreationUtils.createCas(createTypeSystemDescription(), null, null);
ProgressMeter progress = new ProgressMeter(aToPredictList.size());
for (EvaluationResult result : aToPredictList) {
cas.setDocumentText(result.getItem().getCoveredText());
cas.setDocumentLanguage(language);
// dummy sentence split
sent.process(cas);
// tokenize
tok.process(cas);
// get parse from db, or parse now
String pennTree = parse(result, cas);
// write tree to file
Feature tree = new Feature("TK_tree", StringUtils.normalizeSpace(pennTree));
TreeFeatureVector tfv = classifier.getFeaturesEncoder().encodeAll(
Arrays.asList(tree));
try {
bw.write("0");
bw.write(TKSVMlightDataWriter.createString(tfv));
bw.write(SystemUtils.LINE_SEPARATOR);
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
cas.reset();
progress.next();
LOG.info(progress);
if (task != null) {
task.increment();
task.checkCanceled();
}
}
}
finally {
IOUtils.closeQuietly(bw);
}
// classify all
List<Double> predictions = classifier.tkSvmLightPredict2(cFile);
if (predictions.size() != aToPredictList.size()) {
// TODO throw different exception instead
throw new IOException("there are [" + predictions.size() + "] predictions, but ["
+ aToPredictList.size() + "] were expected.");
}
for (int i = 0; i < aToPredictList.size(); i++) {
Mark m = (predictions.get(i) > THRESHOLD) ? Mark.PRED_CORRECT : Mark.PRED_WRONG;
aToPredictList.get(i).setResult(m.getTitle());
}
}
public void predict(List<EvaluationResult> aTrainingList, List<EvaluationResult> aToPredictList)
throws UIMAException, IOException
{
if (aTrainingList.size() == 0) {
return;
}
if (task != null) {
task.setTotal(aTrainingList.size() + aToPredictList.size());
}
// create temp dir for model files
File modelDir = Files.createTempDir();
createTrainingData(modelDir, aTrainingList);
// train model
try {
Train.main(modelDir.getPath(), "-t", "5", "-c", "1.0", "-C", "+");
}
catch (Exception e) {
throw new UIMAException(e);
}
// classify
classify(modelDir, aToPredictList);
}
public boolean predict(List<EvaluationResult> aResults, int aMinItemsAnnotated)
throws UIMAException, IOException
{
// split results in annotated and empty
List<EvaluationResult> annotated = new ArrayList<EvaluationResult>();
List<EvaluationResult> empty = new ArrayList<EvaluationResult>();
for (EvaluationResult result : aResults) {
Mark m = Mark.fromString(result.getResult());
switch (m) {
case CORRECT:
case WRONG:
annotated.add(result);
break;
case NA:
case PRED_CORRECT:
case PRED_WRONG:
empty.add(result);
break;
default:
// CHECK
break;
}
}
// exit, if not enough items have been annotated
// TODO differentiate between correct/wrong?
// i.e. ensure the user to at least have X correct and X wrong items before predicting?
// a classifier trained only on "correct"s will not issue "wrong"s for anything, etc.
if (annotated.size() < aMinItemsAnnotated) {
return false;
}
predict(annotated, empty);
return true;
}
public boolean predictAggregated(List<EvaluationResult> aResults, String aCollectionId,
AnnotationType aType, Set<String> aUsers, double aUserThreshold,
double aConfidenceThreshold)
throws UIMAException, IOException
{
// get aggregated results
List<AggregatedEvaluationResult> aggregatedResults = repository.listAggregatedResults(
singleton(aCollectionId), singleton(aType), aUsers, aUserThreshold,
aConfidenceThreshold);
if (aggregatedResults.isEmpty()) {
return false;
}
// create training list
List<EvaluationResult> trainingList = convertToSimple(aggregatedResults);
// create toPredict list
List<EvaluationResult> toPredict = new ArrayList<EvaluationResult>();
for (EvaluationResult er : aResults) {
Mark result = Mark.fromString(er.getResult());
if (result != Mark.CORRECT && result != Mark.WRONG) {
toPredict.add(er);
}
}
predict(trainingList, toPredict);
return true;
}
private void addPennTree(CAS aCas, String aPennTree)
throws CASException
{
PennTree tree = new PennTree(aCas.getJCas(), 0, aCas.getDocumentText().length());
tree.setPennTree(aPennTree);
tree.addToIndexes();
}
public static List<EvaluationResult> convertToSimple(List<AggregatedEvaluationResult> aAgg)
{
// create training list
List<EvaluationResult> trainingList = new ArrayList<EvaluationResult>();
for (AggregatedEvaluationResult aer : aAgg) {
ResultFilter aggregated = aer.getClassification();
if (aggregated == ResultFilter.CORRECT || aggregated == ResultFilter.WRONG) {
trainingList.add(new EvaluationResult(aer.getItem(), "__dummy__", aggregated
.getLabel()));
}
}
return trainingList;
}
public static File train(List<EvaluationResult> aTrainingList, EvaluationRepository aRepository)
throws IOException, CleartkProcessingException
{
File modelDir = Files.createTempDir();
DefaultTKSVMlightDataWriterFactory dataWriterFactory = new DefaultTKSVMlightDataWriterFactory();
dataWriterFactory.setOutputDirectory(modelDir);
DataWriter<Boolean> dataWriter = dataWriterFactory.createDataWriter();
for (EvaluationResult result : aTrainingList) {
CachedParse cp = aRepository.getCachedParse(result.getItem());
if (cp == null || cp.getPennTree().isEmpty() || "ERROR".equals(cp.getPennTree())) {
System.out.println("Unable to parse: [" + result.getItem().getCoveredText()
+ "] (cached)");
continue;
}
Instance<Boolean> instance = new Instance<Boolean>();
instance.add(new Feature("TK_tree", StringUtils.normalizeSpace(cp.getPennTree())));
instance.setOutcome(Mark.fromString(result.getResult()) == Mark.CORRECT);
dataWriter.write(instance);
}
dataWriter.finish();
// train model
try {
Train.main(modelDir.getPath(), "-t", "5", "-c", "1.0", "-C", "+");
}
catch (Exception e) {
throw new CleartkProcessingException(e);
}
return modelDir;
}
/**
* Mind this method may return less results than parses were passed to it, e.g. because a
* cached parse may be empty or "ERROR" in which case no result for it is generated!
*/
public static List<EvaluationResult> classifyPreParsed(File aModelDir, List<CachedParse> aParses,
String aType, String aUser)
throws IOException, UIMAException
{
TKSVMlightSequenceClassifierBuilder builder = new TKSVMlightSequenceClassifierBuilder();
TKSVMlightSequenceClassifier classifier = builder
.loadClassifierFromTrainingDirectory(aModelDir);
File cFile = File.createTempFile("tkclassify", ".txt");
List<EvaluationItem> items = new ArrayList<EvaluationItem>();
BufferedWriter bw = null;
try {
bw = new BufferedWriter(new FileWriter(cFile));
for (CachedParse parse : aParses) {
if (parse.getPennTree().isEmpty() || "ERROR".equals(parse.getPennTree())) {
continue;
}
String coveredText;
try {
coveredText = PennTreeUtils.toText(parse.getPennTree());
}
catch (EmptyStackException e) {
LOG.error("Invalid Penn Tree: ["+parse.getPennTree()+"]", e);
continue;
}
// Prepare evaluation item to return
EvaluationItem item = new EvaluationItem();
item.setType(aType);
item.setBeginOffset(parse.getBeginOffset());
item.setEndOffset(parse.getEndOffset());
item.setDocumentId(parse.getDocumentId());
item.setCollectionId(parse.getCollectionId());
item.setCoveredText(coveredText);
items.add(item);
// write tree to file
Feature tree = new Feature("TK_tree", StringUtils.normalizeSpace(parse.getPennTree()));
TreeFeatureVector tfv = classifier.getFeaturesEncoder().encodeAll(
Arrays.asList(tree));
bw.write("0");
bw.write(TKSVMlightDataWriter.createString(tfv));
bw.write(SystemUtils.LINE_SEPARATOR);
}
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
finally {
IOUtils.closeQuietly(bw);
}
// classify all
List<Double> predictions = classifier.tkSvmLightPredict2(cFile);
if (predictions.size() != items.size()) {
// TODO throw different exception instead
throw new IOException("there are [" + predictions.size() + "] predictions, but ["
+ items.size() + "] were expected.");
}
List<EvaluationResult> results = new ArrayList<EvaluationResult>();
for (EvaluationItem item : items) {
results.add(new EvaluationResult(item, aUser, ""));
}
for (int i = 0; i < results.size(); i++) {
Mark m = (predictions.get(i) > THRESHOLD) ? Mark.PRED_CORRECT : Mark.PRED_WRONG;
results.get(i).setResult(m.getTitle());
}
return results;
}
}