package quickml.supervised.classifier.logisticRegression;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.javatuples.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.MathUtils;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.regressionModel.IsotonicRegression.PoolAdjacentViolatorsModel;
import java.io.Serializable;
import java.util.*;
/**
* Created by alexanderhawk on 10/9/15.
*/
public class LogisticRegression extends AbstractClassifier {
double[] weights;
private final HashMap<String, Integer> nameToIndexMap;
private static final Logger logger = LoggerFactory.getLogger(LogisticRegression.class);
private Map<Serializable, Double> classificationToClassNameMap;
private PoolAdjacentViolatorsModel poolAdjacentViolatorsModel;
private Set<Double> classifications =Sets.newHashSet();
public LogisticRegression(LogisticRegression uncalibrated, PoolAdjacentViolatorsModel poolAdjacentViolatorsModel) {
this.poolAdjacentViolatorsModel = poolAdjacentViolatorsModel;
this.weights = uncalibrated.weights;
this.nameToIndexMap = uncalibrated.nameToIndexMap;
this.classificationToClassNameMap = uncalibrated.classificationToClassNameMap;
this.classifications = uncalibrated.classifications;
}
public LogisticRegression(double[] weights, final HashMap<String, Integer> nameToIndexMap,
Map<Serializable, Double> classificationToClassNameMap) {
this.weights = weights;
this.nameToIndexMap = nameToIndexMap;
this.classificationToClassNameMap = classificationToClassNameMap;
for (Double classification: classificationToClassNameMap.values()) {
classifications.add(classification);
}
}
public LogisticRegression(double[] weights, final HashMap<String, Integer> nameToIndexMap,
Set<Double> classifications) {
this.weights = weights;
this.nameToIndexMap = nameToIndexMap;
this.classifications= classifications;
}
@Override
public double getProbability(final AttributesMap attributes, final Serializable classification) {
double dotProduct = 0;
dotProduct += weights[0];
for (String attribute : attributes.keySet()) {
int index = nameToIndexMap.get(attribute);
dotProduct += weights[index] * (Double) attributes.get(attribute);
}
double uncalibrated;
if ((double)classification == 1.0) {
uncalibrated = MathUtils.sigmoid(dotProduct);
} else {
uncalibrated = 1.0-MathUtils.sigmoid(dotProduct);
}
if (poolAdjacentViolatorsModel!=null) {
return poolAdjacentViolatorsModel.predictIfInterpolation(uncalibrated);
} else {
return uncalibrated;
}
}
@Override
public PredictionMap predict(final AttributesMap attributes) {
PredictionMap predictionMap = new PredictionMap(new HashMap<Serializable, Double>());
for (Serializable classification : classifications) {
predictionMap.put(classification, getProbability(attributes, classification));
}
return predictionMap;
}
@Override
public PredictionMap predictWithoutAttributes(final AttributesMap attributes, final Set<String> attributesToIgnore) {
throw new RuntimeException("not implemented");
}
public List<Pair<Double, String>> getTopMostPredictiveAttributes(double fractionOfList){
List<Pair<Double, String>> topAttributes = Lists.newArrayList();
for (Map.Entry<String, Integer> entry: nameToIndexMap.entrySet()) {
topAttributes.add(new Pair<Double, String>(weights[entry.getValue()],entry.getKey()));
}
Collections.sort(topAttributes, new Comparator<Pair<Double, String>>() {
@Override
public int compare(Pair<Double, String> o1, Pair<Double, String> o2) {
return Double.compare(o1.getValue0(), o2.getValue0());
}
});
double attributesToReturn = Math.max(1.0, fractionOfList) * topAttributes.size();
return topAttributes.subList(0, (int)attributesToReturn);
}
}