/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.boost;
import org.junit.Test;
import rapaio.core.SamplingTools;
import rapaio.data.Frame;
import rapaio.data.Numeric;
import rapaio.datasets.Datasets;
import rapaio.ml.classifier.Classifier;
import rapaio.ml.classifier.tree.CTree;
import rapaio.ml.eval.Confusion;
import rapaio.printer.IdeaPrinter;
import rapaio.sys.WS;
import java.io.IOException;
import java.net.URISyntaxException;
import static rapaio.graphics.Plotter.color;
public class AdaBoostSAMMETest {
@Test
public void testBuild() throws IOException, URISyntaxException {
WS.setPrinter(new IdeaPrinter());
Classifier ab = new AdaBoostSAMME()
.withClassifier(CTree.newC45().withMinCount(5).withMaxDepth(3).withMCols(5))
.withRuns(20);
Frame df = Datasets.loadSpamBase();
df.printSummary();
int[] rows = SamplingTools.sampleWOR(df.rowCount(), df.rowCount() / 2);
Frame tr = df.mapRows(rows);
Frame te = df.removeRows(rows);
String target = "spam";
Numeric runs = Numeric.empty().withName("runs");
Numeric errTr = Numeric.empty().withName("tr");
Numeric errTe = Numeric.empty().withName("te");
ab.withRunningHook((c, run) -> {
runs.addValue(run);
errTr.addValue(new Confusion(tr.var(target), ab.fit(tr).classes(target)).error());
errTe.addValue(new Confusion(te.var(target), ab.fit(te).classes(target)).error());
// WS.draw(
// plot(color(3))
// .lines(runs, errTr, color(1))
// .lines(runs, errTe, color(2))
// .yLim(0, Double.NaN));
});
ab.train(tr, target);
ab.printSummary();
new Confusion(tr.var(target), ab.fit(tr).firstClasses()).printSummary();
}
}