/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.sgd; import com.google.common.io.Closeables; import org.apache.hadoop.io.Writable; import org.apache.mahout.classifier.OnlineLearner; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.function.Functions; import org.apache.mahout.math.function.DoubleFunction; import org.apache.mahout.math.stats.GlobalOnlineAuc; import org.apache.mahout.math.stats.OnlineAuc; import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.util.Random; public final class ModelSerializerTest extends MahoutTestCase { private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException { ByteArrayOutputStream buf = new ByteArrayOutputStream(1000); DataOutputStream dos = new DataOutputStream(buf); try { PolymorphicWritable.write(dos, m); } finally { Closeables.closeQuietly(dos); } return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz); } @Test public void onlineAucRoundtrip() throws IOException { RandomUtils.useTestSeed(); OnlineAuc auc1 = new GlobalOnlineAuc(); Random gen = RandomUtils.getRandom(); for (int i = 0; i < 10000; i++) { auc1.addSample(0, gen.nextGaussian()); auc1.addSample(1, gen.nextGaussian() + 1); } assertEquals(0.76, auc1.auc(), 0.01); OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class); assertEquals(auc1.auc(), auc3.auc(), 0); for (int i = 0; i < 1000; i++) { auc1.addSample(0, gen.nextGaussian()); auc1.addSample(1, gen.nextGaussian() + 1); auc3.addSample(0, gen.nextGaussian()); auc3.addSample(1, gen.nextGaussian() + 1); } assertEquals(auc1.auc(), auc3.auc(), 0.01); } @Test public void onlineLogisticRegressionRoundTrip() throws IOException { OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1()); train(olr, 100); OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); train(olr, 100); train(olr3, 100); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); } @Test public void crossFoldLearnerRoundTrip() throws IOException { CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1()); train(learner, 100); CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class); double auc1 = learner.auc(); assertTrue(auc1 > 0.85); assertEquals(auc1, learner.auc(), 1.0e-6); assertEquals(auc1, olr3.auc(), 1.0e-6); train(learner, 100); train(learner, 100); train(olr3, 100); assertEquals(learner.auc(), learner.auc(), 0.02); assertEquals(learner.auc(), olr3.auc(), 0.02); double auc2 = learner.auc(); assertTrue(auc2 > auc1); } @Test public void adaptiveLogisticRegressionRoundTrip() throws IOException { AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1()); learner.setInterval(200); train(learner, 400); AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class); double auc1 = learner.auc(); assertTrue(auc1 > 0.85); assertEquals(auc1, learner.auc(), 1.0e-6); assertEquals(auc1, olr3.auc(), 1.0e-6); train(learner, 1000); train(learner, 1000); train(olr3, 1000); assertEquals(learner.auc(), learner.auc(), 0.005); assertEquals(learner.auc(), olr3.auc(), 0.005); double auc2 = learner.auc(); assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1); } private static void train(OnlineLearner olr, int n) { Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5}); Random gen = RandomUtils.getRandom(); for (int i = 0; i < n; i++) { Vector x = randomVector(gen, 5); int target = gen.nextDouble() < beta.dot(x) ? 1 : 0; olr.train(target, x); } } private static Vector randomVector(final Random gen, int n) { Vector x = new DenseVector(n); x.assign(new DoubleFunction() { @Override public double apply(double v) { return gen.nextGaussian(); } }); return x; } }