/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.boost;
import rapaio.data.*;
import rapaio.data.sample.RowSampler;
import rapaio.ml.classifier.AbstractClassifier;
import rapaio.ml.classifier.CFit;
import rapaio.ml.classifier.Classifier;
import rapaio.ml.common.Capabilities;
import rapaio.ml.regression.RFit;
import rapaio.experiment.ml.regression.boost.gbt.BTRegression;
import rapaio.experiment.ml.regression.boost.gbt.GBTLossFunction;
import rapaio.ml.regression.tree.RTree;
import rapaio.sys.WS;
import java.util.ArrayList;
import java.util.List;
/**
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> at 12/12/14.
*/
public class GBTClassifier extends AbstractClassifier implements Classifier {
private static final long serialVersionUID = -2979235364091072967L;
int K;
double[][] f;
double[][] p;
private double shrinkage = 1.0;
// prediction artifact
private BTRegression classifier = RTree.buildCART().withMaxDepth(4);
private List<List<BTRegression>> trees;
public GBTClassifier() {
withRuns(10);
}
@Override
public GBTClassifier newInstance() {
return (GBTClassifier) new GBTClassifier()
.withSampler(sampler())
.withShrinkage(shrinkage)
.withTree(classifier.newInstance())
.withRuns(runs());
}
@Override
public String name() {
return "GBTClassifier";
}
@Override
public String fullName() {
StringBuilder sb = new StringBuilder();
sb.append(name()).append("{");
sb.append("runs=").append(runs());
sb.append("}");
return sb.toString();
}
@Override
public Capabilities capabilities() {
return new Capabilities()
.withInputCount(1, 1_000_000)
.withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NOMINAL, VarType.ORDINAL, VarType.NUMERIC)
.withAllowMissingInputValues(true)
.withTargetCount(1, 1)
.withTargetTypes(VarType.NOMINAL)
.withAllowMissingTargetValues(false);
}
public GBTClassifier withTree(BTRegression rTree) {
this.classifier = rTree;
return this;
}
public GBTClassifier withShrinkage(double shrinkage) {
this.shrinkage = shrinkage;
return this;
}
@Override
public GBTClassifier withSampler(RowSampler sampler) {
return (GBTClassifier) super.withSampler(sampler);
}
@Override
public boolean coreTrain(Frame df, Var weights) {
// algorithm described by ESTL pag. 387
K = firstTargetLevels().length - 1;
f = new double[df.rowCount()][K];
p = new double[df.rowCount()][K];
trees = new ArrayList<>();
for (int i = 0; i < K; i++) {
trees.add(new ArrayList<>());
}
for (int m = 0; m < runs(); m++) {
buildAdditionalTree(df, weights);
if (runningHook() != null) {
runningHook().accept(this, m);
}
}
return true;
}
private void buildAdditionalTree(Frame df, Var weights) {
// a) Set p_k(x)
for (int i = 0; i < df.rowCount(); i++) {
double sum = 0;
for (int k = 0; k < K; k++) {
sum += Math.pow(Math.E, f[i][k]);
}
for (int k = 0; k < K; k++) {
p[i][k] = Math.pow(Math.E, f[i][k]) / sum;
if (Double.isNaN(p[i][k])) {
WS.println("ERROR");
}
}
}
// b)
for (int k = 0; k < K; k++) {
Numeric r = Numeric.empty().withName("##tt##");
for (int i = 0; i < df.rowCount(); i++) {
double y_i = (df.var(firstTargetName()).index(i) == k + 1) ? 1 : 0;
r.addValue(y_i - p[i][k]);
}
Frame x = df.removeVars(targetNames());
Frame train = x.bindVars(r);
BTRegression tree = classifier.newInstance();
Mapping samplerMapping = sampler().nextSample(x, weights).mapping;
tree.train(train.mapRows(samplerMapping), weights.mapRows(samplerMapping), "##tt##");
tree.boostFit(x, r, r, new ClassifierLossFunction(K));
RFit rr = tree.fit(train, true);
for (int i = 0; i < df.rowCount(); i++) {
f[i][k] += shrinkage * rr.firstFit().value(i);
}
trees.get(k).add(tree);
}
}
@Override
public CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) {
CFit cr = CFit.build(this, df, withClasses, withDistributions);
for (int k = 0; k < K; k++) {
for (BTRegression tree : trees.get(k)) {
RFit rr = tree.fit(df, false);
for (int i = 0; i < df.rowCount(); i++) {
double p = cr.firstDensity().value(i, k + 1);
p += shrinkage * rr.firstFit().value(i);
cr.firstDensity().setValue(i, k + 1, p);
}
}
}
// make probabilities
for (int i = 0; i < df.rowCount(); i++) {
double t = 0.0;
for (int k = 0; k < K; k++) {
t += Math.exp(cr.firstDensity().value(i, k + 1));
}
if (t != 0) {
for (int k = 0; k < K; k++) {
cr.firstDensity().setValue(i, k + 1, Math.exp(cr.firstDensity().value(i, k + 1)) / t);
}
}
}
for (int i = 0; i < df.rowCount(); i++) {
int maxIndex = 0;
double maxValue = Double.NEGATIVE_INFINITY;
double total = 0;
for (int k = 0; k < K; k++) {
if (cr.firstDensity().value(i, k + 1) > maxValue) {
maxValue = cr.firstDensity().value(i, k + 1);
maxIndex = k + 1;
}
total += cr.firstDensity().value(i, k + 1);
}
// this does not work directly since we have also negative scores
// why is that happening?
// for (int k = 0; k < K; k++) {
// double p = cr.firstDensity().value(i, k + 1);
// p /= total;
// cr.firstDensity().setValue(i, k + 1, p);
// }
cr.firstClasses().setIndex(i, maxIndex);
}
return cr;
}
}
class ClassifierLossFunction implements GBTLossFunction {
private static final long serialVersionUID = -2622054975826334290L;
private final double K;
public ClassifierLossFunction(int K) {
this.K = K;
}
@Override
public String name() {
return "ClassifierLossFunction";
}
@Override
public double findMinimum(Var y, Var fx) {
// this must implement
double up = 0.0;
double down = 0.0;
for (int i = 0; i < y.rowCount(); i++) {
up += y.value(i);
down += Math.abs(y.value(i)) * (1.0 - Math.abs(y.value(i)));
}
if (down == 0) {
return 0;
}
if (Double.isNaN(up) || Double.isNaN(down)) {
return 0;
}
return ((K - 1) * up) / (K * down);
}
@Override
public Numeric gradient(Var y, Var fx) {
return null;
}
}