package edu.stanford.nlp.semparse.open.model;
import java.util.*;
import edu.stanford.nlp.semparse.open.dataset.Example;
import edu.stanford.nlp.semparse.open.model.candidate.Candidate;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
public class LearnerMaxEntWithBeamSearch extends LearnerMaxEnt {
public static class Options {
@Option public int beamSize = 500;
@Option public int beamTrainStartIter = 1;
@Option public String beamCandidateType = "cutrange";
}
public static Options opts = new Options();
@Override
protected List<Candidate> getCandidates(Example example) {
if (trainIter <= opts.beamTrainStartIter) {
return super.getCandidates(example);
} else {
return getBeamSearchedCandidates(example);
}
}
protected List<Candidate> getBeamSearchedCandidates(Example example) {
List<Pair<Candidate, Double>> rankedCandidates = super.getRankedCandidates(example);
rankedCandidates = rankedCandidates.subList(0, Math.min(opts.beamSize, rankedCandidates.size()));
List<Candidate> derivedCandidates = new ArrayList<>();
for (Pair<Candidate, Double> entry : rankedCandidates) {
derivedCandidates.addAll(getDerivedCandidates(entry.getFirst()));
}
return derivedCandidates;
}
protected List<Candidate> getDerivedCandidates(Candidate original) {
switch (opts.beamCandidateType) {
case "cutrange":
LogInfo.fails("... not implemented yet ...");
//return TreePatternAndRange.generateCutRangeCandidates(original);
return null;
case "endcut":
LogInfo.fails("... not implemented yet ...");
return null;
default:
LogInfo.fails("Unrecognized beam candidate type: %s", opts.beamCandidateType);
return null;
}
}
@Override
public List<Pair<Candidate, Double>> getRankedCandidates(Example example) {
List<Pair<Candidate, Double>> answer = new ArrayList<>();
for (Candidate candidate : getBeamSearchedCandidates(example)) {
double score = candidate.features.dotProduct(params);
answer.add(new Pair<Candidate, Double>(candidate, score));
}
Collections.sort(answer, new Pair.ReverseSecondComparator<Candidate, Double>());
return answer;
}
}