package quickml.supervised.featureEngineering; import com.google.common.collect.Lists; import org.testng.annotations.Test; import quickml.data.*; import quickml.data.instances.ClassifierInstance; import quickml.data.instances.InstanceWithAttributesMap; import quickml.supervised.PredictiveModel; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.featureEngineering1.AttributesEnrichStrategy; import quickml.supervised.featureEngineering1.AttributesEnricher; import quickml.supervised.featureEngineering1.FeatureEngineeredClassifier; import quickml.supervised.featureEngineering1.FeatureEngineeringClassifierBuilder; import javax.annotation.Nullable; import java.io.Serializable; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; public class FeatureEngineeringClassifierBuilderTest { private static Integer valueToTest = 1; @Test public void simpleTest() { List<InstanceWithAttributesMap<?>> trainingData = Lists.newArrayList(); trainingData.add(new ClassifierInstance(AttributesMap.newHashMap(), 1)); PredictiveModelBuilder testPMB = new TestPMBuilder(); FeatureEngineeringClassifierBuilder feBuilder = new FeatureEngineeringClassifierBuilder(testPMB, Lists.newArrayList(new TestAEBS())); final FeatureEngineeredClassifier predictiveModel = feBuilder.buildPredictiveModel(trainingData); predictiveModel.getProbability(trainingData.get(0).getAttributes(), valueToTest); } public static class TestAEBS implements AttributesEnrichStrategy { @Override public AttributesEnricher build(final Iterable<InstanceWithAttributesMap<?>> trainingData) { return new AttributesEnricher() { private static final long serialVersionUID = -4851048617673142530L; public AttributesMap apply(@Nullable final AttributesMap attributes) { AttributesMap er = AttributesMap.newHashMap(); er.putAll(attributes); er.put("enriched", 1); return er; } }; } } public static class TestPMBuilder implements PredictiveModelBuilder<TestPM, InstanceWithAttributesMap<?>> { @Override public TestPM buildPredictiveModel(Iterable<InstanceWithAttributesMap<?>> trainingData) { for (InstanceWithAttributesMap<?> instance : trainingData) { if (!instance.getAttributes().containsKey("enriched")) { throw new IllegalArgumentException("Predictive model training data must contain enriched instances"); } } return new TestPM(); } @Override public void updateBuilderConfig(Map<String, Serializable> config) { } } public static class TestPM implements PredictiveModel<AttributesMap, PredictionMap> { private static final long serialVersionUID = -3449746370937561259L; @Override public PredictionMap predict(AttributesMap attributes) { if (!attributes.containsKey("enriched")) { throw new IllegalArgumentException("Predictive model training data must contain enriched instances"); } Map<Serializable, Double> map = new HashMap<>(); map.put(valueToTest, 0.0); return new PredictionMap(map); } @Override public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) { return predict(attributes); } } }