package quickml.supervised.classifier.twoStageModel; import com.google.common.collect.Lists; import quickml.data.instances.ClassifierInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.classifier.Classifier; import java.io.Serializable; import java.util.List; import java.util.Map; /** * Created by alexanderhawk on 10/7/14. */ public class TwoStageModelBuilder implements PredictiveModelBuilder<TwoStageClassifier, ClassifierInstance> { PredictiveModelBuilder< ? extends Classifier, ClassifierInstance> wrappedModelBuilder1; PredictiveModelBuilder<? extends Classifier, ClassifierInstance> wrappedModelBuilder2; public TwoStageModelBuilder(PredictiveModelBuilder< ? extends Classifier, ClassifierInstance> wrappedModelBuilder1, PredictiveModelBuilder<? extends Classifier, ClassifierInstance> wrappedModelBuilder2) { this.wrappedModelBuilder1 = wrappedModelBuilder1; this.wrappedModelBuilder2 = wrappedModelBuilder2; } @Override public TwoStageClassifier buildPredictiveModel(Iterable<ClassifierInstance> trainingData) { List<ClassifierInstance> stage1Data = Lists.newArrayList(); List<ClassifierInstance> stage2Data = Lists.newArrayList(); List<ClassifierInstance> validationData = Lists.newArrayList(); for (ClassifierInstance instance : trainingData) { if (instance.getLabel().equals("positive-both")) { stage1Data.add(new ClassifierInstance(instance.getAttributes(), 1.0)); stage2Data.add(new ClassifierInstance(instance.getAttributes(), 1.0)); validationData.add(new ClassifierInstance(instance.getAttributes(), 1.0)); } else if (instance.getLabel().equals("positive-first")) { stage1Data.add(new ClassifierInstance(instance.getAttributes(), 1.0)); stage2Data.add(new ClassifierInstance(instance.getAttributes(), 0.0)); validationData.add(new ClassifierInstance(instance.getAttributes(), 0.0)); } else if (instance.getLabel().equals("negative")) { stage1Data.add(new ClassifierInstance(instance.getAttributes(), 0.0)); validationData.add(new ClassifierInstance(instance.getAttributes(), 0.0)); } else { throw new RuntimeException("missing valid label"); } } Classifier c1 = wrappedModelBuilder1.buildPredictiveModel(stage1Data); Classifier c2 = wrappedModelBuilder2.buildPredictiveModel(stage2Data); return new TwoStageClassifier(c1, c2); } @Override public void updateBuilderConfig(Map<String, Serializable> config) { wrappedModelBuilder1.updateBuilderConfig(config); wrappedModelBuilder2.updateBuilderConfig(config); } }