package edu.berkeley.nlp.PCFGLA; import java.io.Serializable; import java.util.Map; import java.util.Set; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Counter; import edu.berkeley.nlp.util.MapFactory; import edu.berkeley.nlp.util.MapFactory.HashMapFactory; public class BinaryCounterTable implements Serializable { /** * Based on Counter. * * A map from objects to doubles. Includes convenience methods for getting, * setting, and incrementing element counts. Objects not in the counter will * return a count of zero. The counter is backed by a HashMap (unless specified * otherwise with the MapFactory constructor). * * @author Slav Petrov */ private static final long serialVersionUID = 1L; Map<BinaryRule, double[][][]> entries; short[] numSubStates; BinaryRule searchKey; /** * The elements in the counter. * * @return set of keys */ public Set<BinaryRule> keySet() { return entries.keySet(); } /** * The number of entries in the counter (not the total count -- use totalCount() instead). */ public int size() { return entries.size(); } /** * True if there are no entries in the counter (false does not mean totalCount > 0) */ public boolean isEmpty() { return size() == 0; } /** * Returns whether the counter contains the given key. Note that this is the * way to distinguish keys which are in the counter with count zero, and those * which are not in the counter (and will therefore return count zero from * getCount(). * * @param key * @return whether the counter contains the key */ public boolean containsKey(BinaryRule key) { return entries.containsKey(key); } /** * Get the count of the element, or zero if the element is not in the * counter. Can return null! * * @param key * @return */ public double[][][] getCount(BinaryRule key) { double[][][] value = entries.get(key); return value; } public double[][][] getCount(short pState, short lState, short rState) { searchKey.setNodes(pState,lState,rState); double[][][] value = entries.get(searchKey); return value; } /** * Set the count for the given key, clobbering any previous count. * * @param key * @param count */ public void setCount(BinaryRule key, double[][][] counts) { entries.put(key, counts); } /** * Increment a key's count by the given amount. * Assumes for efficiency that the arrays have the same size. * * @param key * @param increment */ public void incrementCount(BinaryRule key, double[][][] increment) { double[][][] current = getCount(key); if (current==null) { setCount(key,increment); return; } for (int i=0; i<current.length; i++){ for (int j=0; j<current[i].length; j++){ // test if increment[i][j] is null or zero, in which case // we needn't add it if (increment[i][j]==null) continue; // allocate more space as needed if (current[i][j]==null) current[i][j] = new double[increment[i][j].length]; // if we've gotten here, then both current and increment // have correct arrays in index i for (int k=0; k<current[i][j].length; k++){ current[i][j][k]+=increment[i][j][k]; } } } setCount(key, current); } public void incrementCount(BinaryRule key, double increment) { double[][][] current = getCount(key); if (current == null){ double[][][] tmp = key.getScores2(); current = new double[tmp.length][tmp[0].length][tmp[0][0].length]; ArrayUtil.fill(current,increment); setCount(key, current); return; } for (int i=0; i<current.length; i++){ for (int j=0; j<current[i].length; j++){ if (current[i][j]==null) current[i][j] = new double[numSubStates[key.getParentState()]]; for (int k=0; k<current[i][j].length; k++){ current[i][j][k]+=increment; } } } setCount(key, current); } public BinaryCounterTable(short[] numSubStates) { this(new MapFactory.HashMapFactory<BinaryRule, double[][][]>(), numSubStates); } public BinaryCounterTable(MapFactory<BinaryRule, double[][][]> mf, short[] numSubStates) { entries = mf.buildMap(); searchKey = new BinaryRule((short)0,(short)0,(short)0); this.numSubStates = numSubStates; } public static void main(String[] args) { Counter<String> counter = new Counter<String>(); System.out.println(counter); counter.incrementCount("planets", 7); System.out.println(counter); counter.incrementCount("planets", 1); System.out.println(counter); counter.setCount("suns", 1); System.out.println(counter); counter.setCount("aliens", 0); System.out.println(counter); System.out.println(counter.toString(2)); System.out.println("Total: " + counter.totalCount()); } }