package edu.stanford.nlp.ie.crf; import edu.stanford.nlp.util.logging.Redwood; import java.io.Serializable; import java.util.HashSet; import java.util.Map; import java.util.Set; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.HashIndex; import edu.stanford.nlp.util.Index; /** * Constrains test-time inference to labels observed in training. * * @author Spence Green * */ public class LabelDictionary implements Serializable { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(LabelDictionary.class); private static final long serialVersionUID = 6790400453922524056L; private final boolean DEBUG = false; /** * Initial capacity of the bookkeeping data structures. */ private final int DEFAULT_CAPACITY = 30000; // Bookkeeping private Counter<String> observationCounts; private Map<String,Set<String>> observedLabels; // Final data structure private Index<String> observationIndex; private int[][] labelDictionary; /** * Constructor. */ public LabelDictionary() { this.observationCounts = new ClassicCounter<>(DEFAULT_CAPACITY); this.observedLabels = Generics.newHashMap(DEFAULT_CAPACITY); } /** * Increment counts for an observation/label pair. * * @param observation * @param label */ public void increment(String observation, String label) { if (labelDictionary != null) { throw new RuntimeException("Label dictionary is already locked."); } observationCounts.incrementCount(observation); if ( ! observedLabels.containsKey(observation)) { observedLabels.put(observation, new HashSet<>()); } observedLabels.get(observation).add(label.intern()); } /** * True if this observation is constrained, and false otherwise. */ public boolean isConstrained(String observation) { return observationIndex.indexOf(observation) >= 0; } /** * Get the allowed label set for an observation. * * @param observation * @return The allowed label set, or null if the observation is unconstrained. */ public int[] getConstrainedSet(String observation) { int i = observationIndex.indexOf(observation); return i >= 0 ? labelDictionary[i] : null; } /** * Setup the constrained label sets and free bookkeeping resources. * * @param threshold * @param labelIndex */ public void lock(int threshold, Index<String> labelIndex) { if (labelDictionary != null) throw new RuntimeException("Label dictionary is already locked"); log.info("Label dictionary enabled"); System.err.printf("#observations: %d%n", (int) observationCounts.totalCount()); Counters.retainAbove(observationCounts, threshold); Set<String> constrainedObservations = observationCounts.keySet(); labelDictionary = new int[constrainedObservations.size()][]; observationIndex = new HashIndex<>(constrainedObservations.size()); for (String observation : constrainedObservations) { int i = observationIndex.addToIndex(observation); assert i < labelDictionary.length; Set<String> allowedLabels = observedLabels.get(observation); labelDictionary[i] = new int[allowedLabels.size()]; int j = 0; for (String label : allowedLabels) { labelDictionary[i][j++] = labelIndex.indexOf(label); } if (DEBUG) { System.err.printf("%s : %s%n", observation, allowedLabels.toString()); } } observationIndex.lock(); System.err.printf("#constraints: %d%n", labelDictionary.length); // Free bookkeeping data structures observationCounts = null; observedLabels = null; } }