package quickml.supervised.classifier.logisticRegression; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import quickml.data.instances.ClassifierInstance; import quickml.data.instances.InstanceWithAttributesMap; import java.io.Serializable; import java.util.*; /** * Created by alexanderhawk on 10/12/15. */ public class InstanceTransformerUtils { public static final String BIAS_TERM = "biasTerm"; public static <T extends InstanceWithAttributesMap> HashMap<String, Integer> populateNameToIndexMap (List<T> trainingData, boolean useBias) { HashMap<String, Integer> nameToIndexMap = Maps.newHashMap(); int index = 0; if (useBias) { nameToIndexMap.put(BIAS_TERM, index); index++; } for (T instance : trainingData) { for (String key : instance.getAttributes().keySet()) { if (!nameToIndexMap.containsKey(key)) { nameToIndexMap.put(key, index); index++; } } } return nameToIndexMap; } public static <T extends InstanceWithAttributesMap> Map<Serializable, Double> determineNumericClassLabels (List<T> trainingData) { /**class identifies a map from instances to numeric values;*/ Map<Serializable, Double> classifications = Maps.newHashMap(); if (hasOneZeroLabels(trainingData)) { classifications.put(1.0, 1.0); classifications.put(0.0, 0.0); return classifications; } double numericClassRepresentation = 0.0; for (T instance : trainingData) { if (!classifications.containsKey(instance.getLabel())) { classifications.put(instance.getLabel(), numericClassRepresentation); numericClassRepresentation += 1.0; } } return classifications; } private static <T extends InstanceWithAttributesMap> boolean hasOneZeroLabels(List<T> trainingData) { for (T instance : trainingData) { if (!instance.getLabel().equals(Double.valueOf(1.0)) && !instance.getLabel().equals(Double.valueOf(0.0))) { return false; } } return true; } public static <T extends ClassifierInstance> Set<Double> getClassifications(List<T> trainingData) { Set<Double> classifications = Sets.newHashSet(); for (T instance : trainingData) { if (!(instance.getLabel() instanceof Double)) { throw new RuntimeException("must have numeric features"); } classifications.add((Double) instance.getLabel()); } return classifications; } public static String oneHotEncode(String attributeName, Serializable attributeValue) { return attributeName + "--" + attributeValue; } }