/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;
import java.io.IOException;
import java.util.Random;
public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
/**
* The CrossFoldLearner is probably the best learner to use for new applications.
* @throws IOException If test resources aren't readable.
*/
@Test
public void crossValidation() throws IOException {
Vector target = readStandardData();
CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
.lambda(1 * 1.0e-3)
.learningRate(50);
train(getInput(), target, lr);
System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
test(getInput(), target, lr, 0.05, 0.3);
}
@Test
public void crossValidatedAuc() throws IOException {
RandomUtils.useTestSeed();
Random gen = RandomUtils.getRandom();
Matrix data = readCsv("cancer.csv");
CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
.stepOffset(10)
.decayExponent(0.7)
.lambda(1 * 1.0e-3)
.learningRate(5);
int k = 0;
int[] ordering = permute(gen, data.numRows());
for (int epoch = 0; epoch < 100; epoch++) {
for (int row : ordering) {
lr.train(row, (int) data.get(row, 9), data.viewRow(row));
System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
}
assertEquals(1, lr.auc(), 0.2);
}
assertEquals(1, lr.auc(), 0.1);
}
/**
* Verifies that a classifier with known coefficients does the right thing.
*/
@Test
public void testClassify() {
OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
// set up some internal coefficients as if we had learned them
lr.setBeta(0, 0, -1);
lr.setBeta(1, 0, -2);
// zero vector gives no information. All classes are equal.
Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
assertEquals(1 / 3.0, v.get(0), 1.0e-8);
assertEquals(1 / 3.0, v.get(1), 1.0e-8);
v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
assertEquals(1.0, v.zSum(), 1.0e-8);
assertEquals(1 / 3.0, v.get(0), 1.0e-8);
assertEquals(1 / 3.0, v.get(1), 1.0e-8);
assertEquals(1 / 3.0, v.get(2), 1.0e-8);
// weights for second vector component are still zero so all classifications are equally likely
v = lr.classify(new DenseVector(new double[]{0, 1}));
assertEquals(1 / 3.0, v.get(0), 1.0e-3);
assertEquals(1 / 3.0, v.get(1), 1.0e-3);
v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
assertEquals(1.0, v.zSum(), 1.0e-8);
assertEquals(1 / 3.0, v.get(0), 1.0e-3);
assertEquals(1 / 3.0, v.get(1), 1.0e-3);
assertEquals(1 / 3.0, v.get(2), 1.0e-3);
// but the weights on the first component are non-zero
v = lr.classify(new DenseVector(new double[]{1, 0}));
assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
assertEquals(1.0, v.zSum(), 1.0e-8);
assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
lr.setBeta(0, 1, 1);
v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
assertEquals(1.0, v.zSum(), 1.0e-8);
assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
lr.setBeta(1, 1, 3);
v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
assertEquals(1.0, v.zSum(), 1.0e-8);
assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
}
@Test
public void testTrain() throws Exception {
Vector target = readStandardData();
// lambda here needs to be relatively small to avoid swamping the actual signal, but can be
// larger than usual because the data are dense. The learning rate doesn't matter too much
// for this example, but should generally be < 1
// --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
// --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
.lambda(1 * 1.0e-3)
.learningRate(50);
train(getInput(), target, lr);
test(getInput(), target, lr, 0.05, 0.3);
}
}