/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.tree; import rapaio.core.RandomSource; import rapaio.core.tools.DTable; import rapaio.data.Frame; import rapaio.data.Index; import rapaio.data.RowComparators; import rapaio.data.Var; import rapaio.data.filter.var.VFRefSort; import rapaio.sys.WS; import rapaio.util.Tagged; import java.io.Serializable; /** * Impurity test implementation * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 10/9/15. */ public interface CTreePurityTest extends Tagged, Serializable { CTreePurityTest Ignore = new CTreePurityTest() { private static final long serialVersionUID = 2862814158096438654L; @Override public String name() { return "Ignore"; } @Override public CTreeCandidate computeCandidate(CTree c, Frame df, Var w, String testName, String targetName, CTreePurityFunction function) { return null; } }; CTreePurityTest NumericBinary = new CTreePurityTest() { private static final long serialVersionUID = -2093990830002355963L; @Override public String name() { return "NumericBinary"; } @Override public CTreeCandidate computeCandidate(CTree c, Frame df, Var weights, String testName, String targetName, CTreePurityFunction function) { Var test = df.var(testName); Var target = df.var(targetName); DTable dt = DTable.empty(DTable.NUMERIC_DEFAULT_LABELS, target.levels(), false); int misCount = 0; for (int i = 0; i < df.rowCount(); i++) { int row = (test.missing(i)) ? 0 : 2; if (test.missing(i)) misCount++; dt.update(row, target.index(i), weights.value(i)); } Var sort = new VFRefSort(RowComparators.numeric(test, true)).fitApply(Index.seq(df.rowCount())); CTreeCandidate best = null; double bestScore = 0.0; for (int i = 0; i < df.rowCount(); i++) { int row = sort.index(i); if (test.missing(row)) continue; dt.update(2, target.index(row), -weights.value(row)); dt.update(1, target.index(row), +weights.value(row)); if (i >= misCount + c.minCount() - 1 && i < df.rowCount() - c.minCount() && test.value(sort.index(i)) < test.value(sort.index(i + 1))) { double currentScore = function.compute(dt); if (best != null) { int comp = Double.compare(bestScore, currentScore); if (comp > 0) continue; if (comp == 0 && RandomSource.nextDouble() > 0.5) continue; } best = new CTreeCandidate(bestScore, testName); double testValue = (test.value(sort.index(i)) + test.value(sort.index(i + 1))) / 2.0; best.addGroup( String.format("%s <= %s", testName, WS.formatFlex(testValue)), spot -> !spot.missing(testName) && spot.value(testName) <= testValue); best.addGroup( String.format("%s > %s", testName, WS.formatFlex(testValue)), spot -> !spot.missing(testName) && spot.value(testName) > testValue); bestScore = currentScore; } } return best; } }; CTreePurityTest BinaryBinary = new CTreePurityTest() { private static final long serialVersionUID = 1771541941375729870L; @Override public String name() { return "BinaryBinary"; } @Override public CTreeCandidate computeCandidate(CTree c, Frame df, Var w, String testName, String targetName, CTreePurityFunction function) { Var test = df.var(testName); Var target = df.var(targetName); DTable dt = DTable.fromCounts(test, target, false); if (!(dt.hasColsWithMinimumCount(c.minCount(), 2))) { return null; } CTreeCandidate best = new CTreeCandidate(function.compute(dt), testName); best.addGroup(testName + " == 1", spot -> spot.binary(testName)); best.addGroup(testName + " != 1", spot -> !spot.binary(testName)); return best; } }; CTreePurityTest NominalFull = new CTreePurityTest() { private static final long serialVersionUID = 2261155834044153945L; @Override public String name() { return "NominalFull"; } @Override public CTreeCandidate computeCandidate(CTree c, Frame df, Var weights, String testName, String targetName, CTreePurityFunction function) { Var test = df.var(testName); Var target = df.var(targetName); if (!DTable.fromCounts(test, target, false).hasColsWithMinimumCount(c.minCount(), 2)) { return null; } DTable dt = DTable.fromWeights(test, target, weights, false); double value = function.compute(dt); CTreeCandidate candidate = new CTreeCandidate(value, testName); for (int i = 1; i < test.levels().length; i++) { final String label = test.levels()[i]; candidate.addGroup( String.format("%s == %s", testName, label), spot -> !spot.missing(testName) && spot.label(testName).equals(label)); } return candidate; } }; CTreePurityTest NominalBinary = new CTreePurityTest() { private static final long serialVersionUID = -1257733788317891040L; @Override public String name() { return "Nominal_Binary"; } @Override public CTreeCandidate computeCandidate(CTree c, Frame df, Var weights, String testName, String targetName, CTreePurityFunction function) { Var test = df.var(testName); Var target = df.var(targetName); DTable counts = DTable.fromCounts(test, target, false); if (!(counts.hasColsWithMinimumCount(c.minCount(), 2))) { return null; } CTreeCandidate best = null; double bestScore = 0.0; int[] termCount = new int[test.levels().length]; test.stream().forEach(s -> termCount[s.index()]++); double[] rowCounts = counts.rowTotals(); for (int i = 1; i < test.levels().length; i++) { if (rowCounts[i] < c.minCount()) continue; String testLabel = df.var(testName).levels()[i]; DTable dt = DTable.binaryFromWeights(test, target, weights, testLabel, false); double currentScore = function.compute(dt); if (best != null) { int comp = Double.compare(bestScore, currentScore); if (comp > 0) continue; if (comp == 0 && RandomSource.nextDouble() > 0.5) continue; } best = new CTreeCandidate(currentScore, testName); best.addGroup(testName + " == " + testLabel, spot -> spot.label(testName).equals(testLabel)); best.addGroup(testName + " != " + testLabel, spot -> !spot.label(testName).equals(testLabel)); } return best; } }; CTreeCandidate computeCandidate(CTree c, Frame df, Var w, String testName, String targetName, CTreePurityFunction function); }