package edu.berkeley.nlp.util; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Random; /** * Finds the minimum k of a set of observations * * @author denero */ public class Beam<T> { private int size; private double kstat = Double.POSITIVE_INFINITY; // kth order statistic private List<T> kbest; private double[] kbestValues; public Beam(int size) { this.size = size; kbest = new ArrayList<T>(size + 1); kbestValues = new double[size]; Arrays.fill(kbestValues, Double.POSITIVE_INFINITY); } public void observe(T t, double val) { if (val < kstat || (val <= kstat && kbest.size() < size)) { // something falls off the beam if (kbest.size() == size) kbest.remove(kbest.size() - 1); int index = Arrays.binarySearch(kbestValues, val); int pos = (index < 0) ? -1 * index - 1 : index; kbest.add(pos, t); kbestValues[kbest.size() - 1] = val; Arrays.sort(kbestValues); // This step might be a little slow kstat = kbestValues[size - 1]; } } public int getSize() { return size; } public double beamCutoff() { return kstat; } public void setSize(int size) { this.size = size; } public List<T> contents() { return kbest; } public double[] getKbestValues() { return kbestValues; } public int size() { return kbest.size(); } public T argMin() { if (size() == 0) return null; return kbest.get(0); } public static void main(String[] args) { Beam<String> bs = new Beam<String>(3); bs.observe("what1", 1); bs.observe("what2", 4); bs.observe("what3", 0); bs.observe("what4", 2); bs.observe("what5", 3); bs.observe("what6", 1); System.out.println(bs.contents()); int n = 10000; Beam<Double> bsd = new Beam<Double>(n / 10); List<Double> l = new ArrayList<Double>(n); Random random = new Random(); for (int i = 0; i < n; i++) { double r = random.nextDouble(); bsd.observe(r, r); l.add(r); } Collections.sort(l); for (int i = 0; i < n; i++) { if (i < bsd.kbest.size()) { System.out.println("Same?:\t"+ bsd.kbest.get(i) + "\t" + l.get(i)); } } } }