/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier; import rapaio.data.*; import rapaio.data.filter.FFilter; import rapaio.data.sample.RowSampler; import rapaio.printer.format.TextTable; import java.util.*; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.IntStream; import static java.util.stream.Collectors.joining; /** * Abstract base class for all classifiers. * * @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> */ public abstract class AbstractClassifier implements Classifier { private static final long serialVersionUID = -6866948033065091047L; private List<FFilter> inputFilters = new ArrayList<>(); private String[] inputNames; private VarType[] inputTypes; private String[] targetNames; private VarType[] targetTypes; private Map<String, String[]> dict; private RowSampler sampler = RowSampler.identity(); private boolean learned = false; private int poolSize = Runtime.getRuntime().availableProcessors(); private int runs = 1; private BiConsumer<Classifier, Integer> runningHook; @Override public RowSampler sampler() { return sampler; } @Override public AbstractClassifier withSampler(RowSampler sampler) { this.sampler = sampler; return this; } @Override public List<FFilter> inputFilters() { return inputFilters; } @Override public Classifier withInputFilters(List<FFilter> filters) { inputFilters = new ArrayList<>(); for (FFilter filter : filters) inputFilters.add(filter.newInstance()); return this; } @Override public String[] inputNames() { return inputNames; } @Override public VarType[] inputTypes() { return inputTypes; } @Override public String[] targetNames() { return targetNames; } @Override public VarType[] targetTypes() { return targetTypes; } @Override public Map<String, String[]> targetLevels() { return dict; } public boolean hasLearned() { return learned; } @Override public final Classifier train(Frame df, String... targetVars) { Numeric weights = Numeric.fill(df.rowCount(), 1); return train(df, weights, targetVars); } @Override public final Classifier train(Frame df, Var weights, String... targetVars) { BaseTrainSetup setup = baseTrain(df, weights, targetVars); Frame workDf = prepareTraining(setup.df, setup.w, setup.targetVars); learned = coreTrain(workDf, setup.w); return this; } /** * This method is prepares learning phase. It is a generic method which works * for all learners. It's tass includes initialization of target names, * input names, check the capabilities at learning phase, etc. * * @param dfOld data frame * @param weights weights of instances * @param targetVars target variable names */ protected Frame prepareTraining(Frame dfOld, final Var weights, final String... targetVars) { Frame df = dfOld; for (FFilter filter : inputFilters) { df = filter.fitApply(df); } Frame result = df; List<String> targets = VRange.of(targetVars).parseVarNames(result); this.targetNames = targets.stream().toArray(String[]::new); this.targetTypes = targets.stream().map(name -> result.var(name).type()).toArray(VarType[]::new); this.dict = new HashMap<>(); this.dict.put(firstTargetName(), result.var(firstTargetName()).levels()); HashSet<String> targetSet = new HashSet<>(targets); List<String> inputs = Arrays.stream(result.varNames()).filter(varName -> !targetSet.contains(varName)).collect(Collectors.toList()); this.inputNames = inputs.stream().toArray(String[]::new); this.inputTypes = inputs.stream().map(name -> result.var(name).type()).toArray(VarType[]::new); capabilities().checkAtLearnPhase(result, weights, targetVars); return result; } protected BaseTrainSetup baseTrain(Frame df, Var weights, String... targetVars) { return BaseTrainSetup.valueOf(df, weights, targetVars); } protected abstract boolean coreTrain(Frame df, Var weights); @Override public final CFit fit(Frame df) { return fit(df, true, true); } @Override public final CFit fit(Frame df, boolean withClasses, boolean withDistributions) { BaseFitSetup setup = baseFit(df, withClasses, withDistributions); Frame workDf = prepareFit(setup.df); return coreFit(workDf, setup.withClasses, setup.withDistributions); } // by default do nothing, it is only for two stage training protected BaseFitSetup baseFit(Frame df, boolean withClasses, boolean withDistributions) { return BaseFitSetup.valueOf(df, withClasses, withDistributions); } protected Frame prepareFit(Frame df) { Frame result = df; for (FFilter filter : inputFilters) { result = filter.apply(result); } return result; } protected abstract CFit coreFit(Frame df, boolean withClasses, boolean withDistributions); @Override public String summary() { return "not implemented"; } public String baseSummary() { StringBuilder sb = new StringBuilder(); sb.append("input vars: \n"); int varCount = inputNames.length; TextTable tt = TextTable.newEmpty(varCount, 5); for (int i = 0; i < varCount; i++) { tt.set(i, 0, i + ".", 1); tt.set(i, 1, inputNames[i], 1); tt.set(i, 2, ":", -1); tt.set(i, 3, inputTypes[i].name(), -1); tt.set(i, 4, " |", 1); } tt.withMerge(); sb.append("\n").append(tt.summary()).append("\n"); sb.append("target vars:\n"); IntStream.range(0, targetNames().length).forEach(i -> sb.append("> ") .append(targetName(i)).append(" : ") .append(targetType(i)) .append(" [").append(Arrays.stream(targetLevels(targetName(i))).collect(joining(","))).append("]") .append("\n")); return sb.toString(); } @Override public AbstractClassifier withRunPoolSize(int poolSize) { this.poolSize = poolSize < 0 ? Runtime.getRuntime().availableProcessors() : poolSize; return this; } @Override public int runPoolSize() { return poolSize; } @Override public int runs() { return runs; } @Override public Classifier withRuns(int runs) { this.runs = runs; return this; } @Override public BiConsumer<Classifier, Integer> runningHook() { return runningHook; } @Override public Classifier withRunningHook(BiConsumer<Classifier, Integer> runningHook) { this.runningHook = runningHook; return this; } protected static class BaseTrainSetup { public final Frame df; public final Var w; public final String[] targetVars; private BaseTrainSetup(Frame df, Var w, String[] targetVars) { this.df = df; this.w = w; this.targetVars = targetVars; } public static BaseTrainSetup valueOf(Frame df, Var w, String[] targetVars) { return new BaseTrainSetup(df, w, targetVars); } } protected static final class BaseFitSetup { public final Frame df; public final boolean withClasses; public final boolean withDistributions; private BaseFitSetup(Frame df, boolean withClasses, boolean withDistributions) { this.df = df; this.withClasses = withClasses; this.withDistributions = withDistributions; } public static BaseFitSetup valueOf(Frame df, boolean withClasses, boolean withDistributions) { return new BaseFitSetup(df, withClasses, withDistributions); } } }