package quickml.supervised.classifier.logisticRegression;
import com.google.common.collect.Lists;
import quickml.data.instances.ClassifierInstance;
import quickml.data.instances.ClassifierInstanceFactory;
import quickml.supervised.Utils;
import quickml.supervised.dataProcessing.AttributeCharacteristics;
import quickml.supervised.dataProcessing.BasicTrainingDataSurveyor;
import quickml.supervised.dataProcessing.ElementaryDataTransformer;
import quickml.supervised.dataProcessing.instanceTranformer.*;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Created by alexanderhawk on 10/14/15.
*/
public class DatedAndMeanNormalizedLogisticRegressionDataTransformer extends StandardDataTransformer<MeanNormalizedAndDatedLogisticRegressionDTO> {
//to do: get label to digit Map and stick in DTO (and transform to logistic regression eventually)
//make LogisticRegressionBuilder use this class and not be tightly coupled to mean normalization (e.g. allow log^2 values)
//make cross validator take a datetransformer (specifically, the Logistic regression PMB, and then do the data normalization
// and set the date time extractor)
/**
* class provides the method: transformInstances, to convert a set of classifier instances into instances that can be processed by
* the LogisticRegressionBuilder.
*
* it assumes that all attributes with numeric values are numeric, and are not in need of one hot encoding.
* product feature appendation as well as common co-occurences should be hyper-params within logistic regression.
*
*/
/*Options, wrap logistic regression? in a new logistic regression class that has a logistic reg transformer?
* Or change sparse classifier instance as the the type of Logistic Regression? I almost prefer this. So now to use it...one just passes in a normal list of training instances
*/
private OneHotEncoder<Serializable, ClassifierInstance, ClassifierInstance> oneHotEncoder;
private BinaryAndNumericAttributeNormalizer<Serializable, ClassifierInstance, ClassifierInstance> normalizer;
private LabelToDigitConverter<Serializable, ClassifierInstance, ClassifierInstance> labelToDigitConverter;
private ClassifierInstance2SparseClassifierInstance<Serializable, ClassifierInstance> inputType2ReturnTypeTransformer;
public DatedAndMeanNormalizedLogisticRegressionDataTransformer() {}
public DatedAndMeanNormalizedLogisticRegressionDataTransformer productFeatureAppender(ProductFeatureAppender<ClassifierInstance> productFeatureAppender) {
this.productFeatureAppender = productFeatureAppender;
return this;
}
public boolean usingProductFeatures(){
return productFeatureAppender!=null;
}
public void doLabelToDigitConversion(boolean doLabelToDigitConversion){
this.doLabelToDigitConversion = doLabelToDigitConversion;
}
public DatedAndMeanNormalizedLogisticRegressionDataTransformer minObservationsOfAttribute(int minObservationsOfAttribute) {
this.minObservationsOfAttribute = minObservationsOfAttribute;
return this;
}
public Map<Serializable, Double> getNumericClassLabels() {
return numericClassLabels;
}
public DatedAndMeanNormalizedLogisticRegressionDataTransformer usingProductFeatures(boolean useProductFeatures) {
this.useProductFeatures = useProductFeatures;
return this;
}
//shouldn't be hard coded as a logistic Regression DTO..or at least it should be an abstract type...or a generic?
@Override
public MeanNormalizedAndDatedLogisticRegressionDTO transformData(List<ClassifierInstance> trainingData){
List<InstanceTransformer<ClassifierInstance, ClassifierInstance>> input2InputTransformations = Lists.newArrayList();
List<ClassifierInstance> firstStageData;
if (doLabelToDigitConversion) {
labelToDigitConverter = getLabelToDigitConverter(trainingData);
input2InputTransformations.add(labelToDigitConverter);
ElementaryDataTransformer<ClassifierInstance, ClassifierInstance> dataTransformer = new ElementaryDataTransformer<ClassifierInstance, ClassifierInstance>(
input2InputTransformations, null);
firstStageData = dataTransformer.transformInstances(trainingData);
numericClassLabels = labelToDigitConverter.getNumericClassLabels();
}
else {
firstStageData = trainingData;
}
oneHotEncoder = getOneHotEncoder(firstStageData, minObservationsOfAttribute);
List<ClassifierInstance> oneHotEncoded = oneHotEncoder.transformAll(firstStageData);
List<ClassifierInstance> instancesToNormalize;
if (useProductFeatures) {
instancesToNormalize = productFeatureAppender.addProductAttributes(oneHotEncoded);
} else {
instancesToNormalize = oneHotEncoded;
}
normalizer = getNormalizer(instancesToNormalize);
input2InputTransformations = Lists.newArrayList();
input2InputTransformations.add(normalizer);
ElementaryDataTransformer dataTransformer = new ElementaryDataTransformer(
input2InputTransformations, null);
List<ClassifierInstance> normalized = dataTransformer.transformInstances(instancesToNormalize);
inputType2ReturnTypeTransformer = getInputType2ReturnTypeTransformer(normalized);
input2InputTransformations = Lists.newArrayList();
ElementaryDataTransformer<ClassifierInstance, SparseClassifierInstance> inputType2OutputdataTransformer = new ElementaryDataTransformer<ClassifierInstance, SparseClassifierInstance>(
input2InputTransformations, inputType2ReturnTypeTransformer);
List<SparseClassifierInstance> sparseClassifierInstances = inputType2OutputdataTransformer.transformInstances(normalized);
return new MeanNormalizedAndDatedLogisticRegressionDTO(sparseClassifierInstances, getNameToIndexMap(), getMeanStdMaxMins(), numericClassLabels);
}
public HashMap<String, Integer> getNameToIndexMap(){
return inputType2ReturnTypeTransformer.getNameToIndexMap();
}
public Map<String, Utils.MeanStdMaxMin> getMeanStdMaxMins(){
return normalizer.getMeanStdMaxMins();
}
static ClassifierInstance2SparseClassifierInstance<Serializable, ClassifierInstance> getInputType2ReturnTypeTransformer(List<ClassifierInstance> trainingData) {
return new ClassifierInstance2SparseClassifierInstance<>(trainingData);
}
static LabelToDigitConverter<Serializable, ClassifierInstance, ClassifierInstance> getLabelToDigitConverter(List<ClassifierInstance> trainingData) {
return new LabelToDigitConverter<>(new ClassifierInstanceFactory(), trainingData);
}
static BinaryAndNumericAttributeNormalizer<Serializable, ClassifierInstance, ClassifierInstance> getNormalizer(List<ClassifierInstance> trainingData) {
return new BinaryAndNumericAttributeNormalizer<>(trainingData, new ClassifierInstanceFactory(), new BinaryAndNumericAttributeNormalizer.NoNormalizationCondition() {
@Override
public boolean noNormalization(String key) {
return false;//key.contains("timeOfArrival-");
}
});
}
static OneHotEncoder<Serializable, ClassifierInstance, ClassifierInstance> getOneHotEncoder(List<ClassifierInstance> trainingData, int minObservationsOfAttribute) {
BasicTrainingDataSurveyor<ClassifierInstance> btds = new BasicTrainingDataSurveyor<ClassifierInstance>(false);
Map<String, AttributeCharacteristics> attributeCharacteristics = btds.getMapOfAttributesToAttributeCharacteristics(trainingData);
return new OneHotEncoder<>(attributeCharacteristics, new ClassifierInstanceFactory(), minObservationsOfAttribute);
}
}