/* * This file is part of mlDHT. * * mlDHT is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * mlDHT is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with mlDHT. If not, see <http://www.gnu.org/licenses/>. */ package lbms.plugins.mldht.kad.tasks; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Deque; import java.util.List; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import lbms.plugins.mldht.kad.DHT; import lbms.plugins.mldht.kad.DHTConstants; import lbms.plugins.mldht.kad.RPCServer; import lbms.plugins.mldht.kad.tasks.Task.TaskState; /** * Manages all dht tasks. * * @author Damokles */ public class TaskManager { private ConcurrentHashMap<RPCServer, ServerSet> taskSets; private DHT dht; private AtomicInteger next_id = new AtomicInteger(); private TaskListener finishListener = t -> { dht.getStats().taskFinished(t); setFor(t.getRPC()).ifPresent(s -> { synchronized (s.active) { s.active.remove(t); } s.dequeue(); });; }; public TaskManager (DHT dht) { this.dht = dht; taskSets = new ConcurrentHashMap<>(); next_id.set(1); } public void addTask(Task task) { addTask(task, false); } class ServerSet { RPCServer server; Deque<Task> queued = new ArrayDeque<>(); List<Task> active = new ArrayList<>(); void dequeue() { while (true) { Task t; synchronized (queued) { t = queued.peekFirst(); if (t == null) break; if (!canStartTask(t.getRPC())) break; queued.removeFirst(); } if (t.isFinished()) continue; synchronized(active) { active.add(t); } dht.getScheduler().execute(t::start); } } boolean canStartTask(RPCServer srv) { // we can start a task if we have less then 7 runnning per server and // there are at least 16 RPC slots available int activeCalls = srv.getNumActiveRPCCalls(); if(activeCalls + 16 >= DHTConstants.MAX_ACTIVE_CALLS) return false; int perServer = active.size(); if(perServer < DHTConstants.MAX_ACTIVE_TASKS) return true; if(activeCalls >= (DHTConstants.MAX_ACTIVE_CALLS * 2) / 3) return false; // if all their tasks have sent at least their initial volley and we still have enough head room we can allow more tasks. synchronized(active) { return active.stream().allMatch(t -> t.requestConcurrency() < t.getSentReqs()); } } Collection<Task> snapshotActive() { synchronized (active) { return new ArrayList<>(active); } } Collection<Task> snapshotQueued() { synchronized (queued) { return new ArrayList<>(queued); } } } Optional<ServerSet> setFor(RPCServer srv) { if(srv.getState() != RPCServer.State.RUNNING) return Optional.empty(); return Optional.ofNullable(taskSets.computeIfAbsent(srv, k -> { ServerSet ss = new ServerSet(); ss.server = k; return ss; })); } public void dequeue(RPCServer k) { setFor(k).ifPresent(ServerSet::dequeue); } public void dequeue() { for(RPCServer srv : taskSets.keySet()) setFor(srv).ifPresent(ServerSet::dequeue); } /** * Add a task to manage. * @param task */ public void addTask (Task task, boolean isPriority) { int id = next_id.incrementAndGet(); task.addListener(finishListener); task.setTaskID(id); Optional<ServerSet> s = setFor(task.getRPC()); if(!s.isPresent()) { task.kill(); return; } if (task.state.get() == TaskState.RUNNING) { synchronized (s.get().active) { s.get().active.add(task); } return; } if(!task.setState(TaskState.INITIAL, TaskState.QUEUED)) return; synchronized (s.get().queued) { if (isPriority) s.get().queued.addFirst(task); else s.get().queued.addLast(task); } } public void removeServer(RPCServer srv) { ServerSet set = taskSets.get(srv); if(set == null) return; taskSets.remove(srv); synchronized (set.active) { set.active.forEach(Task::kill); } synchronized (set.queued) { set.queued.forEach(Task::kill); } } /// Get the number of running tasks public int getNumTasks () { return taskSets.values().stream().mapToInt(s -> s.active.size()).sum(); } /// Get the number of queued tasks public int getNumQueuedTasks () { return taskSets.values().stream().mapToInt(s -> s.queued.size()).sum(); } public Task[] getActiveTasks () { Task[] t = taskSets.values().stream().flatMap(s -> s.snapshotActive().stream()).toArray(Task[]::new); Arrays.sort(t); return t; } public Task[] getQueuedTasks () { return taskSets.values().stream().flatMap(s -> s.snapshotQueued().stream()).toArray(Task[]::new); } public boolean canStartTask (Task toCheck) { RPCServer srv = toCheck.getRPC(); return canStartTask(srv); } public boolean canStartTask(RPCServer srv) { return setFor(srv).map(s -> s.canStartTask(srv)).orElse(false); } public int queuedCount(RPCServer srv) { Optional<ServerSet> set = setFor(srv); if(!set.isPresent()) return 0; Collection<Task> q = set.get().queued; synchronized (q) { return q.size(); } } @Override public String toString() { StringBuilder b = new StringBuilder(); b.append("next id: ").append(next_id).append('\n'); b.append("#### active: \n"); for(Task t : getActiveTasks()) b.append(t.toString()).append('\n'); b.append("#### queued: \n"); for(Task t : getQueuedTasks()) b.append(t.toString()).append('\n'); return b.toString(); } }