package quickml.supervised.crossValidation.data;
import com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.joda.time.DateTimeConstants;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import static com.google.common.collect.Lists.newArrayList;
//TODO: generalize this to make object a generic
public class OutOfTimeData<I> implements TrainingDataCycler<I> {
private static final int MIN_INSTANCES_PER_VALIDATION_PERIOD = 5;
public static final double ACCEPTABLE_EXTRA_TAIL_TIME = 0.3;
private final List<I> allData;
private final double crossValidationFraction;
private final int timeSliceHours;
private DateTimeExtractor<I> dateTimeExtractor;
private int offSetForValidation =0;
private List<I> trainingSet;
private List<I> validationSet;
private static Logger logger = LoggerFactory.getLogger(OutOfTimeData.class);
private DateTime endValidationPeriod;
private int ignoredInstances=0;
public OutOfTimeData(List<I> allData, double crossValidationFraction, int timeSliceHours, DateTimeExtractor dateTimeExtractor) {
this.allData = allData;
this.crossValidationFraction = crossValidationFraction;
this.timeSliceHours = timeSliceHours;
this.dateTimeExtractor = dateTimeExtractor;
sortData();
reset();
}
public OutOfTimeData(List<I> allData, double crossValidationFraction, int timeSliceHours, DateTimeExtractor dateTimeExtractor, int offSetForValidation) {
this.allData = allData;
this.crossValidationFraction = crossValidationFraction;
this.timeSliceHours = timeSliceHours;
this.dateTimeExtractor = dateTimeExtractor;
this.offSetForValidation = offSetForValidation;
sortData();
reset();
}
@Override
public void reset() {
endValidationPeriod = null;
setTrainingSetBasedOnFraction();
updateValidationSet();
}
@Override
public List<I> getTrainingSet() {
return trainingSet;
}
@Override
public List<I> getValidationSet() {
return validationSet;
}
@Override
public List<I> getAllData() {
return allData;
}
@Override
public boolean nextCycle() {
if (hasMore()) {
//TODO the method in the if should be general for both cases but need to verify
if (offSetForValidation!=0) {
setNextTrainingSet();
} else {
trainingSet.addAll(validationSet);
}
updateValidationSet();
return true;
}
return false;
}
@Override
public boolean hasMore() {
return trainingSet.size() + validationSet.size() + ignoredInstances < allData.size();
}
private void updateValidationSet() {
logger.info("re-entering update validation set");
List<I> potentialValidationSet = getCandidateValidationSet();
if (getCandidateValidationSet().isEmpty()) {
validationSet = Lists.newArrayList();
return;
}
if (endValidationPeriod == null) {
endValidationPeriod = dateTimeExtractor.extractDateTime(potentialValidationSet.get(0)).plusHours(timeSliceHours);
} else {
DateTime lastValidationPeriodEnd = endValidationPeriod;
endValidationPeriod = lastValidationPeriodEnd.plusHours(timeSliceHours);
logger.info("endValidationPeriod: {}", endValidationPeriod.toString());
}
validationSet = newArrayList();
int instancesAddedToTheValidationSet = 0;
for (int i =0; i< potentialValidationSet.size(); i++) {
I instance = potentialValidationSet.get(i);
if (dateTimeExtractor.extractDateTime(instance).isBefore(endValidationPeriod)) {
validationSet.add(instance);
instancesAddedToTheValidationSet++;
} else if (validationSet.size() == potentialValidationSet.size()) {
break;
} else if (validationSet.isEmpty()) {
// If the set is empty and we are at the end of the validation period
// so we increase the validation period
endValidationPeriod = endValidationPeriod.plusHours(timeSliceHours);
i =-1; //the post incremente in the for loop will reset i to 0, allowing a complete re-run of the enclosing for loop.
logger.info("no data in time window, pushing endValidationPeriod: {}", endValidationPeriod.toString());
} else {
break; //because the list is sorted, once the first if fails, and else if fails, the loop should end.
}
}
addRemainderOfPotentialValidationSetIfNecessary(potentialValidationSet, instancesAddedToTheValidationSet);
DateTime dateTimeOfFirstInstance = dateTimeExtractor.extractDateTime(validationSet.get(0));
DateTime dateTimeOfLastInstance = dateTimeExtractor.extractDateTime(validationSet.get(validationSet.size() - 1));
logger.info("num instances in validation period: {}, with first entry at {}, and last entry at {}", validationSet.size(), dateTimeOfFirstInstance, dateTimeOfLastInstance);
if (instancesAddedToTheValidationSet < potentialValidationSet.size()) {
logger.info("num instances in potential validation set {}, with first entry not added in first pass at {}, and last entry at {}",
potentialValidationSet.size(),
dateTimeExtractor.extractDateTime(potentialValidationSet.get(instancesAddedToTheValidationSet)),
dateTimeExtractor.extractDateTime(potentialValidationSet.get(potentialValidationSet.size() - 1)));
} else {
logger.info("no more insntances potential val set.");
}
}
private List<I> getCandidateValidationSet() {
List<I> potentialValidationSet = Lists.newArrayList();
DateTime validationStartTime = dateTimeExtractor.extractDateTime(allData.get(trainingSet.size() - 1));
validationStartTime = validationStartTime.plusHours(offSetForValidation);
ignoredInstances = 0;
for (int i = trainingSet.size(); i<allData.size(); i++) {
DateTime instanceTime = dateTimeExtractor.extractDateTime(allData.get(i));
if (instanceTime.isAfter(validationStartTime)) {
break;
}
ignoredInstances++;
}
if (ignoredInstances + trainingSet.size() < allData.size()) {
potentialValidationSet = allData.subList(trainingSet.size() + ignoredInstances, allData.size());
}
return potentialValidationSet;
}
private void setNextTrainingSet() {
DateTime lastTrainingTime = dateTimeExtractor.extractDateTime(allData.get(trainingSet.size() - 1));
lastTrainingTime = lastTrainingTime.plusHours(timeSliceHours);
int initialTrainingSize = trainingSet.size();
for (int i = initialTrainingSize; i<allData.size(); i++) {
DateTime instanceTime = dateTimeExtractor.extractDateTime(allData.get(i));
if (instanceTime.isBefore(lastTrainingTime)) {
trainingSet.add(allData.get(i));
} else {
break;
}
}
}
private void addRemainderOfPotentialValidationSetIfNecessary(List<I> potentialValidationSet, int instancesAddedToTheValidationSet) {
/**this method adds prevents situations where the last validation period consists of very little data, by adding the data from the last
* validation period to the period before it.
*/
if (validationSet.size()>0) {
DateTime lastTimeOfValidationSet = dateTimeExtractor.extractDateTime(validationSet.get(validationSet.size() - 1));
DateTime lastTimeOfPotentialValidationSet = dateTimeExtractor.extractDateTime(potentialValidationSet.get(potentialValidationSet.size() - 1));
Duration durationRemaining = new Duration(lastTimeOfValidationSet, lastTimeOfPotentialValidationSet);
if (instancesAddedToTheValidationSet < potentialValidationSet.size() && durationRemaining.getStandardHours() < (long) (timeSliceHours * ACCEPTABLE_EXTRA_TAIL_TIME)) {
validationSet.addAll(potentialValidationSet.subList(instancesAddedToTheValidationSet, potentialValidationSet.size()));
}
}
}
public DateTime firstTimeOfValidationSet(){
return dateTimeExtractor.extractDateTime(validationSet.get(0));
}
private void sortData() {
Collections.sort(allData, new Comparator<I>() {
@Override
public int compare(I o1, I o2) {
return dateTimeExtractor.extractDateTime(o1).compareTo(dateTimeExtractor.extractDateTime(o2));
}
});
}
private void setTrainingSetBasedOnFraction() {
int size = (int) (allData.size() * (1 - crossValidationFraction));
verifySizeIsLessThanTotalSize(allData, size);
trainingSet = new ArrayList<>(allData.subList(0, size));
}
private static void verifySizeIsLessThanTotalSize(List data, int size) {
if (size == data.size()) {
throw new RuntimeException("fractionOfDataForCrossValidation must be non zero");
}
}
}