package uk.ac.imperial.lsds.seepmaster.scheduler.memorymanagement;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.imperial.lsds.seep.core.DatasetMetadata;
import uk.ac.imperial.lsds.seep.core.DatasetMetadataPackage;
import uk.ac.imperial.lsds.seep.scheduler.ScheduleDescription;
import uk.ac.imperial.lsds.seep.scheduler.Stage;
public class MDFMemoryManagementPolicy implements MemoryManagementPolicy {
final private Logger LOG = LoggerFactory.getLogger(MDFMemoryManagementPolicy.class);
private ScheduleDescription sd;
private double dmRatio;
private Map<Integer, Map<Integer, Double>> euId_mdf = new HashMap<>();
private Map<Integer, Integer> stageid_accesses = new HashMap<>();
private Map<Integer, Long> stageid_size = new HashMap<>(); // proportional to #nodes
private Map<Integer, Long> stageid_cost = new HashMap<>(); // proportional to #nodes
private Map<Integer, Double> stageid_ratio_inmem = new HashMap<>(); // proportional to #nodes
private Map<Integer, Integer> dataset_access_count = new HashMap<>();
private Map<Integer, Integer> dataset_expected_count = new HashMap<>();
// Metrics
private long __totalUpdateTime = 0;
private long __totalRankTime = 0;
public MDFMemoryManagementPolicy(ScheduleDescription sd, double dmRatio) {
this.sd = sd;
this.dmRatio = dmRatio;
computeAccesses(sd);
}
@Override
public void updateDatasetsForNode(int euId, DatasetMetadataPackage datasetsMetadata, int stageId) {
long start = System.currentTimeMillis();
if(! euId_mdf.containsKey(euId)) {
euId_mdf.put(euId, new HashMap<Integer, Double>());
}
// Get datasets generated by this stage
Set<DatasetMetadata> datasetsOfThisStage = new HashSet<>();
for (DatasetMetadata dm : datasetsMetadata.newDatasets) {
if (!dataset_expected_count.containsKey(dm.getDatasetId())) {
dataset_expected_count.put(dm.getDatasetId(), stageid_accesses.get(stageId));
euId_mdf.get(euId).put(dm.getDatasetId(),
stageid_accesses.get(stageId)* Math.min(dm.getSize() * dmRatio, computeRecomputeCostFor(stageId)));
datasetsOfThisStage.add(dm);
//System.out.println("Stage " + stageId + ", dataset " + dm.getDatasetId() + ", accesses " + stageid_accesses.get(stageId));
}
}
// Compute variables for the model
long sizeOfThisDataset = computeSizeOfDataset(datasetsOfThisStage);
long costOfDataset = computeCostOfDataset(datasetsOfThisStage);
double percDataInMem = computeRatioDataInMem(datasetsOfThisStage, sizeOfThisDataset);
// Store variables
stageid_size.put(stageId, sizeOfThisDataset);
stageid_cost.put(stageId, costOfDataset);
stageid_ratio_inmem.put(stageId, percDataInMem);
for (DatasetMetadata dm : datasetsMetadata.usedDatasets) {
if (dataset_access_count.containsKey(dm.getDatasetId())) {
double decrement = dataset_expected_count.get(dm.getDatasetId()) - dataset_access_count.get(dm.getDatasetId());
decrement = (decrement-1.)/decrement;
if (euId_mdf.get(euId).containsKey(dm.getDatasetId())) {
euId_mdf.get(euId).put(dm.getDatasetId(), euId_mdf.get(euId).get(dm.getDatasetId())*(decrement));
} else {
//Source dataset. This will mark it for removal immediately, which only works because all
// workers run their first stages in parallel.
dataset_expected_count.put(dm.getDatasetId(), 1);
}
dataset_access_count.put(dm.getDatasetId(), dataset_access_count.get(dm.getDatasetId())+1);
} else {
dataset_access_count.put(dm.getDatasetId(), 1);
}
}
long end = System.currentTimeMillis();
this.__totalUpdateTime = this.__totalUpdateTime + (end - start);
}
@Override
public List<Integer> rankDatasetsForNode(int euId, Set<Integer> datasetIds) {
long start = System.currentTimeMillis();
List<Integer> rankedDatasets = new ArrayList<>();
if(! euId_mdf.containsKey(euId)) {
return rankedDatasets;
}
// Now we use the datasets that are alive to prune the datasets in the node
removeEvictedDatasets(euId, datasetIds);
// Evict datasets based on access patterns
removeFinishedDatasets(euId);
// We get the datasets in the node, after pruning
Map<Integer, Double> datasetId_timestamp = euId_mdf.get(euId);
Map<Integer, Double> sorted = sortByValue(datasetId_timestamp);
System.out.println("MDF VALUES");
for(Double v : sorted.values()) {
System.out.print(v+" - ");
}
System.out.println();
// TODO: may break ordering due to keyset returning a set ?
for(Integer key : sorted.keySet()) {
rankedDatasets.add(key);
}
long end = System.currentTimeMillis();
this.__totalRankTime = this.__totalRankTime + (end - start);
return rankedDatasets;
}
private double computeRecomputeCostFor(int stageId) {
double recomputeCost = Long.MAX_VALUE;
Stage s = sd.getStageWithId(stageId);
Set<Stage> upstream = s.getDependencies();
if(upstream.size() > 1) {
LOG.error("upstream of more than 1 when computing recompute cost for stageId: {}" ,stageId);
}
if(! upstream.iterator().hasNext()) {
return recomputeCost; // source stage, cut
}
int sid = upstream.iterator().next().getStageId();
if(! stageid_size.containsKey(sid)) {
return recomputeCost; // make sure this is not selected, as it does not exist yet
}
long size = stageid_size.get(sid);
long cost = stageid_cost.get(sid);
double percDataInMem = stageid_ratio_inmem.get(sid);
recomputeCost = cost + percDataInMem * (size * dmRatio);
return recomputeCost;
}
private long computeSizeOfDataset(Set<DatasetMetadata> datasetsMetadata) {
long size = 0;
for(DatasetMetadata dm : datasetsMetadata) {
size = size + dm.getSize();
}
return size;
}
private long computeCostOfDataset(Set<DatasetMetadata> datasetsMetadata) {
long cost = 0;
for(DatasetMetadata dm : datasetsMetadata) {
cost = cost + dm.getCreationCost();
}
return cost;
}
private double computeRatioDataInMem(Set<DatasetMetadata> datasetsMetadata, long sizeOfThisDataset) {
// doing this on actual size in case datasets are of different lenghts in the future
double r = 0;
long mem = 0;
for(DatasetMetadata dm : datasetsMetadata) {
if(dm.isInMem()) {
mem = mem + dm.getSize();
}
}
sizeOfThisDataset++;
r = mem/sizeOfThisDataset;
return r;
}
private void removeFinishedDatasets(int euId) {
// Datasets to evict
Set<Integer> toEvict = new HashSet<>();
// Check those datasets that must be evicted
for(Integer e : dataset_access_count.keySet()) {
if (dataset_access_count.get(e) >= dataset_expected_count.get(e)) {
toEvict.add(e);
}
}
for(Integer e : toEvict) {
dataset_access_count.remove(e);
dataset_expected_count.remove(e);
}
// Now evict them from the data structure for propagation to the cluster
for(Entry<Integer, Map<Integer, Double>> e : euId_mdf.entrySet()) {
Map<Integer, Double> entry = e.getValue();
for(Integer kill : toEvict) {
entry.remove(kill);
}
}
}
private void computeAccesses(ScheduleDescription sd) {
// TODO: will work if there is only one (logical) source
for(Stage s : sd.getStages()) {
int sid = s.getStageId();
int numDownstream = s.getDependants().size();
stageid_accesses.put(sid, numDownstream);
}
}
private void removeEvictedDatasets(int euId, Set<Integer> datasetIdsToKeep) {
Map<Integer, Double> allEntries = euId_mdf.get(euId);
// Select entries to remove
Set<Integer> toRemove = new HashSet<>();
for(int id : allEntries.keySet()) {
if(! datasetIdsToKeep.contains(id)) {
toRemove.add(id);
}
}
// Remove the selection
for(int toRem : toRemove) {
allEntries.remove(toRem);
}
// Update the info
euId_mdf.put(euId, allEntries);
}
private Map<Integer, Double> sortByValue( Map<Integer, Double> map ) {
Map<Integer, Double> result = new LinkedHashMap<>();
Stream <Entry<Integer, Double>> st = map.entrySet().stream();
// FIXME: precedence means that higher should be higher in ranked values. Probably need to multiply * -1 down there, or invert the order
st.sorted(Comparator.comparingDouble(e -> e.getValue())).forEachOrdered(e -> result.put(e.getKey(),e.getValue()));
return result;
}
@Override
public long __totalUpdateTime() {
return this.__totalUpdateTime;
}
@Override
public long __totalRankTime() {
return this.__totalRankTime;
}
}