/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.ml.classifier.meta; import rapaio.data.*; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.classifier.Classifier; import rapaio.ml.classifier.ensemble.CForest; import rapaio.ml.common.Capabilities; import rapaio.util.Util; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.logging.Logger; import static java.util.stream.Collectors.toList; /** * Stacking with a stacking classifier * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 9/30/15. */ public class CStacking extends AbstractClassifier { private static final long serialVersionUID = -9087871586729573030L; private static final Logger logger = Logger.getLogger(CStacking.class.getName()); private List<Classifier> weaks = new ArrayList<>(); private Classifier stacker = CForest.newRF(); public CStacking withLearners(Classifier... learners) { weaks.clear(); Collections.addAll(weaks, learners); return this; } public CStacking withStacker(Classifier stacker) { this.stacker = stacker; return this; } @Override public Classifier newInstance() { return new CStacking() .withLearners(weaks.stream().map(Classifier::newInstance).toArray(Classifier[]::new)) .withStacker(stacker.newInstance()) .withRunPoolSize(runPoolSize()) .withRunningHook(runningHook()) .withRuns(runs()) .withInputFilters(inputFilters()); } @Override public String name() { return "CStacking"; } @Override public String fullName() { StringBuilder sb = new StringBuilder(); sb.append("CStacking{stacker=").append(stacker.fullName()).append(";"); return sb.toString(); } @Override public Capabilities capabilities() { return new Capabilities() .withAllowMissingTargetValues(false) .withAllowMissingInputValues(false) .withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NUMERIC) .withTargetTypes(VarType.NOMINAL) .withInputCount(1, 100_000) .withTargetCount(1, 1); } protected BaseTrainSetup baseTrain(Frame df, Var w, String... targetVars) { logger.fine("train method called."); int pos = 0; logger.fine("check learners for learning.... "); List<Var> vars = Util.rangeStream(weaks.size(), true) .boxed() .map(i -> { if (!weaks.get(i).hasLearned()) { logger.fine("started learning for weak learner ..."); weaks.get(i).train(df, w, targetVars); } logger.fine("started fitting weak learner..."); return weaks.get(i).fit(df).firstDensity().var(1).solidCopy() .withName("V" + i); }) .collect(toList()); List<String> targets = VRange.of(targetVars).parseVarNames(df); vars.add(df.var(targets.get(0)).solidCopy()); return BaseTrainSetup.valueOf(SolidFrame.byVars(vars), w, targetVars); } @Override protected boolean coreTrain(Frame df, Var weights) { logger.fine("started learning for stacker classifier..."); stacker.train(df, weights, targetNames()); logger.fine("end train method call"); return true; } protected BaseFitSetup baseFit(Frame df, boolean withClasses, boolean withDistributions) { logger.fine("fit method called."); List<Var> vars = Util.rangeStream(weaks.size(), true) .boxed() .map(i -> { logger.fine("started fitting weak learner ..."); return weaks.get(i) .fit(df) .firstDensity() .var(1) .solidCopy() .withName("V" + i); }).collect(toList()); return BaseFitSetup.valueOf(SolidFrame.byVars(vars), withClasses, withDistributions); } @Override protected CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) { logger.fine("started fitting stacker classifier .. "); CFit fit = stacker.fit(df); logger.fine("end fit method call"); return fit; } }