package uk.ac.shef.dcs.jate; import java.util.ArrayList; import java.util.List; import java.util.concurrent.RecursiveTask; /** * Created by zqz on 15/09/2015. */ public abstract class JATERecursiveTaskWorker<S, T> extends RecursiveTask<T>{ private static final long serialVersionUID = -5145284438127806541L; protected List<S> tasks; protected int maxTasksPerThread; public JATERecursiveTaskWorker(List<S> tasks, int maxTasksPerWorker){ this.tasks = tasks; this.maxTasksPerThread=maxTasksPerWorker; } protected abstract JATERecursiveTaskWorker<S, T> createInstance(List<S> splitTasks); protected abstract T mergeResult(List<JATERecursiveTaskWorker<S, T>> workers); protected abstract T computeSingleWorker(List<S> tasks); @Override protected T compute() { if (this.tasks.size() > maxTasksPerThread) { List<JATERecursiveTaskWorker<S, T>> subWorkers = new ArrayList<>(); subWorkers.addAll(createSubWorkers()); for (JATERecursiveTaskWorker<S, T> subWorker : subWorkers) subWorker.fork(); return mergeResult(subWorkers); } else{ return computeSingleWorker(tasks); } } protected List<JATERecursiveTaskWorker<S, T>> createSubWorkers() { List<JATERecursiveTaskWorker<S, T>> subWorkers = new ArrayList<>(); int total = tasks.size() / 2; List<S> splitTask1 = new ArrayList<>(); for (int i = 0; i < total; i++) splitTask1.add(tasks.get(i)); JATERecursiveTaskWorker<S, T> subWorker1 = createInstance(splitTask1); List<S> splitTask2 = new ArrayList<>(); for (int i = total; i < tasks.size(); i++) splitTask2.add(tasks.get(i)); JATERecursiveTaskWorker<S, T> subWorker2 = createInstance(splitTask2); subWorkers.add(subWorker1); subWorkers.add(subWorker2); return subWorkers; } }