package edu.washington.escience.myria.operator;
import java.util.List;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.Map;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import edu.washington.escience.myria.DbException;
import edu.washington.escience.myria.Schema;
import edu.washington.escience.myria.Type;
import edu.washington.escience.myria.column.Column;
import edu.washington.escience.myria.column.builder.ColumnBuilder;
import edu.washington.escience.myria.column.builder.ColumnFactory;
import edu.washington.escience.myria.storage.TupleBatch;
import edu.washington.escience.myria.util.SamplingType;
/**
* Given the sizes of each worker, computes a distribution of how much each
* worker should sample.
*/
public class SamplingDistribution extends UnaryOperator {
/** Required for Java serialization. */
private static final long serialVersionUID = 1L;
/** The output schema. */
private static final Schema SCHEMA =
Schema.ofFields(
"WorkerID",
Type.INT_TYPE,
"StreamSize",
Type.INT_TYPE,
"SampleSize",
Type.INT_TYPE,
"SampleType",
Type.STRING_TYPE);
/** Total number of tuples to sample. */
private int sampleSize = 0;
/** True if using a percentage instead of a specific tuple count. */
private boolean isPercentageSample = false;
/** Percentage of total tuples to sample. */
private float samplePercentage;
/** The type of sampling to perform. */
private final SamplingType sampleType;
/** Random generator used for creating the distribution. */
private Random rand;
/** Maps (worker_i) --> (sampling info for worker_i) */
SortedMap<Integer, WorkerInfo> workerInfo = new TreeMap<>();
/** Total number of tuples across all workers. */
int totalTupleCount = 0;
private SamplingDistribution(Operator child, SamplingType sampleType, Long randomSeed) {
super(child);
this.sampleType = sampleType;
rand = new Random();
if (randomSeed != null) {
rand.setSeed(randomSeed);
}
}
/**
* Instantiate a SamplingDistribution operator using a specific sample size.
*
* @param sampleSize
* total samples to create a distribution for.
* @param sampleType
* the type of sampling distribution to create
* @param child
* extracts (WorkerID, PartitionSize, StreamSize) information from
* this child.
* @param randomSeed
* value to seed the random generator with. null if no specified seed
*/
public SamplingDistribution(
Operator child, int sampleSize, SamplingType sampleType, Long randomSeed) {
this(child, sampleType, randomSeed);
Preconditions.checkArgument(sampleSize >= 0, "Sample Size must be >= 0: %s", sampleSize);
this.sampleSize = sampleSize;
}
/**
* Instantiate a SamplingDistribution operator using a percentage of total
* tuples.
*
* @param samplePercentage
* percentage of total samples to create a distribution for.
* @param sampleType
* the type of sampling distribution to create
* @param child
* extracts (WorkerID, PartitionSize, StreamSize) information from
* this child.
* @param randomSeed
* value to seed the random generator with. null if no specified seed
*/
public SamplingDistribution(
Operator child, float samplePercentage, SamplingType sampleType, Long randomSeed) {
this(child, sampleType, randomSeed);
this.isPercentageSample = true;
this.samplePercentage = samplePercentage;
Preconditions.checkArgument(
samplePercentage >= 0 && samplePercentage <= 100,
"Sample Percentage must be >= 0 && <= 100: %s",
samplePercentage);
}
@Override
protected TupleBatch fetchNextReady() throws DbException {
// Drain out all the worker info.
while (!getChild().eos()) {
TupleBatch tb = getChild().nextReady();
if (tb == null) {
if (getChild().eos()) {
break;
}
return null;
}
extractWorkerInfo(tb);
}
getChild().close();
// Convert samplePct to sampleSize if using a percentage sample.
if (isPercentageSample) {
sampleSize = Math.round(totalTupleCount * (samplePercentage / 100));
}
Preconditions.checkState(
sampleSize >= 0 && sampleSize <= totalTupleCount,
"Cannot extract %s samples from a population of size %s",
sampleSize,
totalTupleCount);
// Generate a sampling distribution across the workers.
if (sampleType == SamplingType.WithReplacement) {
withReplacementDistribution(workerInfo, totalTupleCount, sampleSize);
} else if (sampleType == SamplingType.WithoutReplacement) {
withoutReplacementDistribution(workerInfo, totalTupleCount, sampleSize);
} else {
throw new DbException("Invalid sampleType: " + sampleType);
}
// Build and return a TupleBatch with the distribution.
// Assumes that the sampling information can fit into one tuple batch.
List<ColumnBuilder<?>> colBuilders = ColumnFactory.allocateColumns(SCHEMA);
for (Map.Entry<Integer, WorkerInfo> iWorker : workerInfo.entrySet()) {
colBuilders.get(0).appendInt(iWorker.getKey());
colBuilders.get(1).appendInt(iWorker.getValue().actualTupleCount);
colBuilders.get(2).appendInt(iWorker.getValue().sampleSize);
colBuilders.get(3).appendString(sampleType.name());
}
ImmutableList.Builder<Column<?>> columns = ImmutableList.builder();
for (ColumnBuilder<?> cb : colBuilders) {
columns.add(cb.build());
}
setEOS();
return new TupleBatch(SCHEMA, columns.build());
}
/** Helper function to extract worker information from a tuple batch. */
private void extractWorkerInfo(TupleBatch tb) throws DbException {
Type col0Type = tb.getSchema().getColumnType(0);
Type col1Type = tb.getSchema().getColumnType(1);
boolean hasActualTupleCount = false;
Type col2Type = null;
if (tb.getSchema().numColumns() > 2) {
hasActualTupleCount = true;
col2Type = tb.getSchema().getColumnType(2);
}
for (int i = 0; i < tb.numTuples(); i++) {
int workerID;
if (col0Type == Type.INT_TYPE) {
workerID = tb.getInt(0, i);
} else if (col0Type == Type.LONG_TYPE) {
workerID = (int) tb.getLong(0, i);
} else {
throw new DbException("WorkerID must be of type INT or LONG");
}
Preconditions.checkState(workerID > 0, "WorkerID must be > 0");
Preconditions.checkState(!workerInfo.containsKey(workerID), "Duplicate WorkerIDs");
int tupleCount;
if (col1Type == Type.INT_TYPE) {
tupleCount = tb.getInt(1, i);
} else if (col1Type == Type.LONG_TYPE) {
tupleCount = (int) tb.getLong(1, i);
} else {
throw new DbException("TupleCount must be of type INT or LONG");
}
Preconditions.checkState(
tupleCount >= 0, "Worker cannot have a negative TupleCount: %s", tupleCount);
int actualTupleCount = tupleCount;
if (hasActualTupleCount) {
if (col2Type == Type.INT_TYPE) {
actualTupleCount = tb.getInt(2, i);
} else if (col2Type == Type.LONG_TYPE) {
actualTupleCount = (int) tb.getLong(2, i);
} else {
throw new DbException("ActualTupleCount must be of type INT or LONG");
}
Preconditions.checkState(
tupleCount >= 0,
"Worker cannot have a negative ActualTupleCount: %d",
actualTupleCount);
}
WorkerInfo wInfo = new WorkerInfo(tupleCount, actualTupleCount);
workerInfo.put(workerID, wInfo);
totalTupleCount += tupleCount;
}
}
/**
* Creates a WithReplacement distribution across the workers.
*
* @param workerInfo
* reference to the workerInfo to modify.
* @param totalTupleCount
* total # of tuples across all workers.
* @param sampleSize
* total # of samples to distribute across the workers.
*/
private void withReplacementDistribution(
SortedMap<Integer, WorkerInfo> workerInfo, int totalTupleCount, int sampleSize) {
for (int i = 0; i < sampleSize; i++) {
int sampleTupleIdx = rand.nextInt(totalTupleCount);
// Assign this tuple to the workerID that holds this sampleTupleIdx.
int tupleOffset = 0;
for (Map.Entry<Integer, WorkerInfo> iWorker : workerInfo.entrySet()) {
WorkerInfo wInfo = iWorker.getValue();
if (sampleTupleIdx < wInfo.tupleCount + tupleOffset) {
wInfo.sampleSize += 1;
break;
}
tupleOffset += wInfo.tupleCount;
}
}
}
/**
* Creates a WithoutReplacement distribution across the workers.
*
* @param workerInfo
* reference to the workerInfo to modify.
* @param totalTupleCount
* total # of tuples across all workers.
* @param sampleSize
* total # of samples to distribute across the workers.
*/
private void withoutReplacementDistribution(
SortedMap<Integer, WorkerInfo> workerInfo, int totalTupleCount, int sampleSize) {
SortedMap<Integer, Integer> logicalTupleCounts = new TreeMap<>();
for (Map.Entry<Integer, WorkerInfo> wInfo : workerInfo.entrySet()) {
logicalTupleCounts.put(wInfo.getKey(), wInfo.getValue().tupleCount);
}
for (int i = 0; i < sampleSize; i++) {
int sampleTupleIdx = rand.nextInt(totalTupleCount - i);
// Assign this tuple to the workerID that holds this sampleTupleIdx.
int tupleOffset = 0;
for (Map.Entry<Integer, WorkerInfo> iWorker : workerInfo.entrySet()) {
int wID = iWorker.getKey();
WorkerInfo wInfo = iWorker.getValue();
if (sampleTupleIdx < logicalTupleCounts.get(wID) + tupleOffset) {
wInfo.sampleSize += 1;
// Cannot sample the same tuple, so pretend it doesn't exist anymore.
logicalTupleCounts.put(wID, logicalTupleCounts.get(wID) - 1);
break;
}
tupleOffset += logicalTupleCounts.get(wID);
}
}
}
/**
* Returns the sample size of this operator. If operator was created using a
* samplePercentage, this value will be 0 until after fetchNextReady.
*/
public int getSampleSize() {
return sampleSize;
}
/**
* Returns the percentage of total tuples that this operator will distribute.
* Will be 0 if the operator was created using a specific sampleSize.
*/
public float getSamplePercentage() {
return samplePercentage;
}
/** Returns the type of sampling distribution that this operator will create. */
public SamplingType getSampleType() {
return sampleType;
}
@Override
public Schema generateSchema() {
return SCHEMA;
}
@Override
public void cleanup() {
workerInfo = null;
}
/** Encapsulates sampling information about a worker. */
private class WorkerInfo {
/** # of tuples that this worker owns. */
int tupleCount;
/**
* Actual # of tuples that the worker has stored. May be different than
* tupleCount if the worker pre-sampled the data.
**/
int actualTupleCount;
/** # of tuples that the distribution assigned to this worker. */
int sampleSize = 0;
WorkerInfo(int tupleCount, int actualTupleCount) {
this.tupleCount = tupleCount;
this.actualTupleCount = actualTupleCount;
}
}
}