package lbms.plugins.mldht.kad.tasks;
import static java.lang.Math.max;
import lbms.plugins.mldht.kad.IDMismatchDetector;
import lbms.plugins.mldht.kad.KBucketEntry;
import lbms.plugins.mldht.kad.Key;
import lbms.plugins.mldht.kad.NonReachableCache;
import lbms.plugins.mldht.kad.RPCCall;
import lbms.plugins.mldht.kad.RPCState;
import lbms.plugins.mldht.kad.SpamThrottle;
import java.net.InetAddress;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/*
* Issues:
*
* - spurious packet loss
* - remotes might fake IDs. possibly with collusion.
* - invalid results
* - duplicate IPs
* - duplicate IDs
* - wrong IDs -> might trip up fake ID detection!
* - IPs not belonging to DHT nodes -> DoS abuse
*
* Solution:
*
* - generally avoid querying an IP more than once
* - dedup result lists from each node
* - ignore responses with unexpected IDs. normally this could be abused to silence others, but...
* - allow duplicate requests if many *separate* sources suggest precisely the same <id, ip, port> tuple
*
* -> we can recover from all the above-listed issues because the terminal set of nodes should have some partial agreement about their neighbors
*
*
*
*/
public class IterativeLookupCandidates {
Key target;
Map<KBucketEntry, LookupGraphNode> candidates = new ConcurrentHashMap<>();
// maybe split out call tracking
Map<RPCCall, KBucketEntry> calls;
Map<InetAddress, Set<RPCCall>> callsByIp;
Collection<Object> accepted;
boolean allowRetransmits = true;
IDMismatchDetector detector;
NonReachableCache nonReachableCache;
SpamThrottle throttle;
class LookupGraphNode {
final KBucketEntry e;
Set<LookupGraphNode> sources = new CopyOnWriteArraySet<>();
Set<LookupGraphNode> returnedNodes = ConcurrentHashMap.newKeySet();
List<RPCCall> calls = new CopyOnWriteArrayList<>();
boolean tainted;
boolean acceptedResponse;
boolean root;
int previouslyFailedCount;
boolean unreachable;
boolean throttled;
public LookupGraphNode(KBucketEntry kbe) {
e = kbe;
}
void addCall(RPCCall c) {
calls.add(c);
}
void addSource(LookupGraphNode toAdd) {
sources.add(toAdd);
}
boolean callsNotSuccessful() {
return !calls.isEmpty() && !wasAccepted();
}
int nonSuccessfulDescendantCalls() {
return (int) Math.ceil(returnedNodes.stream().filter(LookupGraphNode::callsNotSuccessful).mapToDouble(node -> 1.0 / Math.max(node.sources.size(), 1)).sum());
}
void addChild(LookupGraphNode toAdd) {
returnedNodes.add(toAdd);
}
KBucketEntry toKbe() {
return e;
}
void accept() {
acceptedResponse = true;
}
boolean wasAccepted() {
return acceptedResponse;
}
@Override
public boolean equals(Object other) {
if(other instanceof LookupGraphNode) {
return e.equals(((LookupGraphNode) other).e);
}
return false;
}
@Override
public int hashCode() {
return e.hashCode();
}
@Override
public String toString() {
return "LookupNode desc:" + nonSuccessfulDescendantCalls();
}
}
public IterativeLookupCandidates(Key target, IDMismatchDetector detector) {
this.target = target;
calls = new ConcurrentHashMap<>();
callsByIp = new ConcurrentHashMap<>();
candidates = new ConcurrentHashMap<>();
accepted = new HashSet<>();
this.detector = detector;
}
public void setNonReachableCache(NonReachableCache nonReachableCache) {
this.nonReachableCache = nonReachableCache;
}
public void setSpamThrottle(SpamThrottle throttle) {
this.throttle = throttle;
}
void allowRetransmits(boolean toggle) {
allowRetransmits = toggle;
}
void addCall(RPCCall c, KBucketEntry kbe) {
calls.put(c, kbe);
Set<RPCCall> byIp = callsByIp.computeIfAbsent(c.getRequest().getDestination().getAddress(), k -> new HashSet<>());
synchronized (byIp) {
byIp.add(c);
}
candidates.get(kbe).addCall(c);
}
KBucketEntry acceptResponse(RPCCall c) {
// we ignore on mismatch, node will get a 2nd chance if sourced from multiple nodes and hasn't sent a successful reply yet
synchronized (this) {
if(!c.matchesExpectedID())
return null;
KBucketEntry kbe = calls.get(c);
if(!kbe.getVersion().isPresent())
c.getResponse().getVersion().ifPresent(kbe::setVersion);
LookupGraphNode node = candidates.get(kbe);
boolean insertOk = !accepted.contains(kbe.getAddress().getAddress()) && !accepted.contains(kbe.getID());
if(insertOk) {
accepted.add(kbe.getAddress().getAddress());
accepted.add(kbe.getID());
node.accept();
return kbe;
}
return null;
}
}
void addCandidates(KBucketEntry source, Collection<KBucketEntry> entries) {
Set<Object> dedup = new HashSet<>();
LookupGraphNode sourceNode = source != null ? candidates.get(source) : null;
for(KBucketEntry e : entries) {
if(!dedup.add(e.getID()) || !dedup.add(e.getAddress().getAddress()))
continue;
LookupGraphNode newNode = candidates.compute(e, (kbe, node) -> {
if(node == null) {
node = new LookupGraphNode(kbe);
node.root = source == null;
node.tainted = detector.isIdInconsistencyExpected(kbe.getAddress(), kbe.getID());
if(nonReachableCache != null) {
int failures = nonReachableCache.getFailures(kbe.getAddress());
node.previouslyFailedCount = failures;
// 0-20
int rnd = ThreadLocalRandom.current().nextInt(21);
// -2 - 19 -> 5% chance to let even the worst stuff still through to keep the counters going up
node.unreachable = Math.min(failures - 2, 19) > rnd;
}
if(throttle != null) {
node.throttled = throttle.test(kbe.getAddress().getAddress());
}
}
if(sourceNode != null)
node.addSource(sourceNode);
return node;
});
if(sourceNode != null)
sourceNode.addChild(newNode);
}
}
Set<KBucketEntry> getSources(KBucketEntry e) {
return candidates.get(e).sources.stream().map(LookupGraphNode::toKbe).collect(Collectors.toSet());
}
Comparator<LookupGraphNode> comp() {
Comparator<KBucketEntry> d = new KBucketEntry.DistanceOrder(target);
Comparator<LookupGraphNode> s = (a, b) -> b.sources.size() - a.sources.size();
return Comparator.<LookupGraphNode, KBucketEntry>comparing(n -> n.e, d).thenComparing(s);
}
Optional<KBucketEntry> next() {
synchronized (this) {
return allCand().sorted(comp()).filter(lookupFilter).findFirst().map(LookupGraphNode::toKbe);
}
}
Optional<KBucketEntry> next2(Predicate<KBucketEntry> postFilter) {
synchronized (this) {
// sort + filter + findAny should be faster than filter + min in this case since findAny reduces the invocations of the filter, and that is more expensive than the sorting
Optional<KBucketEntry> kbe = allCand().sorted(comp()).filter(retransmitFilter(false)).filter(lookupFilter).findFirst().map(node -> node.e).filter(postFilter);
if(!kbe.isPresent() && allowRetransmits)
kbe = allCand().sorted(comp()).filter(lookupFilter).filter(retransmitFilter(true)).findFirst().map(node -> node.e).filter(postFilter);
return kbe;
}
}
static Predicate<LookupGraphNode> retransmitFilter(boolean retransmits) {
return (node) -> {
if (node.calls.size() > 0 && !retransmits)
return false;
return true;
};
}
Predicate<LookupGraphNode> lookupFilter = node -> {
KBucketEntry kbe = node.e;
if(node.tainted || node.unreachable || node.throttled)
return false;
// check if we can do retransmits
if(!allowRetransmits && !node.calls.isEmpty())
return false;
// skip retransmits if we previously got a response but from the wrong socket address
if(node.calls.stream().anyMatch(RPCCall::hasSocketMismatch))
return false;
InetAddress addr = kbe.getAddress().getAddress();
if(accepted.contains(addr) || accepted.contains(kbe.getID()))
return false;
// only do requests to nodes which have at least one source where the source has not given us lots of bogus candidates
if(node.sources.size() > 0 && node.sources.stream().noneMatch(source -> source.nonSuccessfulDescendantCalls() < 3))
return false;
int dups = 0;
// also check other calls based on matching IP instead of strictly matching ip+port+id
Set<RPCCall> byIp = callsByIp.get(addr);
if(byIp != null) {
synchronized(byIp) {
for(RPCCall c : byIp) {
// in flight, not stalled
if(c.state() == RPCState.SENT || c.state() == RPCState.UNSENT)
return false;
// already got a response from that addr that does not match what we would expect from this candidate anyway
if(c.state() == RPCState.RESPONDED && !c.getResponse().getID().equals(kbe.getID()))
return false;
// we don't strictly check the presence of IDs in error messages, so we can't compare those here
if(c.state() == RPCState.ERROR)
return false;
dups++;
}
}
}
// log2 scale
int sources = max(1, node.sources.size() + (node.root ? 1 : 0));
int scaledSources = 31 - Integer.numberOfLeadingZeros(sources);
//System.out.println("sd:" + sources + " " + dups);
return scaledSources >= dups;
};
Stream<LookupGraphNode> allCand() {
return candidates.values().stream();
}
LookupGraphNode nodeForEntry(KBucketEntry e) {
return candidates.get(e);
}
int numCalls(KBucketEntry kbe) {
return (int) calls.entrySet().stream().filter(me -> me.getValue().equals(kbe)).count();
}
int numRsps(KBucketEntry kbe) {
return (int) calls.keySet().stream().filter(c -> c.state() == RPCState.RESPONDED && c.getResponse().getID().equals(kbe.getID()) && c.getResponse().getOrigin().equals(kbe.getAddress())).count();
}
}