/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.ml.classifier.linear; import rapaio.data.Frame; import rapaio.data.Numeric; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.experiment.math.optimization.IRLSOptimizer; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.common.Capabilities; import rapaio.util.func.SFunction; import java.util.ArrayList; import java.util.List; /** * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> at 2/3/15. */ public class BinaryLogistic extends AbstractClassifier { private static final long serialVersionUID = 1609956190070125059L; private Numeric coef; private final SFunction<Var, Double> logitF = this::logitReg; private final SFunction<Var, Double> logitFD = var -> { double y = logitReg(var); return y * (1 - y); }; private int maxRuns = 1_000_000; private double tol = 1e-5; private static double logit(double z) { return 1 / (1 + Math.exp(-z)); } @Override public BinaryLogistic newInstance() { return new BinaryLogistic() .withMaxRuns(maxRuns) .withTol(tol); } @Override public String name() { return "BinaryLogistic"; } @Override public String fullName() { StringBuilder sb = new StringBuilder(); sb.append(name()).append("{"); sb.append("tol=").append(tol).append(", "); sb.append("maxRuns=").append(maxRuns).append(", "); sb.append("}"); return sb.toString(); } @Override public Capabilities capabilities() { return new Capabilities() .withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NUMERIC, VarType.NOMINAL) .withInputCount(1, 10000) .withTargetTypes(VarType.NOMINAL) .withTargetCount(1, 1) .withAllowMissingInputValues(false) .withAllowMissingTargetValues(true); } /** * Maximum number of iterations if optimum was not met yet * (default value is 10_000) */ public BinaryLogistic withMaxRuns(int maxRuns) { this.maxRuns = maxRuns; return this; } /** * Tolerance used to check the solution optimality * (default value 1e-5). */ public BinaryLogistic withTol(double tol) { this.tol = tol; return this; } private double logitReg(Var input) { double z = coef.value(0); for (int i = 1; i < coef.rowCount(); i++) z += input.value(i - 1) * coef.value(i); return logit(z); } private double regress(Frame df, int row) { if (coef == null) throw new IllegalArgumentException("Model has not been trained"); Numeric inst = Numeric.empty(); for (int i = 0; i < inputNames().length; i++) { inst.addValue(df.value(row, inputName(i))); } return logitReg(inst); } @Override protected boolean coreTrain(Frame df, Var weights) { List<Var> inputs = new ArrayList<>(df.rowCount()); for (int i = 0; i < df.rowCount(); i++) { Numeric line = Numeric.empty(); for (String inputName : inputNames()) line.addValue(df.value(i, inputName)); inputs.add(line); } coef = Numeric.fill(inputNames().length + 1, 0); Numeric targetValues = Numeric.empty(); df.var(firstTargetName()).stream().forEach(s -> targetValues.addValue(s.index() == 1 ? 0 : 1)); IRLSOptimizer optimizer = new IRLSOptimizer(); coef = optimizer.optimize(tol, maxRuns, logitF, logitFD, coef, inputs, targetValues); return true; } @Override protected CFit coreFit(Frame df, boolean withClasses, boolean withDistributions) { CFit cr = CFit.build(this, df, withClasses, withDistributions); for (int i = 0; i < df.rowCount(); i++) { double p = regress(df, i); if (withClasses) { cr.firstClasses().setIndex(i, p < 0.5 ? 1 : 2); } if (withDistributions) { cr.firstDensity().setValue(i, 1, 1 - p); cr.firstDensity().setValue(i, 2, p); } } return cr; } }