package quickml.supervised.regressionModel.IsotonicRegression;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.supervised.regressionModel.SingleVariableRealValuedFunction;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
// TODO: This should be split out into a separate Builder rather than using the constructor,
// so that it follows the same pattern as other PredictiveModels
public class PoolAdjacentViolatorsModel implements SingleVariableRealValuedFunction {
private static final Logger logger = LoggerFactory.getLogger(PoolAdjacentViolatorsModel.class);
private static final long serialVersionUID = 4389814244047503245L;
private int size;
ArrayList<Observation> calibrationList = Lists.newArrayList();
TreeSet<Observation> calibrationSet = Sets.newTreeSet();
TreeSet<Observation> preSmoothingSet = Sets.newTreeSet();
boolean reversed = false;
private static Random rand = new Random();
private boolean interpolateThroughOrigin = true;
private boolean extropolateOffUpperEnd = true;
public PoolAdjacentViolatorsModel(final Iterable<Observation> predictions) {
this(predictions, 1);
}
/**
* @param predictions The input to the calibration function
* @param minWeight The minimum weight of a point, used to pre-smooth the function
*/
public PoolAdjacentViolatorsModel(final Iterable<Observation> predictions, int minWeight) {
Preconditions.checkNotNull(predictions);
Preconditions.checkArgument(minWeight >= 1, "minWeight %s must be >= 1", minWeight);
TreeSet<Observation> orderedCalibrations = Sets.newTreeSet();
Iterables.addAll(orderedCalibrations, predictions);
if (minWeight > 1) {
Observation toAdd = null;
for (final Observation p : orderedCalibrations) {
if (toAdd == null) {
toAdd = p;
continue;
}
if (toAdd.weight < minWeight) {
toAdd = toAdd.mergeWith(p);
continue;
}
calibrationList.add(toAdd);
toAdd = p;
}
if (toAdd != null)
calibrationList.add(toAdd);
} else {
calibrationList.addAll(orderedCalibrations);
}
preSmoothingSet.addAll(calibrationList);
calibrationSet = createCalibrationSet(calibrationList);
this.size = calibrationSet.size();
}
public TreeSet<Observation> createCalibrationSet(List<Observation> inputOrderedList) {
Observation currentObservation = null, preceedingObservation = null;
int preceedingObservationIndex = 0;
for (int i = 1; i<inputOrderedList.size(); i++) {
boolean currentObservationIsViolator = true;
currentObservation = inputOrderedList.get(i);
while (currentObservationIsViolator) {
if (preceedingObservationIndex >= 0) {
preceedingObservation = inputOrderedList.get(preceedingObservationIndex);
} else {
break;
}
if (!reversed) {
currentObservationIsViolator = currentObservation.output < preceedingObservation.output;
} else {
currentObservationIsViolator = currentObservation.output > preceedingObservation.output;
}
if (currentObservationIsViolator) {
currentObservation = preceedingObservation.mergeWith(currentObservation);
preceedingObservationIndex--;
}
}
inputOrderedList.set(preceedingObservationIndex+1, currentObservation);
preceedingObservationIndex++;
}
TreeSet<Observation> localCalibrationSet = Sets.newTreeSet();
for (int i = 0; i< preceedingObservationIndex+1; i++) {
localCalibrationSet.add(inputOrderedList.get(i));
}
return localCalibrationSet;
}
public PoolAdjacentViolatorsModel interpolateThroughOrigin(boolean interpolateThroughOrigin) {
this.interpolateThroughOrigin = interpolateThroughOrigin;
return this;
}
public PoolAdjacentViolatorsModel extropolateOffUpperEnd(boolean extropolateOffUpperEnd) {
this.extropolateOffUpperEnd = extropolateOffUpperEnd;
return this;
}
public TreeSet<Observation> getCalibrationSet(){
return calibrationSet;
}
public TreeSet<Observation> getPreSmoothingSet(){
return preSmoothingSet;
}
public void stripZeroOutputs() {
while (!calibrationSet.isEmpty() && calibrationSet.first().output == 0) {
calibrationSet.pollFirst();
}
this.size = calibrationSet.size();
}
public void addObservation(Observation observation) {
calibrationSet.add(observation);
}
public boolean willExtrapolate(double input) {
final Observation toCorrect = new Observation(input, 0);
Observation floor = calibrationSet.floor(toCorrect);
Observation ceil = calibrationSet.ceiling(toCorrect);
if (floor == null || ceil ==null)
return true;
else
return false;
}
@Override
public Double predictWithoutAttributes(Double input, Set<String> attributesToIgnore) {
//there is only one attribute in predict, therefore it does not make sense to drop anything
return predict(input);
}
public double predictIfInterpolation(double input) {
if (input < calibrationSet.first().input || input>calibrationSet.last().input) {
return input;
} else {
return predict(input);
}
}
@Override
public Double predict(Double input) {
Preconditions.checkState(!calibrationSet.isEmpty());
final double kProp;
final Observation toCorrect = new Observation(input, 0);
Observation floor = calibrationSet.floor(toCorrect);
if (!interpolateThroughOrigin && floor == null && calibrationSet.higher(calibrationSet.first()) != null) {
double upperXCoord = 0, upperYCoord = 0;
if (calibrationSet.higher(calibrationSet.first()) != null) {
upperXCoord = (calibrationSet.higher(calibrationSet.first())).input;
upperYCoord = (calibrationSet.higher(calibrationSet.first())).output;
}
try {
double slopeOffEnd = (upperYCoord - calibrationSet.first().output) /
(upperXCoord - calibrationSet.first().input);
double inputDistanceFromFirst = input - calibrationSet.first().input;
return Math.max(0, calibrationSet.first().output + slopeOffEnd * inputDistanceFromFirst);
} catch (NoSuchElementException e) {
logger.warn("NoSuchElementException finding calibrationSet elements");
return calibrationSet.first().output;
}
} else if (floor ==null) {
floor = new Observation(0, 0, calibrationSet.first().weight);
}
Observation ceiling = calibrationSet.ceiling(toCorrect);
if (ceiling == null && extropolateOffUpperEnd) {
double lowerXcoord = 0, lowerYCoord = 0;
if (calibrationSet.lower(calibrationSet.last()) != null) {
lowerXcoord = calibrationSet.lower(calibrationSet.last()).input;
lowerYCoord = calibrationSet.lower(calibrationSet.last()).output;
} else {
return floor.output;
}
try {
double slopeOffEnd = (calibrationSet.last().output - lowerYCoord) /
(calibrationSet.last().input - lowerXcoord);
double inputDistanceFromLast = input - calibrationSet.last().input;
return calibrationSet.last().output + slopeOffEnd * inputDistanceFromLast;
} catch (NoSuchElementException e) {
logger.warn("NoSuchElementException finding ceiling or calibrationSet has no element calibrationSet.lower(calibrationSet.last()).input");
return floor.output;
}
}
boolean inputOnAPointInTheCalibrationSet = input.equals(ceiling.input) || input.equals(floor.input);
if (inputOnAPointInTheCalibrationSet) {
return input.equals(ceiling.input) ? ceiling.output : floor.output;
}
//PAV has just one point in calibration set
boolean ceilingInputEqualFloorInput = ceiling.input == floor.input;
if (ceilingInputEqualFloorInput)
return input.equals(ceiling.input) ? ceiling.output : floor.output;
double floorWeight = (ceiling.input - input)*floor.weight;
double ceilingWeight = (input - floor.input)*ceiling.weight;
double corrected = (floor.output*floorWeight + ceiling.output*ceilingWeight)/(floorWeight + ceilingWeight);
if (Double.isInfinite(corrected) || Double.isNaN(corrected)) {
logger.info("corrected is NaN or inf");
return input;
} else {
return corrected;
}
}
public double reverse(final double output) {
double lowCPC = calibrationSet.first().input, highCPC = calibrationSet.last().input;
for (int x = 0; x < 16; x++) {
final double tst = (lowCPC + highCPC) / 2.0;
final double opt = predict(tst);
if (opt < output) {
lowCPC = tst;
} else {
highCPC = tst;
}
}
return (lowCPC + highCPC) / 2.0;
}
public void dump(final Appendable ps) {
for (final Observation p : calibrationSet) {
try {
ps.append(p + "\n");
} catch (final IOException e) {
throw new RuntimeException(e);
}
}
}
@Override
public String toString() {
final StringBuffer sb = new StringBuffer();
dump(sb);
return sb.toString();
}
public int size() {
return size;
}
public Observation minNonZeroObservation() {
Observation minObs = null;
for (final PoolAdjacentViolatorsModel.Observation observation : calibrationSet) {
if (observation.input >= 0.0) {
minObs = observation;
break;
}
}
return minObs;
}
public Observation maxObservation() {
return calibrationSet.last();
}
public static final class Observation implements Comparable<Observation>, Serializable {
private static final long serialVersionUID = -5472613396250257288L;
public final double input;
public final double output;
private final int seed;
public final double weight;
/**
* This type of observation can be used to predict a previous observation.
* So adding:
* Observation(1, 0) and Observation.WEIGHTLESS(1, 2)
*
* Has the exact same effect as adding:
* Observation(1, 1)
*
* @param input
* @param output
* @return
*/
public static Observation newWeightless(final double input, final double output) {
return new Observation(input, output, 0);
}
public Observation(final double input, final double output) {
this(input, output, 1);
}
public Observation(final double input, final double output, final double weight) {
Preconditions.checkState(!(Double.isNaN(input) && Double.isNaN(output) && Double.isNaN((double) weight)));
this.input = input;
this.output = output;
seed = PoolAdjacentViolatorsModel.rand.nextInt();
this.weight = weight;
}
@Override
public int compareTo(final Observation o) {
final int r = Double.compare(input, o.input);
if (r != 0)
return r;
return Double.compare(seed, o.seed);
}
@Override
public boolean equals(final Object o) {
if (o instanceof Observation)
return ((Observation) o).seed == seed;
else
return false;
}
public Observation mergeWith(final Observation other) {
if ((weight == 0 && other.weight == 0) || (weight + other.weight == 0)) {
return Observation.newWeightless((input + other.input) / 2.0, (output + other.output) / 2.0);
} else if (other.weight == 0) {
return this;//other.mergeWith(this);
} else if (weight == 0) {
return other;//new Observation(other.input,
// (this.output + other.output * other.weight) / (other.weight + 1),
// other.weight);
}
return new Observation(
(input * weight + other.input * other.weight) / (weight + other.weight),
(output * weight + other.output * other.weight) / (weight + other.weight),
weight + other.weight);
}
@Override
public String toString() {
final StringBuilder builder = new StringBuilder();
builder.append("Observation [input=");
builder.append(input);
builder.append(", output=");
builder.append(output);
builder.append(", weight=");
builder.append(weight);
builder.append("]");
return builder.toString();
}
}
}