package org.wikibrain.sr.normalize; import gnu.trove.list.array.TDoubleArrayList; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import org.wikibrain.sr.SRResultList; import java.io.Serializable; import java.util.Random; /** * A class that supports various kinds of normalization. * Usage: * 1. Create the normalizer. * 2. Call observe with each observation. * 3. Call finalize. * 4. Call normalize() on a new datapoint. * Make sure to set the missingScore value for the SRResultList version. */ public abstract class BaseNormalizer implements Serializable, Normalizer { public static final long serialVersionUID = 4305858822325261880L; public final static int SAMPLE_SIZE = 50000; public double min = Double.MIN_VALUE; protected double max = -Double.MAX_VALUE; // After calling finalize, stats will be non-null. protected TDoubleArrayList sample = new TDoubleArrayList(); protected DescriptiveStatistics stats; protected Integer numObservations = 0; protected Random random = new Random(); // mean actual similarity for scores that are missing or infinite. protected double missingMean = Double.NaN; // accumulators for missing values private double missingSum = 0.0; private int missingCount = 0; private boolean trained = false; private int sampleSize = SAMPLE_SIZE; /** * To meet the serializable contract. */ protected BaseNormalizer() {} public void setSampleSize(int sampleSize) { this.sampleSize = sampleSize; } @Override public void reset() { min = Double.MIN_VALUE; max = -Double.MAX_VALUE; sample.clear(); stats = null; numObservations = 0; missingMean = Double.NaN; missingSum = 0.0; missingCount = 0; trained = false; } @Override public void observe(SRResultList sims, int rank, double y) { if (rank >= 0) { observe(sims.get(rank).getScore(), y); } else { observe(Double.NaN, y); } } @Override public void observe(double x, double y){ if (Double.isNaN(x) || Double.isInfinite(x)) { synchronized (this) { missingSum += y; missingCount++; } } observe(x); } @Override public void observe(double x) { synchronized (this) { if (!Double.isNaN(x) && !Double.isInfinite(x)) { if (x < min) { min = x; } if (x > max) { max = x; } if (sample.size() < sampleSize) { sample.add(x); } else if (random.nextDouble() < 1.0 * sample.size() / (numObservations + 1)) { sample.set(random.nextInt(sample.size()), x); } } numObservations++; } } @Override public void observationsFinished() { sample.sort(); stats = new DescriptiveStatistics(sample.toArray()); if (missingCount > 0) { missingMean = missingSum / missingCount; missingSum = 0.0; missingCount = 0; } trained = true; } /** * A basic implementation of normalize. * @param list */ @Override public SRResultList normalize(SRResultList list) { SRResultList dsl = new SRResultList(list.numDocs()); list.setMissingScore(missingMean); for (int i = 0; i < list.numDocs(); i++) { dsl.set(i, list.getId(i), normalize(list.getScore(i))); } return dsl; } public String toString() { return "min=" + min + ", max=" + max; } public double getMin() { return min; } public double getMax() { return max; } public abstract String dump(); public boolean isTrained() { return trained; } }