/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.regression.tree;
import rapaio.core.RandomSource;
import rapaio.data.*;
import rapaio.data.stream.FSpot;
import rapaio.util.Pair;
import rapaio.util.func.SPredicate;
import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/24/14.
*/
public interface RTreeSplitter extends Serializable {
RTreeSplitter REMAINS_IGNORED = new RTreeSplitter() {
private static final long serialVersionUID = -3841482294679686355L;
@Override
public String name() {
return "REMAINS_IGNORED";
}
@Override
public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, RTree.RTreeCandidate candidate) {
List<Mapping> mappings = new ArrayList<>();
List<Var> weightsList = new ArrayList<>();
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
mappings.add(Mapping.empty());
weightsList.add(Numeric.empty());
}
df.stream().forEach(s -> {
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
SPredicate<FSpot> predicate = candidate.getGroupPredicates().get(i);
if (predicate.test(s)) {
mappings.get(i).add(s.row());
weightsList.get(i).addValue(weights.value(s.row()));
break;
}
}
});
List<Frame> frames = new ArrayList<>();
mappings.stream().forEach(mapping -> frames.add(df.mapRows(mapping)));
return Pair.from(frames, weightsList);
}
};
RTreeSplitter REMAINS_TO_MAJORITY = new RTreeSplitter() {
private static final long serialVersionUID = 5206066415613740170L;
@Override
public String name() {
return "REMAINS_TO_MAJORITY";
}
@Override
public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, RTree.RTreeCandidate candidate) {
List<Mapping> mappings = new ArrayList<>();
List<Var> weightsList = new ArrayList<>();
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
mappings.add(Mapping.empty());
weightsList.add(Numeric.empty());
}
List<FSpot> missingSpots = new LinkedList<>();
df.stream().forEach(s -> {
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
SPredicate<FSpot> predicate = candidate.getGroupPredicates().get(i);
if (predicate.test(s)) {
mappings.get(i).add(s.row());
weightsList.get(i).addValue(weights.value(s.row()));
return;
}
}
missingSpots.add(s);
});
int majorityGroup = 0;
int majoritySize = 0;
for (int i = 0; i < mappings.size(); i++) {
if (mappings.get(i).size() > majoritySize) {
majorityGroup = i;
majoritySize = mappings.get(i).size();
}
}
final int index = majorityGroup;
missingSpots.stream().forEach(spot -> {
mappings.get(index).add(spot.row());
weightsList.get(index).addValue(weights.value(spot.row()));
});
List<Frame> frames = new ArrayList<>();
mappings.stream().forEach(mapping -> frames.add(MappedFrame.byRow(df, mapping)));
return Pair.from(frames, weightsList);
}
};
RTreeSplitter REMAINS_TO_ALL_WEIGHTED = new RTreeSplitter() {
private static final long serialVersionUID = -7751464101852319794L;
@Override
public String name() {
return "REMAINS_TO_ALL_WEIGHTED";
}
@Override
public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, RTree.RTreeCandidate candidate) {
List<Mapping> mappings = new ArrayList<>();
List<Var> weightsList = new ArrayList<>();
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
mappings.add(Mapping.empty());
weightsList.add(Numeric.empty());
}
final Set<Integer> missingSpots = new HashSet<>();
df.stream().forEach(s -> {
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
SPredicate<FSpot> predicate = candidate.getGroupPredicates().get(i);
if (predicate.test(s)) {
mappings.get(i).add(s.row());
weightsList.get(i).addValue(weights.value(s.row()));
return;
}
}
missingSpots.add(s.row());
});
final double[] p = new double[mappings.size()];
double n = 0;
for (int i = 0; i < mappings.size(); i++) {
p[i] = mappings.get(i).size();
n += p[i];
}
for (int i = 0; i < p.length; i++) {
p[i] /= n;
}
for (int i = 0; i < mappings.size(); i++) {
final int ii = i;
missingSpots.forEach(missingRow -> {
mappings.get(ii).add(missingRow);
weightsList.get(ii).addValue(weights.value(missingRow) * p[ii]);
});
}
List<Frame> frames = new ArrayList<>();
for (Mapping mapping : mappings) {
frames.add(MappedFrame.byRow(df, mapping));
}
return Pair.from(frames, weightsList);
}
};
RTreeSplitter REMAINS_TO_RANDOM = new RTreeSplitter() {
private static final long serialVersionUID = -592529235216896819L;
@Override
public String name() {
return "REMAINS_TO_RANDOM";
}
@Override
public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, RTree.RTreeCandidate candidate) {
List<Mapping> mappings = IntStream.range(0, candidate.getGroupPredicates().size())
.boxed().map(i -> Mapping.empty()).collect(Collectors.toList());
df.stream().forEach(s -> {
for (int i = 0; i < candidate.getGroupPredicates().size(); i++) {
SPredicate<FSpot> predicate = candidate.getGroupPredicates().get(i);
if (predicate.test(s)) {
mappings.get(i).add(s.row());
return;
}
}
mappings.get(RandomSource.nextInt(mappings.size())).add(s.row());
});
List<Frame> frameList = mappings.stream().map(df::mapRows).collect(Collectors.toList());
List<Var> weightList = mappings.stream().map(weights::mapRows).collect(Collectors.toList());
return Pair.from(frameList, weightList);
}
};
String name();
Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, RTree.RTreeCandidate candidate);
}