package quickml.supervised.dataProcessing.instanceTranformer; import com.google.common.collect.Lists; import quickml.data.instances.ClassifierInstance; import quickml.data.instances.SparseClassifierInstanceFactory; import quickml.supervised.classifier.logisticRegression.SparseClassifierInstance; import java.io.Serializable; import java.util.HashMap; import java.util.List; import static quickml.supervised.classifier.logisticRegression.InstanceTransformerUtils.populateNameToIndexMap; /** * Created by alexanderhawk on 10/14/15. */ public class ClassifierInstance2SparseClassifierInstance<L extends Serializable, I extends ClassifierInstance> implements InstanceTransformer<I, SparseClassifierInstance> { private SparseClassifierInstanceFactory instanceFactory; private HashMap<String, Integer> nameToIndexMap; public ClassifierInstance2SparseClassifierInstance(List<I> trainingData) { this.nameToIndexMap = populateNameToIndexMap(trainingData, true); this.instanceFactory = new SparseClassifierInstanceFactory(nameToIndexMap); } public HashMap<String, Integer> getNameToIndexMap() { return nameToIndexMap; } public SparseClassifierInstanceFactory getInstanceFactory() { return instanceFactory; } public static <L extends Serializable, I extends ClassifierInstance> List<SparseClassifierInstance> transformAllInstances(List<I> instances) { ClassifierInstance2SparseClassifierInstance transformer = new ClassifierInstance2SparseClassifierInstance(instances); SparseClassifierInstanceFactory instanceFactory = transformer.getInstanceFactory(); List<SparseClassifierInstance> returnInstances = Lists.<SparseClassifierInstance>newArrayList(); for (I instance : instances) { returnInstances.add(instanceFactory.createInstance(instance.getAttributes(), instance.getLabel(), instance.getWeight())); } return returnInstances; } @Override public SparseClassifierInstance transformInstance(I instance) { return instanceFactory.createInstance(instance.getAttributes(), instance.getLabel(), instance.getWeight()); } }