package edu.washington.escience.myria.operator.network.distribute;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nonnull;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonSubTypes.Type;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.google.common.collect.Lists;
import edu.washington.escience.myria.storage.TupleBatch;
/**
* A dataset is distributed by two steps: First, using a partition function to generate a partition for each tuple;
* Second, mapping each partition to a set of destinations. A destination corresponds to an output channel ID
* corresponding to a (worker ID, operator ID) pair.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
@JsonSubTypes({
@Type(value = BroadcastDistributeFunction.class, name = "Broadcast"),
@Type(value = HyperCubeDistributeFunction.class, name = "HyperCube"),
@Type(value = HashDistributeFunction.class, name = "Hash"),
@Type(value = RoundRobinDistributeFunction.class, name = "RoundRobin"),
@Type(value = IdentityDistributeFunction.class, name = "Identity")
})
public abstract class DistributeFunction implements Serializable {
/** Required for Java serialization. */
private static final long serialVersionUID = 1L;
/** The partition function. */
protected PartitionFunction partitionFunction;
/** The mapping from partitions to destinations. */
protected List<List<Integer>> partitionToDestination;
/**
* @param partitionFunction partition function.
*/
public DistributeFunction(final PartitionFunction partitionFunction) {
this.partitionFunction = partitionFunction;
}
/**
* @param data the input data
* @return a list of tuple batch lists, each represents output data of one destination.
*/
public List<List<TupleBatch>> distribute(@Nonnull final TupleBatch data) {
List<List<TupleBatch>> result = new ArrayList<List<TupleBatch>>();
if (data.isEOI()) {
for (int i = 0; i < getNumDestinations(); ++i) {
result.add(Lists.newArrayList(data));
}
} else {
for (int i = 0; i < getNumDestinations(); ++i) {
result.add(new ArrayList<TupleBatch>());
}
TupleBatch[] tbs = partitionFunction.partition(data);
for (int i = 0; i < tbs.length; ++i) {
for (int channelIdx : partitionToDestination.get(i)) {
result.get(channelIdx).add(tbs[i]);
}
}
}
return result;
}
/**
* @return number of destinations
*/
public int getNumDestinations() {
Set<Integer> d = new HashSet<Integer>();
for (List<Integer> t : partitionToDestination) {
d.addAll(t);
}
return d.size();
}
/**
* @param numWorker the number of workers to distribute on
* @param numOperatorId the number of involved operator IDs
*/
public abstract void setDestinations(final int numWorker, final int numOperatorId);
}