package org.wikibrain.sr.normalize;
import com.typesafe.config.Config;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.sr.SRResultList;
import java.text.DecimalFormat;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RankAndScoreNormalizer extends BaseNormalizer {
private static Logger LOG = LoggerFactory.getLogger(RankAndScoreNormalizer.class);
private double intercept;
private double rankCoeff;
private double scoreCoeff;
private boolean logTransform = false;
// temporary accumulators that feed into regression
private transient TIntArrayList ranks = new TIntArrayList();
private transient TDoubleArrayList scores = new TDoubleArrayList();
private transient TDoubleArrayList ys = new TDoubleArrayList();
@Override
public void reset() {
ranks.clear();
scores.clear();
ys.clear();
}
@Override
public void observe(SRResultList list, int index, double y) {
if (index >= 0) {
double score = list.getScore(index);
if (!Double.isNaN(score) && !Double.isInfinite(score)) {
synchronized (ranks) {
ranks.add(index);
scores.add(score);
ys.add(y);
}
}
}
super.observe(list, index, y);
}
public void setLogTransform(boolean logTransform) {
this.logTransform = logTransform;
}
@Override
public void observationsFinished() {
double Y[] = ys.toArray();
double X[][] = new double[Y.length][2];
for (int i = 0; i < Y.length; i++) {
X[i][0] = Math.log(1 + ranks.get(i));
X[i][1] = logIfNecessary(scores.get(i));
}
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(Y, X);
double [] params = regression.estimateRegressionParameters();
intercept = params[0];
rankCoeff = params[1];
scoreCoeff = params[2];
super.observationsFinished();
LOG.info("trained model on " + X.length + " observations: " + dump() + " with R-squared " + regression.calculateRSquared());
}
@Override
public SRResultList normalize(SRResultList list) {
SRResultList normalized = new SRResultList(list.numDocs());
normalized.setMissingScore(missingMean);
for (int i = 0; i < list.numDocs(); i++) {
double s = logIfNecessary(list.getScore(i));
double score = intercept + rankCoeff * Math.log(i + 1) + scoreCoeff * s;
normalized.set(i, list.getId(i), score);
}
return normalized;
}
private double logIfNecessary(double x) {
return logTransform ? Math.log(1 + x - min) : x;
}
@Override
public double normalize(double x) {
throw new UnsupportedOperationException();
}
@Override
public String dump() {
DecimalFormat df = new DecimalFormat("#.###");
return (
df.format(rankCoeff) + "*log(1+rank) + " +
df.format(scoreCoeff) + "*score + " +
df.format(intercept)
);
}
@Override
public String toString() {
return "Rank and score normalizer: " + dump();
}
public static class Provider extends org.wikibrain.conf.Provider<RankAndScoreNormalizer> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return Normalizer.class;
}
@Override
public String getPath() {
return "sr.normalizer";
}
@Override
public Scope getScope() {
return Scope.INSTANCE;
}
@Override
public RankAndScoreNormalizer get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("rank")) {
return null;
}
return new RankAndScoreNormalizer();
}
}
}