package quickml.supervised.classifier.splitOnAttribute; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.data.AttributesMap; import quickml.data.PredictionMap; import quickml.supervised.classifier.AbstractClassifier; import quickml.supervised.classifier.Classifier; import java.io.IOException; import java.io.Serializable; import java.util.Map; import java.util.Set; /** * Created by ian on 5/29/14. */ public class SplitOnAttributeClassifier extends AbstractClassifier { private static final long serialVersionUID = 2642074639257374588L; private final String attributeKey; private final Map<? extends Serializable, Integer> splitValToGroupId; private final Map<Integer, Classifier> splitModels; private final Integer defaultGroup; private static final Logger logger = LoggerFactory.getLogger(SplitOnAttributeClassifier.class); public SplitOnAttributeClassifier(String attributeKey, Map<? extends Serializable, Integer> splitValToGroupId, Integer defaultGroup, final Map<Integer, Classifier> splitModels) { logger.info("creating split classifier"); this.attributeKey = attributeKey; this.splitModels = splitModels; this.splitValToGroupId = splitValToGroupId; this.defaultGroup = defaultGroup; } public Integer getDefaultGroup() { return defaultGroup; } public Map<? extends Serializable, Integer> getSplitValToGroupId() { return splitValToGroupId; } @Override public double getProbability(final AttributesMap attributes, final Serializable classification) { return getModelForAttributes(attributes).getProbability(attributes, classification); } @Override public double getProbabilityWithoutAttributes(final AttributesMap attributes, final Serializable classification, Set<String> attributesToIgnore) { return getModelForAttributes(attributes).getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore); } @Override public PredictionMap predict(final AttributesMap attributes) { return getModelForAttributes(attributes).predict(attributes); } @Override public PredictionMap predictWithoutAttributes(final AttributesMap attributes, Set<String> attributesToIgnore) { return getModelForAttributes(attributes).predictWithoutAttributes(attributes, attributesToIgnore); } @Override public Serializable getClassificationByMaxProb(final AttributesMap attributes) { return getModelForAttributes(attributes).getClassificationByMaxProb(attributes); } public Map<Integer, Classifier> getSplitModels() { return splitModels; } private Classifier getModelForAttributes(AttributesMap attributes) { Serializable value = attributes.get(attributeKey); if (value == null) { throw new NullPointerException("not getting splitVar value"); } Integer groupId = splitValToGroupId.get(value); if (groupId == null) { groupId = defaultGroup; logger.error("not getting a groupId"); } return splitModels.get(groupId); } }