package quickml.supervised.featureEngineering;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.testng.Assert;
import org.testng.annotations.Test;
import quickml.data.AttributesMap;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.supervised.featureEngineering1.AttributesEnricher;
import quickml.supervised.featureEngineering1.enrichStrategies.probabilityInjector.ProbabilityEnrichStrategy;
import java.util.List;
public class ProbabilityEnrichStrategyTest {
@Test
public void testCreateAttributesEnricher() throws Exception {
List<InstanceWithAttributesMap<?>> trainingData = Lists.newLinkedList();
AttributesMap attributes = AttributesMap.newHashMap() ;
attributes.put("k1",2);
attributes.put("k2",1);
trainingData.add(new InstanceWithAttributesMap(attributes, "true"));
attributes = AttributesMap.newHashMap() ;
attributes.put("k1",1);
attributes.put("k2",2);
trainingData.add(new InstanceWithAttributesMap(attributes, "true"));
attributes = AttributesMap.newHashMap() ;
attributes.put("k1",2);
attributes.put("k2",2);
trainingData.add(new InstanceWithAttributesMap(attributes, "false"));
attributes = AttributesMap.newHashMap() ;
attributes.put("k1",1);
attributes.put("k2",2);
trainingData.add(new InstanceWithAttributesMap(attributes, "false"));
ProbabilityEnrichStrategy probabilityEnrichStrategy = new ProbabilityEnrichStrategy(Sets.newHashSet("k1", "k2"), "true");
final AttributesEnricher attributesEnricher = probabilityEnrichStrategy.build(trainingData);
{
AttributesMap inputAttributes = AttributesMap.newHashMap() ;
inputAttributes.put("k1", 1);
inputAttributes.put("k2", 1);
final AttributesMap outputAttributes = attributesEnricher.apply(inputAttributes);
Assert.assertEquals(outputAttributes.get("k1-PROB"), 0.5);
Assert.assertEquals(outputAttributes.get("k2-PROB"), 1.0);
}
{
AttributesMap inputAttributes = AttributesMap.newHashMap() ;
inputAttributes.put("k1", 2);
inputAttributes.put("k2", 2);
final AttributesMap outputAttributes = attributesEnricher.apply(inputAttributes);
Assert.assertEquals(outputAttributes.get("k1-PROB"), 0.5);
Assert.assertEquals(outputAttributes.get("k2-PROB"), 1.0/3.0);
}
}
}