/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.data.sample; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import rapaio.core.CoreTools; import rapaio.core.RandomSource; import rapaio.core.tests.ChiSquareTest; import rapaio.core.tools.DVector; import rapaio.data.Frame; import rapaio.data.Numeric; import rapaio.datasets.Datasets; import java.util.stream.DoubleStream; /** * Test for row sampling tools * * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 1/26/16. */ public class RowSamplerTest { private Frame df; private Numeric w; @Before public void setUp() throws Exception { df = Datasets.loadIrisDataset(); w = Numeric.from(df.rowCount(), row -> (double) df.index(row, "class")).withName("w"); Assert.assertEquals(w.stream().mapToDouble().sum(), 50 * (1 + 2 + 3), 1e-20); } @Test public void identitySamplerTest() { Sample s = RowSampler.identity().nextSample(df, w); Assert.assertTrue(s.df.deepEquals(df)); Assert.assertTrue(s.weights.deepEquals(w)); } @Test public void bootstrapTest() { RandomSource.setSeed(123); int N = 1_000; Numeric count = Numeric.empty().withName("bcount"); for (int i = 0; i < N; i++) { Sample s = RowSampler.bootstrap(1.0).nextSample(df, w); count.addValue(1.0 * s.mapping.rowStream().distinct().count() / df.rowCount()); } // close to 1 - 1 / exp(1) Assert.assertEquals(0.63328, CoreTools.mean(count).value(), 1e-5); } @Test public void subsampleTest() { RandomSource.setSeed(123); int N = 1_000; Numeric count = Numeric.fill(df.rowCount(), 0.0).withName("sscount"); for (int i = 0; i < N; i++) { Sample s = RowSampler.subsampler(0.5).nextSample(df, w); s.mapping.rowStream().forEach(r -> count.setValue(r, count.value(r) + 1)); } // uniform counts close to 500 count.printLines(); DVector freq = DVector.empty(true, df.rowCount()); for (int i = 0; i < df.rowCount(); i++) { freq.set(i, count.value(i)); } double[] p = DoubleStream.generate(() -> 1 / 150.).limit(150).toArray(); ChiSquareTest chiTest = ChiSquareTest.goodnessOfFitTest(freq, p); chiTest.printSummary(); // chi square goodness of fit Assert.assertTrue(chiTest.pValue() > 0.99); } @Test public void nameSamplerTest() { Assert.assertEquals("Identity", RowSampler.identity().name()); Assert.assertEquals("Bootstrap(p=1)", RowSampler.bootstrap().name()); Assert.assertEquals("Bootstrap(p=0.2)", RowSampler.bootstrap(0.2).name()); Assert.assertEquals("SubSampler(p=1)", RowSampler.subsampler(1.0).name()); Assert.assertEquals("SubSampler(p=0.2)", RowSampler.subsampler(0.2).name()); } }