/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.regression.tree; import rapaio.core.stat.WeightedMean; import rapaio.data.Frame; import rapaio.data.Mapping; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.data.stream.FSpot; import rapaio.ml.common.Capabilities; import rapaio.ml.common.VarSelector; import rapaio.ml.regression.AbstractRegression; import rapaio.ml.regression.RFit; import rapaio.experiment.ml.regression.boost.gbt.BTRegression; import rapaio.experiment.ml.regression.boost.gbt.GBTLossFunction; import rapaio.util.Pair; import rapaio.util.func.SPredicate; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.stream.IntStream; import static rapaio.sys.WS.formatFlex; /** * Implements a regression tree. * <p> * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/24/14. */ @Deprecated public class RTree extends AbstractRegression implements BTRegression { private static final long serialVersionUID = -2748764643670512376L; int minCount = 1; int maxDepth = -1; RTreeNominalMethod nominalMethod = RTreeNominalMethod.BINARY; RTreeNumericMethod numericMethod = RTreeNumericMethod.BINARY; RTreeTestFunction function = RTreeTestFunction.WeightedVarGain; RTreeSplitter splitter = RTreeSplitter.REMAINS_IGNORED; RTreePredictor predictor = RTreePredictor.STANDARD; VarSelector varSelector = VarSelector.ALL; // tree root node private RTreeNode root; private int rows; private RTree() { } public static RTree buildDecisionStump() { return new RTree() .withMaxDepth(2) .withNominalMethod(RTreeNominalMethod.BINARY) .withNumericMethod(RTreeNumericMethod.BINARY) .withSplitter(RTreeSplitter.REMAINS_TO_MAJORITY) ; } public static RTree buildC45() { return new RTree() .withMaxDepth(-1) .withNominalMethod(RTreeNominalMethod.FULL) .withNumericMethod(RTreeNumericMethod.BINARY) .withSplitter(RTreeSplitter.REMAINS_TO_RANDOM) .withMinCount(2) ; } public static RTree buildCART() { return new RTree() .withMaxDepth(-1) .withNominalMethod(RTreeNominalMethod.BINARY) .withNumericMethod(RTreeNumericMethod.BINARY) .withSplitter(RTreeSplitter.REMAINS_TO_RANDOM) .withFunction(RTreeTestFunction.WeightedSdGain) .withMinCount(1); } @Override public BTRegression newInstance() { return new RTree() .withMinCount(minCount) .withNumericMethod(numericMethod) .withNominalMethod(nominalMethod) .withMaxDepth(maxDepth) .withSplitter(splitter) .withFunction(function) .withVarSelector(varSelector); } @Override public String name() { return "RTree"; } @Override public String fullName() { StringBuilder sb = new StringBuilder(); sb.append("TreeClassifier {"); sb.append(" varSelector=").append(varSelector.name()).append(",\n"); sb.append(" minCount=").append(minCount).append(",\n"); sb.append(" maxDepth=").append(maxDepth).append(",\n"); sb.append(" numericMethod=").append(numericMethod.name()).append(",\n"); sb.append(" nominalMethod=").append(nominalMethod.name()).append(",\n"); sb.append(" function=").append(function.name()).append(",\n"); sb.append(" splitter=").append(splitter.name()).append(",\n"); sb.append(" predictor=").append(predictor.name()).append("\n"); sb.append("}"); return sb.toString(); } @Override public Capabilities capabilities() { return new Capabilities() .withInputCount(1, 1_000_000) .withTargetCount(1, 1) .withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NUMERIC, VarType.ORDINAL, VarType.NOMINAL) .withTargetTypes(VarType.NUMERIC) .withAllowMissingInputValues(true) .withAllowMissingTargetValues(false); } @Override public void boostFit(Frame x, Var y, Var fx, GBTLossFunction lossFunction) { root.boostFit(x, y, fx, lossFunction); } public RTree withVarSelector(VarSelector varSelector) { this.varSelector = varSelector; return this; } public RTree withMinCount(int minCount) { this.minCount = minCount; return this; } public RTree withMaxDepth(int maxDepth) { this.maxDepth = maxDepth; return this; } public RTree withNumericMethod(RTreeNumericMethod numericMethod) { this.numericMethod = numericMethod; return this; } public RTree withNominalMethod(RTreeNominalMethod nominalMethod) { this.nominalMethod = nominalMethod; return this; } public RTree withFunction(RTreeTestFunction function) { this.function = function; return this; } public RTree withSplitter(RTreeSplitter splitter) { this.splitter = splitter; return this; } @Override protected boolean coreTrain(Frame df, Var weights) { if (targetNames().length == 0) { throw new IllegalArgumentException("tree classifier must specify a target variable"); } if (targetNames().length > 1) { throw new IllegalArgumentException("tree classifier can't fit more than one target variable"); } rows = df.rowCount(); root = new RTreeNode(null, "root", spot -> true); this.varSelector.withVarNames(inputNames()); root.learn(this, df, weights, maxDepth < 0 ? Integer.MAX_VALUE : maxDepth); return true; } @Override protected RFit coreFit(Frame df, boolean withResiduals) { RFit pred = RFit.build(this, df, withResiduals); df.stream().forEach(spot -> { Pair<Double, Double> result = predictor.predict(this, spot, root); pred.fit(firstTargetName()).setValue(spot.row(), result._1); }); pred.buildComplete(); return pred; } @Override public String summary() { StringBuilder sb = new StringBuilder(); sb.append("\n > ").append(fullName()).append("\n"); sb.append(String.format("n=%d\n", rows)); sb.append("\n"); sb.append("description:\n"); sb.append("split, mean (total weight) [* if is leaf]\n\n"); buildSummary(sb, root, 0); return sb.toString(); } private void buildSummary(StringBuilder sb, RTreeNode node, int level) { sb.append("|"); for (int i = 0; i < level; i++) { sb.append(" |"); } sb.append(node.getGroupName()).append(" "); sb.append(formatFlex(node.getValue())); sb.append(" (").append(formatFlex(node.getWeight())).append(") "); if (node.isLeaf()) sb.append(" *"); sb.append("\n"); // children if (!node.isLeaf()) { node.getChildren().stream().forEach(child -> buildSummary(sb, child, level + 1)); } } /** * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/24/14. */ @Deprecated public static class RTreeNode implements Serializable { private static final long serialVersionUID = 385363626560575837L; private final RTreeNode parent; private final String groupName; private final SPredicate<FSpot> predicate; private boolean leaf = true; private double value; private double weight; private List<RTreeNode> children; private RTreeCandidate bestCandidate; public RTreeNode(final RTreeNode parent, final String groupName, final SPredicate<FSpot> predicate) { this.parent = parent; this.groupName = groupName; this.predicate = predicate; } public RTreeNode getParent() { return parent; } public String getGroupName() { return groupName; } public SPredicate<FSpot> getPredicate() { return predicate; } public boolean isLeaf() { return leaf; } public List<RTreeNode> getChildren() { return children; } public RTreeCandidate getBestCandidate() { return bestCandidate; } public double getValue() { return value; } public double getWeight() { return weight; } public void learn(RTree tree, Frame df, Var weights, int depth) { value = WeightedMean.from(df.var(tree.firstTargetName()), weights).value(); weight = weights.stream().complete().mapToDouble().sum(); if (weight == 0) { // WS.println("ERROR"); value = parent!=null ? parent.value : Double.NaN; } if (df.rowCount() == 0 || df.rowCount() <= tree.minCount || depth <= 1) { return; } List<RTreeCandidate> candidateList = new ArrayList<>(); ConcurrentLinkedQueue<RTreeCandidate> candidates = new ConcurrentLinkedQueue<>(); Arrays.stream(tree.varSelector.nextVarNames()).parallel().forEach(testCol -> { if (testCol.equals(tree.firstTargetName())) return; if (df.var(testCol).type().isNumeric()) { tree.numericMethod.computeCandidates( tree, df, weights, testCol, tree.firstTargetName(), tree.function) .forEach(candidates::add); } else { tree.nominalMethod.computeCandidates( tree, df, weights, testCol, tree.firstTargetName(), tree.function) .forEach(candidates::add); } }); candidateList.addAll(candidates); Collections.sort(candidateList); if (candidateList.isEmpty()) { return; } leaf = false; bestCandidate = candidateList.get(0); // now that we have a best candidate,do the effective split if (bestCandidate.getGroupNames().isEmpty()) { leaf = true; return; } Pair<List<Frame>, List<Var>> frames = tree.splitter.performSplit(df, weights, bestCandidate); children = new ArrayList<>(frames._1.size()); for (int i = 0; i < frames._1.size(); i++) { RTreeNode child = new RTreeNode(this, bestCandidate.getGroupNames().get(i), bestCandidate.getGroupPredicates().get(i)); children.add(child); child.learn(tree, frames._1.get(i), frames._2.get(i), depth - 1); } } public void boostFit(Frame x, Var y, Var fx, GBTLossFunction lossFunction) { if (leaf) { value = lossFunction.findMinimum(y, fx); return; } Mapping[] mapping = IntStream .range(0, children.size()).boxed() .map(i -> Mapping.empty()).toArray(Mapping[]::new); x.stream().forEach(spot -> { for (int i = 0; i < children.size(); i++) { RTreeNode child = children.get(i); if (child.predicate.test(spot)) { mapping[i].add(spot.row()); return; } } }); for (int i = 0; i < children.size(); i++) { children.get(i).boostFit(x.mapRows(mapping[i]), y.mapRows(mapping[i]), fx.mapRows(mapping[i]), lossFunction); } } } /** * RTree split candidate. * <p> * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/24/14. */ public static class RTreeCandidate implements Comparable<RTreeCandidate>, Serializable { private static final long serialVersionUID = 6698766675237089849L; private final double score; private final String testName; private final List<String> groupNames = new ArrayList<>(); private final List<SPredicate<FSpot>> groupPredicates = new ArrayList<>(); public RTreeCandidate(double score, String testName) { this.score = score; this.testName = testName; } public void addGroup(String name, SPredicate<FSpot> predicate) { if (groupNames.contains(name)) { throw new IllegalArgumentException("group name already defined"); } groupNames.add(name); groupPredicates.add(predicate); } public List<String> getGroupNames() { return groupNames; } public List<SPredicate<FSpot>> getGroupPredicates() { return groupPredicates; } public double getScore() { return score; } public String getTestName() { return testName; } @Override public int compareTo(RTreeCandidate o) { if (o == null) return 1; return -Double.compare(score, o.score); } } }