/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.regression.tree;
import rapaio.core.CoreTools;
import rapaio.core.stat.OnlineStat;
import rapaio.core.tools.DVector;
import rapaio.data.Frame;
import rapaio.data.Mapping;
import rapaio.data.Var;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>.
*/
public interface RTreeNominalMethod extends Serializable {
RTreeNominalMethod IGNORE = new RTreeNominalMethod() {
private static final long serialVersionUID = 7275580448899976553L;
@Override
public String name() {
return "IGNORE";
}
@Override
public List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame df, Var weights, String testColName, String targetColName, RTreeTestFunction function) {
return Collections.EMPTY_LIST;
}
};
RTreeNominalMethod FULL = new RTreeNominalMethod() {
private static final long serialVersionUID = 2733570883914611103L;
@Override
public String name() {
return "FULL";
}
@Override
public List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame dfOld, Var weightsOld, String testColName, String targetColName, RTreeTestFunction function) {
List<RTree.RTreeCandidate> result = new ArrayList<>();
RTree.RTreeCandidate best = null;
Mapping cleanMapping = dfOld.stream().filter(s -> !s.missing(testColName)).collectMapping();
Frame df = dfOld.mapRows(cleanMapping);
Var testVar = df.var(testColName);
Var targetVar = df.var(targetColName);
Var weights = weightsOld.mapRows(cleanMapping);
DVector dvWeights = DVector.fromWeights(false, testVar, weights);
DVector dvCount = DVector.fromCount(false, testVar);
// check to see if we have enough instances in at least 2 child nodes
if (dvCount.countValues(x -> x >= c.minCount) <= 1)
return Collections.EMPTY_LIST;
// make the payload
RTreeTestPayload p = new RTreeTestPayload(testVar.levels().length - 1);
p.totalVar = CoreTools.var(targetVar).value();
for (int i = 1; i < testVar.levels().length; i++) {
p.splitWeight[i - 1] = dvWeights.get(i - 1);
String label = testVar.levels()[i];
p.splitVar[i - 1] = CoreTools.var(df.stream().filter(s -> s.label(testColName).equals(label)).toMappedFrame().var(targetColName)).value();
}
double value = c.function.computeTestValue(p);
RTree.RTreeCandidate candidate = new RTree.RTreeCandidate(value, testColName);
for (int i = 1; i < testVar.levels().length; i++) {
String label = testVar.levels()[i];
candidate.addGroup(testColName + " == " + label, spot -> spot.label(testColName).equals(label));
}
return Collections.singletonList(candidate);
}
};
RTreeNominalMethod BINARY = new RTreeNominalMethod() {
private static final long serialVersionUID = -4703727362952157041L;
@Override
public String name() {
return "BINARY";
}
@Override
public List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame dfOld, Var weightsOld, String testColName, String targetColName, RTreeTestFunction function) {
Mapping cleanMapping = dfOld.stream().filter(s -> !s.missing(testColName)).collectMapping();
Frame df = dfOld.mapRows(cleanMapping);
Var testVar = df.var(testColName);
Var targetVar = df.var(targetColName);
Var weights = weightsOld.mapRows(cleanMapping);
DVector dvWeights = DVector.fromWeights(false, testVar, weights);
DVector dvCount = DVector.fromCount(false, testVar);
// compute online statistics for all level slices
OnlineStat[] os = new OnlineStat[testVar.levels().length - 1];
for (int i = 0; i < testVar.levels().length - 1; i++) {
os[i] = OnlineStat.empty();
}
for (int i = 0; i < testVar.rowCount(); i++) {
int index = testVar.index(i);
if (index == 0)
continue;
os[index - 1].update(targetVar.value(i));
}
double totalVar = CoreTools.var(targetVar).value();
RTree.RTreeCandidate best = null;
double bestScore = Double.MIN_VALUE;
for (int i = 1; i < testVar.levels().length; i++) {
String testLabel = testVar.levels()[i];
// check to see if we have enough values
if (dvCount.get(i) < c.minCount || df.rowCount() - dvCount.get(i) < c.minCount)
continue;
OnlineStat osSelect = os[i - 1];
OnlineStat osOther = OnlineStat.empty();
for (int j = 1; j < testVar.levels().length; j++) {
if (i == j)
continue;
osOther.update(os[j - 1]);
}
RTreeTestPayload p = new RTreeTestPayload(2);
p.totalVar = totalVar;
// payload for current node
p.splitWeight[0] = dvWeights.get(i);
p.splitVar[0] = osSelect.variance();
// payload for the others
p.splitWeight[1] = dvWeights.sum() - dvWeights.get(i);
p.splitVar[1] = osOther.variance();
double value = c.function.computeTestValue(p);
if (value > bestScore) {
bestScore = value;
best = new RTree.RTreeCandidate(value, testColName);
best.addGroup(testColName + " == " + testLabel,
spot -> !spot.missing(testColName) && spot.label(testColName).equals(testLabel));
best.addGroup(testColName + " != " + testLabel,
spot -> !spot.missing(testColName) && !spot.label(testColName).equals(testLabel));
}
}
return (best == null) ? Collections.EMPTY_LIST : Collections.singletonList(best);
}
};
String name();
List<RTree.RTreeCandidate> computeCandidates(RTree c, Frame df, Var weights, String testColName, String targetColName, RTreeTestFunction function);
}