package quickml.supervised.tree.regressionTree.reducers;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.twitter.common.stats.ReservoirSampler;
import com.twitter.common.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.collections.MapUtils;
import quickml.data.AttributesMap;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Created by alexanderhawk on 4/23/15.
*/
public class RTNumBranchReducer<I extends RegressionInstance> extends RTreeReducer<I> {
private static final Logger logger = LoggerFactory.getLogger(RTNumBranchReducer.class);
public static final double DOWN_FACTOR = 10E5;
private Random rand = Random.Util.fromSystemRandom(MapUtils.random);
//TODO: once verify functionality is correct, remove these variables and get n classification counters which can then be further merged in the branchFinder
final int numSamplesPerBin;
final int numNumericBins;
public RTNumBranchReducer(List<I> trainingData, int numSamplesPerBin, int numNumericBins) {
super(trainingData);
this.numSamplesPerBin = numSamplesPerBin;
this.numNumericBins = numNumericBins;
}
@Override
public Optional<AttributeStats<MeanValueCounter>> getAttributeStats(String attribute) {
//get List of Classification counters for each bin. First get bin locations in the data, then loop though the data and get Classification counters by bin
if (getTrainingData().size() < numNumericBins) {
return Optional.absent();
}
Optional<double[]> splitsOptional = createNumericSplit(getTrainingData(), attribute);
if (!splitsOptional.isPresent()) {
// createNumericSplit(getTrainingData(), attribute);
return Optional.absent();
}
double[] splitPoints = splitsOptional.get();
return getAttributeStatsOptional(attribute, splitPoints, getTrainingData());
}
public static <I extends RegressionInstance> Optional<AttributeStats<MeanValueCounter>> getAttributeStatsOptional(String attribute, double[] splitPoints, List<I> trainingData) {
//TODO: split points should not be doubles. They should be Numbers, which can be longs for the case that numeric values are longs.
List<MeanValueCounter> meanValueCounters = Lists.newArrayListWithCapacity(splitPoints.length + 1);
MeanValueCounter aggregateStats = new MeanValueCounter();
double delta = getDelta(splitPoints);
for (int i = 0; i < splitPoints.length; i++) {
meanValueCounters.add(new MeanValueCounter(splitPoints[i]));
}
meanValueCounters.add(new MeanValueCounter(splitPoints[splitPoints.length - 1] + delta)); //cc holds all vals greater than greatest split point.
int uncaughtMissingValues = 0;
for (I instance : trainingData) {
AttributesMap attributes = instance.getAttributes();
double attributeVal;
if (!attributes.containsKey(attribute) || attributes.get(attribute)==null)
attributeVal=Double.MIN_VALUE; //check old quickml
else {
attributeVal = ((Number) (attributes.get(attribute))).doubleValue();
}
double threshold = 0, previousThreshold = 0, nextThreshold = 0;
boolean added = false;
for (int i = 0; i < splitPoints.length; i++) {
previousThreshold = threshold;
threshold = splitPoints[i];
if (splitPointIsADuplicateOfLast(threshold, previousThreshold, i)) {
continue;
} else if (attributeVal <= threshold + delta){ //total hack, and prevents quickml from working well with fine grained num attributes
meanValueCounters.get(i).update(instance.getLabel(), instance.getWeight());
added = true;
break; //break ensures the instance is added to only one bin.
}
}
if (attributeVal > splitPoints[splitPoints.length - 1] + delta) {
meanValueCounters.get(splitPoints.length).update(instance.getLabel(), instance.getWeight());
added = true;
}
aggregateStats.update(instance.getLabel(), instance.getWeight());
if (!added) {
uncaughtMissingValues++;
}
}
//remove: testCode
double total = 0;
for (MeanValueCounter cc : meanValueCounters) {
total+=cc.getTotal();
}
assert total<=aggregateStats.getTotal() +1E-5 && total>= aggregateStats.getTotal() -1E-5;
if (uncaughtMissingValues > 0) {
logger.info("uncaught missing values for attribute {} : {}", attribute, uncaughtMissingValues);
}
return Optional.of(new AttributeStats<>(meanValueCounters, aggregateStats, attribute));
}
private static double getDelta(double[] splitPoints) {
return (splitPoints.length >= 2) ? (splitPoints[1] -splitPoints[0])/DOWN_FACTOR : splitPoints[0]/DOWN_FACTOR;
}
private static boolean splitPointIsADuplicateOfLast(double threshold, double previousThreshold, int i) {
return previousThreshold == threshold && i != 0;
}
private Optional<double[]> getSplit(ReservoirSampler<Double> reservoirSampler) {
final ArrayList<Double> splitList = Lists.newArrayList();
for (final Double sample : reservoirSampler.getSamples()) {
splitList.add(sample);
}
if (splitList.isEmpty() || splitList.size()<numNumericBins) {
return Optional.absent();
}
return getBinDividerPoints(numNumericBins, splitList);
}
public static <I extends RegressionInstance> Optional<double[]> getDeterministicSplit(List<I> instances, String attribute, int numNumericBins) {
final ArrayList<Double> splitList = Lists.newArrayList();
for (final I sample : instances) {
if (sample.getAttributes().containsKey(attribute)) {
splitList.add(((Number) (sample.getAttributes().get(attribute))).doubleValue());
}
}
if (splitList.isEmpty() || splitList.size()<numNumericBins) {
return Optional.absent();
}
return getBinDividerPoints(numNumericBins, splitList);
}
public static Optional<double[]> getBinDividerPoints(int numNumericBins, List<Double> attributeValues) {
/**Gets the midPoint of the first value in the upper bin and the last value in the lower bin...with values evenly distributed between bins.
when there is a remainder, the bins with lower index will get 1 additional value.
*/
Collections.sort(attributeValues);
int numSplitPoints = numNumericBins-1;
final double[] split = new double[numSplitPoints];
final int indexMultiplier = attributeValues.size() / (numNumericBins); //note indexMultiplier*numericBins < splitListSize => last bin will have more samples (the remainder) than other bins.
final int remainder = attributeValues.size()%numNumericBins;
int splitPointIndex = 0;
int firstIndexOf2ndBin = indexMultiplier;
for (int upperIndex = firstIndexOf2ndBin; upperIndex < attributeValues.size(); upperIndex+=indexMultiplier) {
if (splitPointIndex < remainder) {
upperIndex++;
}
split[splitPointIndex] = (attributeValues.get(upperIndex) + attributeValues.get(upperIndex-1))/2.0;
splitPointIndex++;
}
boolean allValuesSame = allValuesSame(split);
if (allValuesSame) {
return Optional.absent();
}
return Optional.of(split);
}
public static boolean allValuesSame(double[] split) {
if (split.length==1) {
return false;
}
boolean allValuesSame = true;
for (int x = 0; x<split.length-1; x++) {
if (split[x] != split[x+1])
allValuesSame = false;
}
return allValuesSame;
}
private Optional<double[]> createNumericSplit(final List<I> trainingData, final String attribute) {
int desiredSamples = numSamplesPerBin * numNumericBins;
if (trainingData.size() < desiredSamples) {
return getDeterministicSplit(trainingData, attribute, numNumericBins); //makes code testable, because now can be made deterministic by making numSamplesPerNumericBin < trainingData.getSize.
}
final ReservoirSampler<Double> reservoirSampler = fillReservoirSampler(trainingData, attribute, desiredSamples);
return getSplit(reservoirSampler);
}
public static <I extends RegressionInstance> ReservoirSampler<Double> fillReservoirSampler(List<I> trainingData, String attribute, int desiredSamples) {
Random rand = Random.Util.fromSystemRandom(MapUtils.random);
final ReservoirSampler<Double> reservoirSampler = new ReservoirSampler<Double>(desiredSamples + trainingData.size()%desiredSamples, rand);
int incrementSize = trainingData.size() / desiredSamples;
for (int i = 0; i < trainingData.size(); i += incrementSize) {
Serializable value = trainingData.get(i).getAttributes().get(attribute);
if (value == null) {
continue;
}
reservoirSampler.sample(((Number) value).doubleValue());
}
return reservoirSampler;
}
}