package quickml.supervised.crossValidation.attributeImportance;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import quickml.data.AttributesMap;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.PredictiveModel;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.crossValidation.data.TrainingDataCycler;
import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionLossFunction;
import java.util.List;
import java.util.Set;
import static com.google.common.base.Preconditions.checkArgument;
public class RegAttributeImportanceFinderBuilder<T extends RegressionInstance> {
private PredictiveModelBuilder<? extends PredictiveModel<AttributesMap, Double>, T> modelBuilder;
private TrainingDataCycler<T> dataCycler;
private double percentAttributesToRemovePerIteration = 0.2;
private int numberOfIterations = 5;
private Set<String> attributesToKeep = Sets.newHashSet();
private List<RegressionLossFunction> lossFunctions = Lists.newArrayList();
private RegressionLossFunction primaryLossFunction;
public RegAttributeImportanceFinderBuilder<T> modelBuilder(PredictiveModelBuilder<? extends PredictiveModel<AttributesMap, Double>, T> modelBuilder) {
this.modelBuilder = modelBuilder;
return this;
}
public RegAttributeImportanceFinderBuilder<T> dataCycler(TrainingDataCycler<T> dataCycler) {
this.dataCycler = dataCycler;
return this;
}
public RegAttributeImportanceFinderBuilder<T> numOfIterations(int numberOfIterations) {
this.numberOfIterations = numberOfIterations;
return this;
}
public RegAttributeImportanceFinderBuilder<T> percentAttributesToRemovePerIteration(double attributesToRemovePerIteration) {
this.percentAttributesToRemovePerIteration = attributesToRemovePerIteration;
return this;
}
public RegAttributeImportanceFinderBuilder<T> primaryLossFunction(RegressionLossFunction primaryLossFunction) {
lossFunctions.add(primaryLossFunction);
this.primaryLossFunction = primaryLossFunction;
return this;
}
public RegAttributeImportanceFinderBuilder<T> lossFunction(RegressionLossFunction lossFunction) {
this.lossFunctions.add(lossFunction);
return this;
}
public RegAttributeImportanceFinderBuilder<T> attributesToKeep(Set<String> attributesToKeep) {
this.attributesToKeep = attributesToKeep;
return this;
}
public RegAttributeImportanceFinder<T> build() {
checkArgument(primaryLossFunction != null, "A primary loss function must be set");
checkArgument(modelBuilder != null, "Must supply a model builder");
checkArgument(dataCycler != null, "Must supply a data cycler");
return new RegAttributeImportanceFinder<>(modelBuilder, dataCycler, percentAttributesToRemovePerIteration,
numberOfIterations, attributesToKeep, lossFunctions, primaryLossFunction);
}
}