package com.github.lwhite1.tablesaw.api.ml.classification;
import com.github.lwhite1.tablesaw.api.BooleanColumn;
import com.github.lwhite1.tablesaw.api.CategoryColumn;
import com.github.lwhite1.tablesaw.api.IntColumn;
import com.github.lwhite1.tablesaw.api.NumericColumn;
import com.github.lwhite1.tablesaw.api.ShortColumn;
import com.github.lwhite1.tablesaw.util.DoubleArrays;
import com.google.common.base.Preconditions;
import java.util.SortedSet;
import java.util.TreeSet;
/**
*
*/
public class LogisticRegression extends AbstractClassifier {
private final smile.classification.LogisticRegression classifierModel;
public static LogisticRegression learn(ShortColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(ShortColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray(),
lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray(), lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray(),
lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray(), lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(ShortColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.data().toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.data().toIntArray(),
lambda,
tolerance,
maxIters
);
return new LogisticRegression(classifierModel);
}
private LogisticRegression(smile.classification.LogisticRegression classifierModel) {
this.classifierModel = classifierModel;
}
public int predict(double[] data) {
return classifierModel.predict(data);
}
public ConfusionMatrix predictMatrix(ShortColumn labels, NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
SortedSet<Object> labelSet = new TreeSet<>(labels.asSet());
ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
populateMatrix(labels.toIntArray(), confusion, predictors);
return confusion;
}
public ConfusionMatrix predictMatrix(IntColumn labels, NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
SortedSet<Object> labelSet = new TreeSet<>(labels.asSet());
ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
populateMatrix(labels.data().toIntArray(), confusion, predictors);
return confusion;
}
public ConfusionMatrix predictMatrix(CategoryColumn labels, NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
SortedSet<String> labelSet = new TreeSet<>(labels.asSet());
ConfusionMatrix confusion = new CategoryConfusionMatrix(labels, labelSet);
populateMatrix(labels.data().toIntArray(), confusion, predictors);
return confusion;
}
public ConfusionMatrix predictMatrix(BooleanColumn labels, NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
SortedSet<Object> labelSet = new TreeSet<>(labels.asSet());
ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
populateMatrix(labels.toIntArray(), confusion, predictors);
return confusion;
}
public int[] predict(NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
int[] predictedLabels = new int[predictors[0].size()];
for (int row = 0; row < predictors[0].size(); row++) {
double[] data = new double[predictors.length];
for (int col = 0; col < predictors.length; col++) {
data[row] = predictors[col].getFloat(row);
}
predictedLabels[row] = classifierModel.predict(data);
}
return predictedLabels;
}
@Override
int predictFromModel(double[] data) {
return classifierModel.predict(data);
}
public double logLikelihood() {
return classifierModel.loglikelihood();
}
public double predictFromModel(double[] x, double[] posteriori) {
return classifierModel.predict(x, posteriori);
}
public double predictFromModel(int row, double[] posteriori, NumericColumn... predictors) {
double[] data = new double[predictors.length];
for (int col = 0; col < predictors.length; col++) {
data[row] = predictors[col].getFloat(row);
}
return classifierModel.predict(data, posteriori);
}
}