package com.subgraph.orchid.circuits.path; import java.util.ArrayList; import java.util.List; import com.subgraph.orchid.Router; import com.subgraph.orchid.crypto.TorRandom; public class BandwidthWeightedRouters { private static class WeightedRouter { private final Router router; private boolean isUnknown; private double weightedBandwidth; private long scaledBandwidth; WeightedRouter(Router router, double bw) { this.router = router; this.weightedBandwidth = bw; } void scaleBandwidth(double scaleFactor) { scaledBandwidth = Math.round(weightedBandwidth * scaleFactor); } } private final static long MAX_SCALE = Long.MAX_VALUE / 4; private final static double EPSILON = 0.1; private final List<WeightedRouter> weightedRouters = new ArrayList<WeightedRouter>(); private final TorRandom random = new TorRandom(); private double totalExitBw; private double totalNonExitBw; private double totalGuardBw; private boolean isScaled; private int unknownCount; void addRouter(Router router, double weightedBandwidth) { weightedRouters.add(new WeightedRouter(router, weightedBandwidth)); adjustTotals(router, weightedBandwidth); isScaled = false; } boolean isTotalBandwidthZero() { return getTotalBandwidth() < EPSILON; } double getTotalBandwidth() { return totalExitBw + totalNonExitBw; } double getTotalGuardBandwidth() { return totalGuardBw; } double getTotalExitBandwidth() { return totalExitBw; } private void adjustTotals(Router router, double bw) { if(router.isExit()) { totalExitBw += bw; } else { totalNonExitBw += bw; } if(router.isPossibleGuard()) { totalGuardBw += bw; } } void addRouterUnknown(Router router) { final WeightedRouter wr = new WeightedRouter(router, 0); wr.isUnknown = true; weightedRouters.add(wr); unknownCount += 1; } int getRouterCount() { return weightedRouters.size(); } int getUnknownCount() { return unknownCount; } void fixUnknownValues() { if(unknownCount == 0) { return; } if(isTotalBandwidthZero()) { fixUnknownValues(40000, 20000); } else { final int knownCount = weightedRouters.size() - unknownCount; final long average = (long) (getTotalBandwidth() / knownCount); fixUnknownValues(average, average); } } private void fixUnknownValues(long fastBw, long slowBw) { for(WeightedRouter wr: weightedRouters) { if(wr.isUnknown) { long bw = wr.router.isFast() ? fastBw : slowBw; wr.weightedBandwidth = bw; wr.isUnknown = false; adjustTotals(wr.router, bw); } } unknownCount = 0; isScaled = false; } Router chooseRandomRouterByWeight() { final long total = getScaledTotal(); if(total == 0) { if(weightedRouters.size() == 0) { return null; } final int idx = random.nextInt(weightedRouters.size()); return weightedRouters.get(idx).router; } return chooseFirstElementAboveRandom(random.nextLong(total)); } void adjustWeights(double exitWeight, double guardWeight) { for(WeightedRouter wr: weightedRouters) { Router r = wr.router; if(r.isExit() && r.isPossibleGuard()) { wr.weightedBandwidth *= (exitWeight * guardWeight); } else if(r.isPossibleGuard()) { wr.weightedBandwidth *= guardWeight; } else if(r.isExit()) { wr.weightedBandwidth *= exitWeight; } } scaleRouterWeights(); } private Router chooseFirstElementAboveRandom(long randomValue) { long sum = 0; Router chosen = null; for(WeightedRouter wr: weightedRouters) { sum += wr.scaledBandwidth; if(sum > randomValue) { chosen = wr.router; /* Don't return early to avoid leaking timing information about choice */ randomValue = Long.MAX_VALUE; } } if(chosen == null) { return weightedRouters.get(weightedRouters.size() - 1).router; } return chosen; } private double getWeightedTotal() { double total = 0.0; for(WeightedRouter wr: weightedRouters) { total += wr.weightedBandwidth; } return total; } private void scaleRouterWeights() { final double scaleFactor = MAX_SCALE / getWeightedTotal(); for(WeightedRouter wr: weightedRouters) { wr.scaleBandwidth(scaleFactor); } isScaled = true; } private long getScaledTotal() { if(!isScaled) { scaleRouterWeights(); } long total = 0; for(WeightedRouter wr: weightedRouters) { total += wr.scaledBandwidth; } return total; } }