package uk.ac.imperial.lsds.seepmaster.scheduler.loadbalancing;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import uk.ac.imperial.lsds.seep.api.DataReference;
import uk.ac.imperial.lsds.seep.comm.Connection;
import uk.ac.imperial.lsds.seep.comm.protocol.ProtocolCommandFactory;
import uk.ac.imperial.lsds.seep.comm.protocol.SeepCommand;
import uk.ac.imperial.lsds.seep.infrastructure.SeepEndPoint;
import uk.ac.imperial.lsds.seep.scheduler.Stage;
import uk.ac.imperial.lsds.seep.scheduler.StageType;
import uk.ac.imperial.lsds.seepmaster.infrastructure.master.ExecutionUnit;
import uk.ac.imperial.lsds.seepmaster.infrastructure.master.InfrastructureManager;
import uk.ac.imperial.lsds.seepmaster.scheduler.ClusterDatasetRegistry;
import uk.ac.imperial.lsds.seepmaster.scheduler.CommandToNode;
import uk.ac.imperial.lsds.seepmaster.scheduler.ScheduleTracker;
public class DataParallelWithInputDataLocalityLoadBalancingStrategy implements LoadBalancingStrategy {
@Override
public List<CommandToNode> assignWorkToWorkers(Stage nextStage, InfrastructureManager inf, ScheduleTracker tracker) {
// moved in from previously external method
Set<Connection> conns = getWorkersInvolvedInStage(nextStage, inf);
// All input data references to process during next stage
int nextStageId = nextStage.getStageId();
Map<Integer, Set<DataReference>> drefs = nextStage.getInputDataReferences();
// Split input DataReference per worker to maximize locality (not load balancing)
List<CommandToNode> commands = new ArrayList<>();
final int totalWorkers = conns.size();
int currentWorker = 0;
for(Connection c : conns) {
SeepCommand esc = null;
Map<Integer, Set<DataReference>> perWorker = new HashMap<>();
for(Integer streamId : drefs.keySet()) {
for(DataReference dr : drefs.get(streamId)) {
// EXTERNAL. assign one and continue
if(! dr.isManaged()) {
assignDataReferenceToWorker(perWorker, streamId, dr);
currentWorker++;
break;
}
// MANAGED. Check whether to assign this DR or not. Assign when shuffled or locality=local
else {
// SHUFFLE/PARTITIONED CASE
if(dr.isPartitioned()) {
// In this case, assign to this worker all DataReference with seqId module
int partitionSeqId = dr.getPartitionId();
if(partitionSeqId % totalWorkers == currentWorker) {
assignDataReferenceToWorker(perWorker, streamId, dr);
}
}
// NORMAL CASE, MAKE LOCALITY=LOCAL
else if(dr.getControlEndPoint().getId() == c.getId()) {
// assign
assignDataReferenceToWorker(perWorker, streamId, dr);
}
}
}
currentWorker++;
}
// FIXME: what is outputdatareferences
int euId = c.getId();
List<Integer> rankedDatasets = tracker.getClusterDatasetRegistry().getRankedDatasetForNode(euId, tracker.getScheduleDescription());
esc = ProtocolCommandFactory.buildScheduleStageCommand(nextStageId,
perWorker, nextStage.getOutputDataReferences(), rankedDatasets);
CommandToNode ctn = new CommandToNode(esc, c);
commands.add(ctn);
}
return commands;
}
private Set<Connection> getWorkersInvolvedInStage(Stage stage, InfrastructureManager inf) {
Set<Connection> cons = new HashSet<>();
// In this case DataReference do not necessarily contain EndPoint information
if(stage.getStageType().equals(StageType.SOURCE_STAGE) || stage.getStageType().equals(StageType.UNIQUE_STAGE)) {
//TODO: probably this won't work later
// Simply report all nodes
for(ExecutionUnit eu : inf.executionUnitsInUse()) {
Connection conn = new Connection(eu.getControlEndPoint());
cons.add(conn);
}
}
// If not first stages, then DataReferences contain the right EndPoint information
else {
Set<SeepEndPoint> eps = stage.getInvolvedNodes();
for(SeepEndPoint ep : eps) {
Connection c = new Connection(ep);
cons.add(c);
}
}
return cons;
}
private void assignDataReferenceToWorker(Map<Integer, Set<DataReference>> perWorker, int streamId, DataReference dr) {
if(! perWorker.containsKey(streamId)) {
perWorker.put(streamId, new HashSet<>());
}
perWorker.get(streamId).add(dr);
}
}