/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.experiment.ml.classifier.meta;
import rapaio.data.*;
import rapaio.ml.classifier.AbstractClassifier;
import rapaio.ml.classifier.CFit;
import rapaio.ml.classifier.Classifier;
import rapaio.experiment.ml.classifier.linear.BinaryLogistic;
import rapaio.ml.common.Capabilities;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import static java.util.stream.Collectors.toList;
/**
* Stacking with Binary Logistic as stacking classifier
* <p>
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 9/30/15.
*/
public class CBinaryLogisticStacking extends AbstractClassifier {
private static final long serialVersionUID = -9087871586729573030L;
private static final Logger logger = Logger.getLogger(CBinaryLogisticStacking.class.getName());
private List<Classifier> weaks = new ArrayList<>();
private BinaryLogistic log = new BinaryLogistic();
private double tol = 1e-5;
private int maxRuns = 1_000_000;
public CBinaryLogisticStacking withLearners(Classifier... learners) {
weaks.clear();
Collections.addAll(weaks, learners);
return this;
}
public CBinaryLogisticStacking withTol(double tol) {
this.tol = tol;
return this;
}
public CBinaryLogisticStacking withMaxRuns(int maxRuns) {
this.maxRuns = maxRuns;
return this;
}
@Override
public Classifier newInstance() {
return new CBinaryLogisticStacking();
}
@Override
public String name() {
return "CBinaryLogisticStacking";
}
@Override
public String fullName() {
return null;
}
@Override
public Capabilities capabilities() {
return new Capabilities()
.withAllowMissingTargetValues(false)
.withAllowMissingInputValues(false)
.withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NUMERIC)
.withTargetTypes(VarType.NOMINAL)
.withInputCount(1, 100_000)
.withTargetCount(1, 1);
}
@Override
protected BaseTrainSetup baseTrain(Frame df, Var weights, String... targetVars) {
logger.config("train method called.");
List<Var> vars = new ArrayList<>();
int pos = 0;
logger.config("check learners for learning.... ");
weaks.parallelStream().map(weak -> {
if (!weak.hasLearned()) {
logger.config("started learning for weak learner ...");
weak.train(df, weights, targetVars);
}
logger.config("started fitting weak learner...");
return weak.fit(df).firstDensity().var(1);
}).collect(toList()).forEach(var -> vars.add(var.solidCopy().withName("V" + vars.size())));
List<Var> quadratic = vars.stream()
.map(v -> v.solidCopy().stream().transValue(x -> x * x).toMappedVar().withName(v.name() + "^2").solidCopy())
.collect(toList());
vars.addAll(quadratic);
List<String> targets = VRange.of(targetVars).parseVarNames(df);
vars.add(df.var(targets.get(0)).solidCopy());
return BaseTrainSetup.valueOf(SolidFrame.byVars(vars), weights, targetVars);
}
@Override
protected boolean coreTrain(Frame df, Var weights) {
logger.config("started learning for binary logistic...");
log.withTol(tol);
log.withMaxRuns(maxRuns);
log.train(df, weights, targetNames());
logger.config("end train method call");
return true;
}
@Override
protected BaseFitSetup baseFit(Frame df, boolean withClasses, boolean withDistributions) {
logger.config("fit method called.");
List<Var> vars = new ArrayList<>();
weaks.parallelStream().map(weak -> {
logger.config("started fitting weak learner ...");
return weak.fit(df).firstDensity().var(1);
}).collect(toList()).forEach(var -> vars.add(var.solidCopy().withName("V" + vars.size())));
List<Var> quadratic = vars.stream()
.map(v -> v.solidCopy().stream().transValue(x -> x * x).toMappedVar().withName(v.name() + "^2").solidCopy())
.collect(toList());
vars.addAll(quadratic);
return BaseFitSetup.valueOf(SolidFrame.byVars(vars), withClasses, withDistributions);
}
@Override
protected CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) {
logger.config("started fitting binary logistic regression.. ");
CFit fit = log.fit(df);
logger.config("end fit method call");
return fit;
}
}