package com.github.lwhite1.tablesaw.api.ml.classification; import com.github.lwhite1.tablesaw.api.BooleanColumn; import com.github.lwhite1.tablesaw.api.CategoryColumn; import com.github.lwhite1.tablesaw.api.IntColumn; import com.github.lwhite1.tablesaw.api.NumericColumn; import com.github.lwhite1.tablesaw.api.ShortColumn; import com.github.lwhite1.tablesaw.util.DoubleArrays; import com.google.common.base.Preconditions; import smile.classification.LDA; import java.util.SortedSet; import java.util.TreeSet; /** * */ public class Lda extends AbstractClassifier { private final LDA classifierModel; public static Lda learn(ShortColumn labels, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray()); return new Lda(classifierModel); } public static Lda learn(IntColumn labels, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray()); return new Lda(classifierModel); } public static Lda learn(BooleanColumn labels, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray()); return new Lda(classifierModel); } public static Lda learn(CategoryColumn labels, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray()); return new Lda(classifierModel); } public static Lda learn(ShortColumn labels, double[] priors, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray(), priors); return new Lda(classifierModel); } public static Lda learn(IntColumn labels, double[] priors, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray(), priors); return new Lda(classifierModel); } public static Lda learn(BooleanColumn labels, double[] priors, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray(), priors); return new Lda(classifierModel); } public static Lda learn(CategoryColumn labels, double[] priors, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray(), priors); return new Lda(classifierModel); } public static Lda learn(ShortColumn labels, double[] priors, double tolerance, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray(), priors, tolerance); return new Lda(classifierModel); } public static Lda learn(IntColumn labels, double[] priors, double tolerance, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray(), priors, tolerance); return new Lda(classifierModel); } public static Lda learn(BooleanColumn labels, double[] priors, double tolerance, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.toIntArray(), priors, tolerance); return new Lda(classifierModel); } public static Lda learn(CategoryColumn labels, double[] priors, double tolerance, NumericColumn... predictors) { LDA classifierModel = new LDA(DoubleArrays.to2dArray(predictors), labels.data().toIntArray(), priors, tolerance); return new Lda(classifierModel); } private Lda(LDA classifierModel) { this.classifierModel = classifierModel; } public int predict(double[] data) { return classifierModel.predict(data); } public ConfusionMatrix predictMatrix(ShortColumn labels, NumericColumn... predictors) { Preconditions.checkArgument(predictors.length > 0); SortedSet<Object> labelSet = new TreeSet<>(labels.asSet()); ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet); populateMatrix(labels.toIntArray(), confusion, predictors); return confusion; } public ConfusionMatrix predictMatrix(IntColumn labels, NumericColumn... predictors) { Preconditions.checkArgument(predictors.length > 0); SortedSet<Object> labelSet = new TreeSet<>(labels.asSet()); ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet); populateMatrix(labels.data().toIntArray(), confusion, predictors); return confusion; } public ConfusionMatrix predictMatrix(BooleanColumn labels, NumericColumn... predictors) { Preconditions.checkArgument(predictors.length > 0); SortedSet<Object> labelSet = new TreeSet<>(labels.asSet()); ConfusionMatrix confusion = new StandardConfusionMatrix(labelSet); populateMatrix(labels.toIntArray(), confusion, predictors); return confusion; } public ConfusionMatrix predictMatrix(CategoryColumn labels, NumericColumn... predictors) { Preconditions.checkArgument(predictors.length > 0); SortedSet<String> labelSet = new TreeSet<>(labels.asSet()); ConfusionMatrix confusion = new CategoryConfusionMatrix(labels, labelSet); populateMatrix(labels.data().toIntArray(), confusion, predictors); return confusion; } public int[] predict(NumericColumn... predictors) { Preconditions.checkArgument(predictors.length > 0); int[] predictedLabels = new int[predictors[0].size()]; for (int row = 0; row < predictors[0].size(); row++) { double[] data = new double[predictors.length]; for (int col = 0; col < predictors.length; col++) { data[row] = predictors[col].getFloat(row); } predictedLabels[row] = classifierModel.predict(data); } return predictedLabels; } @Override int predictFromModel(double[] data) { return classifierModel.predict(data); } }