package quickml.supervised.tree.regressionTree.reducers;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import org.javatuples.Pair;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import static quickml.supervised.tree.constants.MissingValue.MISSING_VALUE;
/**
* Created by alexanderhawk on 4/22/15.
*/
public class RTCatBranchReducer<I extends RegressionInstance> extends RTreeReducer<I> {
public RTCatBranchReducer(List<I> trainingData) {
super(trainingData);
}
@Override
public Optional<AttributeStats<MeanValueCounter>> getAttributeStats(String attribute) {
Optional<AttributeStats<MeanValueCounter>> attributeStatsOptional = getUnsortedAttributeStats(attribute);
if (!attributeStatsOptional.isPresent()) {
return Optional.absent();
}
AttributeStats<MeanValueCounter> attributeStats = attributeStatsOptional.get();
List<MeanValueCounter> attributesWithClassificationCounters = attributeStats.getStatsOnEachValue();
Collections.sort(attributesWithClassificationCounters, new Comparator<MeanValueCounter>() {
@Override
public int compare(MeanValueCounter mv1, MeanValueCounter mv2) {
double meanOfOne = mv1.getAccumulatedValue() / mv1.getTotal();
double meanOfTwo = mv2.getAccumulatedValue() / mv2.getTotal();
return Ordering.natural().reverse().compare(meanOfOne, meanOfTwo);
}
});
return Optional.of(attributeStats);
}
private Optional<AttributeStats<MeanValueCounter>> getUnsortedAttributeStats(String attribute) {
Pair<MeanValueCounter, Map<Serializable, MeanValueCounter>> aggregateAndAttributeValueMeanValueCounters = getAggregateAndAttributeValueMeanValueCounters(attribute);
MeanValueCounter aggregateStats = aggregateAndAttributeValueMeanValueCounters.getValue0();
Map<Serializable, MeanValueCounter> result = aggregateAndAttributeValueMeanValueCounters.getValue1();
List<MeanValueCounter> attributesWithMeanValueCounters= Lists.newArrayList(result.values());
if (attributesWithMeanValueCounters.size() <=1) {
return Optional.absent();
}
return Optional.of(new AttributeStats<>(attributesWithMeanValueCounters, aggregateStats, attribute));
}
protected Pair<MeanValueCounter, Map<Serializable, MeanValueCounter>> getAggregateAndAttributeValueMeanValueCounters(String attribute) {
final Map<Serializable, MeanValueCounter> result = Maps.newHashMap();
final MeanValueCounter totals = new MeanValueCounter();
for (RegressionInstance instance : getTrainingData()) {
final Serializable attrVal = instance.getAttributes().get(attribute);
MeanValueCounter mv;
boolean acceptableMissingValue = attrVal == null; //|| attrVal.equals("");//trial
if (attrVal != null)
mv = result.get(attrVal);
else if (acceptableMissingValue)
mv = result.get(MISSING_VALUE);
else
continue;
if (mv == null) {
mv = new MeanValueCounter(attrVal != null ? attrVal : MISSING_VALUE);
Serializable newKey = (attrVal != null) ? attrVal : MISSING_VALUE;
result.put(newKey, mv);
}
mv.update(instance.getLabel(), instance.getWeight());
totals.update(instance.getLabel(), instance.getWeight());
}
return Pair.with(totals, result);
}
}