/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.boost;
import rapaio.data.Frame;
import rapaio.data.Var;
import rapaio.data.VarType;
import rapaio.data.filter.FFilter;
import rapaio.data.sample.Sample;
import rapaio.data.sample.RowSampler;
import rapaio.ml.classifier.AbstractClassifier;
import rapaio.ml.classifier.CFit;
import rapaio.ml.classifier.Classifier;
import rapaio.ml.classifier.tree.CTree;
import rapaio.ml.common.Capabilities;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
/**
* AdaBoost SAMME classifier is the classical version of AdaBoost which has
* the correction which works for classification with multiple
* labels.
* <p>
* User: Aurelian Tutuianu <paderati@yahoo.com>
*/
public class AdaBoostSAMME extends AbstractClassifier {
private static final long serialVersionUID = -9154973036108114765L;
private static final double delta_error = 10e-10;
// parameters
private Classifier weak = CTree.newCART().withMaxDepth(6).withMinCount(6);
private boolean stopOnError = false;
private double shrinkage = 1.0;
// model artifacts
private List<Double> a;
private List<Classifier> h;
private Var w;
private double k;
public AdaBoostSAMME() {
this.a = new ArrayList<>();
this.h = new ArrayList<>();
withRuns(10);
}
@Override
public AdaBoostSAMME newInstance() {
return (AdaBoostSAMME) new AdaBoostSAMME()
.withInputFilters(inputFilters())
.withClassifier(this.weak.newInstance())
.withStopOnError(stopOnError)
.withShrinkage(shrinkage)
.withSampler(sampler())
.withRuns(runs())
.withRunningHook(runningHook())
.withRunPoolSize(runPoolSize());
}
@Override
public String name() {
return "AdaBoost.SAMME";
}
@Override
public String fullName() {
StringBuilder sb = new StringBuilder();
sb.append("AdaBoost.SAMME {");
sb.append("weak: ").append(weak.fullName()).append(", ");
sb.append("runs: ").append(runs()).append(", ");
sb.append("sampler: ").append(sampler().name()).append(", ");
sb.append("stopOnError: ").append(stopOnError).append(", ");
sb.append("}");
return sb.toString();
}
@Override
public Capabilities capabilities() {
return new Capabilities()
.withInputTypes(VarType.NUMERIC, VarType.NOMINAL, VarType.INDEX, VarType.BINARY)
.withInputCount(1, 10_000)
.withAllowMissingInputValues(true)
.withTargetTypes(VarType.NOMINAL)
.withTargetCount(1, 1)
.withAllowMissingTargetValues(false);
}
public AdaBoostSAMME withClassifier(Classifier weak) {
this.weak = weak;
return this;
}
public AdaBoostSAMME withSampler(RowSampler sampler) {
return (AdaBoostSAMME) super.withSampler(sampler);
}
public AdaBoostSAMME withStopOnError(boolean stopOnError) {
this.stopOnError = stopOnError;
return this;
}
public AdaBoostSAMME withShrinkage(double shrinkage) {
this.shrinkage = shrinkage;
return this;
}
@Override
protected boolean coreTrain(Frame df, Var weights) {
k = firstTargetLevels().length - 1;
h = new ArrayList<>();
a = new ArrayList<>();
w = weights.solidCopy();
double total = w.stream().mapToDouble().reduce(0.0, (x, y) -> x + y);
for (int i = 0; i < w.rowCount(); i++) {
w.setValue(i, w.value(i) / total);
}
for (int i = 0; i < runs(); i++) {
boolean success = learnRound(df);
if (!success && stopOnError) {
break;
}
if (runningHook() != null) {
runningHook().accept(this, i + 1);
}
}
return true;
}
private boolean learnRound(Frame df) {
Classifier hh = weak.newInstance();
Sample sample = sampler().nextSample(df, w);
hh.train(sample.df, sample.weights.solidCopy(), targetNames());
CFit fit = hh.fit(df, true, false);
double err = 0;
for (int j = 0; j < df.rowCount(); j++) {
if (fit.firstClasses().index(j) != df.var(firstTargetName()).index(j)) {
err += w.value(j);
}
}
err /= w.stream().mapToDouble().sum();
double alpha = Math.log((1.0 - err) / err) + Math.log(k - 1.0);
if (err == 0) {
if (h.isEmpty()) {
h.add(hh);
a.add(alpha);
}
return false;
}
if (stopOnError && err > (1.0 - 1.0 / k) + delta_error) {
return false;
}
h.add(hh);
a.add(alpha);
for (int j = 0; j < w.rowCount(); j++) {
if (fit.firstClasses().index(j) != df.var(firstTargetName()).index(j)) {
w.setValue(j, w.value(j) * Math.exp(alpha * shrinkage));
}
}
double total = w.stream().mapToDouble().reduce(0.0, (x, y) -> x + y);
for (int i = 0; i < w.rowCount(); i++) {
w.setValue(i, w.value(i) / total);
}
return true;
}
@Override
protected CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) {
CFit fit = CFit.build(this, df, withClasses, true);
for (int i = 0; i < h.size(); i++) {
CFit hp = h.get(i).fit(df, true, false);
for (int j = 0; j < df.rowCount(); j++) {
int index = hp.firstClasses().index(j);
fit.firstDensity().setValue(j, index, fit.firstDensity().value(j, index) + a.get(i));
}
}
// simply fit
for (int i = 0; i < fit.firstDensity().rowCount(); i++) {
double max = 0;
int best = 0;
double total = 0;
for (int j = 1; j < fit.firstDensity().varCount(); j++) {
total += fit.firstDensity().value(i, j);
if (fit.firstDensity().value(i, j) > max) {
best = j;
max = fit.firstDensity().value(i, j);
}
}
for (int j = 1; j < fit.firstDensity().varCount(); j++) {
fit.firstDensity().setValue(i, j, fit.firstDensity().value(i, j) / total);
}
fit.firstClasses().setIndex(i, best);
}
return fit;
}
@Override
public AdaBoostSAMME withRuns(int runs) {
return (AdaBoostSAMME) super.withRuns(runs);
}
@Override
public AdaBoostSAMME withRunningHook(BiConsumer<Classifier, Integer> runningHook) {
return (AdaBoostSAMME) super.withRunningHook(runningHook);
}
@Override
public AdaBoostSAMME withInputFilters(FFilter... filters) {
return (AdaBoostSAMME) super.withInputFilters(filters);
}
@Override
public AdaBoostSAMME withInputFilters(List<FFilter> filters) {
return (AdaBoostSAMME) super.withInputFilters(filters);
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
sb.append("\n > ").append(fullName()).append("\n");
sb.append("prediction:\n");
sb.append("weak learners built: ").append(h.size()).append("\n");
return sb.toString();
}
}