package uk.ac.imperial.lsds.seepmaster.scheduler; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.CountDownLatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import uk.ac.imperial.lsds.seep.api.DataReference; public class StageTracker { final private Logger LOG = LoggerFactory.getLogger(StageTracker.class); private final int stageId; private Set<Integer> euInvolved; private final CountDownLatch countDown; private Set<Integer> completed; private Map<Integer, Set<DataReference>> results; public StageTracker(int stageId, Set<Integer> euInvolved) { this.stageId = stageId; this.euInvolved = euInvolved; this.countDown = new CountDownLatch(euInvolved.size()); this.completed = new HashSet<>(); this.results = new HashMap<>(); } public Map<Integer, Set<DataReference>> getStageResults() { return results; } public void waitForStageToFinish() { try { countDown.await(); } catch (InterruptedException e) { e.printStackTrace(); } } public void notifyOk(int euId, int stageId, Map<Integer, Set<DataReference>> newResults) { if(this.stageId != stageId) { System.out.println("ERROR, notifying for non-current stage"); System.exit(-1); } boolean wasNotPresent = completed.add(euId); if(! wasNotPresent) { LOG.warn("Notified {} that was already present", euId); } else{ for(Entry<Integer, Set<DataReference>> entry : newResults.entrySet()) { int key = entry.getKey(); if(! results.containsKey(key)){ results.put(key, new HashSet<>()); } Set<DataReference> newDRefs = newResults.get(entry.getKey()); results.get(key).addAll(newDRefs); } countDown.countDown(); } } public boolean finishedSuccessfully() { return completed.containsAll(euInvolved); } }