/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.regression.tree; import rapaio.core.CoreTools; import rapaio.core.stat.OnlineStat; import rapaio.data.Frame; import rapaio.data.Mapping; import rapaio.data.Var; import rapaio.data.filter.Filters; import rapaio.data.stream.VSpot; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; /** * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>. */ public interface RTreeNumericMethod extends Serializable { RTreeNumericMethod IGNORE = new RTreeNumericMethod() { @Override public String name() { return "IGNORE"; } @Override public List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame df, Var weights, String testVarName, String targetVarName, RTreeTestFunction function) { return new ArrayList<>(); } }; RTreeNumericMethod BINARY = new RTreeNumericMethod() { @Override public String name() { return "BINARY"; } @Override public List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame dfOld, Var weights, String testVarName, String targetVarName, RTreeTestFunction function) { Frame df = Filters.refSort(dfOld, dfOld.var(testVarName).refComparator()); Mapping cleanMapping = Mapping.wrap(df.var(testVarName).stream().complete().map(VSpot::row).collect(Collectors.toList())); Var test = df.var(testVarName).mapRows(cleanMapping); Var target = df.var(targetVarName).mapRows(cleanMapping); double[] leftWeight = new double[test.rowCount()]; double[] leftVar = new double[test.rowCount()]; double[] rightWeight = new double[test.rowCount()]; double[] rightVar = new double[test.rowCount()]; OnlineStat so = OnlineStat.empty(); double w = 0.0; for (int i = 0; i < test.rowCount(); i++) { so.update(target.value(i)); w += weights.value(i); leftWeight[i] = w; leftVar[i] = so.variance(); } w = 0.0; for (int i = test.rowCount() - 1; i >= 0; i--) { w += weights.value(i); so.update(target.value(i)); rightWeight[i] = w; rightVar[i] += so.variance(); } RTree.RTreeCandidate best = null; double bestScore = 0.0; RTreeTestPayload p = new RTreeTestPayload(2); p.totalVar = CoreTools.var(target).value(); for (int i = c.minCount; i < test.rowCount() - c.minCount - 1; i++) { if (test.value(i) == test.value(i + 1)) continue; p.splitVar[0] = leftVar[i]; p.splitVar[1] = rightVar[i]; p.splitWeight[0] = leftWeight[i]; p.splitWeight[1] = rightWeight[i]; double value = c.function.computeTestValue(p); if (value > bestScore) { bestScore = value; best = new RTree.RTreeCandidate(value, testVarName); double testValue = test.value(i); best.addGroup( String.format("%s <= %.6f", testVarName, testValue), spot -> !spot.missing(testVarName) && spot.value(testVarName) <= testValue); best.addGroup( String.format("%s > %.6f", testVarName, testValue), spot -> !spot.missing(testVarName) && spot.value(testVarName) > testValue); } } return (best != null) ? Collections.singletonList(best) : Collections.EMPTY_LIST; } }; String name(); List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame df, Var weights, String testVarName, String targetVarName, RTreeTestFunction function); }