/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.analysis;
import rapaio.data.*;
import rapaio.data.stream.FSpot;
import rapaio.math.linear.RM;
import rapaio.math.linear.RV;
import rapaio.math.linear.EigenPair;
import rapaio.math.linear.Linear;
import rapaio.math.linear.dense.QR;
import rapaio.math.linear.dense.SolidRM;
import rapaio.math.linear.dense.SolidRV;
import rapaio.printer.Printable;
import rapaio.printer.Summary;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.logging.Logger;
/**
* Linear discriminant analysis
* <p>
* <p>
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 10/5/15.
*/
public class LDA implements Printable {
private static final Logger logger = Logger.getLogger(LDA.class.getName());
private double tol = 1e-24;
private int maxRuns = 10_000;
private boolean scaling = true;
private String targetName;
private String[] targetLevels;
private String[] inputNames;
private RV mean;
private RV sd;
private RV[] classMean;
private RV eigenValues;
private RM eigenVectors;
public LDA withMaxRuns(int maxRuns) {
this.maxRuns = maxRuns;
return this;
}
public LDA withTol(double tol) {
this.tol = tol;
return this;
}
public RV getEigenValues() {
return eigenValues;
}
public RM getEigenVectors() {
return eigenVectors;
}
public void learn(Frame df, String... targetVars) {
validate(df, targetVars);
logger.fine("start lda train");
RM xx = SolidRM.copy(df.removeVars(targetName));
// compute mean and sd
mean = SolidRV.empty(xx.colCount());
sd = SolidRV.empty(xx.colCount());
for (int i = 0; i < xx.colCount(); i++) {
mean.set(i, xx.mapCol(i).mean().value());
sd.set(i, xx.mapCol(i).var().sdValue());
}
// scale the whole data if it is the case
if (scaling) {
for (int i = 0; i < xx.rowCount(); i++) {
for (int j = 0; j < xx.colCount(); j++) {
if (sd.get(j) != 0)
xx.set(i, j, (xx.get(i, j) - mean.get(j)) / sd.get(j));
}
}
}
// compute sliced data for each class
RM[] x = new RM[targetLevels.length];
for (int i = 0; i < targetLevels.length; i++) {
int index = i;
x[i] = xx.mapRows(df.stream()
.filter(s -> s.label(targetName).equals(targetLevels[index]))
.mapToInt(FSpot::row)
.toArray());
}
// compute class means
classMean = new RV[targetLevels.length];
for (int i = 0; i < targetLevels.length; i++) {
classMean[i] = SolidRV.empty(x[i].colCount());
for (int j = 0; j < x[i].colCount(); j++) {
classMean[i].set(j, x[i].mapCol(j).mean().value());
}
}
// build within scatter matrix
RM sw = SolidRM.empty(inputNames.length, inputNames.length);
for (int i = 0; i < targetLevels.length; i++) {
sw.plus(x[i].scatter());
}
// build between-class scatter matrix
RM sb = SolidRM.empty(inputNames.length, inputNames.length);
for (int i = 0; i < targetLevels.length; i++) {
RM cm = scaling ? classMean[i].asMatrix() : classMean[i].asMatrix().minus(mean.asMatrix());
sb.plus(cm.dot(cm.t()).dot(x[i].rowCount()));
}
// inverse sw
RM swi = new QR(sw).solve(SolidRM.identity(inputNames.length));
// RM swi = new CholeskyDecomposition(sw).solve(SolidRM.identity(inputNames.length));
// use decomp of sbe
RM sbplus = Linear.pdPower(sb, 0.5, maxRuns, tol);
RM sbminus = Linear.pdPower(sb, -0.5, maxRuns, tol);
EigenPair p = Linear.eigenDecomp(sbplus.dot(swi).dot(sbplus), maxRuns, tol);
logger.fine("compute eigenvalues");
eigenValues = p.values();
eigenVectors = sbminus.dot(p.vectors());
// eigenVectors = p.vectors();
logger.fine("sort eigen values and vectors");
Integer[] rows = new Integer[eigenValues.count()];
for (int i = 0; i < rows.length; i++) {
rows[i] = i;
}
Arrays.sort(rows, (o1, o2) -> -Double.valueOf(eigenValues.get(o1)).compareTo(eigenValues.get(o2)));
int[] indexes = Arrays.stream(rows).mapToInt(v -> v).toArray();
eigenValues = eigenValues.asMatrix().mapRows(indexes).mapCol(0).solidCopy();
eigenVectors = eigenVectors.mapCols(indexes).solidCopy();
}
public Frame fit(Frame df, BiFunction<RV, RM, Integer> kFunction) {
RM x = SolidRM.copy(df.mapVars(inputNames));
if (scaling) {
for (int i = 0; i < x.rowCount(); i++) {
for (int j = 0; j < x.colCount(); j++) {
x.set(i, j, (x.get(i, j) - mean.get(j)) / sd.get(j));
}
}
}
int k = kFunction.apply(eigenValues, eigenVectors);
int[] dim = new int[k];
String[] names = new String[k];
for (int i = 0; i < dim.length; i++) {
dim[i] = i;
names[i] = "lda_" + (i + 1);
}
RM result = x.dot(eigenVectors.mapCols(dim));
Frame rest = df.removeVars(inputNames);
return rest.varCount() == 0 ?
SolidFrame.matrix(result, names) :
SolidFrame.matrix(result, names).bindVars(df.removeVars(inputNames));
}
private void validate(Frame df, String... targetVars) {
List<String> targetNames = VRange.of(targetVars).parseVarNames(df);
if (targetNames.isEmpty() || targetNames.size() > 1)
throw new IllegalArgumentException("LDA needs one target var");
targetName = targetNames.get(0);
Set<VarType> allowedTypes = new HashSet<>(Arrays.asList(VarType.BINARY, VarType.INDEX, VarType.ORDINAL, VarType.NUMERIC));
df.varStream().forEach(var -> {
if (targetName.equals(var.name())) {
if (!var.type().equals(VarType.NOMINAL)) {
throw new IllegalArgumentException("target var must be nominal");
}
targetLevels = new String[var.levels().length - 1];
System.arraycopy(var.levels(), 1, targetLevels, 0, var.levels().length - 1);
return;
}
if (!allowedTypes.contains(var.type())) {
throw new IllegalArgumentException("column type not allowed");
}
});
inputNames = df.varStream().filter(v -> !v.name().equals(targetName)).map(Var::name).toArray(String[]::new);
}
public String summary() {
StringBuilder sb = new StringBuilder();
Frame eval = SolidFrame.byVars(
Numeric.empty(eigenValues.count()).withName("values"),
Numeric.empty(eigenValues.count()).withName("percent")
);
double total = 0.0;
for (int i = 0; i < eigenValues.count(); i++) {
total += eigenValues.get(i);
}
for (int i = 0; i < eigenValues.count(); i++) {
eval.setValue(i, "values", eigenValues.get(i));
eval.setValue(i, "percent", eigenValues.get(i) / total);
}
sb.append("Eigen values\n");
sb.append("============\n");
sb.append(Summary.headString(true, eval)).append("\n");
sb.append("Eigen vectors\n");
sb.append("=============\n");
sb.append(eigenVectors.summary()).append("\n");
return sb.toString();
}
}