/* * This file is part of mlDHT. * * mlDHT is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * mlDHT is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with mlDHT. If not, see <http://www.gnu.org/licenses/>. */ package lbms.plugins.mldht.kad.tasks; import static lbms.plugins.mldht.kad.tasks.CountedStat.FAILED; import static lbms.plugins.mldht.kad.tasks.CountedStat.RECEIVED; import static lbms.plugins.mldht.kad.tasks.CountedStat.SENT; import static lbms.plugins.mldht.kad.tasks.CountedStat.SENT_SINCE_RECEIVE; import static lbms.plugins.mldht.kad.tasks.CountedStat.STALLED; import the8472.utils.concurrent.SerializedTaskExecutor; import lbms.plugins.mldht.kad.DHT; import lbms.plugins.mldht.kad.DHT.LogLevel; import lbms.plugins.mldht.kad.DHTConstants; import lbms.plugins.mldht.kad.KBucketEntry; import lbms.plugins.mldht.kad.Key; import lbms.plugins.mldht.kad.Node; import lbms.plugins.mldht.kad.RPCCall; import lbms.plugins.mldht.kad.RPCCallListener; import lbms.plugins.mldht.kad.RPCServer; import lbms.plugins.mldht.kad.RPCState; import lbms.plugins.mldht.kad.messages.MessageBase; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; /** * Performs a task on K nodes provided by a KClosestNodesSearch. * This is a base class for all tasks. * * @author Damokles */ public abstract class Task implements Comparable<Task> { protected Set<RPCCall> inFlight; protected Node node; protected String info; protected RPCServer rpc; public enum TaskState { INITIAL, QUEUED, RUNNING, FINISHED, KILLED; public boolean isTerminal() { return this == FINISHED || this == KILLED; } public boolean preStart() { return this == INITIAL || this == QUEUED; } } AtomicReference<TaskState> state = new AtomicReference<>(TaskState.INITIAL); long startTime; long firstResultTime; long finishTime; private int taskID; private List<TaskListener> listeners; private boolean lowPriority; protected final AtomicReference<TaskStats> counts = new AtomicReference<>(new TaskStats()); /** * Create a task. * @param rpc The RPC server to do RPC calls * @param node The node */ Task (RPCServer rpc, Node node) { if(rpc == null) throw new IllegalArgumentException("RPC must not be null"); this.rpc = rpc; this.node = node; inFlight = ConcurrentHashMap.newKeySet(); } boolean setState(TaskState expected, TaskState newState) { return setState(EnumSet.of(expected), newState); } boolean setState(Set<TaskState> expected, TaskState newState) { TaskState current; do { current = state.get(); if(!expected.contains(current)) return false; } while(!state.weakCompareAndSet(current, newState)); return true; } public RPCServer getRPC() { return rpc; } public int compareTo(Task o) { return taskID - o.taskID; } @Override public int hashCode() { return taskID; } final RPCCallListener preProcessingListener = new RPCCallListener() { public void stateTransition(RPCCall c, RPCState previous, RPCState current) { counts.updateAndGet(cnt -> { EnumSet<CountedStat> inc = EnumSet.noneOf(CountedStat.class); EnumSet<CountedStat> dec = EnumSet.noneOf(CountedStat.class); EnumSet<CountedStat> zero = EnumSet.noneOf(CountedStat.class); if(previous == RPCState.STALLED) dec.add(STALLED); if(current == RPCState.STALLED) inc.add(STALLED); if(current == RPCState.RESPONDED) { inc.add(RECEIVED); zero.add(SENT_SINCE_RECEIVE); } if(current == RPCState.TIMEOUT || current == RPCState.ERROR) inc.add(FAILED); return cnt.update(inc, dec, zero); }); switch(current) { case RESPONDED: inFlight.remove(c); if (!isFinished()) callFinished(c, c.getResponse()); break; case ERROR: inFlight.remove(c); break; case TIMEOUT: inFlight.remove(c); if (!isFinished()) callTimeout(c); break; default: break; } } }; final RPCCallListener postProcessingListener = new RPCCallListener() { public void stateTransition(RPCCall c, RPCState previous, RPCState current) { switch(current) { case RESPONDED: case TIMEOUT: case STALLED: case ERROR: serializedUpdate.run(); break; default: break; } } }; /** * Start the task, to be used when a task is queued. */ public void start () { if (setState(EnumSet.of(TaskState.INITIAL, TaskState.QUEUED), TaskState.RUNNING)) { DHT.logDebug("Starting Task: " + toString()); startTime = System.currentTimeMillis(); try { serializedUpdate.run(); } catch (Exception e) { DHT.log(e, LogLevel.Error); } } } private void runStuff() { if(isDone()) finish(); if (canDoRequest() && !isFinished()) { update(); // check again in case todo-queue has been drained by update() if(isDone()) finish(); } } private final Runnable serializedUpdate = SerializedTaskExecutor.onceMore(this::runStuff); /** * Will continue the task, this will be called every time we have * rpc slots available for this task. Should be implemented by derived classes. */ abstract void update (); /** * A call is finished and a response was received. * @param c The call * @param rsp The response */ abstract void callFinished (RPCCall c, MessageBase rsp); /** * A call timedout * @param c The call */ abstract void callTimeout (RPCCall c); /** * Do a call to the rpc server, increments the outstanding_reqs variable. * @param req THe request to send * @return true if call was made, false if not */ boolean rpcCall (MessageBase req, Key expectedID, Consumer<RPCCall> modifyCallBeforeSubmit) { if (!canDoRequest()) { // if we reject a request we need something to wakeup the task later rpc.onDeclog(serializedUpdate); return false; } RPCCall call = new RPCCall(req).setExpectedID(expectedID); // bump counters early to ensure task stays alive counts.updateAndGet(cnt -> cnt.update(EnumSet.of(SENT, SENT_SINCE_RECEIVE), EnumSet.noneOf(CountedStat.class), EnumSet.noneOf(CountedStat.class))); call.addListener(preProcessingListener); if(modifyCallBeforeSubmit != null) modifyCallBeforeSubmit.accept(call); call.addListener(postProcessingListener); inFlight.add(call); // asyncify since we're under a lock here rpc.getDHT().getScheduler().execute(() -> rpc.doCall(call)) ; return true; } public void setLowPriority(boolean lowPriority) { this.lowPriority = lowPriority; } public int requestConcurrency() { return lowPriority ? DHTConstants.MAX_CONCURRENT_REQUESTS_LOWPRIO : DHTConstants.MAX_CONCURRENT_REQUESTS; } static interface CandidateSupplier { boolean has(); KBucketEntry current(); void remove(KBucketEntry e); } protected CandidateSupplier candidates; enum RequestPermit { NONE_ALLOWED, FREE_SLOT, FREE_STALL_SLOT } RequestPermit checkFreeSlot() { TaskStats stats = counts.get(); int activeOnly = stats.activeOnly(); int activeAndStalled = stats.unanswered(); int concurrency = requestConcurrency(); // based on measurements the expected loss rate is ~50% on average (see RPCServer) // if we exceed that (+margin) don't let stalls trigger additional requests, wait for new responses/full timeouts if(activeAndStalled >= concurrency && stats.get(RECEIVED) * 3 < stats.get(SENT)) return RequestPermit.NONE_ALLOWED; if(activeAndStalled < concurrency) return RequestPermit.FREE_SLOT; if(activeOnly < concurrency /*&& stats.get(SENT_SINCE_RECEIVE) < concurrency*/) return RequestPermit.FREE_STALL_SLOT; return RequestPermit.NONE_ALLOWED; } /// See if we can do a request boolean canDoRequest () { return checkFreeSlot() != RequestPermit.NONE_ALLOWED; } boolean hasUnfinishedRequests() { return counts.get().unanswered() > 0; } /// Is the task finished public boolean isFinished () { return state.get().isTerminal(); } /// Set the task ID void setTaskID (int tid) { taskID = tid; } /// Get the task ID public int getTaskID () { return taskID; } /** * @return the Count of Failed Requests */ public int getFailedReqs () { return counts.get().get(FAILED); } /** * @return the Count of Received Responses */ public int getRecvResponses () { return counts.get().get(RECEIVED); } /** * @return the Count of Sent Requests */ public int getSentReqs () { return counts.get().get(SENT); } abstract public int getTodoCount (); /** * @return the info */ public String getInfo () { return info; } public long getStartTime() { return startTime; } public long getFinishedTime() { return finishTime; } public long getFirstResultTime() { return firstResultTime; } /** * @param info the info to set */ public void setInfo (String info) { this.info = info; } /** * @return number of requests that this task is actively waiting for */ public int getNumOutstandingRequestsExcludingStalled () { return counts.get().activeOnly(); } /** * @return number of requests that still haven't reached their final state but might have stalled */ public int getNumOutstandingRequests() { return counts.get().unanswered(); } public boolean isQueued () { return state.get() == TaskState.QUEUED; } /// Kills the task public void kill() { if(setState(EnumSet.complementOf(EnumSet.of(TaskState.FINISHED, TaskState.KILLED)), TaskState.KILLED)) notifyCompletionListeners(); } private void finish() { if(setState(EnumSet.complementOf(EnumSet.of(TaskState.FINISHED, TaskState.KILLED)), TaskState.FINISHED)) notifyCompletionListeners(); } private void notifyCompletionListeners() { finishTime = System.currentTimeMillis(); DHT.logDebug("Task "+getTaskID()+" finished: " + toString()); if (listeners != null) { for (TaskListener tl : listeners) { tl.finished(this); } } } protected abstract boolean isDone(); public void addListener (TaskListener listener) { if (listeners == null) { listeners = new ArrayList<>(1); } // listener is added after the task already terminated, thus it won't get the event, trigger it manually if(state.get().isTerminal()) listener.finished(this); listeners.add(listener); } public void removeListener (TaskListener listener) { if (listeners != null) { listeners.remove(listener); } } public Duration age() { return Duration.between(Instant.ofEpochMilli(startTime), Instant.now()); } @Override public String toString() { StringBuilder b = new StringBuilder(100); TaskStats stats = counts.get(); b.append(this.getClass().getSimpleName()); b.append(' ').append(getTaskID()); if(this instanceof TargetedTask) b.append(" target:").append(((TargetedTask)this).getTargetKey()); b.append(" todo:").append(getTodoCount()); if(!state.get().preStart()) { //b.append(" sent:").append(stats.get(SENT)); //b.append(" recv:").append(stats.get(RECEIVED)); b.append(" ").append(stats); } b.append(" srv: ").append(rpc.getDerivedID()); b.append(' ').append(state.get().toString()); if(startTime != 0) { if(finishTime == 0) b.append(" age:").append(age()); else if(finishTime > 0) b.append(" time to finish:").append(Duration.between(Instant.ofEpochMilli(startTime), Instant.ofEpochMilli(finishTime))); } b.append(" name:").append(info); return b.toString(); } }