package edu.cmu.graphchi.util; import edu.cmu.graphchi.ChiVertex; import java.util.Random; /** * Samples values from multinomial distribution. * @author Aapo Kyrola */ public class MultinomialSampler { public static <VT> int[] generateSamplesAliasMethod(Random r, float[] weights, int n) { int l = weights.length; float[] values = new float[l]; int[] aliases = new int[l]; // Compute average float sum = 0; for(int i=0; i < l; i++) { float x = weights[i]; sum += x; values[i] = x; } int[] aboveAverages = new int[l]; int[] belowAverages = new int[l]; int aboveIdx = 0; int belowIdx = 0; // Init stacks for(int i=0; i < l; i++) { values[i] = values[i] / sum * l; if (values[i] < 1.0f) { belowAverages[belowIdx++] = i; } else { aboveAverages[aboveIdx++] = i; } aliases[i] = -1; } // Start shoveling while(aboveIdx > 0 && belowIdx > 0) { int small = belowAverages[--belowIdx]; int large = aboveAverages[--aboveIdx]; aliases[small] = large; values[large] = (values[large] - (1.0f - values[small])); if (values[large] < 1) { belowAverages[belowIdx++] = large; } else { aboveAverages[aboveIdx++] = large; } } while(aboveIdx > 0) { values[aboveAverages[--aboveIdx]] = 1.0f; } while(belowIdx > 0) { // might happen for numerical instability values[belowAverages[--belowIdx]] = 1.0f; } int[] samples = new int[n]; // Hops for(int i=0; i < n; i++) { int bucket = r.nextInt(l); float val = r.nextFloat(); if (val < values[bucket]) { samples[i] = bucket; } else { samples[i] = aliases[bucket]; if (samples[i] < 0) samples[i] = bucket; } } return samples; } }