package ecologylab.bigsemantics.distributed; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import org.junit.Test; import ecologylab.bigsemantics.distributed.Task.State; /** * * @author quyin */ @SuppressWarnings({ "rawtypes", "unchecked" }) public class TestDispatcherMultiThreads { static class FakeTask extends Task { int maxFailCount; int waitTime; long beginTime = -1; long endTime; List<Worker<Task>> workers = new ArrayList<Worker<Task>>(); public FakeTask(String id, int priority, int maxFailCount, int waitTime) { super(id, priority); this.maxFailCount = maxFailCount; this.waitTime = waitTime; } @Override public Result perform() { try { Thread.sleep(waitTime); } catch (InterruptedException e) { return Result.FATAL; } return getFailCount() >= maxFailCount ? Result.OK : Result.ERROR; } } static class FakeWorker extends Worker<Task> { List<String> forbiddenDomains; public FakeWorker(String id, int numThreads, String... forbiddenDomains) { super(id, numThreads); this.forbiddenDomains = Arrays.asList(forbiddenDomains); } @Override public boolean canHandle(Task task) { String id = task.getId(); for (String forbiddenDomain : forbiddenDomains) { if (id.contains(forbiddenDomain)) { return false; } } return true; } @Override protected void performTask(Task task, TaskEventHandler<Task> handler) { if (task instanceof FakeTask) { FakeTask fakeTask = (FakeTask) task; fakeTask.workers.add(this); } super.performTask(task, handler); } } @Test public void testMultiThreads() throws InterruptedException { final Dispatcher dispatcher = new Dispatcher() { @Override protected void onDispatch(Task task) { if (task instanceof FakeTask) { FakeTask fakeTask = (FakeTask) task; if (fakeTask.beginTime < 0) { fakeTask.beginTime = System.currentTimeMillis(); } } } @Override protected boolean isTooManyFail(Task task) { return task.getFailCount() >= 4; } }; // add 4 workers, each with 4 threads and 2 domains that it cannot handle for (int i = 1; i <= 4; ++i) { FakeWorker worker = new FakeWorker("w" + i, 4, "" + i + ".com", "" + (11 - i) + ".com"); dispatcher.addWorker(worker); } // create 500 tasks: 10 domains x 5 fail counts x 10 wait times Random rand = new Random(System.currentTimeMillis()); // generating random priority List<FakeTask> tasks = new ArrayList<FakeTask>(); Map<Integer, Map<Integer, List<FakeTask>>> taskMap = new HashMap<Integer, Map<Integer, List<FakeTask>>>(); for (int i = 1; i <= 10; ++i) { for (int failCount = 0; failCount < 5; ++failCount) { taskMap.put(failCount, new HashMap<Integer, List<FakeTask>>()); for (int waitTime = 0; waitTime < 100; waitTime += 10) { taskMap.get(failCount).put(waitTime, new ArrayList<FakeTask>()); String id = String.format("http://%d.com/index.html?f=%d&t=%d", i, failCount, waitTime); FakeTask task = new FakeTask(id, rand.nextInt(), failCount, waitTime); TaskEventHandler<FakeTask> handler = new TaskEventHandler<FakeTask>() { @Override public void onComplete(FakeTask task) { task.endTime = System.currentTimeMillis(); System.out.format("Task %s completed\n", task); } @Override public void onFail(FakeTask task) { task.endTime = System.currentTimeMillis(); System.out.format("Task %s failed\n", task); } @Override public void onTerminate(FakeTask task) { task.endTime = System.currentTimeMillis(); System.out.format("Task %s terminated\n", task); } }; tasks.add(task); taskMap.get(failCount).get(waitTime).add(task); dispatcher.queueTask(task, handler); } } } // do it! Runner runner = new Runner() { @Override protected void body() throws Exception { dispatcher.dispatchTask(); } }; long allBeginTime = System.currentTimeMillis(); runner.start(); for (FakeTask task : tasks) { task.waitForDone(); } long allTotalTime = System.currentTimeMillis() - allBeginTime; // verify: // no worker does domains it cannot handle for (FakeTask task : tasks) { for (Worker worker : task.workers) { assertTrue(worker instanceof FakeWorker); for (String domain : ((FakeWorker) worker).forbiddenDomains) { assertFalse(task.getId().contains(domain)); } } } // some of them should have completed, while some terminated. for (FakeTask task : tasks) { if (task.maxFailCount == 4) { assertEquals(State.TERMINATED, task.getState()); } else { assertEquals(State.SUCCEEDED, task.getState()); } } // stats: // for each fail count x wait time: average of: total time, # of retries, # of different workers double ttime = 0; for (Integer failCount : taskMap.keySet()) { Map<Integer, List<FakeTask>> these = taskMap.get(failCount); for (Integer waitTime : these.keySet()) { List<FakeTask> those = these.get(waitTime); int n = those.size(); double time = 0; double retries = 0; double uniqw = 0; for (FakeTask task : those) { time += task.endTime - task.beginTime; retries += task.getFailCount() + 1; uniqw = (new HashSet<Worker>(task.workers)).size(); } ttime += time; System.out.format("FC=%d, WT=%d: Tmin = %d, Tmean = %.0f, Rmean = %.0f, Umean = %.0f\n", failCount, waitTime, failCount * waitTime, time / n, retries / n, uniqw / n); } } System.out.format("Total mean latency: %.2f\n", ttime / tasks.size()); System.out.format("Total throughput per sec: %.2f\n", tasks.size() / (allTotalTime / 1000.0)); } }