/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.ml.classifier.meta; import rapaio.data.Frame; import rapaio.data.VRange; import rapaio.data.Var; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.classifier.Classifier; import rapaio.ml.common.Capabilities; import rapaio.ml.eval.Confusion; import rapaio.printer.format.TextTable; import rapaio.sys.WS; import java.util.*; import java.util.stream.Collectors; import static java.util.stream.Collectors.toList; /** * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 11/11/15. */ public class CStepwiseSelection extends AbstractClassifier { private static final long serialVersionUID = 2642562123626893974L; Classifier best; private Classifier c; private int minVars = 1; private int maxVars = 1; private String[] startSelection = new String[]{}; private int restartAfter = 2; private int maxSearch = Integer.MAX_VALUE; private Frame test; // training artifacts private List<String> selection; @Override public String name() { return "CStepwiseSelection"; } @Override public String fullName() { return null; } @Override public Classifier newInstance() { return new CStepwiseSelection() .withRestartAfter(restartAfter) .withClassifier(c) .withMaxVars(maxVars) .withMinVars(minVars) .withStartSelection(startSelection) .withTestFrame(test) .withRunPoolSize(runPoolSize()); } @Override public Capabilities capabilities() { return c.capabilities(); } public CStepwiseSelection withClassifier(Classifier c) { this.c = c; return this; } public CStepwiseSelection withMinVars(int minVars) { this.minVars = minVars; return this; } public CStepwiseSelection withMaxVars(int maxVars) { this.maxVars = maxVars; return this; } public CStepwiseSelection withStartSelection(String... startSelection) { this.startSelection = startSelection; return this; } public CStepwiseSelection withRestartAfter(int restartAfter) { this.restartAfter = restartAfter; return this; } public CStepwiseSelection withMaxSearch(int maxSearch) { this.maxSearch = maxSearch; return this; } public CStepwiseSelection withTestFrame(Frame test) { this.test = test; return this; } @Override protected boolean coreTrain(Frame df, Var weights) { selection = VRange.of(startSelection).parseVarNames(df); Frame testFrame = test != null ? test : df; List<String> bestSelection = new ArrayList<>(selection); Classifier bestClassifier = null; double bestAcc = 0.0; String forwardNext = null; String backwardNext = null; for (int r = 0; r < runs(); r++) { boolean found = false; Set<String> inSet = new HashSet<>(selection); if (selection.size() < maxVars) { // do forward selection List<String> in = Arrays.stream(inputNames()).collect(Collectors.toList()); Collections.shuffle(in); int restart = 0; for (int i = 0; i < in.size() && i < maxSearch; i++) { String test = in.get(i); if (inSet.contains(test)) continue; List<String> next = new ArrayList<>(selection); next.add(test); next.add(firstTargetName()); Classifier cNext = c.newInstance(); cNext.train(df.mapVars(next), firstTargetName()); Confusion cm = new Confusion(testFrame.var(firstTargetName()), cNext.fit(testFrame).firstClasses()); double acc = cm.accuracy(); if (acc > bestAcc) { WS.println(WS.formatFlex(acc)); bestAcc = acc; bestClassifier = cNext; forwardNext = test; backwardNext = null; found = true; restart++; if (restart >= restartAfter) { break; } } } } if (!found && selection.size() > minVars) { // do backward selection int restart = 0; Collections.shuffle(selection); for (int i = 0; i < selection.size() && i < maxSearch; i++) { String test = selection.get(i); List<String> next = selection.stream().filter(n -> !test.equals(n)).collect(toList()); next.add(firstTargetName()); Classifier cNext = c.newInstance(); cNext.train(df.mapVars(next), firstTargetName()); Confusion cm = new Confusion(testFrame.var(firstTargetName()), cNext.fit(testFrame).firstClasses()); double acc = cm.accuracy(); if (acc > bestAcc) { WS.println(WS.formatFlex(acc)); bestAcc = acc; bestClassifier = cNext; forwardNext = null; backwardNext = test; found = true; restart++; if (restart >= restartAfter) break; } } } if (!found) break; best = bestClassifier; String testNext = (forwardNext == null) ? backwardNext : forwardNext; if (forwardNext != null) { selection.add(testNext); } if (backwardNext != null) { selection = selection.stream().filter(n -> !n.equals(testNext)).collect(toList()); } WS.println("best selection: "); TextTable tt = TextTable.newEmpty(selection.size() + 1, 2); tt.set(0, 0, "No.", 0); tt.set(0, 1, "Name", -1); for (int i = 0; i < selection.size(); i++) { tt.set(i + 1, 0, i + ".", 1); tt.set(i + 1, 1, selection.get(i), -1); } tt.withMerge(); tt.withHeaderRows(1); tt.printSummary(); WS.println("last test: " + testNext); new Confusion(testFrame.var(firstTargetName()), best.fit(testFrame).firstClasses()).printSummary(); } return true; } @Override protected CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) { return best.fit(df, withClasses, withDistributions); } }