/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.tree.ctree;
import org.junit.Test;
import rapaio.data.Frame;
import rapaio.data.VarType;
import rapaio.data.filter.frame.FFRetainTypes;
import rapaio.datasets.Datasets;
import rapaio.ml.classifier.CFit;
import rapaio.ml.classifier.tree.CTree;
import rapaio.ml.classifier.tree.CTreeCandidate;
import rapaio.ml.classifier.tree.CTreeNode;
import rapaio.printer.Summary;
import java.io.IOException;
import java.net.URISyntaxException;
import static org.junit.Assert.*;
/**
* Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>.
*/
public class CTreeTest {
@Test
public void testBuilderDecisionStump() throws IOException, URISyntaxException {
Frame df = Datasets.loadIrisDataset();
CTree tree = CTree.newDecisionStump();
assertEquals(1, tree.maxDepth());
tree.train(df, "class");
tree.printSummary();
CTreeNode root = tree.getRoot();
assertEquals("root", root.getGroupName());
String testName = root.getBestCandidate().getTestName();
if ("petal-width".equals(testName)) {
assertEquals("petal-width", root.getBestCandidate().getTestName());
assertEquals("petal-width <= 0.8", root.getBestCandidate().getGroupNames().get(0));
assertEquals("petal-width > 0.8", root.getBestCandidate().getGroupNames().get(1));
} else {
assertEquals("petal-length", root.getBestCandidate().getTestName());
assertEquals("petal-length <= 2.45", root.getBestCandidate().getGroupNames().get(0));
assertEquals("petal-length > 2.45", root.getBestCandidate().getGroupNames().get(1));
}
}
@Test
public void testBuilderID3() throws IOException, URISyntaxException {
Frame df = Datasets.loadMushrooms();
Summary.printNames(df);
df = new FFRetainTypes(VarType.NOMINAL).fitApply(df);
Summary.printSummary(df);
}
@Test
public void testCandidate() {
CTreeCandidate candidate = new CTreeCandidate(1, "test");
candidate.addGroup("test <= 0", s -> s.value("test") <= 0);
candidate.addGroup("test > 0", s -> s.value("test") > 0);
assertEquals(1, candidate.compareTo(new CTreeCandidate(2, "test")));
assertEquals(-1, candidate.compareTo(new CTreeCandidate(-2, "test")));
assertEquals(-1, candidate.compareTo(new CTreeCandidate(0.5, "test")));
try {
candidate.addGroup("test <= 0", s -> true);
assertTrue("should raise an exception", false);
} catch (IllegalArgumentException ignored) {
}
}
@Test
public void testPredictorStandard() throws IOException, URISyntaxException {
Frame df = Datasets.loadIrisDataset();
CTree tree = CTree.newCART().withMaxDepth(10000).withMinCount(1);
tree.train(df, "class");
tree.printSummary();
CFit pred = tree.fit(df, true, true);
df = df.bindVars(pred.firstClasses().solidCopy().withName("fit"));
Frame match = df.stream().filter(spot -> spot.index("class") == spot.index("fit")).toMappedFrame();
assertEquals(150, match.rowCount());
df.setMissing(0, 0);
df.setMissing(0, 1);
df.setMissing(0, 2);
df.setMissing(0, 3);
tree.fit(df, true, false);
match = df.stream().filter(spot -> spot.index("class") == spot.index("fit")).toMappedFrame();
assertEquals(150, match.rowCount());
}
}