/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.java.sampling; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; import org.apache.flink.testutils.junit.RetryOnFailure; import org.apache.flink.testutils.junit.RetryRule; import org.apache.flink.util.Preconditions; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * This test suite try to verify whether all the random samplers work as we expected, which mainly focus on: * <ul> * <li>Does sampled result fit into input parameters? we check parameters like sample fraction, sample size, * w/o replacement, and so on.</li> * <li>Does sampled result randomly selected? we verify this by measure how much does it distributed on source data. * Run Kolmogorov-Smirnov (KS) test between the random samplers and default reference samplers which is distributed * well-proportioned on source data. If random sampler select elements randomly from source, it would distributed * well-proportioned on source data as well. The KS test will fail to strongly reject the null hypothesis that * the distributions of sampling gaps are the same. * </li> * </ul> * * @see <a href="https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test">Kolmogorov Smirnov test</a> */ public class RandomSamplerTest { private static final int SOURCE_SIZE = 10000; private static final int DEFAULT_PARTITION_NUMBER = 10; private static final KolmogorovSmirnovTest ksTest = new KolmogorovSmirnovTest(); private static final List<Double> source = new ArrayList<Double>(SOURCE_SIZE); @Rule public final RetryRule retryRule = new RetryRule(); @SuppressWarnings({"unchecked", "rawtypes"}) private final List<Double>[] sourcePartitions = new List[DEFAULT_PARTITION_NUMBER]; @BeforeClass public static void init() { // initiate source data set. for (int i = 0; i < SOURCE_SIZE; i++) { source.add((double) i); } } private void initSourcePartition() { for (int i = 0; i< DEFAULT_PARTITION_NUMBER; i++) { sourcePartitions[i] = new ArrayList<Double>((int)Math.ceil((double)SOURCE_SIZE / DEFAULT_PARTITION_NUMBER)); } for (int i = 0; i< SOURCE_SIZE; i++) { int index = i % DEFAULT_PARTITION_NUMBER; sourcePartitions[index].add((double)i); } } @Test(expected = java.lang.IllegalArgumentException.class) public void testBernoulliSamplerWithUnexpectedFraction1() { verifySamplerFraction(-1, false); } @Test(expected = java.lang.IllegalArgumentException.class) public void testBernoulliSamplerWithUnexpectedFraction2() { verifySamplerFraction(2, false); } @Test @RetryOnFailure(times=3) public void testBernoulliSamplerFraction() { verifySamplerFraction(0.01, false); verifySamplerFraction(0.05, false); verifySamplerFraction(0.1, false); verifySamplerFraction(0.3, false); verifySamplerFraction(0.5, false); verifySamplerFraction(0.854, false); verifySamplerFraction(0.99, false); } @Test @RetryOnFailure(times=3) public void testBernoulliSamplerDuplicateElements() { verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.01)); verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.1)); verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.5)); } @Test(expected = java.lang.IllegalArgumentException.class) public void testPoissonSamplerWithUnexpectedFraction1() { verifySamplerFraction(-1, true); } @Test @RetryOnFailure(times=3) public void testPoissonSamplerFraction() { verifySamplerFraction(0.01, true); verifySamplerFraction(0.05, true); verifySamplerFraction(0.1, true); verifySamplerFraction(0.5, true); verifySamplerFraction(0.854, true); verifySamplerFraction(0.99, true); verifySamplerFraction(1.5, true); } @Test(expected = java.lang.IllegalArgumentException.class) public void testReservoirSamplerUnexpectedSize1() { verifySamplerFixedSampleSize(-1, true); } @Test(expected = java.lang.IllegalArgumentException.class) public void testReservoirSamplerUnexpectedSize2() { verifySamplerFixedSampleSize(-1, false); } @Test @RetryOnFailure(times=3) public void testBernoulliSamplerDistribution() { verifyBernoulliSampler(0.01d); verifyBernoulliSampler(0.05d); verifyBernoulliSampler(0.1d); verifyBernoulliSampler(0.5d); } @Test @RetryOnFailure(times=3) public void testPoissonSamplerDistribution() { verifyPoissonSampler(0.01d); verifyPoissonSampler(0.05d); verifyPoissonSampler(0.1d); verifyPoissonSampler(0.5d); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerSampledSize() { verifySamplerFixedSampleSize(1, true); verifySamplerFixedSampleSize(10, true); verifySamplerFixedSampleSize(100, true); verifySamplerFixedSampleSize(1234, true); verifySamplerFixedSampleSize(9999, true); verifySamplerFixedSampleSize(20000, true); verifySamplerFixedSampleSize(1, false); verifySamplerFixedSampleSize(10, false); verifySamplerFixedSampleSize(100, false); verifySamplerFixedSampleSize(1234, false); verifySamplerFixedSampleSize(9999, false); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerSampledSize2() { RandomSampler<Double> sampler = new ReservoirSamplerWithoutReplacement<Double>(20000); Iterator<Double> sampled = sampler.sample(source.iterator()); assertTrue("ReservoirSamplerWithoutReplacement sampled output size should not beyond the source size.", getSize(sampled) == SOURCE_SIZE); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerDuplicateElements() { verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(100)); verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(1000)); verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(5000)); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerWithoutReplacement() { verifyReservoirSamplerWithoutReplacement(100, false); verifyReservoirSamplerWithoutReplacement(500, false); verifyReservoirSamplerWithoutReplacement(1000, false); verifyReservoirSamplerWithoutReplacement(5000, false); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerWithReplacement() { verifyReservoirSamplerWithReplacement(100, false); verifyReservoirSamplerWithReplacement(500, false); verifyReservoirSamplerWithReplacement(1000, false); verifyReservoirSamplerWithReplacement(5000, false); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerWithMultiSourcePartitions1() { initSourcePartition(); verifyReservoirSamplerWithoutReplacement(100, true); verifyReservoirSamplerWithoutReplacement(500, true); verifyReservoirSamplerWithoutReplacement(1000, true); verifyReservoirSamplerWithoutReplacement(5000, true); } @Test @RetryOnFailure(times=3) public void testReservoirSamplerWithMultiSourcePartitions2() { initSourcePartition(); verifyReservoirSamplerWithReplacement(100, true); verifyReservoirSamplerWithReplacement(500, true); verifyReservoirSamplerWithReplacement(1000, true); verifyReservoirSamplerWithReplacement(5000, true); } /* * Sample with fixed size, verify whether the sampled result size equals to input size. */ private void verifySamplerFixedSampleSize(int numSample, boolean withReplacement) { RandomSampler<Double> sampler; if (withReplacement) { sampler = new ReservoirSamplerWithReplacement<Double>(numSample); } else { sampler = new ReservoirSamplerWithoutReplacement<Double>(numSample); } Iterator<Double> sampled = sampler.sample(source.iterator()); assertEquals(numSample, getSize(sampled)); } /* * Sample with fraction, and verify whether the sampled result close to input fraction. */ private void verifySamplerFraction(double fraction, boolean withReplacement) { RandomSampler<Double> sampler; if (withReplacement) { sampler = new PoissonSampler<Double>(fraction); } else { sampler = new BernoulliSampler<Double>(fraction); } // take 20 times sample, and take the average result size for next step comparison. int totalSampledSize = 0; double sampleCount = 20; for (int i = 0; i < sampleCount; i++) { totalSampledSize += getSize(sampler.sample(source.iterator())); } double resultFraction = totalSampledSize / ((double) SOURCE_SIZE * sampleCount); assertTrue(String.format("expected fraction: %f, result fraction: %f", fraction, resultFraction), Math.abs((resultFraction - fraction) / fraction) < 0.2); } /* * Test sampler without replacement, and verify that there should not exist any duplicate element in sampled result. */ private void verifyRandomSamplerDuplicateElements(final RandomSampler<Double> sampler) { List<Double> list = Lists.newLinkedList(new Iterable<Double>() { @Override public Iterator<Double> iterator() { return sampler.sample(source.iterator()); } }); Set<Double> set = Sets.newHashSet(list); assertTrue("There should not have duplicate element for sampler without replacement.", list.size() == set.size()); } private int getSize(Iterator<?> iterator) { int size = 0; while (iterator.hasNext()) { iterator.next(); size++; } return size; } private void verifyBernoulliSampler(double fraction) { BernoulliSampler<Double> sampler = new BernoulliSampler<Double>(fraction); verifyRandomSamplerWithFraction(fraction, sampler, true); verifyRandomSamplerWithFraction(fraction, sampler, false); } private void verifyPoissonSampler(double fraction) { PoissonSampler<Double> sampler = new PoissonSampler<Double>(fraction); verifyRandomSamplerWithFraction(fraction, sampler, true); verifyRandomSamplerWithFraction(fraction, sampler, false); } private void verifyReservoirSamplerWithReplacement(int numSamplers, boolean sampleOnPartitions) { ReservoirSamplerWithReplacement<Double> sampler = new ReservoirSamplerWithReplacement<Double>(numSamplers); verifyRandomSamplerWithSampleSize(numSamplers, sampler, true, sampleOnPartitions); verifyRandomSamplerWithSampleSize(numSamplers, sampler, false, sampleOnPartitions); } private void verifyReservoirSamplerWithoutReplacement(int numSamplers, boolean sampleOnPartitions) { ReservoirSamplerWithoutReplacement<Double> sampler = new ReservoirSamplerWithoutReplacement<Double>(numSamplers); verifyRandomSamplerWithSampleSize(numSamplers, sampler, true, sampleOnPartitions); verifyRandomSamplerWithSampleSize(numSamplers, sampler, false, sampleOnPartitions); } /* * Verify whether random sampler sample with fraction from source data randomly. There are two default sample, one is * sampled from source data with certain interval, the other is sampled only from the first half region of source data, * If random sampler select elements randomly from source, it would distributed well-proportioned on source data as well, * so the K-S Test result would accept the first one, while reject the second one. */ private void verifyRandomSamplerWithFraction(double fraction, RandomSampler<Double> sampler, boolean withDefaultSampler) { double[] baseSample; if (withDefaultSampler) { baseSample = getDefaultSampler(fraction); } else { baseSample = getWrongSampler(fraction); } verifyKSTest(sampler, baseSample, withDefaultSampler); } /* * Verify whether random sampler sample with fixed size from source data randomly. There are two default sample, one is * sampled from source data with certain interval, the other is sampled only from the first half region of source data, * If random sampler select elements randomly from source, it would distributed well-proportioned on source data as well, * so the K-S Test result would accept the first one, while reject the second one. */ private void verifyRandomSamplerWithSampleSize(int sampleSize, RandomSampler<Double> sampler, boolean withDefaultSampler, boolean sampleWithPartitions) { double[] baseSample; if (withDefaultSampler) { baseSample = getDefaultSampler(sampleSize); } else { baseSample = getWrongSampler(sampleSize); } verifyKSTest(sampler, baseSample, withDefaultSampler, sampleWithPartitions); } private void verifyKSTest(RandomSampler<Double> sampler, double[] defaultSampler, boolean expectSuccess) { verifyKSTest(sampler, defaultSampler, expectSuccess, false); } private void verifyKSTest(RandomSampler<Double> sampler, double[] defaultSampler, boolean expectSuccess, boolean sampleOnPartitions) { double[] sampled = getSampledOutput(sampler, sampleOnPartitions); double pValue = ksTest.kolmogorovSmirnovStatistic(sampled, defaultSampler); double dValue = getDValue(sampled.length, defaultSampler.length); if (expectSuccess) { assertTrue(String.format("KS test result with p value(%f), d value(%f)", pValue, dValue), pValue <= dValue); } else { assertTrue(String.format("KS test result with p value(%f), d value(%f)", pValue, dValue), pValue > dValue); } } private double[] getSampledOutput(RandomSampler<Double> sampler, boolean sampleOnPartitions) { Iterator<Double> sampled; if (sampleOnPartitions) { DistributedRandomSampler<Double> reservoirRandomSampler = (DistributedRandomSampler<Double>)sampler; List<IntermediateSampleData<Double>> intermediateResult = Lists.newLinkedList(); for (int i = 0; i< DEFAULT_PARTITION_NUMBER; i++) { Iterator<IntermediateSampleData<Double>> partialIntermediateResult = reservoirRandomSampler.sampleInPartition(sourcePartitions[i].iterator()); while (partialIntermediateResult.hasNext()) { intermediateResult.add(partialIntermediateResult.next()); } } sampled = reservoirRandomSampler.sampleInCoordinator(intermediateResult.iterator()); } else { sampled = sampler.sample(source.iterator()); } List<Double> list = Lists.newArrayList(); while (sampled.hasNext()) { list.add(sampled.next()); } return transferFromListToArrayWithOrder(list); } /* * Some sample result may not order by the input sequence, we should make it in order to do K-S test. */ private double[] transferFromListToArrayWithOrder(List<Double> list) { Collections.sort(list); double[] result = new double[list.size()]; for (int i = 0; i < list.size(); i++) { result[i] = list.get(i); } return result; } private double[] getDefaultSampler(double fraction) { Preconditions.checkArgument(fraction > 0, "Sample fraction should be positive."); int size = (int) (SOURCE_SIZE * fraction); double step = 1 / fraction; double[] defaultSampler = new double[size]; for (int i = 0; i < size; i++) { defaultSampler[i] = Math.round(step * i); } return defaultSampler; } private double[] getDefaultSampler(int fixSize) { Preconditions.checkArgument(fixSize > 0, "Sample fraction should be positive."); double step = SOURCE_SIZE / (double) fixSize; double[] defaultSampler = new double[fixSize]; for (int i = 0; i < fixSize; i++) { defaultSampler[i] = Math.round(step * i); } return defaultSampler; } /* * Build a failed sample distribution which only contains elements in the first half of source data. */ private double[] getWrongSampler(double fraction) { Preconditions.checkArgument(fraction > 0, "Sample size should be positive."); int size = (int) (SOURCE_SIZE * fraction); int halfSourceSize = SOURCE_SIZE / 2; double[] wrongSampler = new double[size]; for (int i = 0; i < size; i++) { wrongSampler[i] = (double) i % halfSourceSize; } return wrongSampler; } /* * Build a failed sample distribution which only contains elements in the first half of source data. */ private double[] getWrongSampler(int fixSize) { Preconditions.checkArgument(fixSize > 0, "Sample size be positive."); int halfSourceSize = SOURCE_SIZE / 2; double[] wrongSampler = new double[fixSize]; for (int i = 0; i < fixSize; i++) { wrongSampler[i] = (double) i % halfSourceSize; } return wrongSampler; } /* * Calculate the D value of K-S test for p-value 0.001, m and n are the sample size */ private double getDValue(int m, int n) { Preconditions.checkArgument(m > 0, "input sample size should be positive."); Preconditions.checkArgument(n > 0, "input sample size should be positive."); double first = (double) m; double second = (double) n; return 1.95 * Math.sqrt((first + second) / (first * second)); } }