package de.jungblut.online.regression.multinomial; import java.util.List; import java.util.function.IntFunction; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.math3.random.RandomDataImpl; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import de.jungblut.math.DoubleVector; import de.jungblut.math.activation.SigmoidActivationFunction; import de.jungblut.math.dense.DenseDoubleVector; import de.jungblut.math.loss.LogLoss; import de.jungblut.online.minimizer.StochasticGradientDescent; import de.jungblut.online.minimizer.StochasticGradientDescent.StochasticGradientDescentBuilder; import de.jungblut.online.ml.FeatureOutcomePair; import de.jungblut.online.regression.RegressionLearner; public class TestMultinomialRegressionLearner { private RandomDataImpl rnd; @Before public void setup() { rnd = new RandomDataImpl(); rnd.reSeed(0); } @Test public void testSimpleMultinomialRegression() { IntFunction<RegressionLearner> factory = (i) -> { StochasticGradientDescent minimizer = StochasticGradientDescentBuilder .create(1e-4).progressReportInterval(100_000).build(); RegressionLearner learner = new RegressionLearner(minimizer, new SigmoidActivationFunction(), new LogLoss()); learner.setNumPasses(100); return learner; }; MultinomialRegressionLearner learner = new MultinomialRegressionLearner( factory); List<FeatureOutcomePair> trainingSet = generateData(); MultinomialRegressionModel model = learner .train(() -> trainingSet.stream()); double acc = computeClassificationAccuracy(generateData(), model); Assert.assertEquals(1d, acc, 0.1); } public double computeClassificationAccuracy(List<FeatureOutcomePair> data, MultinomialRegressionModel model) { double correct = 0; MultinomialRegressionClassifier clf = new MultinomialRegressionClassifier( model); for (FeatureOutcomePair pair : data) { DoubleVector prediction = clf.predict(pair.getFeature()); if (prediction.maxIndex() == pair.getOutcome().maxIndex()) { correct++; } } return correct / data.size(); } public List<FeatureOutcomePair> generateData() { // similar to the mickey mouse data set final int[] centersX = new int[] { 25, 50, 75 }; final int[] centersY = new int[] { 25, 150, 75 }; return IntStream .range(1, 5000) .mapToObj( (i) -> { int clz = i % centersX.length; double meanX = centersX[clz]; double meanY = centersY[clz]; double stddev = 5d; double[] feat = new double[] { 1, rnd.nextGaussian(meanX, stddev), rnd.nextGaussian(meanY, stddev) }; DoubleVector outcome = new DenseDoubleVector(centersX.length); outcome.set(clz, 1d); return new FeatureOutcomePair(new DenseDoubleVector(feat), outcome); }).collect(Collectors.toList()); } }