package edu.stanford.nlp.ie;
import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.pipeline.DefaultPaths;
import edu.stanford.nlp.simple.Sentence;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.RedwoodConfiguration;
import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import static edu.stanford.nlp.util.logging.Redwood.Util.*;
/**
* A relation extractor to work with Victor's new KBP data.
*/
@SuppressWarnings("FieldCanBeLocal")
public class KBPStatisticalExtractor implements KBPRelationExtractor, Serializable {
private static final long serialVersionUID = 1L;
@ArgumentParser.Option(name="train", gloss="The dataset to train on")
public static File TRAIN_FILE = new File("train.conll");
@ArgumentParser.Option(name="test", gloss="The dataset to test on")
public static File TEST_FILE = new File("test.conll");
@ArgumentParser.Option(name="model", gloss="The dataset to test on")
public static String MODEL_FILE = DefaultPaths.DEFAULT_KBP_CLASSIFIER;
@ArgumentParser.Option(name="predictions", gloss="Dump model predictions to this file")
public static Optional<String> PREDICTIONS = Optional.empty();
private enum MinimizerType{ QN, SGD, HYBRID, L1 }
@ArgumentParser.Option(name="minimizer", gloss="The minimizer to use for training the classifier")
private static MinimizerType minimizer = MinimizerType.L1;
@ArgumentParser.Option(name="feature_threshold", gloss="The minimum number of times to see a feature to count it")
private static int FEATURE_THRESHOLD = 0;
@ArgumentParser.Option(name="sigma", gloss="The regularizer for the classifier")
private static double SIGMA = 1.0;
private static final Redwood.RedwoodChannels log = Redwood.channels(KBPStatisticalExtractor.class);
/**
* A list of triggers for top employees.
*/
private static final Set<String> TOP_EMPLOYEE_TRIGGERS = Collections.unmodifiableSet(new HashSet<String>(){{
add("executive");
add("chairman");
add("president");
add("chief");
add("head");
add("general");
add("ceo");
add("officer");
add("founder");
add("found");
add("leader");
add("vice");
add("king");
add("prince");
add("manager");
add("host");
add("minister");
add("adviser");
add("boss");
add("chair");
add("ambassador");
add("shareholder");
add("star");
add("governor");
add("investor");
add("representative");
add("dean");
add("commissioner");
add("deputy");
add("commander");
add("scientist");
add("midfielder");
add("speaker");
add("researcher");
add("editor");
add("chancellor");
add("fellow");
add("leadership");
add("diplomat");
add("attorney");
add("associate");
add("striker");
add("pilot");
add("captain");
add("banker");
add("mayer");
add("premier");
add("producer");
add("architect");
add("designer");
add("major");
add("advisor");
add("presidency");
add("senator");
add("specialist");
add("faculty");
add("monitor");
add("chairwoman");
add("mayor");
add("columnist");
add("mediator");
add("prosecutor");
add("entrepreneur");
add("creator");
add("superstar");
add("commentator");
add("principal");
add("operative");
add("businessman");
add("peacekeeper");
add("investigator");
add("coordinator");
add("knight");
add("lawmaker");
add("justice");
add("publisher");
add("playmaker");
add("moderator");
add("negotiator");
}});
/**
* <p>
* Often, features fall naturally into <i>feature templates</i> and their associated value.
* For example, unigram features have a feature template of unigram, and a feature value of the word
* in question.
* </p>
*
* <p>
* This method is a convenience convention for defining these feature template / value pairs.
* The advantage of using the method is that it allows for easily finding the feature template for a
* given feature value; thus, you can do feature selection post-hoc on the String features by splitting
* out certain feature templates.
* </p>
*
* <p>
* Note that spaces in the feature value are also replaced with a special character, mostly out of
* paranoia.
* </p>
*
* @param features The feature counter we are updating.
* @param featureTemplate The feature template to add a value to.
* @param featureValue The value of the feature template. This is joined with the template, so it
* need only be unique within the template.
*/
private static void indicator(Counter<String> features, String featureTemplate, String featureValue) {
features.incrementCount(featureTemplate + "ℵ" + featureValue.replace(' ', 'ˑ'));
}
/**
* Get information from the span between the two mentions.
* Canonically, get the words in this span.
* For instance, for "Obama was born in Hawaii", this would return a list
* "was born in" if the selector is <code>CoreLabel::token</code>;
* or "be bear in" if the selector is <code>CoreLabel::lemma</code>.
*
* @param input The featurizer input.
* @param selector The field to compute for each element in the span. A good default is <code></code>CoreLabel::word</code> or <code></code>CoreLabel::token</code>
* @param <E> The type of element returned by the selector.
*
* @return A list of elements between the two mentions.
*/
@SuppressWarnings("unchecked")
private static <E> List<E> spanBetweenMentions(KBPInput input, Function<CoreLabel, E> selector) {
List<CoreLabel> sentence = input.sentence.asCoreLabels(Sentence::lemmas, Sentence::nerTags);
Span subjSpan = input.subjectSpan;
Span objSpan = input.objectSpan;
// Corner cases
if (Span.overlaps(subjSpan, objSpan)) {
return Collections.EMPTY_LIST;
}
// Get the range between the subject and object
int begin = subjSpan.end();
int end = objSpan.start();
if (begin > end) {
begin = objSpan.end();
end = subjSpan.start();
}
if (begin > end) {
throw new IllegalArgumentException("Gabor sucks at logic and he should feel bad about it: " + subjSpan + " and " + objSpan);
} else if (begin == end) {
return Collections.EMPTY_LIST;
}
// Compute the return value
List<E> rtn = new ArrayList<>();
for (int i = begin; i < end; ++i) {
rtn.add(selector.apply(sentence.get(i)));
}
return rtn;
}
/**
* <p>
* Span features often only make sense if the subject and object are positioned at the correct ends of the span.
* For example, "x is the son of y" and "y is the son of x" have the same span feature, but mean different things
* depending on where x and y are.
* </p>
*
* <p>
* This is a simple helper to position a dummy subject and object token appropriately.
* </p>
*
* @param input The featurizer input.
* @param feature The span feature to augment.
*
* @return The augmented feature.
*/
private static String withMentionsPositioned(KBPInput input, String feature) {
if (input.subjectSpan.isBefore(input.objectSpan)) {
return "+__SUBJ__ " + feature + " __OBJ__";
} else {
return "__OBJ__ " + feature + " __SUBJ__";
}
}
@SuppressWarnings("UnusedParameters")
private static void denseFeatures(KBPInput input, Sentence sentence, ClassicCounter<String> feats) {
boolean subjBeforeObj = input.subjectSpan.isBefore(input.objectSpan);
// Type signature
indicator(feats, "type_signature", input.subjectType + "," + input.objectType);
// Relative position
indicator(feats, "subj_before_obj", subjBeforeObj ? "y" : "n");
}
@SuppressWarnings("UnusedParameters")
private static void surfaceFeatures(KBPInput input, Sentence simpleSentence, ClassicCounter<String> feats) {
List<String> lemmaSpan = spanBetweenMentions(input, CoreLabel::lemma);
List<String> nerSpan = spanBetweenMentions(input, CoreLabel::ner);
List<String> posSpan = spanBetweenMentions(input, CoreLabel::tag);
// Unigram features of the sentence
List<CoreLabel> tokens = input.sentence.asCoreLabels(Sentence::lemmas, Sentence::nerTags);
for (CoreLabel token : tokens) {
indicator(feats, "sentence_unigram", token.lemma());
}
// Full lemma span ( -0.3 F1 )
// if (lemmaSpan.size() <= 5) {
// indicator(feats, "full_lemma_span", withMentionsPositioned(input, StringUtils.join(lemmaSpan, " ")));
// }
// Lemma n-grams
String lastLemma = "_^_";
for (String lemma : lemmaSpan) {
indicator(feats, "lemma_bigram", withMentionsPositioned(input, lastLemma + " " + lemma));
indicator(feats, "lemma_unigram", withMentionsPositioned(input, lemma));
lastLemma = lemma;
}
indicator(feats, "lemma_bigram", withMentionsPositioned(input, lastLemma + " _$_"));
// NER + lemma bi-grams
for (int i = 0; i < lemmaSpan.size() - 1; ++i) {
if (!"O".equals(nerSpan.get(i)) && "O".equals(nerSpan.get(i + 1)) && "IN".equals(posSpan.get(i + 1))) {
indicator(feats, "ner/lemma_bigram", withMentionsPositioned(input, nerSpan.get(i) + " " + lemmaSpan.get(i + 1)));
}
if (!"O".equals(nerSpan.get(i + 1)) && "O".equals(nerSpan.get(i)) && "IN".equals(posSpan.get(i))) {
indicator(feats, "ner/lemma_bigram", withMentionsPositioned(input, lemmaSpan.get(i) + " " + nerSpan.get(i + 1)));
}
}
// Distance between mentions
String distanceBucket = ">10";
if (lemmaSpan.size() == 0) {
distanceBucket = "0";
} else if (lemmaSpan.size() <= 3) {
distanceBucket = "<=3";
} else if (lemmaSpan.size() <= 5) {
distanceBucket = "<=5";
} else if (lemmaSpan.size() <= 10) {
distanceBucket = "<=10";
} else if (lemmaSpan.size() <= 15) {
distanceBucket = "<=15";
}
indicator(feats, "distance_between_entities_bucket", distanceBucket);
// Punctuation features
int numCommasInSpan = 0;
int numQuotesInSpan = 0;
int parenParity = 0;
for (String lemma : lemmaSpan) {
if (lemma.equals(",")) { numCommasInSpan += 1; }
if (lemma.equals("\"") || lemma.equals("``") || lemma.equals("''")) {
numQuotesInSpan += 1;
}
if (lemma.equals("(") || lemma.equals("-LRB-")) { parenParity += 1; }
if (lemma.equals(")") || lemma.equals("-RRB-")) { parenParity -= 1; }
}
indicator(feats, "comma_parity", numCommasInSpan % 2 == 0 ? "even" : "odd");
indicator(feats, "quote_parity", numQuotesInSpan % 2 == 0 ? "even" : "odd");
indicator(feats, "paren_parity", "" + parenParity);
// Is broken by entity
Set<String> intercedingNERTags = nerSpan.stream().filter(ner -> !ner.equals("O")).collect(Collectors.toSet());
if (!intercedingNERTags.isEmpty()) {
indicator(feats, "has_interceding_ner", "t");
}
for (String ner : intercedingNERTags) {
indicator(feats, "interceding_ner", ner);
}
// Left and right context
List<CoreLabel> sentence = input.sentence.asCoreLabels(Sentence::nerTags);
if (input.subjectSpan.start() == 0) {
indicator(feats, "subj_left", "^");
} else {
indicator(feats, "subj_left", sentence.get(input.subjectSpan.start() - 1).lemma());
}
if (input.subjectSpan.end() == sentence.size()) {
indicator(feats, "subj_right", "$");
} else {
indicator(feats, "subj_right", sentence.get(input.subjectSpan.end()).lemma());
}
if (input.objectSpan.start() == 0) {
indicator(feats, "obj_left", "^");
} else {
indicator(feats, "obj_left", sentence.get(input.objectSpan.start() - 1).lemma());
}
if (input.objectSpan.end() == sentence.size()) {
indicator(feats, "obj_right", "$");
} else {
indicator(feats, "obj_right", sentence.get(input.objectSpan.end()).lemma());
}
// Skip-word patterns
if (lemmaSpan.size() == 1 && input.subjectSpan.isBefore(input.objectSpan)) {
String left = input.subjectSpan.start() == 0 ? "^" : sentence.get(input.subjectSpan.start() - 1).lemma();
indicator(feats, "X<subj>Y<obj>", left + "_" + lemmaSpan.get(0));
}
}
private static void dependencyFeatures(KBPInput input, Sentence sentence, ClassicCounter<String> feats) {
int subjectHead = sentence.algorithms().headOfSpan(input.subjectSpan);
int objectHead = sentence.algorithms().headOfSpan(input.objectSpan);
// indicator(feats, "subject_head", sentence.lemma(subjectHead));
// indicator(feats, "object_head", sentence.lemma(objectHead));
if (input.objectType.isRegexNERType) {
indicator(feats, "object_head", sentence.lemma(objectHead));
}
// Get the dependency path
List<String> depparsePath = sentence.algorithms().dependencyPathBetween(subjectHead, objectHead, Optional.of(Sentence::lemmas));
// Chop out appos edges
if (depparsePath.size() > 3) {
List<Integer> apposChunks = new ArrayList<>();
for (int i = 1; i < depparsePath.size() - 1; ++i) {
if ("-appos->".equals(depparsePath.get(i))) {
if (i != 1) {
apposChunks.add(i - 1);
}
apposChunks.add(i);
} else if ("<-appos-".equals(depparsePath.get(i))) {
if (i < depparsePath.size() - 1) {
apposChunks.add(i + 1);
}
apposChunks.add(i);
}
}
Collections.sort(apposChunks);
for (int i = apposChunks.size() - 1; i >= 0; --i) {
depparsePath.remove(i);
}
}
// Dependency path distance buckets
String distanceBucket = ">10";
if (depparsePath.size() == 3) {
distanceBucket = "<=3";
} else if (depparsePath.size() <= 5) {
distanceBucket = "<=5";
} else if (depparsePath.size() <= 7) {
distanceBucket = "<=7";
} else if (depparsePath.size() <= 9) {
distanceBucket = "<=9";
} else if (depparsePath.size() <= 13) {
distanceBucket = "<=13";
} else if (depparsePath.size() <= 17) {
distanceBucket = "<=17";
}
indicator(feats, "parse_distance_between_entities_bucket", distanceBucket);
// Add the path features
if (depparsePath.size() > 2 && depparsePath.size() <= 7) {
// indicator(feats, "deppath", StringUtils.join(depparsePath.subList(1, depparsePath.size() - 1), ""));
// indicator(feats, "deppath_unlex", StringUtils.join(depparsePath.subList(1, depparsePath.size() - 1).stream().filter(x -> x.startsWith("-") || x.startsWith("<")), ""));
indicator(feats, "deppath_w/tag",
sentence.posTag(subjectHead) + StringUtils.join(depparsePath.subList(1, depparsePath.size() - 1), "") + sentence.posTag(objectHead));
indicator(feats, "deppath_w/ner",
input.subjectType + StringUtils.join(depparsePath.subList(1, depparsePath.size() - 1), "") + input.objectType);
}
// Add the edge features
//noinspection Convert2streamapi
for (String node : depparsePath) {
if (!node.startsWith("-") && !node.startsWith("<-")) {
indicator(feats, "deppath_word", node);
}
}
for (int i = 0; i < depparsePath.size() - 1; ++i) {
indicator(feats, "deppath_edge", depparsePath.get(i) + depparsePath.get(i + 1));
}
for (int i = 0; i < depparsePath.size() - 2; ++i) {
indicator(feats, "deppath_chunk", depparsePath.get(i) + depparsePath.get(i + 1) + depparsePath.get(i + 2));
}
}
@SuppressWarnings("UnusedParameters")
private static void relationSpecificFeatures(KBPInput input, Sentence sentence, ClassicCounter<String> feats) {
if (input.objectType.equals(KBPRelationExtractor.NERTag.NUMBER)) {
// Bucket the object value if it is a number
// This is to prevent things like "age:9000" and to soft penalize "age:one"
// The following features are extracted:
// 1. Whether the object parses as a number (should always be true)
// 2. Whether the object is an integer
// 3. If the object is an integer, around what value is it (bucketed around common age values)
// 4. Was the number spelled out, or written as a numeric number
try {
Number number = NumberNormalizer.wordToNumber(input.getObjectText());
if (number != null) {
indicator(feats, "obj_parsed_as_num", "t");
if (number.equals(number.intValue())) {
indicator(feats, "obj_isint", "t");
int numAsInt = number.intValue();
String bucket = "<0";
if (numAsInt == 0) {
bucket = "0";
} else if (numAsInt == 1) {
bucket = "1";
} else if (numAsInt < 5) {
bucket = "<5";
} else if (numAsInt < 18) {
bucket = "<18";
} else if (numAsInt < 25) {
bucket = "<25";
} else if (numAsInt < 50) {
bucket = "<50";
} else if (numAsInt < 80) {
bucket = "<80";
} else if (numAsInt < 125) {
bucket = "<125";
} else if (numAsInt >= 100) {
bucket = ">125";
}
indicator(feats, "obj_number_bucket", bucket);
} else {
indicator(feats, "obj_isint", "f");
}
if (input.getObjectText().replace(",", "").equalsIgnoreCase(number.toString())) {
indicator(feats, "obj_spelledout_num", "f");
} else {
indicator(feats, "obj_spelledout_num", "t");
}
} else {
indicator(feats, "obj_parsed_as_num", "f");
}
} catch (NumberFormatException e) {
indicator(feats, "obj_parsed_as_num", "f");
}
// Special case dashes and the String "one"
if (input.getObjectText().contains("-")) {
indicator(feats, "obj_num_has_dash", "t");
} else {
indicator(feats, "obj_num_has_dash", "f");
}
if (input.getObjectText().equalsIgnoreCase("one")) {
indicator(feats, "obj_num_is_one", "t");
} else {
indicator(feats, "obj_num_is_one", "f");
}
}
if (
(input.subjectType == KBPRelationExtractor.NERTag.PERSON && input.objectType.equals(KBPRelationExtractor.NERTag.ORGANIZATION)) ||
(input.subjectType == KBPRelationExtractor.NERTag.ORGANIZATION && input.objectType.equals(KBPRelationExtractor.NERTag.PERSON))
) {
// Try to capture some denser features for employee_of
// These are:
// 1. Whether a TITLE tag occurs either before, after, or inside the relation span
// 2. Whether a top employee trigger occurs either before, after, or inside the relation span
Span relationSpan = Span.union(input.subjectSpan, input.objectSpan);
// (triggers before span)
for (int i = Math.max(0, relationSpan.start() - 5); i < relationSpan.start(); ++i) {
if ("TITLE".equals(sentence.nerTag(i))) {
indicator(feats, "title_before", "t");
}
if (TOP_EMPLOYEE_TRIGGERS.contains(sentence.word(i).toLowerCase())) {
indicator(feats, "top_employee_trigger_before", "t");
}
}
// (triggers after span)
for (int i = relationSpan.end(); i < Math.min(sentence.length(), relationSpan.end()); ++i) {
if ("TITLE".equals(sentence.nerTag(i))) {
indicator(feats, "title_after", "t");
}
if (TOP_EMPLOYEE_TRIGGERS.contains(sentence.word(i).toLowerCase())) {
indicator(feats, "top_employee_trigger_after", "t");
}
}
// (triggers inside span)
for (int i : relationSpan) {
if ("TITLE".equals(sentence.nerTag(i))) {
indicator(feats, "title_inside", "t");
}
if (TOP_EMPLOYEE_TRIGGERS.contains(sentence.word(i).toLowerCase())) {
indicator(feats, "top_employee_trigger_inside", "t");
}
}
}
}
public static Counter<String> features(KBPInput input) {
// Get useful variables
ClassicCounter<String> feats = new ClassicCounter<>();
if (Span.overlaps(input.subjectSpan, input.objectSpan) || input.subjectSpan.size() == 0 || input.objectSpan.size() == 0) {
return new ClassicCounter<>();
}
// Actually featurize
denseFeatures(input, input.sentence, feats);
surfaceFeatures(input, input.sentence, feats);
dependencyFeatures(input, input.sentence, feats);
relationSpecificFeatures(input, input.sentence, feats);
return feats;
}
/**
* Create a classifier factory
* @param <L> The label class of the factory
* @return A factory to minimize a classifier against.
*/
private static <L> LinearClassifierFactory<L, String> initFactory(double sigma) {
LinearClassifierFactory<L,String> factory = new LinearClassifierFactory<>();
Factory<Minimizer<DiffFunction>> minimizerFactory;
switch(minimizer) {
case QN:
minimizerFactory = () -> new QNMinimizer(15);
break;
case SGD:
minimizerFactory = () -> new SGDMinimizer<>(sigma, 100, 1000);
break;
case HYBRID:
factory.useHybridMinimizerWithInPlaceSGD(100, 1000, sigma);
minimizerFactory = () -> {
SGDMinimizer<DiffFunction> firstMinimizer = new SGDMinimizer<>(sigma, 50, 1000);
QNMinimizer secondMinimizer = new QNMinimizer(15);
return new HybridMinimizer(firstMinimizer, secondMinimizer, 50);
};
break;
case L1:
minimizerFactory = () -> {
try {
return MetaClass.create("edu.stanford.nlp.optimization.OWLQNMinimizer").createInstance(sigma);
} catch (Exception e) {
log.err("Could not create l1 minimizer! Reverting to l2.");
return new QNMinimizer(15);
}
};
break;
default:
throw new IllegalStateException("Unknown minimizer: " + minimizer);
}
factory.setMinimizerCreator(minimizerFactory);
return factory;
}
/**
* Train a multinomial classifier off of the provided dataset.
* @param dataset The dataset to train the classifier off of.
* @return A classifier.
*/
public static Classifier<String, String> trainMultinomialClassifier(
GeneralDataset<String, String> dataset,
int featureThreshold,
double sigma) {
// Set up the dataset and factory
log.info("Applying feature threshold (" + featureThreshold + ")...");
dataset.applyFeatureCountThreshold(featureThreshold);
log.info("Randomizing dataset...");
dataset.randomize(42l);
log.info("Creating factory...");
LinearClassifierFactory<String,String> factory = initFactory(sigma);
// Train the final classifier
log.info("BEGIN training");
LinearClassifier<String, String> classifier = factory.trainClassifier(dataset);
log.info("END training");
// Debug
Accuracy trainAccuracy = new Accuracy();
for (Datum<String, String> datum : dataset) {
String guess = classifier.classOf(datum);
trainAccuracy.predict(Collections.singleton(guess), Collections.singleton(datum.label()));
}
log.info("Training accuracy:");
log.info(trainAccuracy.toString());
log.info("");
// Return the classifier
return classifier;
}
/**
* The implementing classifier of this extractor.
*/
public final Classifier<String, String> classifier;
/**
* Create a new KBP relation extractor, from the given implementing classifier.
* @param classifier The implementing classifier.
*/
public KBPStatisticalExtractor(Classifier<String, String> classifier) {
this.classifier = classifier;
}
/**
* Score the given input, returning both the classification decision and the
* probability of that decision.
* Note that this method will not return a relation which does not type check.
*
*
* @param input The input to classify.
* @return A pair with the relation we classified into, along with its confidence.
*/
public Pair<String,Double> classify(KBPInput input) {
RVFDatum<String, String> datum = new RVFDatum<>(features(input));
Counter<String> scores = classifier.scoresOf(datum);
Counters.expInPlace(scores);
Counters.normalize(scores);
String best = Counters.argmax(scores);
// While it doesn't type check, continue going down the list.
// NO_RELATION is always an option somewhere in there, so safe to keep going...
while (!NO_RELATION.equals(best) && scores.size() > 1 &&
(!KBPRelationExtractor.RelationType.fromString(best).get().validNamedEntityLabels.contains(input.objectType) ||
RelationType.fromString(best).get().entityType != input.subjectType) ) {
scores.remove(best);
Counters.normalize(scores);
best = Counters.argmax(scores);
}
return Pair.makePair(best, scores.getCount(best));
}
public static void main(String[] args) throws IOException, ClassNotFoundException {
RedwoodConfiguration.standard().apply(); // Disable SLF4J crap.
ArgumentParser.fillOptions(KBPStatisticalExtractor.class, args); // Fill command-line options
// Load the test (or dev) data
forceTrack("Test data");
List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE);
log.info("Read " + testExamples.size() + " examples");
endTrack("Test data");
// If we can't find an existing model, train one
if (!IOUtils.existsInClasspathOrFileSystem(MODEL_FILE)) {
forceTrack("Training data");
List<Pair<KBPInput, String>> trainExamples = KBPRelationExtractor.readDataset(TRAIN_FILE);
log.info("Read " + trainExamples.size() + " examples");
log.info("" + trainExamples.stream().map(Pair::second).filter(NO_RELATION::equals).count() + " are " + NO_RELATION);
endTrack("Training data");
// Featurize + create the dataset
forceTrack("Creating dataset");
RVFDataset<String, String> dataset = new RVFDataset<>();
final AtomicInteger i = new AtomicInteger(0);
long beginTime = System.currentTimeMillis();
trainExamples.stream().parallel().forEach(example -> {
if (i.incrementAndGet() % 1000 == 0) {
log.info("[" + Redwood.formatTimeDifference(System.currentTimeMillis() - beginTime) +
"] Featurized " + i.get() + " / " + trainExamples.size() + " examples");
}
Counter<String> features = features(example.first); // This takes a while per example
synchronized (dataset) {
dataset.add(new RVFDatum<>(features, example.second));
}
});
trainExamples.clear(); // Free up some memory
endTrack("Creating dataset");
// Train the classifier
log.info("Training classifier:");
Classifier<String, String> classifier = trainMultinomialClassifier(dataset, FEATURE_THRESHOLD, SIGMA);
dataset.clear(); // Free up some memory
// Save the classifier
IOUtils.writeObjectToFile(new KBPStatisticalExtractor(classifier), MODEL_FILE);
}
// Read either a newly-trained or pre-trained model
Object model = IOUtils.readObjectFromURLOrClasspathOrFileSystem(MODEL_FILE);
KBPStatisticalExtractor classifier;
if (model instanceof Classifier) {
//noinspection unchecked
classifier = new KBPStatisticalExtractor((Classifier<String, String>) model);
} else {
classifier = ((KBPStatisticalExtractor) model);
}
// Evaluate the model
classifier.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> {
try {
return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}));
}
}