package edu.stanford.nlp.semparse.open.model;
import java.util.*;
import edu.stanford.nlp.semparse.open.core.eval.IterativeTester;
import edu.stanford.nlp.semparse.open.dataset.Dataset;
import edu.stanford.nlp.semparse.open.dataset.Example;
import edu.stanford.nlp.semparse.open.model.candidate.Candidate;
import edu.stanford.nlp.semparse.open.model.candidate.PathEntry;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.ValueComparator;
/**
* Baseline classifier.
*/
public class LearnerBaseline implements Learner {
public static class Options {
@Option public int baselineSuffixLength = 5;
@Option public int baselineMaxNumPatterns = 10000;
@Option public boolean baselineUseMaxSize = false; // false = use most frequent
@Option public IndexType baselineIndexType = IndexType.STAR;
@Option public boolean baselineBagOfTags = true;
}
public static Options opts = new Options();
public enum IndexType {NONE, STAR, FULL};
protected IterativeTester iterativeTester;
public boolean beVeryQuiet = false;
/*
* IDEA:
* - Look at the training data and record the most frequent tree pattern (suffix)
* - For a test example, find a suffix that matches -- maybe choose the longest one
*/
// Map from suffix to count
Map<List<String>, Integer> goodPathCounts;
// ============================================================
// Log
// ============================================================
@Override
public void logParam() {
LogInfo.begin_track("Params");
if (goodPathCounts == null) {
LogInfo.log("No parameters.");
} else {
List<Map.Entry<List<String>, Integer>> entries = new ArrayList<>(goodPathCounts.entrySet());
Collections.sort(entries, new ValueComparator<List<String>, Integer>(true));
for (Map.Entry<List<String>, Integer> entry : entries) {
LogInfo.logs("%8d : %s", entry.getValue(), entry.getKey());
}
}
LogInfo.end_track();
}
@Override
public void logFeatureWeights(Candidate candidate) {
LogInfo.log("Using BASELINE Learner - no features");
}
@Override
public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) {
LogInfo.log("Using BASELINE Learner - no features");
}
@Override
public void shutUp() {
beVeryQuiet = true;
}
// ============================================================
// Predict
// ============================================================
@Override
public List<Pair<Candidate, Double>> getRankedCandidates(Example example) {
List<Pair<Candidate, Double>> answer = new ArrayList<>();
for (Candidate candidate : example.candidates) {
double score = getScore(candidate);
answer.add(new Pair<Candidate, Double>(candidate, score));
}
Collections.sort(answer, new Pair.ReverseSecondComparator<Candidate, Double>());
return answer;
}
protected double getScore(Candidate candidate) {
List<String> suffix = getPathSuffix(candidate);
Integer frequency = goodPathCounts.get(suffix);
if (frequency == null) return 0;
return opts.baselineUseMaxSize ? candidate.predictedEntities.size() : frequency;
}
// ============================================================
// Learn
// ============================================================
@Override
public void setIterativeTester(IterativeTester tester) {
this.iterativeTester = tester;
}
@Override
public void learn(Dataset dataset, FeatureMatcher additionalFeatureMatcher) {
Map<List<String>, Integer> pathCounts = new HashMap<>();
dataset.cacheRewards();
// Learn good tree patterns (path suffix)
if (!beVeryQuiet) LogInfo.begin_track("Learning tree patterns ...");
for (Example ex : dataset.trainExamples) {
for (Candidate candidate : ex.candidates) {
if (candidate.getReward() > 0) {
// Good candidate -- remember the tree pattern
MapUtils.incr(pathCounts, getPathSuffix(candidate));
}
}
}
// Sort by count
List<Map.Entry<List<String>, Integer>> entries = new ArrayList<>(pathCounts.entrySet());
Collections.sort(entries, new ValueComparator<List<String>, Integer>(true));
// Retain the top n paths
int n = Math.min(opts.baselineMaxNumPatterns, entries.size());
goodPathCounts = new HashMap<>();
for (Map.Entry<List<String>, Integer> entry : entries.subList(0, n)) {
goodPathCounts.put(entry.getKey(), entry.getValue());
}
if (!beVeryQuiet) LogInfo.logs("Found %d path patterns.", goodPathCounts.size());
if (!beVeryQuiet) LogInfo.end_track();
iterativeTester.run();
}
private List<String> getPathSuffix(Candidate candidate) {
return getPathSuffix(candidate.pattern.getPath());
}
private List<String> getPathSuffix(List<PathEntry> path) {
List<String> suffix = new ArrayList<>();
int startIndex = Math.max(0, path.size() - opts.baselineSuffixLength);
for (PathEntry entry : path.subList(startIndex, path.size())) {
String strEntry = "";
switch (opts.baselineIndexType) {
case NONE: strEntry = entry.tag; break;
case STAR: strEntry = entry.tag + (entry.isIndexed() ? "[*]" : ""); break;
case FULL: strEntry = entry.toString(); break;
}
suffix.add(strEntry.intern());
}
if (opts.baselineBagOfTags)
Collections.sort(suffix);
return suffix;
}
// ============================================================
// Persistence
// ============================================================
@Override
public void saveModel(String path) {
LogInfo.fail("Not implemented");
}
@Override
public void loadModel(String path) {
LogInfo.fail("Not implemented");
}
}