/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.tree; import rapaio.core.tools.DVector; import rapaio.data.Frame; import rapaio.data.Var; import rapaio.data.stream.FSpot; import rapaio.ml.common.VarSelector; import rapaio.util.Pair; import rapaio.util.func.SPredicate; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; /** * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>. */ public class CTreeNode implements Serializable { private static final long serialVersionUID = -5045581827808911763L; private final CTreeNode parent; private final String groupName; private final SPredicate<FSpot> predicate; private final List<CTreeNode> children = new ArrayList<>(); private int id; private boolean leaf = true; private DVector density; private DVector counter; private int bestIndex; private CTreeCandidate bestCandidate; public CTreeNode(final CTreeNode parent, final String groupName, final SPredicate<FSpot> predicate) { this.parent = parent; this.groupName = groupName; this.predicate = predicate; } public CTreeNode getParent() { return parent; } public int getId() { return id; } public String getGroupName() { return groupName; } public Predicate<FSpot> getPredicate() { return predicate; } public DVector getCounter() { return counter; } public int getBestIndex() { return bestIndex; } public DVector getDensity() { return density; } public boolean isLeaf() { return leaf; } public List<CTreeNode> getChildren() { return children; } public CTreeCandidate getBestCandidate() { return bestCandidate; } public int fillId(int index) { id = index; int next = index; for (CTreeNode child : getChildren()) { next = child.fillId(next + 1); } return next; } public void cut() { leaf = true; children.clear(); } public void learn(CTree tree, Frame df, Var weights, int depth) { density = DVector.fromWeights(false, df.var(tree.firstTargetName()), weights); counter = DVector.fromCount(false, df.var(tree.firstTargetName())); bestIndex = density.findBestIndex(); if (df.rowCount() == 0) { bestIndex = parent.bestIndex; return; } if (counter.countValues(x -> x > 0) == 1 || depth < 1 || df.rowCount() <= tree.minCount()) { return; } VarSelector varSel = tree.varSelector(); String[] nextVarNames = varSel.nextAllVarNames(); List<CTreeCandidate> candidateList = new ArrayList<>(); Queue<String> exhaustList = new ConcurrentLinkedQueue<>(); if (tree.runPoolSize() == 0) { int m = varSel.mCount(); for (String testCol : nextVarNames) { if (m <= 0) { continue; } if (testCol.equals(tree.firstTargetName())) { continue; } CTreePurityTest test = null; if (tree.customTestMap().containsKey(testCol)) { test = tree.customTestMap().get(testCol); } if (tree.testMap().containsKey(df.var(testCol).type())) { test = tree.testMap().get(df.var(testCol).type()); } if (test == null) { throw new IllegalArgumentException("can't train ctree with no " + "tests for given variable: " + df.var(testCol).name() + " [" + df.var(testCol).type().name() + "]"); } CTreeCandidate candidate = test.computeCandidate( tree, df, weights, testCol, tree.firstTargetName(), tree.getFunction()); if (candidate != null) { candidateList.add(candidate); m--; } else { exhaustList.add(testCol); } } } else { int m = varSel.mCount(); int start = 0; while (m > 0 && start < nextVarNames.length) { List<CTreeCandidate> next = IntStream.range(start, Math.min(nextVarNames.length, start + m)) .parallel() .mapToObj(i -> nextVarNames[i]) .filter(testCol -> !testCol.equals(tree.firstTargetName())) .map(testCol -> { CTreePurityTest test = null; if (tree.customTestMap().containsKey(testCol)) { test = tree.customTestMap().get(testCol); } if (tree.testMap().containsKey(df.var(testCol).type())) { test = tree.testMap().get(df.var(testCol).type()); } if (test == null) { throw new IllegalArgumentException("can't train ctree with no " + "tests for given variable: " + df.var(testCol).name() + " [" + df.var(testCol).type().name() + "]"); } CTreeCandidate candidate = test.computeCandidate( tree, df, weights, testCol, tree.firstTargetName(), tree.getFunction()); if (candidate == null) { exhaustList.add(testCol); } return candidate; }) .filter(c -> c != null) .collect(Collectors.toList()); candidateList.addAll(next); start += m; m -= next.size(); } } Collections.sort(candidateList); if (candidateList.isEmpty() || candidateList.get(0).getGroupNames().isEmpty()) { return; } // leave as leaf if the gain is not bigger than minimum gain if(candidateList.get(0).getScore()<= tree.minGain()) { return; } leaf = false; bestCandidate = candidateList.get(0); String testName = bestCandidate.getTestName(); // now that we have a best candidate, do the effective split Pair<List<Frame>, List<Var>> frames = tree.getMissingHandler().performSplit(df, weights, bestCandidate); for (int i = 0; i < bestCandidate.getGroupNames().size(); i++) { CTreeNode child = new CTreeNode(this, bestCandidate.getGroupNames().get(i), bestCandidate.getGroupPredicates().get(i)); children.add(child); } tree.varSelector().removeVarNames(exhaustList); for (int i = 0; i < children.size(); i++) { children.get(i).learn(tree, frames._1.get(i), frames._2.get(i), depth - 1); } tree.varSelector().addVarNames(exhaustList); } }