package quickml.supervised.classifier.splitOnAttribute;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.supervised.PredictiveModelBuilder;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import java.io.Serializable;
import java.util.*;
/**
* Created by ian on 5/29/14.
*/
public class SplitOnAttributeClassifierBuilder<I extends InstanceWithAttributesMap<?>> implements PredictiveModelBuilder<SplitOnAttributeClassifier, I> {
private static final Logger logger = LoggerFactory.getLogger(SplitOnAttributeClassifierBuilder.class);
private final String attributeKey;
private final PredictiveModelBuilder<? extends Classifier, I> wrappedBuilder;
private Map<? extends Serializable, Integer> splitValToGroupIdMap;
private Map<Integer, SplitModelGroup> splitModelGroups;
private final Integer defaultGroup;
//TODO: this method should not have any parameters.
public SplitOnAttributeClassifierBuilder(String attributeKey, Collection<SplitModelGroup> splitModelGroupsCollection, Integer defaultGroup, PredictiveModelBuilder<? extends Classifier, I> wrappedBuilder) {
this.attributeKey = attributeKey;
this.defaultGroup = defaultGroup;
this.splitModelGroups = getSplitModelGroups(splitModelGroupsCollection);
this.splitValToGroupIdMap = getSplitValToGroupIdMap(splitModelGroups);
this.wrappedBuilder = wrappedBuilder;
}
@Override
public SplitOnAttributeClassifier buildPredictiveModel(final Iterable<I> trainingData) {
//split by groupId
Map<Integer, List<I>> splitTrainingData = splitTrainingData(trainingData);
Map<Integer, Classifier> splitModels = Maps.newHashMap();
for (Map.Entry<Integer, List<I>> trainingDataEntry : splitTrainingData.entrySet()) {
logger.info("Building predictive model for group" + attributeKey + "=" + trainingDataEntry.getKey());
splitModels.put(trainingDataEntry.getKey(), wrappedBuilder.buildPredictiveModel(trainingDataEntry.getValue()));
}
logger.info("Building default predictive model");
return new SplitOnAttributeClassifier(attributeKey, splitValToGroupIdMap, defaultGroup, splitModels);
}
@Override
public void updateBuilderConfig(Map<String, Serializable> config) {
wrappedBuilder.updateBuilderConfig(config);
}
public int getGroupFromSplitVal(Serializable val) {
return splitValToGroupIdMap.get(val);
}
private Map<Integer, SplitModelGroup> getSplitModelGroups(Collection<SplitModelGroup> splitModelGroupCollection) {
Map<Integer, SplitModelGroup> splitModelGroupMap = new HashMap<>();
for (SplitModelGroup splitModelGroup : splitModelGroupCollection) {
splitModelGroupMap.put(splitModelGroup.groupId, splitModelGroup);
}
return splitModelGroupMap;
}
private Map<Serializable, Integer> getSplitValToGroupIdMap(Map<Integer, SplitModelGroup> splitModelGroups) {
SplitValTGroupIdMap splitValToGroupIdMap = new SplitValTGroupIdMap(defaultGroup);
for (Integer groupId : splitModelGroups.keySet()) {
Set<? extends Serializable> valuesOfSplitVariableInTheGroup = splitModelGroups.get(groupId).valuesOfSplitVariableInTheGroup;
for (Serializable splitVal : valuesOfSplitVariableInTheGroup) {
splitValToGroupIdMap.put(splitVal, groupId);
}
}
return splitValToGroupIdMap;
}
private Map<Integer, List<I>> splitTrainingData(Iterable<I> trainingData) {
//create lists of data for each split attribute val
Map<Integer, List<I>> splitTrainingData = Maps.newHashMap();
for (I instance : trainingData) {
Serializable value = instance.getAttributes().get(attributeKey);
Integer groupId;
if (value != null) {
groupId = splitValToGroupIdMap.get(value);
} else {
continue;
}
List<I> trainingDataForGroup = splitTrainingData.get(groupId);
if (trainingDataForGroup == null) {
trainingDataForGroup = Lists.newArrayList();
splitTrainingData.put(groupId, trainingDataForGroup);
}
trainingDataForGroup.add(instance);
}
//test by walking up to this point with debugger and make sure everything is ok
//do cross polination
crossPollinateData(splitTrainingData);
return splitTrainingData;
}
/*
* Add data to each split data set based on the desired cross data values. Maintain the same ratio of classifications in the split set by
* selecting that ratio from outside sets. Only keep the attributes in the supporting instances that are in the white list
* */
private void crossPollinateData(Map<Integer, List<I>> splitTrainingData) {
Map<Integer, Long> groupIdToSamplesInTheGroup = new HashMap<>();
for (Integer groupId : splitTrainingData.keySet()) {
groupIdToSamplesInTheGroup.put(groupId, (long) splitTrainingData.get(groupId).size());
}
for (Integer presentGroup : splitModelGroups.keySet()) {
List<I> dataForPresentGroup = splitTrainingData.get(presentGroup);
SplitModelGroup splitModelGroup = splitModelGroups.get(presentGroup);
Map<Integer, Long> numSamplesFromOtherGroupsMap = splitModelGroup.computeIdealNumberOfSamplesToCollectFromOtherGroups(groupIdToSamplesInTheGroup);
//for each
for (Integer crossGroupId : numSamplesFromOtherGroupsMap.keySet()) {
List<I> instancesFromCrossGroup = splitTrainingData.get(crossGroupId);
long requestedNumInstances = numSamplesFromOtherGroupsMap.get(crossGroupId);
List<I> listWithRequestedNumberOfInstancesFromThisCrossGroup = filterToRequestedNumber(instancesFromCrossGroup, requestedNumInstances);
dataForPresentGroup.addAll(listWithRequestedNumberOfInstancesFromThisCrossGroup);
}
}
}
private List<I> filterToRequestedNumber(List<I> input, long requestedNumInstances) {
//TODO: consider allowing it to get the most recently dated instances.
/**
* this method obtains a random sublist of approximately m elements from a list of n elements in order m time.
*/
List<I> output = new ArrayList<>((int) requestedNumInstances);
double currentSizeToReducedSizeRatio = (1.0 * input.size()) / requestedNumInstances;
int baseIncrement = (int) Math.floor(currentSizeToReducedSizeRatio);
double randomIncrementProbability = currentSizeToReducedSizeRatio - baseIncrement;
int currentIndex = 0;
Random random = new Random();
for (int i = 0; i < requestedNumInstances && currentIndex < input.size(); i++) {
output.add(input.get(currentIndex));
currentIndex += baseIncrement;
if (random.nextDouble() < randomIncrementProbability) {
currentIndex++;
}
}
return output;
}
private boolean shouldAddInstance(Serializable attributeValue, I instance, ClassificationCounter crossDataCount, double targetCount) {
//if the model's split valaue is not the same as the instance's split value (avoids redundancy)
if (!attributeValue.equals(instance.getAttributes().get(attributeKey))) {
//if we still need instances of a particular classification
// if (targetCount > crossDataCount.getCount(instance.getLabel())) {
return true;
// }
}
return false;
}
public static class SplitModelGroup {
public final int groupId;
public final long minTotalSamples;
public double percentageOfTrainingDataThatIsFromOtherGroups;
public final Map<Integer, Double> groupIdToPercentageOfCrossDataProvidedMap;
public final Set<? extends Serializable> valuesOfSplitVariableInTheGroup;
public SplitModelGroup(int groupId, Set<? extends Serializable> valuesOfSplitVariableInTheGroup, long minTotalSamples, double percentageOfTrainingDataThatIsFromOtherGroups, Map<Integer, Double> relativeImportanceOfEachGroupThatContributesCrossGroupData) {
this.groupId = groupId;
this.valuesOfSplitVariableInTheGroup = valuesOfSplitVariableInTheGroup;
this.minTotalSamples = minTotalSamples;
this.percentageOfTrainingDataThatIsFromOtherGroups = percentageOfTrainingDataThatIsFromOtherGroups;
this.groupIdToPercentageOfCrossDataProvidedMap = relativeImportanceOfEachGroupThatContributesCrossGroupData;
}
public Map<Integer, Long> computeIdealNumberOfSamplesToCollectFromOtherGroups(Map<Integer, Long> groupIdToSamplesInTheGroup) {
Map<Integer, Long> numberOfSamplesToCollectFromGroups = new HashMap<>();
long numNonCrossTrainingDataSamples = groupIdToSamplesInTheGroup.get(groupId);
double percentageOfNonCrossTrainingData = 1 - percentageOfTrainingDataThatIsFromOtherGroups;
boolean cannotAchieveSpecifiedPercentageOfTrainingDataThatIsFromOtherGroups = minTotalSamples > numNonCrossTrainingDataSamples / percentageOfNonCrossTrainingData;
long numCrossPolinatedInstancesNeeded;
if (cannotAchieveSpecifiedPercentageOfTrainingDataThatIsFromOtherGroups) {
numCrossPolinatedInstancesNeeded = minTotalSamples - numNonCrossTrainingDataSamples;
} else {
numCrossPolinatedInstancesNeeded = (long) Math.ceil(numNonCrossTrainingDataSamples * (1 - percentageOfNonCrossTrainingData) / percentageOfNonCrossTrainingData);
}
for (Integer groupId : groupIdToPercentageOfCrossDataProvidedMap.keySet()) {
numberOfSamplesToCollectFromGroups.put(groupId, (long) (groupIdToPercentageOfCrossDataProvidedMap.get(groupId) * numCrossPolinatedInstancesNeeded));
}
//TODO: compare lengths in numberOfSamplesToCollectFromGroups to the actual numberOfTraining examples of the other groups, and intelligently rebalance the requsted numbers based on what is actually possible.
// For the time being, if one group has fewer actual instances than it is requested to provide, just provide all of it's training instances.
return numberOfSamplesToCollectFromGroups;
}
}
}