package com.github.lwhite1.tablesaw.api.ml.classification;
import com.github.lwhite1.tablesaw.api.BooleanColumn;
import com.github.lwhite1.tablesaw.api.Table;
import com.github.lwhite1.tablesaw.util.DoubleArrays;
import com.github.lwhite1.tablesaw.util.Example;
import org.junit.Test;
import smile.classification.KNN;
import java.util.SortedSet;
import java.util.TreeSet;
import static com.github.lwhite1.tablesaw.api.QueryHelper.column;
/**
*
*/
public class ConfusionMatrixTest extends Example {
@Test
public void testAsTable() throws Exception {
Table example = Table.createFromCsv("data/KNN_Example_1.csv");
BooleanColumn booleanTarget = example.selectIntoColumn("bt", column("Label").isEqualTo(1));
Table[] splits = example.sampleSplit(.5);
Table train = splits[0];
Table test = splits[1];
KNN<double[]> knn = KNN.learn(
DoubleArrays.to2dArray(train.nCol("X"), train.nCol("Y")),
train.shortColumn(2).toIntArray(), 2);
int[] predicted = new int[test.rowCount()];
SortedSet<Object> lableSet = new TreeSet<>(train.shortColumn(2).asSet());
ConfusionMatrix confusion = new StandardConfusionMatrix(lableSet);
for (int row : test) {
double[] data = new double[2];
data[0] = test.floatColumn(0).getFloat(row);
data[1] = test.floatColumn(1).getFloat(row);
predicted[row] = knn.predict(data);
confusion.increment((int) test.shortColumn(2).get(row), predicted[row]);
}
}
@Test
public void testWithBooleanColumn() throws Exception {
Table example = Table.createFromCsv("data/KNN_Example_1.csv");
BooleanColumn booleanTarget = example.selectIntoColumn("bt", column("Label").isEqualTo(1));
example.addColumn(booleanTarget);
Table[] splits = example.sampleSplit(.5);
Table train = splits[0];
Table test = splits[1];
LogisticRegression lr = LogisticRegression.learn(
train.booleanColumn(3), train.nCol("X"), train.nCol("Y"));
System.out.println(lr.predictMatrix(test.booleanColumn(3), test.floatColumn(0), test.floatColumn(1)).toString
());
int[] predicted = new int[test.rowCount()];
SortedSet<Object> lableSet = new TreeSet<>(train.shortColumn(2).asSet());
ConfusionMatrix confusion = new StandardConfusionMatrix(lableSet);
for (int row : test) {
double[] data = new double[2];
data[0] = test.floatColumn(0).getFloat(row);
data[1] = test.floatColumn(1).getFloat(row);
predicted[row] = lr.predict(data);
confusion.increment((int) test.shortColumn(2).get(row), predicted[row]);
}
System.out.println(confusion.toString());
}
}