package netflix.ocelli.rxnetty.internal;
import io.netty.channel.embedded.EmbeddedChannel;
import io.reactivex.netty.channel.Connection;
import io.reactivex.netty.channel.ConnectionImpl;
import io.reactivex.netty.client.ConnectionFactory;
import io.reactivex.netty.client.ConnectionObservable;
import io.reactivex.netty.client.ConnectionObservable.OnSubcribeFunc;
import io.reactivex.netty.client.ConnectionProvider;
import io.reactivex.netty.client.events.ClientEventListener;
import io.reactivex.netty.protocol.tcp.client.events.TcpClientEventListener;
import io.reactivex.netty.protocol.tcp.client.events.TcpClientEventPublisher;
import netflix.ocelli.Instance;
import netflix.ocelli.LoadBalancerStrategy;
import netflix.ocelli.loadbalancer.RoundRobinLoadBalancer;
import netflix.ocelli.rxnetty.FailureListener;
import org.junit.rules.ExternalResource;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import org.mockito.Mockito;
import rx.Observable;
import rx.Subscriber;
import rx.Subscription;
import rx.functions.Func1;
import rx.functions.Func3;
import rx.observers.TestSubscriber;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
public class LoadBalancerRule extends ExternalResource {
private Observable<Instance<SocketAddress>> hosts;
private Func1<FailureListener, ? extends TcpClientEventListener> eventListenerFactory;
private LoadBalancerStrategy<HostConnectionProvider<String, String>> loadBalancingStratgey;
private AbstractLoadBalancer<String, String> loadBalancer;
private Func3<Observable<Instance<SocketAddress>>, Func1<FailureListener, ? extends TcpClientEventListener>,
LoadBalancerStrategy<HostConnectionProvider<String, String>>, AbstractLoadBalancer<String, String>> lbFactory;
public LoadBalancerRule() {
}
public LoadBalancerRule(Func3<Observable<Instance<SocketAddress>>,
Func1<FailureListener, ? extends TcpClientEventListener>,
LoadBalancerStrategy<HostConnectionProvider<String, String>>,
AbstractLoadBalancer<String, String>> lbFactory) {
this.lbFactory = lbFactory;
}
@Override
public Statement apply(final Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
base.evaluate();
}
};
}
public AbstractLoadBalancer<String, String> getLoadBalancer() {
return loadBalancer;
}
public List<Instance<SocketAddress>> setupDefault() {
final List<Instance<SocketAddress>> instances = new ArrayList<>();
instances.add(new DummyInstance());
instances.add(new DummyInstance());
setup(instances.get(0).getValue(), instances.get(1).getValue());
return hosts.toList().toBlocking().single();
}
public AbstractLoadBalancer<String, String> setup(SocketAddress... hosts) {
return setup(new Func1<FailureListener, TcpClientEventListener>() {
@Override
public TcpClientEventListener call(FailureListener failureListener) {
return null;
}
}, hosts);
}
public AbstractLoadBalancer<String, String> setup(
Func1<FailureListener, ? extends TcpClientEventListener> eventListenerFactory,
SocketAddress... hosts) {
return setup(eventListenerFactory, new RoundRobinLoadBalancer<HostConnectionProvider<String, String>>(-1),
hosts);
}
public AbstractLoadBalancer<String, String> setup(
Func1<FailureListener, ? extends TcpClientEventListener> eventListenerFactory,
LoadBalancerStrategy<HostConnectionProvider<String, String>> loadBalancingStratgey,
SocketAddress... hosts) {
List<Instance<SocketAddress>> instances = new ArrayList<>(hosts.length);
for (SocketAddress host : hosts) {
instances.add(new DummyInstance(host));
}
return setup(eventListenerFactory, loadBalancingStratgey, instances);
}
public AbstractLoadBalancer<String, String> setup(
Func1<FailureListener, ? extends TcpClientEventListener> eventListenerFactory,
LoadBalancerStrategy<HostConnectionProvider<String, String>> loadBalancingStratgey,
List<Instance<SocketAddress>> hosts) {
this.hosts = Observable.from(hosts);
this.eventListenerFactory = eventListenerFactory;
this.loadBalancingStratgey = loadBalancingStratgey;
if (null != lbFactory) {
loadBalancer = lbFactory.call(this.hosts, eventListenerFactory, loadBalancingStratgey);
return loadBalancer;
}
loadBalancer = new AbstractLoadBalancer<String, String>(this.hosts, eventListenerFactory,
loadBalancingStratgey) {
@Override
protected ConnectionProvider<String, String> newConnectionProviderForHost(final Instance<SocketAddress> host,
final ConnectionFactory<String, String> connectionFactory) {
return new ConnectionProvider<String, String>(connectionFactory) {
@Override
public ConnectionObservable<String, String> nextConnection() {
return connectionFactory.newConnection(host.getValue());
}
};
}
};
return getLoadBalancer();
}
public Func1<FailureListener, ? extends TcpClientEventListener> getEventListenerFactory() {
return eventListenerFactory;
}
public Observable<Instance<SocketAddress>> getHosts() {
return hosts;
}
public Observable<Instance<ConnectionProvider<String, String>>> getHostsAsConnectionProviders(
final ConnectionFactory<String, String> cfMock) {
return hosts.map(new Func1<Instance<SocketAddress>, Instance<ConnectionProvider<String, String>>>() {
@Override
public Instance<ConnectionProvider<String, String>> call(final Instance<SocketAddress> i) {
final ConnectionProvider<String, String> cp = new ConnectionProvider<String, String>(cfMock) {
@Override
public ConnectionObservable<String, String> nextConnection() {
return cfMock.newConnection(i.getValue());
}
};
return new Instance<ConnectionProvider<String, String>>() {
@Override
public Observable<Void> getLifecycle() {
return i.getLifecycle();
}
@Override
public ConnectionProvider<String, String> getValue() {
return cp;
}
};
}
});
}
public LoadBalancerStrategy<HostConnectionProvider<String, String>> getLoadBalancingStratgey() {
return loadBalancingStratgey;
}
public Connection<String, String> connect(ConnectionObservable<String, String> connectionObservable) {
TestSubscriber<Connection<String, String>> testSub = new TestSubscriber<>();
connectionObservable.subscribe(testSub);
testSub.awaitTerminalEvent();
testSub.assertNoErrors();
testSub.assertValueCount(1);
return testSub.getOnNextEvents().get(0);
}
public ConnectionFactory<String, String> newConnectionFactoryMock() {
@SuppressWarnings("unchecked")
final
ConnectionFactory<String, String> cfMock = Mockito.mock(ConnectionFactory.class);
List<Instance<SocketAddress>> instances = hosts.toList().toBlocking().single();
for (Instance<SocketAddress> instance : instances) {
EmbeddedChannel channel = new EmbeddedChannel();
final TcpClientEventPublisher eventPublisher = new TcpClientEventPublisher();
final Connection<String, String> mockConnection =
ConnectionImpl.create(channel, eventPublisher, eventPublisher);
Mockito.when(cfMock.newConnection(instance.getValue()))
.thenReturn(ConnectionObservable.createNew(new OnSubcribeFunc<String, String>() {
@Override
public Subscription subscribeForEvents(ClientEventListener eventListener) {
return eventPublisher.subscribe((TcpClientEventListener) eventListener);
}
@Override
public void call(Subscriber<? super Connection<String, String>> subscriber) {
subscriber.onNext(mockConnection);
subscriber.onCompleted();
}
}));
}
return cfMock;
}
private static class DummyInstance extends Instance<SocketAddress> {
private final SocketAddress socketAddress;
private DummyInstance() {
socketAddress = new SocketAddress() {
private static final long serialVersionUID = 711795406919943230L;
@Override
public String toString() {
return "Dummy socket address: " + hashCode();
}
};
}
private DummyInstance(SocketAddress socketAddress) {
this.socketAddress = socketAddress;
}
@Override
public Observable<Void> getLifecycle() {
return Observable.never();
}
@Override
public SocketAddress getValue() {
return socketAddress;
}
}
}