/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.tree; import rapaio.core.tools.DVector; import rapaio.data.Frame; import rapaio.data.Mapping; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.data.filter.FFilter; import rapaio.data.stream.FSpot; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.common.Capabilities; import rapaio.ml.common.VarSelector; import rapaio.sys.WS; import rapaio.util.FJPool; import rapaio.util.Pair; import rapaio.util.Tag; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import static java.util.stream.Collectors.joining; /** * Tree classifier. * * @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> */ public class CTree extends AbstractClassifier { private static final long serialVersionUID = 1203926824359387358L; // parameter default values private int minCount = 1; private int maxDepth = -1; private double minGain = -1000; private VarSelector varSelector = VarSelector.ALL; private Map<String, CTreePurityTest> customTestMap = new HashMap<>(); private Map<VarType, CTreePurityTest> testMap = new HashMap<>(); private CTreePurityFunction function = CTreePurityFunction.InfoGain; private CTreeMissingHandler splitter = CTreeMissingHandler.Ignored; private Tag<CTreePruning> pruning = CTreePruning.NONE; private Frame pruningDf = null; // tree root node private CTreeNode root; private transient Map<CTreeNode, Map<String, Mapping>> sortingCache = new HashMap<>(); // static builders public CTree() { testMap.put(VarType.BINARY, CTreePurityTest.BinaryBinary); testMap.put(VarType.ORDINAL, CTreePurityTest.NumericBinary); testMap.put(VarType.INDEX, CTreePurityTest.NumericBinary); testMap.put(VarType.NUMERIC, CTreePurityTest.NumericBinary); testMap.put(VarType.NOMINAL, CTreePurityTest.NominalBinary); withRuns(0); } public static CTree newID3() { return new CTree() .withMaxDepth(-1) .withMinCount(1) .withVarSelector(VarSelector.ALL) .withMissingHandler(CTreeMissingHandler.Ignored) .withTest(VarType.NOMINAL, CTreePurityTest.NominalFull) .withTest(VarType.NUMERIC, CTreePurityTest.Ignore) .withFunction(CTreePurityFunction.InfoGain) .withPruning(CTreePruning.NONE); } public static CTree newC45() { return new CTree() .withMaxDepth(-1) .withMinCount(1) .withVarSelector(VarSelector.ALL) .withMissingHandler(CTreeMissingHandler.ToAllWeighted) .withTest(VarType.NOMINAL, CTreePurityTest.NominalFull) .withTest(VarType.NUMERIC, CTreePurityTest.NumericBinary) .withFunction(CTreePurityFunction.GainRatio); } public static CTree newDecisionStump() { return new CTree() .withMaxDepth(1) .withMinCount(1) .withVarSelector(VarSelector.ALL) .withMissingHandler(CTreeMissingHandler.ToAllWeighted) .withFunction(CTreePurityFunction.GainRatio) .withTest(VarType.NOMINAL, CTreePurityTest.NominalBinary) .withTest(VarType.NUMERIC, CTreePurityTest.NumericBinary); } public static CTree newCART() { return new CTree() .withMaxDepth(-1) .withMinCount(1) .withVarSelector(VarSelector.ALL) .withMissingHandler(CTreeMissingHandler.ToAllWeighted) .withTest(VarType.NOMINAL, CTreePurityTest.NominalBinary) .withTest(VarType.NUMERIC, CTreePurityTest.NumericBinary) .withTest(VarType.INDEX, CTreePurityTest.NumericBinary) .withFunction(CTreePurityFunction.GiniGain); } @Override public CTree newInstance() { CTree tree = (CTree) new CTree() .withMinCount(minCount) .withMinGain(minGain) .withMaxDepth(maxDepth) .withFunction(function) .withMissingHandler(splitter) .withVarSelector(varSelector().newInstance()) .withRunningHook(runningHook()) .withSampler(sampler()); tree.withRunPoolSize(runPoolSize()); tree.withRuns(runs()); tree.testMap.clear(); tree.testMap.putAll(testMap); tree.customTestMap.clear(); tree.customTestMap.putAll(customTestMap); return tree; } public CTreeNode getRoot() { return root; } public VarSelector varSelector() { return varSelector; } public CTree withMCols(int mcols) { this.varSelector = new VarSelector(mcols); return this; } public CTree withVarSelector(VarSelector varSelector) { this.varSelector = varSelector; return this; } public int minCount() { return minCount; } public CTree withMinCount(int minCount) { if (minCount < 1) { throw new IllegalArgumentException("min cont must be an integer positive number"); } this.minCount = minCount; return this; } public double minGain() { return minGain; } public CTree withMinGain(double minGain) { this.minGain = minGain; return this; } public int maxDepth() { return maxDepth; } public CTree withMaxDepth(int maxDepth) { this.maxDepth = maxDepth; return this; } public CTree withTest(VarType varType, CTreePurityTest test) { this.testMap.put(varType, test); return this; } public CTree withTest(String varName, CTreePurityTest test) { this.customTestMap.put(varName, test); return this; } public Map<VarType, CTreePurityTest> testMap() { return testMap; } public Map<String, CTreePurityTest> customTestMap() { return customTestMap; } public CTree withNoTests() { this.testMap.clear(); return this; } public CTree withPruning(Tag<CTreePruning> pruning) { return withPruning(pruning, null); } public CTree withPruning(Tag<CTreePruning> pruning, Frame pruningDf) { this.pruning = pruning; this.pruningDf = pruningDf; return this; } public CTreePurityFunction getFunction() { return function; } public CTree withFunction(CTreePurityFunction function) { this.function = function; return this; } public CTreeMissingHandler getMissingHandler() { return splitter; } public CTree withMissingHandler(CTreeMissingHandler splitter) { this.splitter = splitter; return this; } @Override public String name() { return "CTree"; } @Override public String fullName() { StringBuilder sb = new StringBuilder(); sb.append("CTree {"); sb.append("varSelector=").append(varSelector().name()).append(";"); sb.append("minCount=").append(minCount).append(";"); sb.append("maxDepth=").append(maxDepth).append(";"); sb.append("tests=").append(testMap.entrySet().stream() .map(e -> e.getKey().name() + ":" + e.getValue().name()).collect(joining(",")) ).append(";"); if (!customTestMap.isEmpty()) sb.append("customTest=").append(customTestMap.entrySet().stream() .map(e -> e.getKey() + ":" + e.getValue().name()).collect(joining(",")) ).append(";"); sb.append("func=").append(function.name()).append(";"); sb.append("split=").append(splitter.name()).append(";"); sb.append("}"); return sb.toString(); } @Override public Capabilities capabilities() { return new Capabilities() .withInputTypes(VarType.NOMINAL, VarType.INDEX, VarType.NUMERIC, VarType.BINARY) .withInputCount(1, 1_000_000) .withAllowMissingInputValues(true) .withTargetTypes(VarType.NOMINAL) .withTargetCount(1, 1) .withAllowMissingTargetValues(false); } @Override protected boolean coreTrain(Frame df, Var weights) { additionalValidation(df); this.varSelector.withVarNames(inputNames()); int rows = df.rowCount(); root = new CTreeNode(null, "root", spot -> true); if (runPoolSize() == 0) { root.learn(this, df, weights, maxDepth() < 0 ? Integer.MAX_VALUE : maxDepth()); } else { FJPool.run(runPoolSize(), () -> root.learn(this, df, weights, maxDepth < 0 ? Integer.MAX_VALUE : maxDepth)); } this.root.fillId(1); pruning.get().prune(this, (pruningDf == null) ? df : pruningDf, false); return true; } public void prune(Frame df) { prune(df, false); } public void prune(Frame df, boolean all) { pruning.get().prune(this, df, all); } @Override protected CFit coreFit(Frame df, boolean withClasses, boolean withDensities) { CFit prediction = CFit.build(this, df, withClasses, withDensities); df.stream().forEach(spot -> { Pair<Integer, DVector> res = fitPoint(this, spot, root); int index = res._1; DVector dv = res._2; if (withClasses) prediction.firstClasses().setIndex(spot.row(), index); if (withDensities) for (int j = 0; j < firstTargetLevels().length; j++) { prediction.firstDensity().setValue(spot.row(), j, dv.get(j)); } }); return prediction; } protected Pair<Integer, DVector> fitPoint(CTree tree, FSpot spot, CTreeNode node) { if (node.isLeaf()) return Pair.from(node.getBestIndex(), node.getDensity().solidCopy().normalize()); for (CTreeNode child : node.getChildren()) { if (child.getPredicate().test(spot)) { return this.fitPoint(tree, spot, child); } } String[] dict = tree.firstTargetLevels(); DVector dv = DVector.empty(false, dict); double w = 0.0; for (CTreeNode child : node.getChildren()) { DVector d = this.fitPoint(tree, spot, child)._2; double wc = child.getDensity().sum(); dv.increment(d, wc); w += wc; } for (int i = 0; i < dict.length; i++) { dv.set(i, dv.get(i) / w); } return Pair.from(dv.findBestIndex(), dv); } private void additionalValidation(Frame df) { df.varStream().forEach(var -> { if (customTestMap.containsKey(var.name())) return; if (testMap.containsKey(var.type())) return; throw new IllegalArgumentException("can't train ctree with no " + "tests for given variable: " + var.name() + " [" + var.type().name() + "]"); }); } public int countNodes(boolean onlyLeaves) { int count = 0; LinkedList<CTreeNode> nodes = new LinkedList<>(); nodes.addLast(root); while (!nodes.isEmpty()) { CTreeNode node = nodes.pollFirst(); count += onlyLeaves ? (node.isLeaf() ? 1 : 0) : 1; node.getChildren().forEach(nodes::addLast); } return count; } @Override public CTree withInputFilters(List<FFilter> filters) { return (CTree) super.withInputFilters(filters); } @Override public String summary() { StringBuilder sb = new StringBuilder(); sb.append("CTree model\n"); sb.append("================\n\n"); sb.append("Description:\n"); sb.append(fullName().replaceAll(";", ";\n")).append("\n\n"); sb.append("Capabilities:\n"); sb.append(capabilities().summary()).append("\n"); sb.append("Learned model:\n"); if (!hasLearned()) { sb.append("Learning phase not called\n\n"); return sb.toString(); } sb.append(baseSummary()); sb.append("\n"); int nodeCount = 0; int leaveCount = 0; LinkedList<CTreeNode> queue = new LinkedList<>(); queue.add(root); while (!queue.isEmpty()) { CTreeNode node = queue.pollFirst(); nodeCount++; if (node.isLeaf()) leaveCount++; node.getChildren().forEach(queue::addLast); } sb.append("total number of nodes: ").append(nodeCount).append("\n"); sb.append("total number of leaves: ").append(leaveCount).append("\n"); sb.append("description:\n"); sb.append("split, n/err, classes (densities) [* if is leaf / purity if not]\n\n"); buildSummary(sb, root, 0); return sb.toString(); } private void buildSummary(StringBuilder sb, CTreeNode node, int level) { sb.append(level == 0 ? "|- " : "|"); for (int i = 0; i < level; i++) { sb.append((i == level - 1) ? " |- " : " |"); } sb.append(node.getId()).append(". ").append(node.getGroupName()).append(" "); sb.append(WS.formatFlexShort(node.getCounter().sum())).append("/"); sb.append(WS.formatFlexShort(node.getCounter().sumExcept(node.getBestIndex()))).append(" "); sb.append(firstTargetLevels()[node.getBestIndex()]).append(" ("); DVector d = node.getDensity().solidCopy().normalize(); for (int i = 1; i < firstTargetLevels().length; i++) { sb.append(WS.formatFlexShort(d.get(i))).append(" "); } sb.append(") "); if (node.isLeaf()) { sb.append("*"); } else { sb.append("[").append(WS.formatFlex(node.getBestCandidate().getScore())).append("]"); } sb.append("\n"); node.getChildren().stream().forEach(child -> buildSummary(sb, child, level + 1)); } }