/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.sgd; import com.google.common.base.Charsets; import com.google.common.base.Splitter; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.io.Resources; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.Dictionary; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.lang.reflect.Field; import java.util.Collections; import java.util.List; import java.util.Random; public final class OnlineLogisticRegressionTest extends OnlineBaseTest { private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class); /** * The CrossFoldLearner is probably the best learner to use for new applications. * * @throws IOException If test resources aren't readable. */ @Test public void crossValidation() throws IOException { Vector target = readStandardData(); CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1()) .lambda(1 * 1.0e-3) .learningRate(50); train(getInput(), target, lr); System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood()); test(getInput(), target, lr, 0.05, 0.3); } @Test public void crossValidatedAuc() throws IOException { RandomUtils.useTestSeed(); Random gen = RandomUtils.getRandom(); Matrix data = readCsv("cancer.csv"); CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1()) .stepOffset(10) .decayExponent(0.7) .lambda(1 * 1.0e-3) .learningRate(5); int k = 0; int[] ordering = permute(gen, data.numRows()); for (int epoch = 0; epoch < 100; epoch++) { for (int row : ordering) { lr.train(row, (int) data.get(row, 9), data.viewRow(row)); System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc()); } assertEquals(1, lr.auc(), 0.2); } assertEquals(1, lr.auc(), 0.1); } /** * Verifies that a classifier with known coefficients does the right thing. */ @Test public void testClassify() { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1)); // set up some internal coefficients as if we had learned them lr.setBeta(0, 0, -1); lr.setBeta(1, 0, -2); // zero vector gives no information. All classes are equal. Vector v = lr.classify(new DenseVector(new double[]{0, 0})); assertEquals(1 / 3.0, v.get(0), 1.0e-8); assertEquals(1 / 3.0, v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{0, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-8); assertEquals(1 / 3.0, v.get(1), 1.0e-8); assertEquals(1 / 3.0, v.get(2), 1.0e-8); // weights for second vector component are still zero so all classifications are equally likely v = lr.classify(new DenseVector(new double[]{0, 1})); assertEquals(1 / 3.0, v.get(0), 1.0e-3); assertEquals(1 / 3.0, v.get(1), 1.0e-3); v = lr.classifyFull(new DenseVector(new double[]{0, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-3); assertEquals(1 / 3.0, v.get(1), 1.0e-3); assertEquals(1 / 3.0, v.get(2), 1.0e-3); // but the weights on the first component are non-zero v = lr.classify(new DenseVector(new double[]{1, 0})); assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{1, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8); lr.setBeta(0, 1, 1); v = lr.classifyFull(new DenseVector(new double[]{1, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3); assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3); assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3); lr.setBeta(1, 1, 3); v = lr.classifyFull(new DenseVector(new double[]{1, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8); assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8); assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8); } @Test public void iris() throws IOException { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // } // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { int r = lr.classifyFull(data.get(k)).maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]); } // nor perfect assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]); } @Test public void testTrain() throws Exception { Vector target = readStandardData(); // lambda here needs to be relatively small to avoid swamping the actual signal, but can be // larger than usual because the data are dense. The learning rate doesn't matter too much // for this example, but should generally be < 1 // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .learningRate(50); train(getInput(), target, lr); test(getInput(), target, lr, 0.05, 0.3); } /** * Test for Serialization/DeSerialization * */ @Test public void testSerializationAndDeSerialization() throws Exception { OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .stepOffset(11) .alpha(0.01) .learningRate(50) .decayExponent(-0.02); lr.close(); byte[] output; try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) { PolymorphicWritable.write(dataOutputStream, lr); output = byteArrayOutputStream.toByteArray(); } OnlineLogisticRegression read; try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output); DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) { read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class); } //lambda Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7); // Reflection to get private variables //stepOffset Field stepOffset = lr.getClass().getDeclaredField("stepOffset"); stepOffset.setAccessible(true); int stepOffsetVal = (Integer) stepOffset.get(lr); Assert.assertEquals(11, stepOffsetVal); //decayFactor (alpha) Field decayFactor = lr.getClass().getDeclaredField("decayFactor"); decayFactor.setAccessible(true); double decayFactorVal = (Double) decayFactor.get(lr); Assert.assertEquals(0.01, decayFactorVal, 1.0e-7); //learning rate (mu0) Field mu0 = lr.getClass().getDeclaredField("mu0"); mu0.setAccessible(true); double mu0Val = (Double) mu0.get(lr); Assert.assertEquals(50, mu0Val, 1.0e-7); //forgettingExponent (decayExponent) Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent"); forgettingExponent.setAccessible(true); double forgettingExponentVal = (Double) forgettingExponent.get(lr); Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7); } }