/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.tree; import rapaio.core.RandomSource; import rapaio.data.Frame; import rapaio.data.Mapping; import rapaio.data.Var; import rapaio.data.stream.FSpot; import rapaio.util.Pair; import rapaio.util.Tagged; import rapaio.util.func.SPredicate; import java.io.Serializable; import java.util.*; import java.util.stream.IntStream; import static java.util.stream.Collectors.toList; /** * Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>. */ public interface CTreeMissingHandler extends Tagged, Serializable { CTreeMissingHandler Ignored = new CTreeMissingHandler() { private static final long serialVersionUID = -9017265383541294518L; @Override public String name() { return "Ignored"; } @Override public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, CTreeCandidate candidate) { List<SPredicate<FSpot>> p = candidate.getGroupPredicates(); List<Mapping> mappings = IntStream.range(0, p.size()).boxed().map(i -> Mapping.empty()).collect(toList()); df.stream().forEach(s -> { for (int i = 0; i < p.size(); i++) { if (p.get(i).test(s)) { mappings.get(i).add(s.row()); break; } } }); return Pair.from( mappings.stream().map(df::mapRows).collect(toList()), mappings.stream().map(weights::mapRows).collect(toList()) ); } }; CTreeMissingHandler ToMajority = new CTreeMissingHandler() { private static final long serialVersionUID = -5858151664805703831L; @Override public String name() { return "ToMajority"; } @Override public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, CTreeCandidate candidate) { List<SPredicate<FSpot>> p = candidate.getGroupPredicates(); List<Mapping> mappings = IntStream.range(0, p.size()).boxed().map(i -> Mapping.empty()).collect(toList()); List<Integer> missingSpots = new LinkedList<>(); df.stream().forEach(s -> { for (int i = 0; i < p.size(); i++) { if (p.get(i).test(s)) { mappings.get(i).add(s.row()); return; } } missingSpots.add(s.row()); }); List<Integer> lens = mappings.stream().map(Mapping::size).collect(toList()); Collections.shuffle(lens); int majorityGroup = 0; int majoritySize = 0; for (int i = 0; i < mappings.size(); i++) { if (mappings.get(i).size() > majoritySize) { majorityGroup = i; majoritySize = mappings.get(i).size(); } } final int index = majorityGroup; mappings.get(index).addAll(missingSpots); return Pair.from( mappings.stream().map(df::mapRows).collect(toList()), mappings.stream().map(weights::mapRows).collect(toList()) ); } }; CTreeMissingHandler ToAllWeighted = new CTreeMissingHandler() { private static final long serialVersionUID = 5936044048099571710L; @Override public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, CTreeCandidate candidate) { List<SPredicate<FSpot>> pred = candidate.getGroupPredicates(); List<Mapping> mappings = IntStream.range(0, pred.size()).boxed().map(i -> Mapping.empty()).collect(toList()); List<Integer> missingSpots = new ArrayList<>(); df.stream().forEach(s -> { for (int i = 0; i < pred.size(); i++) { if (pred.get(i).test(s)) { mappings.get(i).add(s.row()); return; } } missingSpots.add(s.row()); }); final double[] p = new double[mappings.size()]; double n = 0; for (int i = 0; i < mappings.size(); i++) { p[i] = mappings.get(i).size(); n += p[i]; } for (int i = 0; i < p.length; i++) { p[i] /= n; } List<Var> weightsList = mappings.stream().map(weights::mapRows).map(Var::solidCopy).collect(toList()); for (int i = 0; i < mappings.size(); i++) { final int ii = i; missingSpots.forEach(row -> { mappings.get(ii).add(row); weightsList.get(ii).addValue(weights.missing(row) ? p[ii] : weights.value(row) * p[ii]); }); } List<Frame> frames = mappings.stream().map(df::mapRows).collect(toList()); return Pair.from(frames, weightsList); } @Override public String name() { return "ToAllWeighted"; } }; CTreeMissingHandler ToRandom = new CTreeMissingHandler() { private static final long serialVersionUID = -4762758695801141929L; @Override public Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, CTreeCandidate candidate) { List<SPredicate<FSpot>> pred = candidate.getGroupPredicates(); List<Mapping> mappings = IntStream.range(0, pred.size()).boxed().map(i -> Mapping.empty()).collect(toList()); final Set<Integer> missingSpots = new HashSet<>(); df.stream().forEach(s -> { for (int i = 0; i < pred.size(); i++) { if (pred.get(i).test(s)) { mappings.get(i).add(s.row()); return; } } missingSpots.add(s.row()); }); missingSpots.forEach(rowId -> mappings.get(RandomSource.nextInt(mappings.size())).add(rowId)); List<Frame> frameList = mappings.stream().map(df::mapRows).collect(toList()); List<Var> weightList = mappings.stream().map(weights::mapRows).collect(toList()); return Pair.from(frameList, weightList); } @Override public String name() { return "ToRandom"; } }; Pair<List<Frame>, List<Var>> performSplit(Frame df, Var weights, CTreeCandidate candidate); }