package edu.washington.escience.myria.operator; import java.util.Arrays; import java.util.HashSet; import java.util.Random; import java.util.Set; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.util.SamplingType; /** * Takes in some sampling parameters from the left child and uses that to sample * tuples from the right child. */ public class Sample extends BinaryOperator { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; /** Total number of tuples to expect from the right operator. */ private int streamSize; /** Number of tuples to sample from the right operator. */ private int sampleSize; /** Random generator used for index selection. */ private Random rand; /** The type of sampling to perform. */ private SamplingType sampleType; /** True if operator has extracted sampling info. */ private boolean computedSamplingInfo = false; /** Buffer for tuples that will be returned. */ private TupleBatchBuffer ans; /** Global count of the tuples seen so far. */ private int tuplesSeen = 0; /** Sorted array of tuple indices that will be taken as samples. */ private int[] sampleIndices; /** Current index of the samples array. */ private int curSampIdx = 0; /** * Instantiate a Sample operator using sampling info from the left operator * and the stream from the right operator. * * @param left * inputs a (WorkerID, StreamSize, SampleSize, SampleType) tuple. * @param right * tuples that will be sampled from. * @param randomSeed * value to seed the random generator with. null if no specified seed */ public Sample(final Operator left, final Operator right, Long randomSeed) { super(left, right); rand = new Random(); if (randomSeed != null) { rand.setSeed(randomSeed); } } @Override protected TupleBatch fetchNextReady() throws Exception { // Extract sampling info from left operator. if (!computedSamplingInfo) { TupleBatch tb = getLeft().nextReady(); if (tb == null) { return null; } extractSamplingInfo(tb); getLeft().close(); // Cannot sampleWoR more tuples than there are. if (sampleType == SamplingType.WithoutReplacement) { Preconditions.checkState( sampleSize <= streamSize, "Cannot SampleWoR %s tuples from a population of size %s", sampleSize, streamSize); } // Generate target indices to accept as samples. if (sampleType == SamplingType.WithReplacement) { sampleIndices = generateIndicesWR(streamSize, sampleSize); } else if (sampleType == SamplingType.WithoutReplacement) { sampleIndices = generateIndicesWoR(streamSize, sampleSize); } else { throw new DbException("Invalid sampleType: " + sampleType); } computedSamplingInfo = true; } // Return a ready tuple batch if possible. TupleBatch nexttb = ans.popAny(); if (nexttb != null) { return nexttb; } // Check if there's nothing left to sample. if (curSampIdx >= sampleIndices.length) { getRight().close(); return null; } Operator right = getRight(); for (TupleBatch tb = right.nextReady(); tb != null; tb = right.nextReady()) { if (curSampIdx >= sampleIndices.length) { // done sampling break; } if (sampleIndices[curSampIdx] >= tuplesSeen + tb.numTuples()) { // nextIndex is not in this batch. Continue with next batch. tuplesSeen += tb.numTuples(); continue; } while (curSampIdx < sampleIndices.length && sampleIndices[curSampIdx] < tuplesSeen + tb.numTuples()) { ans.append(tb, sampleIndices[curSampIdx] - tuplesSeen); curSampIdx++; } tuplesSeen += tb.numTuples(); if (ans.hasFilledTB()) { return ans.popFilled(); } } return ans.popAny(); } /** Helper function to extract sampling information from a TupleBatch. */ private void extractSamplingInfo(TupleBatch tb) throws Exception { Preconditions.checkArgument(tb != null); int workerID; Type col0Type = tb.getSchema().getColumnType(0); if (col0Type == Type.INT_TYPE) { workerID = tb.getInt(0, 0); } else if (col0Type == Type.LONG_TYPE) { workerID = (int) tb.getLong(0, 0); } else { throw new DbException("WorkerID column must be of type INT or LONG"); } Preconditions.checkState( workerID == getNodeID(), "Invalid WorkerID for this worker. Expected %s, but received %s", getNodeID(), workerID); Type col1Type = tb.getSchema().getColumnType(1); if (col1Type == Type.INT_TYPE) { streamSize = tb.getInt(1, 0); } else if (col1Type == Type.LONG_TYPE) { streamSize = (int) tb.getLong(1, 0); } else { throw new DbException("StreamSize column must be of type INT or LONG"); } Preconditions.checkState(streamSize >= 0, "streamSize cannot be negative"); Type col2Type = tb.getSchema().getColumnType(2); if (col2Type == Type.INT_TYPE) { sampleSize = tb.getInt(2, 0); } else if (col2Type == Type.LONG_TYPE) { sampleSize = (int) tb.getLong(2, 0); } else { throw new DbException("SampleSize column must be of type INT or LONG"); } Preconditions.checkState(sampleSize >= 0, "sampleSize cannot be negative"); Type col3Type = tb.getSchema().getColumnType(3); if (col3Type == Type.STRING_TYPE) { String col3Val = tb.getString(3, 0); try { sampleType = SamplingType.valueOf(col3Val); } catch (IllegalArgumentException e) { throw new DbException("Invalid SampleType: " + col3Val); } } else { throw new DbException("SampleType column must be of type STRING"); } } /** * Generates a sorted array of random numbers to be taken as samples. * * @param populationSize * size of the population that will be sampled from. * @param sampleSize * number of samples to draw from the population. * @return a sorted array of indices. */ private int[] generateIndicesWR(int populationSize, int sampleSize) { int[] indices = new int[sampleSize]; for (int i = 0; i < sampleSize; i++) { indices[i] = rand.nextInt(populationSize); } Arrays.sort(indices); return indices; } /** * Generates a sorted array of unique random numbers to be taken as samples. * The implementation uses Floyd's algorithm. For an explanation: * www.nowherenearithaca.com/2013/05/robert-floyds-tiny-and-beautiful.html * * @param populationSize * size of the population that will be sampled from. * @param sampleSize * number of samples to draw from the population. * @return a sorted array of indices. */ private int[] generateIndicesWoR(int populationSize, int sampleSize) { Set<Integer> indices = new HashSet<>(sampleSize); for (int i = populationSize - sampleSize; i < populationSize; i++) { int idx = rand.nextInt(i + 1); if (indices.contains(idx)) { indices.add(i); } else { indices.add(idx); } } int[] indicesArr = Ints.toArray(indices); Arrays.sort(indicesArr); return indicesArr; } @Override public Schema generateSchema() { Operator right = getRight(); if (right == null) { return null; } return right.getSchema(); } @Override protected void init(final ImmutableMap<String, Object> execEnvVars) { ans = new TupleBatchBuffer(getSchema()); } @Override public void cleanup() { ans = null; } }