/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.core; import rapaio.data.Frame; import rapaio.data.MappedFrame; import rapaio.data.Mapping; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; import static rapaio.core.RandomSource.nextDouble; /** * User: Aurelian Tutuianu <padreati@yahoo.com> */ public final class SamplingTools { /** * Discrete sampling with repetition. * Nothing special, just using the uniform discrete sampler offered by the system. */ public static int[] sampleWR(final int populationSize, int sampleSize) { int[] sample = new int[sampleSize]; for (int i = 0; i < sampleSize; i++) { sample[i] = RandomSource.nextInt(populationSize); } return sample; } /** * Draws an uniform discrete sample without replacement. * <p> * Implements reservoir sampling. * * @param populationSize population size * @param sampleSizes sample size * @return sampling indexes */ public static int[][] multiSampleWOR(final int populationSize, final int... sampleSizes) { int total = Arrays.stream(sampleSizes).sum(); int[] sample = sampleWOR(populationSize, total); int[][] result = new int[sampleSizes.length][]; int start = 0; for (int i = 0; i < sampleSizes.length; i++) { result[i] = new int[sampleSizes[i]]; System.arraycopy(sample, start, result[i], 0, result[i].length); start += result[i].length; } return result; } /** * Draws an uniform discrete sample without replacement. * <p> * Implements reservoir sampling. * * @param populationSize population size * @param sampleSize sample size * @return sampling indexes */ public static int[] sampleWOR(final int populationSize, final int sampleSize) { if (sampleSize > populationSize) { throw new IllegalArgumentException("Can't draw a sample without replacement bigger than population size."); } int[] sample = new int[sampleSize]; for (int i = 0; i < sampleSize; i++) { sample[i] = i; } for (int i = sampleSize; i > 1; i--) { int j = RandomSource.nextInt(i); int tmp = sample[i - 1]; sample[i - 1] = sample[j]; sample[j] = tmp; } for (int i = sampleSize; i < populationSize; i++) { int j = RandomSource.nextInt(i + 1); if (j < sampleSize) { sample[j] = i; } } return sample; } /** * Generate discrete weighted random samples with replacement (same values might occur) * with building aliases according to the new probabilities. * <p> * Implementation based on Vose alias-method algorithm * * @param sampleSize sample size * @param freq sampling probabilities * @return sampling indexes */ public static int[] sampleWeightedWR(final int sampleSize, final double[] freq) { normalize(freq); double[] prob = Arrays.copyOf(freq, freq.length); for (int i = 0; i < prob.length; i++) { prob[i] *= prob.length; } int[] alias = new int[freq.length]; makeAliasWR(freq, prob, alias); int[] sample = new int[sampleSize]; for (int i = 0; i < sampleSize; i++) { int column = RandomSource.nextInt(prob.length); sample[i] = RandomSource.nextDouble() < prob[column] ? column : alias[column]; } return sample; } /** * Draw m <= n weighted random samples, weight by probabilities * without replacement. * <p> * Weighted random sampling without replacement. * Implements Efraimidis-Spirakis method. * * @param sampleSize number of samples * @param freq var of probabilities * @return sampling indexes * @see "http://link.springer.com/content/pdf/10.1007/978-0-387-30162-4_478.pdf" */ public static int[] sampleWeightedWOR(final int sampleSize, final double[] freq) { // validation if (sampleSize > freq.length) { throw new IllegalArgumentException("required sample size is bigger than population size"); } normalize(freq); int[] result = new int[sampleSize]; if (sampleSize == freq.length) { for (int i = 0; i < freq.length; i++) { result[i] = i; } return result; } int len = 1; while (len <= sampleSize) { len *= 2; } len = len * 2; int[] heap = new int[len]; double[] k = new double[sampleSize]; // fill with invalid ids for (int i = 0; i < len; i++) { heap[i] = -1; } // fill heap base for (int i = 0; i < sampleSize; i++) { heap[i + len / 2] = i; k[i] = Math.pow(nextDouble(), 1. / freq[i]); result[i] = i; } // learn heap for (int i = len / 2 - 1; i > 0; i--) { if (heap[i * 2] == -1) { heap[i] = -1; continue; } if (heap[i * 2 + 1] == -1) { heap[i] = heap[i * 2]; continue; } if (k[heap[i * 2]] < k[heap[i * 2 + 1]]) { heap[i] = heap[i * 2]; } else { heap[i] = heap[i * 2 + 1]; } } // exhaust the source int pos = sampleSize; while (pos < freq.length) { double r = nextDouble(); double xw = Math.log(r) / Math.log(k[heap[1]]); double acc = 0; while (pos < freq.length) { if (acc + freq[pos] < xw) { acc += freq[pos]; pos++; continue; } break; } if (pos == freq.length) break; // min replaced with the new selected value double tw = Math.pow(k[heap[1]], freq[pos]); double r2 = nextDouble() * (1. - tw) + tw; double ki = Math.pow(r2, 1 / freq[pos]); k[heap[1]] = ki; result[heap[1]] = pos++; int start = heap[1] + len / 2; while (start > 1) { start /= 2; if (heap[start * 2 + 1] == -1) { heap[start] = heap[start * 2]; continue; } if (k[heap[start * 2]] < k[heap[start * 2 + 1]]) { heap[start] = heap[start * 2]; } else { heap[start] = heap[start * 2 + 1]; } } } return result; } private static void normalize(double[] freq) { if (freq == null) { throw new IllegalArgumentException("sampling probability array cannot be null"); } double total = 0; for (double p : freq) { if (p < 0) { throw new IllegalArgumentException("frequencies must be positive."); } total += p; } if (total <= 0) { throw new IllegalArgumentException("sum of frequencies must be strict positive"); } if (total != 1.0) { for (int i = 0; i < freq.length; i++) { freq[i] /= total; } } } /** * Builds discrete random sampler without replacement */ private static void makeAliasWR(double[] p, double[] prob, int[] alias) { if (p.length == 0) throw new IllegalArgumentException("Probability var must be nonempty."); int[] dq = new int[p.length]; int smallPos = -1; int largePos = prob.length; for (int i = 0; i < prob.length; ++i) { if (prob[i] >= 1.) { dq[largePos - 1] = i; largePos--; } else { dq[smallPos + 1] = i; smallPos++; } } while (smallPos >= 0 && largePos <= p.length - 1) { int small = dq[smallPos--]; int large = dq[largePos++]; alias[small] = large; prob[large] = prob[large] + prob[small] - 1.; if (prob[large] >= 1.0) { dq[largePos - 1] = large; largePos--; } else { dq[smallPos + 1] = large; smallPos++; } } while (smallPos > 0) { prob[dq[smallPos - 1]] = 1.0; smallPos--; } while (largePos < dq.length) { prob[dq[largePos]] = 1.0; largePos++; } } public static List<Frame> randomSampleSlices(Frame frame, double... freq) { int total = 0; for (double f : freq) { total += (int) (f * frame.rowCount()); } if (total > frame.rowCount()) { throw new IllegalArgumentException("total counts greater than available number of rows"); } List<Frame> result = new ArrayList<>(); List<Integer> rows = IntStream.range(0, frame.rowCount()).mapToObj(i -> i).collect(Collectors.toList()); Collections.shuffle(rows, RandomSource.getRandom()); int start = 0; for (double f : freq) { int len = (int) (f * frame.rowCount()); result.add(frame.mapRows(Mapping.copy(rows.subList(start, start + len)))); start += len; } if (start < frame.rowCount()) { result.add(frame.mapRows(Mapping.copy(rows.subList(start, frame.rowCount())))); } return result; } public static List<Frame> randomSampleStratifiedSplit(Frame df, String strataName, double p) { if (p <= 0 || p >= 1) { throw new IllegalArgumentException("Percentage must be in interval (0, 1)"); } List<List<Integer>> maps = new ArrayList<>(); for (int i = 0; i < df.var(strataName).levels().length; i++) { maps.add(new ArrayList<>()); } df.var(strataName).stream().forEach(s -> maps.get(s.index()).add(s.row())); List<Integer> left = new ArrayList<>(); List<Integer> right = new ArrayList<>(); for (List<Integer> map : maps) { Collections.shuffle(map, RandomSource.getRandom()); left.addAll(map.subList(0, (int) (p * map.size()))); right.addAll(map.subList((int) (p * map.size()), map.size())); } Collections.shuffle(left, RandomSource.getRandom()); Collections.shuffle(right, RandomSource.getRandom()); List<Frame> list = new ArrayList<>(); list.add(df.mapRows(Mapping.wrap(left))); list.add(df.mapRows(Mapping.wrap(right))); return list; } public static Frame randomBootstrap(Frame frame) { return randomBootstrap(frame, 1.0); } public static Frame randomBootstrap(Frame frame, double percent) { return MappedFrame.byRow(frame, Mapping.copy(SamplingTools.sampleWR(frame.rowCount(), (int) (percent * frame.rowCount())))); } }