/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.ensemble; import rapaio.core.CoreTools; import rapaio.core.distributions.Distribution; import rapaio.core.tools.DVector; import rapaio.data.*; import rapaio.data.filter.FFilter; import rapaio.data.filter.Filters; import rapaio.data.sample.Sample; import rapaio.data.sample.RowSampler; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.classifier.Classifier; import rapaio.ml.classifier.tree.CTree; import rapaio.ml.classifier.tree.CTreeNode; import rapaio.ml.common.Capabilities; import rapaio.ml.common.VarSelector; import rapaio.ml.eval.Confusion; import rapaio.util.Pair; import rapaio.util.Util; import java.util.*; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.IntStream; import static java.util.stream.Collectors.*; /** * Breiman random forest implementation. * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 4/16/15. */ public class CForest extends AbstractClassifier { private static final long serialVersionUID = -145958939373105497L; private boolean oobComp = false; private boolean freqVIComp = false; private boolean gainVIComp = false; private boolean permVIComp = false; private Classifier c = CTree.newCART(); private BaggingMode baggingMode = BaggingMode.DISTRIBUTION; // learning artifacts private double oobError = Double.NaN; private List<Classifier> predictors = new ArrayList<>(); private Map<Integer, DVector> oobDensities; private Var oobFit; private Var oobTrueClass; private Map<String, List<Double>> freqVIMap = new HashMap<>(); private Map<String, List<Double>> gainVIMap = new HashMap<>(); private Map<String, List<Double>> permVIMap = new HashMap<>(); private CForest() { withRuns(10); this.baggingMode = BaggingMode.DISTRIBUTION; this.c = CTree.newCART().withVarSelector(VarSelector.AUTO); this.oobComp = false; this.withSampler(RowSampler.bootstrap()); } public static CForest newRF() { return new CForest(); } @Override public String name() { return "CForest"; } @Override public String fullName() { StringBuilder sb = new StringBuilder(); sb.append(name()); sb.append("{"); sb.append("runs:").append(runs()).append(";"); sb.append("baggingMode:").append(baggingMode.name()).append(";"); sb.append("oob:").append(oobComp).append(";"); sb.append("sampler:").append(sampler().name()).append(";"); sb.append("tree:").append(c.fullName()); sb.append("}"); return sb.toString(); } @Override public Classifier newInstance() { return new CForest() .withRuns(runs()) .withInputFilters(inputFilters()) .withBaggingMode(baggingMode) .withOobComp(oobComp) .withFreqVIComp(freqVIComp) .withGainVIComp(gainVIComp) .withPermVIComp(permVIComp) .withClassifier(c.newInstance()) .withSampler(sampler()); } public CForest withRuns(int runs) { return (CForest) super.withRuns(runs); } public CForest withFreqVIComp(boolean freqVIComp) { this.freqVIComp = freqVIComp; return this; } public CForest withGainVIComp(boolean gainVIComp) { this.gainVIComp = gainVIComp; return this; } public CForest withPermVIComp(boolean permVIComp) { this.permVIComp = permVIComp; return this; } public CForest withOobComp(boolean oobCompute) { this.oobComp = oobCompute; return this; } public CForest withBaggingMode(BaggingMode baggingMode) { this.baggingMode = baggingMode; return this; } @Override public CForest withSampler(RowSampler sampler) { return (CForest) super.withSampler(sampler); } public CForest withClassifier(Classifier c) { this.c = c; return this; } public CForest withMCols(int mcols) { if (c instanceof CTree) { ((CTree) c).withMCols(mcols); } return this; } public CForest withVarSelector(VarSelector varSelector) { if (c instanceof CTree) { ((CTree) c).withVarSelector(varSelector); } return this; } @Override public Capabilities capabilities() { Capabilities cc = c.capabilities(); return new Capabilities() .withInputCount(cc.getMinInputCount(), cc.getMaxInputCount()) .withInputTypes(cc.getInputTypes().stream().toArray(VarType[]::new)) .withAllowMissingInputValues(cc.getAllowMissingInputValues()) .withTargetCount(1, 1) .withTargetTypes(VarType.NOMINAL) .withAllowMissingTargetValues(false); } public List<Classifier> getClassifiers() { return predictors; } public double getOobError() { return oobError; } public Confusion getOobInfo() { return new Confusion(oobTrueClass, oobFit); } public Frame getFreqVIInfo() { Var name = Nominal.empty().withName("name"); Var score = Numeric.empty().withName("score mean"); Var sd = Numeric.empty().withName("score sd"); for (Map.Entry<String, List<Double>> e : freqVIMap.entrySet()) { name.addLabel(e.getKey()); Numeric scores = Numeric.copy(e.getValue()); sd.addValue(CoreTools.var(scores).sdValue()); score.addValue(CoreTools.mean(scores).value()); } double maxScore = CoreTools.max(score).value(); Var scaled = Numeric.from(score.rowCount(), row -> 100.0 * score.value(row) / maxScore).withName("scaled score"); return Filters.refSort(SolidFrame.byVars(name, score, sd, scaled), score.refComparator(false)).solidCopy(); } public Frame getGainVIInfo() { Var name = Nominal.empty().withName("name"); Var score = Numeric.empty().withName("score mean"); Var sd = Numeric.empty().withName("score sd"); for (Map.Entry<String, List<Double>> e : gainVIMap.entrySet()) { name.addLabel(e.getKey()); Numeric scores = Numeric.copy(e.getValue()); sd.addValue(CoreTools.var(scores).sdValue()); score.addValue(CoreTools.mean(scores).value()); } double maxScore = CoreTools.max(score).value(); Var scaled = Numeric.from(score.rowCount(), row -> 100.0 * score.value(row) / maxScore).withName("scaled score"); return Filters.refSort(SolidFrame.byVars(name, score, sd, scaled), score.refComparator(false)).solidCopy(); } public Frame getPermVIInfo() { Var name = Nominal.empty().withName("name"); Var score = Numeric.empty().withName("score mean"); Var sds = Numeric.empty().withName("score sd"); Var zscores = Numeric.empty().withName("z-score"); Var pvalues = Numeric.empty().withName("p-value"); Distribution normal = CoreTools.distNormal(); for (Map.Entry<String, List<Double>> e : permVIMap.entrySet()) { name.addLabel(e.getKey()); Numeric scores = Numeric.copy(e.getValue()); double mean = CoreTools.mean(scores).value(); double sd = CoreTools.var(scores).sdValue(); double zscore = mean / (sd); double pvalue = normal.cdf(2 * normal.cdf(-Math.abs(zscore))); score.addValue(Math.abs(mean)); sds.addValue(sd); zscores.addValue(Math.abs(zscore)); pvalues.addValue(pvalue); } return Filters.refSort(SolidFrame.byVars(name, score, sds, zscores, pvalues), zscores.refComparator(false)).solidCopy(); } @Override protected boolean coreTrain(Frame df, Var weights) { double totalOobInstances = 0; double totalOobError = 0; if (oobComp) { oobDensities = new HashMap<>(); oobTrueClass = df.var(firstTargetName()).solidCopy(); oobFit = Nominal.empty(df.rowCount(), firstTargetLevels()); for (int i = 0; i < df.rowCount(); i++) { oobDensities.put(i, DVector.empty(false, firstTargetLevels())); } } if (freqVIComp && c instanceof CTree) { freqVIMap.clear(); } if (gainVIComp && c instanceof CTree) { gainVIMap.clear(); } if (permVIComp) { permVIMap.clear(); } if (runPoolSize() == 0) { predictors = new ArrayList<>(); for (int i = 0; i < runs(); i++) { Pair<Classifier, List<Integer>> weak = buildWeakPredictor(df, weights); predictors.add(weak._1); if (oobComp) { oobCompute(df, weak); } if (freqVIComp && c instanceof CTree) { freqVICompute(weak); } if (gainVIComp && c instanceof CTree) { gainVICompute(weak); } if (permVIComp) { permVICompute(df, weak); } if (runningHook() != null) { runningHook().accept(this, i + 1); } } } else { // build in parallel the trees, than oob and running hook cannot run at the // same moment when weak tree was built // for a real running hook behavior run without threading predictors = new ArrayList<>(); List<Pair<Classifier, List<Integer>>> list = Util.rangeStream(runs(), runPoolSize() > 0).boxed() .map(s -> buildWeakPredictor(df, weights)) .collect(Collectors.toList()); for (int i = 0; i < list.size(); i++) { Pair<Classifier, List<Integer>> weak = list.get(i); predictors.add(weak._1); if (oobComp) { oobCompute(df, weak); } if (freqVIComp && c instanceof CTree) { freqVICompute(weak); } if (gainVIComp && c instanceof CTree) { gainVICompute(weak); } if (permVIComp) { permVICompute(df, weak); } if (runningHook() != null) { runningHook().accept(this, i + 1); } } } return true; } private void permVICompute(Frame df, Pair<Classifier, List<Integer>> weak) { Classifier c = weak._1; List<Integer> oobIndexes = weak._2; // build oob data frame Frame oobFrame = df.mapRows(Mapping.wrap(oobIndexes)); // build accuracy on oob data frame CFit fit = c.fit(oobFrame); double refScore = new Confusion( oobFrame.var(firstTargetName()), fit.firstClasses()) .acceptedCases(); // now for each input variable do computation for (String varName : inputNames()) { // shuffle values from variable Var shuffled = Filters.shuffle(oobFrame.var(varName)); // build oob frame with shuffled variable Frame oobReduced = oobFrame.removeVars(varName).bindVars(shuffled); // compute accuracy on oob shuffled frame CFit pfit = c.fit(oobReduced); double acc = new Confusion( oobReduced.var(firstTargetName()), pfit.firstClasses() ).acceptedCases(); if (!permVIMap.containsKey(varName)) { permVIMap.put(varName, new ArrayList<>()); } permVIMap.get(varName).add(refScore - acc); } } private void gainVICompute(Pair<Classifier, List<Integer>> weak) { CTree weakTree = (CTree) weak._1; DVector scores = DVector.empty(false, inputNames()); collectGainVI(weakTree.getRoot(), scores); for (int j = 0; j < inputNames().length; j++) { String varName = inputName(j); double score = scores.get(varName); if (!gainVIMap.containsKey(varName)) { gainVIMap.put(varName, new ArrayList<>()); } gainVIMap.get(varName).add(score); } } private void collectGainVI(CTreeNode node, DVector dv) { if (node.isLeaf()) return; String varName = node.getBestCandidate().getTestName(); double score = Math.abs(node.getBestCandidate().getScore()); dv.increment(varName, score * node.getDensity().sum()); node.getChildren().forEach(child -> collectGainVI(child, dv)); } private void freqVICompute(Pair<Classifier, List<Integer>> weak) { CTree weakTree = (CTree) weak._1; DVector scores = DVector.empty(false, inputNames()); collectFreqVI(weakTree.getRoot(), scores); for (int j = 0; j < inputNames().length; j++) { String varName = inputName(j); double score = scores.get(varName); if (!freqVIMap.containsKey(varName)) { freqVIMap.put(varName, new ArrayList<>()); } freqVIMap.get(varName).add(score); } } private void collectFreqVI(CTreeNode node, DVector dv) { if (node.isLeaf()) return; String varName = node.getBestCandidate().getTestName(); double score = Math.abs(node.getBestCandidate().getScore()); dv.increment(varName, node.getDensity().sum()); node.getChildren().forEach(child -> collectFreqVI(child, dv)); } private void oobCompute(Frame df, Pair<Classifier, List<Integer>> weak) { double totalOobError; double totalOobInstances; List<Integer> oobIndexes = weak._2; Frame oobTest = df.mapRows(Mapping.wrap(oobIndexes)); CFit fit = weak._1.fit(oobTest); for (int j = 0; j < oobTest.rowCount(); j++) { int fitIndex = fit.firstClasses().index(j); oobDensities.get(oobIndexes.get(j)).increment(fitIndex, 1.0); } oobFit.clear(); totalOobError = 0.0; totalOobInstances = 0.0; for (Map.Entry<Integer, DVector> e : oobDensities.entrySet()) { if (e.getValue().sum() > 0) { int bestIndex = e.getValue().findBestIndex(); String bestLevel = firstTargetLevels()[bestIndex]; oobFit.setLabel(e.getKey(), bestLevel); if (!bestLevel.equals(oobTrueClass.label(e.getKey()))) { totalOobError++; } totalOobInstances++; } } oobError = (totalOobInstances > 0) ? totalOobError / totalOobInstances : 0.0; } private Pair<Classifier, List<Integer>> buildWeakPredictor(Frame df, Var weights) { Classifier weak = c.newInstance(); Sample sample = sampler().nextSample(df, weights); Frame trainFrame = sample.df; Var trainWeights = sample.weights; weak.train(trainFrame, trainWeights, firstTargetName()); List<Integer> oobIndexes = new ArrayList<>(); if (oobComp) { Set<Integer> out = sample.mapping.rowStream().boxed().collect(toSet()); oobIndexes = IntStream.range(0, df.rowCount()).filter(row -> !out.contains(row)).boxed().collect(toList()); } return Pair.from(weak, oobIndexes); } @Override protected CFit coreFit(Frame df, boolean withClasses, boolean withDensities) { CFit cp = CFit.build(this, df, true, true); List<CFit> treeFits = predictors.stream().parallel() .map(pred -> pred.fit(df, baggingMode.needsClass(), baggingMode.needsDensity())) .collect(Collectors.toList()); baggingMode.computeDensity(firstTargetLevels(), new ArrayList<>(treeFits), cp.firstClasses(), cp.firstDensity()); return cp; } @Override public CForest withRunningHook(BiConsumer<Classifier, Integer> runningHook) { return (CForest) super.withRunningHook(runningHook); } @Override public CForest withRunPoolSize(int poolSize) { return (CForest) super.withRunPoolSize(poolSize); } @Override public CForest withInputFilters(List<FFilter> filters) { return (CForest) super.withInputFilters(filters); } @Override public CForest withInputFilters(FFilter... filters) { return (CForest) super.withInputFilters(filters); } @Override public String summary() { StringBuilder sb = new StringBuilder(); sb.append("CForest model\n"); sb.append("================\n\n"); sb.append("Description:\n"); sb.append(fullName().replaceAll(";", ";\n")).append("\n\n"); sb.append("Capabilities:\n"); sb.append(capabilities().summary()).append("\n"); sb.append("Learned model:\n"); if (!hasLearned()) { sb.append("Learning phase not called\n\n"); return sb.toString(); } sb.append(baseSummary()); // stuff specific to rf // todo return sb.toString(); } }