package de.jungblut.online.regression; import org.junit.Assert; import org.junit.Test; import de.jungblut.math.DoubleVector; import de.jungblut.math.activation.SigmoidActivationFunction; import de.jungblut.math.dense.DenseDoubleVector; public class TestRegressionClassifier { @Test public void testClassifier() { // weights from the learner test DenseDoubleVector weights = new DenseDoubleVector(new double[] { -159.7796434436107, 1.178953822695672, 2.0180958310781554 }); RegressionClassifier classifier = new RegressionClassifier(weights, new SigmoidActivationFunction()); DoubleVector prediction = classifier.predict(new DenseDoubleVector( new double[] { 1, 75d, 75d })); Assert.assertEquals(1d, prediction.get(0), 1e-4); prediction = classifier.predict(new DenseDoubleVector(new double[] { 1, 25d, 25d })); Assert.assertEquals(0d, prediction.get(0), 1e-4); } @Test(expected = IllegalArgumentException.class) public void testFailOnDimensionMismatch() { DenseDoubleVector weights = new DenseDoubleVector(new double[] { 0, 0, 0 }); RegressionClassifier classifier = new RegressionClassifier(weights, new SigmoidActivationFunction()); classifier.predict(new DenseDoubleVector(new double[] { 0, 0 })); } }