package com.linkedin.d2.balancer.simple;
import com.linkedin.common.callback.Callback;
import com.linkedin.common.callback.FutureCallback;
import com.linkedin.common.util.None;
import com.linkedin.d2.balancer.ServiceUnavailableException;
import com.linkedin.d2.balancer.clients.RewriteClient;
import com.linkedin.d2.balancer.properties.ClusterProperties;
import com.linkedin.d2.balancer.properties.PropertyKeys;
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.properties.UriProperties;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategyFactory;
import com.linkedin.d2.balancer.strategies.degrader.DegraderLoadBalancerStrategyConfig;
import com.linkedin.d2.balancer.strategies.degrader.DegraderLoadBalancerStrategyFactoryV3;
import com.linkedin.d2.balancer.util.URIRequest;
import com.linkedin.d2.balancer.util.hashing.Ring;
import com.linkedin.d2.discovery.event.PropertyEventThread.PropertyEventShutdownCallback;
import com.linkedin.d2.discovery.event.SynchronousExecutorService;
import com.linkedin.d2.discovery.stores.mock.MockStore;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.transport.common.TransportClientFactory;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;
import com.linkedin.r2.transport.common.bridge.common.TransportResponseImpl;
import com.linkedin.util.clock.Clock;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import static org.testng.Assert.assertFalse;
/**
* LoadBalancerSimulator simulates the transporting delays of different hosts for d2
* degraderloadbalancer debugging, testing and verifications.
*
* The simulator requires 5 inputs:
* . ServiceProperties, ClusterProperties and UriProperties: represent the d2 configurations.
* . DelayGenerator: provides the delays for each given Uri
* . QPSGenerator: provides the number of queries per interval
*
* To control the simulator:
* . Asynchronous call: run(long duration) and runUntil(long untilTime)
* . Synchronous call: runWait(long duration)
* . stop()
*
* To check the status:
* . getClientCounters(): returns the hits for each URI during last interval
* . getPoints(): returns the hashring points for each URI
*
*/
public class LoadBalancerSimulator
{
private static final Logger _log = LoggerFactory.getLogger(LoadBalancerSimulator.class);
private final MockStore<ServiceProperties> _serviceRegistry = new MockStore<>();
private final MockStore<ClusterProperties> _clusterRegistry = new MockStore<>();
private final MockStore<UriProperties> _uriRegistry = new MockStore<>();
private final SimpleLoadBalancer _loadBalancer;
private final TimedValueGenerator<String> _delayGenerator;
private final QPSGenerator _qpsGenerator;
private final ClockedExecutor _clockedExecutor;
private final ScheduledExecutorService _executorService;
private final Map<URI, Integer> _clientCounters = new HashMap<>();
// the delay in milliseconds to schedule the first request
private final int INIT_SCHEDULE_DELAY = 10;
// How often to reschedule next set of requests
private final long SCHEDULE_INTERVAL = DegraderLoadBalancerStrategyConfig.DEFAULT_UPDATE_INTERVAL_MS;
/**
* Return the expected delay at the given time
*/
interface DelayGenerator<T>
{
long nextDelay(T t);
}
/**
* Return the number of queries for the next interval
*/
interface QPSGenerator
{
int nextQPS();
}
/**
* For a stream of values which changes periodically, get the value at the specific time
*/
interface TimedValueGenerator<T>
{
long getValue(T t, long time, TimeUnit unit);
}
LoadBalancerSimulator(ServiceProperties serviceProperties, ClusterProperties clusterProperties,
UriProperties uriProperties, TimedValueGenerator<String> delayGenerator,
QPSGenerator qpsGenerator) throws ExecutionException, InterruptedException
{
_executorService = new SynchronousExecutorService();
_clockedExecutor = new ClockedExecutor();
// mock the properties to pass in simulation info
Map<String, Object> transportProperty = new HashMap<>(serviceProperties.getTransportClientProperties());
transportProperty.put("ClockedExecutor", _clockedExecutor);
Map<String, Object> strategyProperty = new HashMap<>(serviceProperties.getLoadBalancerStrategyProperties());
strategyProperty.put(PropertyKeys.CLOCK, _clockedExecutor);
strategyProperty.put(PropertyKeys.HTTP_LB_QUARANTINE_EXECUTOR_SERVICE, _clockedExecutor);
ServiceProperties updatedServiceProperties = new ServiceProperties(serviceProperties.getServiceName(),
serviceProperties.getClusterName(), serviceProperties.getPath(),
serviceProperties.getLoadBalancerStrategyList(),
strategyProperty, transportProperty,
serviceProperties.getDegraderProperties(),
serviceProperties.getPrioritizedSchemes(),
serviceProperties.getBanned());
_serviceRegistry.put(serviceProperties.getServiceName(), updatedServiceProperties);
_clusterRegistry.put(serviceProperties.getClusterName(), clusterProperties);
_uriRegistry.put(serviceProperties.getClusterName(), uriProperties);
_delayGenerator = delayGenerator;
_qpsGenerator = qpsGenerator;
// construct loadBalancer and start it
Map<String, LoadBalancerStrategyFactory<? extends LoadBalancerStrategy>> loadBalancerStrategyFactories =
new HashMap<>();
Map<String, TransportClientFactory> clientFactories = new HashMap<>();
loadBalancerStrategyFactories.put("degrader", new DegraderLoadBalancerStrategyFactoryV3());
DelayClientFactory delayClientFactory = new DelayClientFactory();
clientFactories.put("http", delayClientFactory);
clientFactories.put("https", delayClientFactory);
SimpleLoadBalancerState loadBalancerState =
new SimpleLoadBalancerState(_executorService,
_uriRegistry,
_clusterRegistry,
_serviceRegistry,
clientFactories,
loadBalancerStrategyFactories);
_loadBalancer = new SimpleLoadBalancer(loadBalancerState, 5, TimeUnit.SECONDS);
FutureCallback<None> balancerCallback = new FutureCallback<None>();
_loadBalancer.start(balancerCallback);
balancerCallback.get();
// schedule the RequestTask, which starts new set of requests repeatedly at the given interval
_clockedExecutor.scheduleWithFixedDelay(new RequestTask(updatedServiceProperties.getServiceName()),
INIT_SCHEDULE_DELAY, SCHEDULE_INTERVAL, TimeUnit.MILLISECONDS);
}
public void shutdown() throws Exception
{
_clockedExecutor.shutdown();
final CountDownLatch latch = new CountDownLatch(1);
PropertyEventShutdownCallback callback = () -> latch.countDown();
_loadBalancer.shutdown(callback);
if (!latch.await(60, TimeUnit.SECONDS))
{
Assert.fail("unable to shutdown state");
}
_log.info("LoadBalancer Shutdown @ {}", _clockedExecutor.currentTimeMillis());
}
public void updateUriProperties(UriProperties uriProperties)
{
_uriRegistry.put(uriProperties.getClusterName(), uriProperties);
}
/**
* Run the simulation until no task in the queue or stopped by explicitly call (Async)
* @return
*/
public Future<Void> run()
{
return run(0);
}
/**
* Run the simulation for the provided duration (Async)
* @param duration
* @return
*/
public Future<Void> run(long duration)
{
return _clockedExecutor.run(duration <= 0 ? 0 : _clockedExecutor._currentTimeMillis + duration);
}
/**
* Run the simulation until the givenTime (Async)
* @param expectedTime
* @return
*/
public Future<Void> runUntil(long expectedTime)
{
return _clockedExecutor.run(expectedTime);
}
/**
* Run the simulation for the given duration (Sync)
* @param duration
*/
public void runWait(long duration)
{
Future<Void> running = run(duration);
if (running != null)
{
try
{
running.get();
}
catch (InterruptedException | ExecutionException e)
{
_log.error("Simulation error: " + e);
}
}
}
public void stop()
{
_clockedExecutor.stop();
}
public Map<URI, Integer> getClientCounters()
{
return _clientCounters;
}
public Clock getClock()
{
return _clockedExecutor;
}
public ScheduledExecutorService getExecutorService()
{
return _clockedExecutor;
}
public ClockedExecutor getClockedExecutor()
{
return _clockedExecutor;
}
/**
* Given a serviceName and partition number, return the hashring points for each URI
* @param serviceName
* @param partition
* @return
* @throws ServiceUnavailableException
*/
public Map<URI, Integer> getPoints(String serviceName, int partition) throws ServiceUnavailableException
{
URI serviceUri = URI.create("d2://" + serviceName);
Ring<URI> ring = _loadBalancer.getRings(serviceUri).get(partition);
Map<URI, Integer> pointsMap = new HashMap<>();
Iterator<URI> iter = ring.getIterator(0);
iter.forEachRemaining(uri -> pointsMap.compute(uri, (k, v) -> v == null ? 1: v + 1));
return pointsMap;
}
/**
* Get the point for the given uri
* @param serviceName
* @param partition
* @param uri
* @return
*/
public int getPoint(String serviceName, int partition, URI uri)
{
try
{
Map<URI, Integer> points = getPoints(serviceName, partition);
return points.getOrDefault(uri, 0);
}
catch (ServiceUnavailableException e)
{
return 0;
}
}
public int getPoint(String serviceName, int partition, String uriString)
{
return getPoint(serviceName, partition, URI.create("http://" + uriString));
}
/**
* Get the hitting percentage of the given uri (ie 'uri count'/'total inquiries')
* @param uri
* @return
*/
public double getCountPercent(URI uri)
{
return getPercentageFromMap(uri, getClientCounters());
}
private double getPercentageFromMap(URI uri, Map<URI, Integer> map)
{
if (!map.containsKey(uri))
{
return 0.0;
}
Integer total = map.values().stream().reduce(0, Integer::sum);
if (total == 0)
{
return 0.0;
}
return 1.0 * map.get(uri) / total;
}
/**
* A runnable task to send out request
*/
private class RequestTask implements Runnable
{
private String _serviceName;
public RequestTask(String serviceName)
{
_serviceName = serviceName;
}
@Override
public void run()
{
int qps = 0;
Map<URI, Long> uriDelays = new HashMap<>();
_clientCounters.clear();
try
{
qps = _qpsGenerator.nextQPS();
}
catch(IllegalArgumentException e)
{
return;
}
for (int i = 0; i < qps; ++i)
{
// construct the requests
URIRequest uriRequest = new URIRequest("d2://" + _serviceName + "/" + i);
RestRequest restRequest = new RestRequestBuilder(uriRequest.getURI()).build();
RequestContext requestContext = new RequestContext();
RewriteClient client = null;
try
{
client = (RewriteClient) _loadBalancer.getClient(restRequest, requestContext);
}
catch (ServiceUnavailableException e)
{
_log.error("Could not find service for request {}", restRequest.getURI());
Assert.fail("Failed to find the service");
}
TransportCallback<RestResponse> restCallback = (response) -> {
assertFalse(response.hasError());
_log.debug("Got response for {} @ {}", response.getResponse(), _clockedExecutor.currentTimeMillis());
// Do nothing for now for the response
};
URI clientUri = client.getUri();
_log.debug("Adding trackerclient for {}", clientUri);
// Increase the counter for each URI
_clientCounters.compute(clientUri, (k, v) -> v == null ? 1 : v + 1);
// send out the request
client.restRequest(restRequest, requestContext, Collections.emptyMap(), restCallback);
}
}
}
/**
* A simulated TransportClient, which schedules a delayed task to return the response.
*/
@SuppressWarnings("unchecked")
private static class DelayClientFactory implements TransportClientFactory
{
@Override
public TransportClient getClient(Map<String, ? extends Object> properties)
{
ClockedExecutor clockedExecutor = (ClockedExecutor) properties.get("ClockedExecutor");
TimedValueGenerator<String> delayGen = (TimedValueGenerator<String>) properties.get("DelayGenerator");
return new DelayClient(clockedExecutor, delayGen);
}
/**
* DelayClient is a TransportClient that can delay the response with a given time
*/
private class DelayClient implements TransportClient
{
final private ClockedExecutor _clockedExecutor;
final private TimedValueGenerator<String> _delayGen;
DelayClient(ClockedExecutor executor, TimedValueGenerator<String> delayGen)
{
_clockedExecutor = executor;
_delayGen = delayGen;
}
@Override
public void streamRequest(StreamRequest request,
RequestContext requestContext,
Map<String, String> wireAttrs,
TransportCallback<StreamResponse> callback)
{
throw new IllegalArgumentException("StreamRequest is not supported yet");
}
@Override
public void restRequest(RestRequest request,
RequestContext requestContext,
Map<String, String> wireAttrs,
TransportCallback<RestResponse> callback)
{
Long delay = _delayGen.getValue(request.getURI().getAuthority(), _clockedExecutor.currentTimeMillis(),
TimeUnit.MILLISECONDS);
_clockedExecutor.schedule(new Runnable() {
@Override
public void run()
{
RestResponse restResponse = new RestResponseBuilder().setEntity(request.getURI().getRawPath().getBytes()).build();
callback.onResponse(TransportResponseImpl.success(restResponse));
}
}, delay, TimeUnit.MILLISECONDS);
}
@Override
public void shutdown(Callback<None> callback)
{
callback.onSuccess(None.none());
}
}
@Override
public void shutdown(Callback<None> callback)
{
callback.onSuccess(None.none());
}
}
/**
* A simulated service executor and clock
*/
public class ClockedExecutor implements Clock, ScheduledExecutorService
{
private volatile long _currentTimeMillis = 0l;
private volatile Boolean _stopped = true;
private volatile long _runUntil = 0l;
private PriorityBlockingQueue<ClockedTask> _taskList = new PriorityBlockingQueue<>();
private ExecutorService _executorService = Executors.newFixedThreadPool(1);
public Future<Void> run(long untilTime)
{
if (!_stopped)
{
throw new IllegalArgumentException("Already Started!");
}
if (_taskList.isEmpty())
{
return null;
}
_stopped = false;
_runUntil = untilTime;
Future<Void> taskExecutor = _executorService.submit(() -> {
while (!_stopped && !_taskList.isEmpty() && (_runUntil <= 0l || _runUntil > _currentTimeMillis))
{
ClockedTask task = _taskList.peek();
long expectTime = task.getScheduledTime();
if (expectTime > _runUntil)
{
_currentTimeMillis = _runUntil;
break;
}
_taskList.remove();
if (expectTime > _currentTimeMillis)
{
_currentTimeMillis = expectTime;
}
_log.debug("Processing task " + task.toString() + " total {}, time {}",
_taskList.size(), _currentTimeMillis);
task.run();
if (task.repeatCount() > 0 && !task.isCancelled() && !_stopped)
{
task.reschedule(_currentTimeMillis);
_taskList.add(task);
}
}
_stopped = true;
return null;
});
return taskExecutor;
}
@Override
public ScheduledFuture<Void> schedule(Runnable cmd, long delay, TimeUnit unit)
{
ClockedTask task = new ClockedTask("ScheduledTask", cmd, _currentTimeMillis + delay);
_taskList.add(task);
return task;
}
@Override
public <Void> ScheduledFuture<Void> schedule(Callable<Void> callable, long delay, TimeUnit unit)
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public ScheduledFuture<Void> scheduleAtFixedRate(Runnable command, long initialDelay,
long period, TimeUnit unit)
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public ScheduledFuture<Void> scheduleWithFixedDelay(Runnable cmd, long initDelay, long interval,
TimeUnit unit)
{
ClockedTask task = new ClockedTask("scheduledWithDelayTask", cmd, _currentTimeMillis
+ unit.convert(initDelay, TimeUnit.MILLISECONDS), interval, Long.MAX_VALUE);
_taskList.add(task);
return task;
}
public void scheduleWithRepeat(Runnable cmd, long initDelay, long interval, long repeatTimes)
{
ClockedTask task = new ClockedTask("scheduledWithRepeatTask", cmd, _currentTimeMillis + initDelay, interval, repeatTimes);
_taskList.add(task);
}
@Override
public void execute(Runnable cmd)
{
ClockedTask task = new ClockedTask("executTask", cmd, _currentTimeMillis);
_taskList.add(task);
}
public void stop()
{
_stopped = true;
}
@Override
public void shutdown()
{
_stopped = true;
_executorService.shutdown();
}
@Override
public List<Runnable> shutdownNow()
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public boolean isShutdown()
{
return _stopped;
}
@Override
public boolean isTerminated()
{
return _stopped && _taskList.isEmpty();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException
{
runUntil(unit.convert(timeout, TimeUnit.MILLISECONDS));
return true;
}
@Override
public <T> Future<T> submit(Callable<T> task)
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public <T> Future<T> submit(Runnable task, T result)
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public Future<?> submit(Runnable task)
{
if (task == null)
{
throw new NullPointerException();
}
RunnableFuture<Void> ftask = new FutureTask<>(()->{}, null);
// Simulation only: Run the task in current thread
task.run();
return ftask;
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
throws InterruptedException
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks,
long timeout, TimeUnit unit)
throws InterruptedException
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks,
long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException
{
throw new IllegalArgumentException("Not supported yet!");
}
@Override
public long currentTimeMillis()
{
return _currentTimeMillis;
}
@Override
public String toString()
{
return "ClockedExecutor [_currentTimeMillis: " + _currentTimeMillis + "_taskList:"
+ _taskList.stream().map(e -> e.toString()).collect(Collectors.joining(","));
}
private class ClockedTask implements Runnable, ScheduledFuture<Void>
{
final private String _name;
private long _expectTimeMillis = 0l;
private long _interval = 0l;
private Runnable _task;
private long _repeatTimes = 0l;
private CountDownLatch _done;
private boolean _cancelled = false;
ClockedTask(String name, Runnable task, long scheduledTime)
{
this(name, task, scheduledTime, 0l, 0l);
}
ClockedTask(String name, Runnable task, long scheduledTime, long interval, long repeat)
{
_name = name;
_task = task;
_expectTimeMillis = scheduledTime;
_interval = interval;
_repeatTimes = repeat;
_done = new CountDownLatch(1);
_cancelled = false;
}
@Override
public void run()
{
if (!_cancelled)
{
_task.run();
_done.countDown();
}
}
long repeatCount()
{
return _repeatTimes;
}
long getScheduledTime()
{
return _expectTimeMillis;
}
void reschedule(long currentTime)
{
if (!_cancelled && currentTime >= _expectTimeMillis && _repeatTimes-- > 0)
{
_expectTimeMillis += (_interval - (currentTime - _expectTimeMillis));
_done = new CountDownLatch(1);
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning)
{
_cancelled = true;
if (_done.getCount() > 0)
{
_done.countDown();
return true;
}
return false;
}
@Override
public boolean isCancelled()
{
return _cancelled;
}
@Override
public boolean isDone()
{
return _done.getCount() == 0;
}
@Override
public Void get() throws InterruptedException
{
_done.await();
return null;
}
@Override
public Void get(long timeout, TimeUnit unit) throws InterruptedException
{
_done.await(timeout, unit);
return null;
}
@Override
public long getDelay(TimeUnit unit)
{
return unit.convert(_expectTimeMillis - _currentTimeMillis, TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed other)
{
return (int) (getDelay(TimeUnit.MILLISECONDS) - other.getDelay(TimeUnit.MILLISECONDS));
}
@Override
public String toString()
{
return "ClockedTask [_name=" + _name + "_expectedTime=" + _expectTimeMillis
+ "_repeatTimes=" + _repeatTimes + "_interval=" + _interval + "]";
}
}
}
}