/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.regression;
import rapaio.data.*;
import rapaio.data.filter.FFilter;
import rapaio.data.sample.RowSampler;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
/**
* Abstract class needed to implement prerequisites for all regression algorithms.
* <p>
* Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/20/14.
*/
public abstract class AbstractRegression implements Regression {
private static final long serialVersionUID = 5544999078321108408L;
private String[] inputNames;
private VarType[] inputTypes;
private String[] targetNames;
private VarType[] targetTypes;
private RowSampler sampler = RowSampler.identity();
private boolean hasLearned;
private int poolSize = Runtime.getRuntime().availableProcessors();
private int runs = 1;
private List<FFilter> inputFilters = new ArrayList<>();
private BiConsumer<Regression, Integer> runningHook;
@Override
public List<FFilter> inputFilters() {
return inputFilters;
}
@Override
public Regression withInputFilters(FFilter... filters) {
inputFilters = new ArrayList<>();
for (FFilter filter : inputFilters)
inputFilters.add(filter.newInstance());
return this;
}
@Override
public String[] inputNames() {
return inputNames;
}
@Override
public String[] targetNames() {
return targetNames;
}
@Override
public RowSampler sampler() {
return sampler;
}
@Override
public AbstractRegression withSampler(RowSampler sampler) {
this.sampler = sampler;
return this;
}
@Override
public int runs() {
return runs;
}
public Regression withRuns(int runs) {
this.runs = runs;
return this;
}
@Override
public final Regression train(Frame df, String... targetVarNames) {
return train(df, Numeric.fill(df.rowCount(), 1), targetVarNames);
}
@Override
public final Regression train(Frame df, Var weights, String... targetVarNames) {
TrainSetup setup = baseTrain(df, weights, targetVarNames);
setup = prepareTraining(setup.df, setup.w, setup.targetVars);
hasLearned = coreTrain(setup.df, setup.w);
return this;
}
protected TrainSetup prepareTraining(Frame dfOld, Var weights, String... targetVarNames) {
Frame df = dfOld;
for (FFilter filter : inputFilters) {
df = filter.fitApply(df);
}
Frame result = df;
List<String> targets = VRange.of(targetVarNames).parseVarNames(result);
this.targetNames = targets.stream().toArray(String[]::new);
this.targetTypes = targets.stream().map(name -> result.var(name).type()).toArray(VarType[]::new);
HashSet<String> targetSet = new HashSet<>(targets);
List<String> inputs = Arrays.stream(result.varNames()).filter(varName -> !targetSet.contains(varName)).collect(Collectors.toList());
this.inputNames = inputs.stream().toArray(String[]::new);
this.inputTypes = inputs.stream().map(name -> result.var(name).type()).toArray(VarType[]::new);
capabilities().checkAtLearnPhase(result, weights, targetVarNames);
return TrainSetup.valueOf(df, weights);
}
protected TrainSetup baseTrain(Frame df, Var weights, String... targetVarNames) {
return TrainSetup.valueOf(df, weights, targetVarNames);
}
protected abstract boolean coreTrain(Frame df, Var weights);
@Override
public RFit fit(Frame df) {
return fit(df, true);
}
@Override
public RFit fit(Frame df, boolean withResiduals) {
FitSetup setup = baseFit(df, withResiduals);
setup = prepareFit(setup.df, withResiduals);
return coreFit(setup.df, setup.withResiduals);
}
// by default do nothing, it is only for two stage training
protected FitSetup baseFit(Frame df, boolean withResiduals) {
return FitSetup.valueOf(df, withResiduals);
}
protected FitSetup prepareFit(Frame df, boolean withResiduals) {
Frame result = df;
for (FFilter filter : inputFilters) {
result = filter.apply(result);
}
return FitSetup.valueOf(result, withResiduals);
}
protected abstract RFit coreFit(Frame df, boolean withResiduals);
@Override
public boolean hasLearned() {
return hasLearned;
}
@Override
public VarType[] inputTypes() {
return inputTypes;
}
@Override
public VarType[] targetTypes() {
return targetTypes;
}
@Override
public Regression withPoolSize(int poolSize) {
this.poolSize = poolSize < 0 ? Runtime.getRuntime().availableProcessors() : poolSize;
return this;
}
@Override
public int poolSize() {
return poolSize;
}
@Override
public BiConsumer<Regression, Integer> runningHook() {
return runningHook;
}
@Override
public Regression withRunningHook(BiConsumer<Regression, Integer> runningHook) {
this.runningHook = runningHook;
return this;
}
protected static class TrainSetup {
public final Frame df;
public final Var w;
public final String[] targetVars;
private TrainSetup(Frame df, Var w, String[] targetVars) {
this.df = df;
this.w = w;
this.targetVars = targetVars;
}
public static TrainSetup valueOf(Frame df, Var w, String[] targetVars) {
return new TrainSetup(df, w, targetVars);
}
public static TrainSetup valueOf(Frame df, Var w) {
return new TrainSetup(df, w, null);
}
}
protected static final class FitSetup {
public final Frame df;
public final boolean withResiduals;
private FitSetup(Frame df, boolean withResiduals) {
this.df = df;
this.withResiduals = withResiduals;
}
public static FitSetup valueOf(Frame df, boolean withResiduals) {
return new FitSetup(df, withResiduals);
}
}
}