package edu.stanford.nlp.semparse.open.model; import java.util.*; import edu.stanford.nlp.semparse.open.core.eval.IterativeTester; import edu.stanford.nlp.semparse.open.dataset.Dataset; import edu.stanford.nlp.semparse.open.dataset.Example; import edu.stanford.nlp.semparse.open.model.candidate.Candidate; import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; import edu.stanford.nlp.semparse.open.model.feature.FeatureType; import fig.basic.Fmt; import fig.basic.LogInfo; import fig.basic.MapUtils; import fig.basic.NumUtils; import fig.basic.Option; import fig.basic.Pair; import fig.exec.Execution; public class LearnerMaxEnt implements Learner { public static class Options { @Option(gloss = "Number of iterations to train") public int numTrainIters = 3; @Option(gloss = "L2 Regularization") public double beta = 0.01; @Option(gloss = "L1 Regularization") public double lambda = 0; @Option(gloss = "Only retain features with parameter weight of at least this magnitude") public double pruneSmallFeaturesThreshold = 0; @Option(gloss = "Keep features that occur at least this many times") public int featureMinimumCount = 0; @Option public boolean getOnly1CandidatePerGroup = false; } public static Options opts = new Options(); protected Params params; // Parameters of the model protected AdvancedWordVectorParams advancedWordVectorParams; protected IterativeTester iterativeTester; public boolean beVeryQuiet = false; // ============================================================ // Log // ============================================================ @Override public void logParam() { params.log(); if (advancedWordVectorParams != null) { advancedWordVectorParams.log(); } } @Override public void logFeatureWeights(Candidate candidate) { LogInfo.begin_track("Features: [sum = %s]", Fmt.D(getScore(candidate))); FeatureVector.logFeatureWeights("normal", candidate.getCombinedFeatures(), params); if (advancedWordVectorParams != null) { advancedWordVectorParams.logFeatureWeights(candidate); } LogInfo.end_track(); } @Override public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) { double trueScore = getScore(trueCandidate), predScore = getScore(predCandidate); LogInfo.begin_track("(TRUE - PRED) Features: [sum = %s = %s - %s]", Fmt.D(trueScore - predScore), Fmt.D(trueScore), Fmt.D(predScore)); FeatureVector.logFeatureDiff("normal", trueCandidate.getCombinedFeatures(), predCandidate.getCombinedFeatures(), params); if (advancedWordVectorParams != null) { advancedWordVectorParams.logFeatureDiff(trueCandidate, predCandidate); } LogInfo.end_track(); } @Override public void shutUp() { beVeryQuiet = true; } // ============================================================ // Predict // ============================================================ /** * Return a list of (candidate, score) pairs sorted by score. */ @Override public List<Pair<Candidate, Double>> getRankedCandidates(Example example) { List<Pair<Candidate, Double>> answer = new ArrayList<>(); for (Candidate candidate : getCandidates(example)) { answer.add(new Pair<Candidate, Double>(candidate, getScore(candidate))); } Collections.sort(answer, new Pair.ReverseSecondComparator<Candidate, Double>()); return answer; } protected double getScore(Candidate candidate) { return getScore(candidate, AllFeatureMatcher.matcher); } protected double getScore(Candidate candidate, FeatureMatcher matcher) { double score = candidate.features.dotProduct(params, matcher); score += candidate.group.features.dotProduct(params, matcher); if (advancedWordVectorParams != null) { score += advancedWordVectorParams.getScore(candidate); } return score; } // ============================================================ // Learn // ============================================================ @Override public void setIterativeTester(IterativeTester tester) { iterativeTester = tester; } @Override public void learn(Dataset dataset, FeatureMatcher additionalFeatureMatcher) { dataset.cacheRewards(); // Select features based on count FeatureMatcher featureMatcher; if (opts.featureMinimumCount > 0) { FeatureCountPruner pruner = new FeatureCountPruner(beVeryQuiet); LogInfo.begin_track("Removing features with count < %d ...", opts.featureMinimumCount); for (Example example : dataset.trainExamples) pruner.add(example); pruner.applyThreshold(opts.featureMinimumCount); LogInfo.end_track(); featureMatcher = pruner; } else { featureMatcher = AllFeatureMatcher.matcher; } // Additional feature filter (for ablation) if (additionalFeatureMatcher != null) featureMatcher = additionalFeatureMatcher; // Learn parameters stochasticGradientDescent(dataset.trainExamples, featureMatcher); // Prune features will small weights params.prune(opts.pruneSmallFeaturesThreshold); if (!beVeryQuiet) params.write(Execution.getFile("params")); } protected int trainIter; // 1, 2, ..., opts.numTrainIters /** * Perform stochastic gradient descent to learn the parameters using the maximum entropy model. */ protected void stochasticGradientDescent(Collection<Example> examples, FeatureMatcher featureMatcher) { params = new Params(); advancedWordVectorParams = FeatureType.usingAdvancedWordVectorFeature() ? AdvancedWordVectorParams.create() : null; for (trainIter = 1; trainIter <= opts.numTrainIters; trainIter++) { if (!beVeryQuiet) { LogInfo.begin_track("Iteration %d/%d", trainIter, opts.numTrainIters); Execution.putOutput("currIter", trainIter); } for (Example example : examples) { if (!beVeryQuiet) Execution.putOutput("currExample", example.displayId); boolean updated = gradientUpdate(getCandidates(example), featureMatcher); if (!updated) { if (!beVeryQuiet) LogInfo.logs("Skip %s ...", example); } else { if (!beVeryQuiet) LogInfo.logs("Computed gradient for example %s ...", example); performL1Regularization(opts.lambda / examples.size()); } } if (iterativeTester != null) { iterativeTester.message = "Iteration " + trainIter + "/" + opts.numTrainIters; iterativeTester.run(); } if (!beVeryQuiet) LogInfo.end_track(); } // Summarize if (iterativeTester != null && !beVeryQuiet) { iterativeTester.summarize(); } } protected List<Candidate> getCandidates(Example example) { if (!opts.getOnly1CandidatePerGroup) { return example.candidates; } else { List<Candidate> candidates = new ArrayList<>(); for (CandidateGroup group : example.candidateGroups) { if (group.numCandidate() > 0) { candidates.add(group.getCandidates().get(0)); } } return candidates; } } protected void performL1Regularization(double cutoff) { if (cutoff <= 0) return; params.applyL1Regularization(cutoff); if (advancedWordVectorParams != null) advancedWordVectorParams.applyL1Regularization(cutoff); } /** * Compute the gradient and update the parameters. * If there are no good candidates, do not update the parameters and return false. * Otherwise, return true. * * If the score function is g(x,y,params), then the gradient is * sum_y [ { gradient of g(x,y,params) } * expectationDiff(x,y,params) ] * where expectationDiff(x,y,params) is * normalized (exp[g(x,y,params)]*R(y)) - normalized (exp[g(x,y,params)]) */ protected boolean gradientUpdate(List<Candidate> candidates, FeatureMatcher featureMatcher) { double[] expectationDiff = computeExpectationDiff(candidates, featureMatcher); if (expectationDiff == null) { // No good candidate -- skip example return false; } // Compute the gradient Map<String, Double> gradient = new HashMap<>(); for (int i = 0; i < expectationDiff.length; i++) { Candidate candidate = candidates.get(i); candidate.group.features.increment(expectationDiff[i], gradient, featureMatcher); candidate.features.increment(expectationDiff[i], gradient, featureMatcher); } // Regularization if (opts.beta != 0) { for (String featureName : gradient.keySet()) { MapUtils.incr(gradient, featureName, (- opts.beta) * params.getWeight(featureName)); } } // Perform gradient updates params.update(gradient); if (advancedWordVectorParams != null) { // Compute the gradient AdvancedWordVectorGradient advGradient = advancedWordVectorParams.createGradient(); for (int i = 0; i < expectationDiff.length; i++) { Candidate candidate = candidates.get(i); advGradient.addToGradient(candidate, expectationDiff[i]); } // Regularization if (opts.beta != 0) { advGradient.addL2Regularization(opts.beta); } advancedWordVectorParams.update(advGradient); } return true; } /** * For each candidate i, compute * expectationDiff[i] = normalized (exp[g(x,y[i],params)]*R(y)) - normalized (exp[g(x,y[i],params)]) * * Return null if any denominator for normalization is 0. */ protected double[] computeExpectationDiff(List<Candidate> candidates, FeatureMatcher featureMatcher) { int n = candidates.size(); double expScore[] = new double[n], expScoreTimesReward[] = new double[n]; for (int i = 0; i < n; i++) { Candidate candidate = candidates.get(i); double prediction = getScore(candidate, featureMatcher); expScore[i] = prediction; expScoreTimesReward[i] = expScore[i] + Math.log(candidate.getReward()); } // Exponentiate and normalize. // Skip this example if there are no good candidates. if (!NumUtils.expNormalize(expScore)) return null; if (!NumUtils.expNormalize(expScoreTimesReward)) return null; // Sum up the expectations double[] expectationDiff = new double[n]; for (int i = 0; i < n; i++) { expectationDiff[i] = expScoreTimesReward[i] - expScore[i]; } return expectationDiff; } // ============================================================ // Persistence // ============================================================ @Override public void saveModel(String path) { params.write(path); } @Override public void loadModel(String path) { params = new Params(); params.read(path); } }