package io.scalecube.transport;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ServerChannel;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder;
import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.schedulers.Schedulers;
import rx.subjects.PublishSubject;
import rx.subjects.Subject;
import java.net.BindException;
import java.net.InetAddress;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
final class TransportImpl implements Transport {
private static final Logger LOGGER = LoggerFactory.getLogger(TransportImpl.class);
private static final CompletableFuture<Void> COMPLETED_PROMISE = CompletableFuture.completedFuture(null);
private final TransportConfig config;
private final Subject<Message, Message> incomingMessagesSubject = PublishSubject.<Message>create().toSerialized();
private final Map<Address, ChannelFuture> outgoingChannels = new ConcurrentHashMap<>();
// Pipeline
private final BootstrapFactory bootstrapFactory;
private final IncomingChannelInitializer incomingChannelInitializer = new IncomingChannelInitializer();
private final ExceptionHandler exceptionHandler = new ExceptionHandler();
private final MessageToByteEncoder<Message> serializerHandler;
private final MessageToMessageDecoder<ByteBuf> deserializerHandler;
private final MessageHandler messageHandler;
// Network emulator
private NetworkEmulator networkEmulator;
private NetworkEmulatorHandler networkEmulatorHandler;
private Address address;
private ServerChannel serverChannel;
private volatile boolean stopped = false;
public TransportImpl(TransportConfig config) {
checkArgument(config != null);
this.config = config;
this.serializerHandler = new MessageSerializerHandler();
this.deserializerHandler = new MessageDeserializerHandler();
this.messageHandler = new MessageHandler(incomingMessagesSubject);
this.bootstrapFactory = new BootstrapFactory(config);
}
/**
* Starts to accept connections on local address.
*/
public CompletableFuture<Transport> bind0() {
incomingMessagesSubject.subscribeOn(Schedulers.from(bootstrapFactory.getWorkerGroup()));
// Resolve listen IP address
final InetAddress listenAddress =
Addressing.getLocalIpAddress(config.getListenAddress(), config.getListenInterface(), config.isPreferIPv6());
// Resolve listen port
int bindPort = config.isPortAutoIncrement()
? Addressing.getNextAvailablePort(listenAddress, config.getPort(), config.getPortCount()) // Find available port
: config.getPort();
// Listen address
address = Address.create(listenAddress.getHostAddress(), bindPort);
ServerBootstrap server = bootstrapFactory.serverBootstrap().childHandler(incomingChannelInitializer);
ChannelFuture bindFuture = server.bind(listenAddress, address.port());
final CompletableFuture<Transport> result = new CompletableFuture<>();
bindFuture.addListener((ChannelFutureListener) channelFuture -> {
if (channelFuture.isSuccess()) {
serverChannel = (ServerChannel) channelFuture.channel();
networkEmulator = new NetworkEmulator(address, config.isUseNetworkEmulator());
networkEmulatorHandler = config.isUseNetworkEmulator() ? new NetworkEmulatorHandler(networkEmulator) : null;
LOGGER.info("Bound to: {}", address);
result.complete(TransportImpl.this);
} else {
Throwable cause = channelFuture.cause();
if (config.isPortAutoIncrement() && isAddressAlreadyInUseException(cause)) {
LOGGER.warn("Can't bind to address {}, try again on different port [cause={}]", address, cause.toString());
bind0().thenAccept(result::complete);
} else {
LOGGER.error("Failed to bind to: {}, cause: {}", address, cause);
result.completeExceptionally(cause);
}
}
});
return result;
}
private boolean isAddressAlreadyInUseException(Throwable exception) {
return exception instanceof BindException
|| (exception.getMessage() != null && exception.getMessage().contains("Address already in use"));
}
@Override
@Nonnull
public Address address() {
return address;
}
@Override
public boolean isStopped() {
return stopped;
}
@Nonnull
@Override
public NetworkEmulator networkEmulator() {
return networkEmulator;
}
@Override
public final void stop() {
stop(COMPLETED_PROMISE);
}
@Override
public final void stop(CompletableFuture<Void> promise) {
checkState(!stopped, "Transport is stopped");
checkArgument(promise != null);
stopped = true;
// Complete incoming messages observable
try {
incomingMessagesSubject.onCompleted();
} catch (Exception ignore) {
// ignore
}
// close connected channels
for (Address address : outgoingChannels.keySet()) {
ChannelFuture channelFuture = outgoingChannels.get(address);
if (channelFuture == null) {
continue;
}
if (channelFuture.isSuccess()) {
channelFuture.channel().close();
} else {
channelFuture.addListener(ChannelFutureListener.CLOSE);
}
}
outgoingChannels.clear();
// close server channel
if (serverChannel != null) {
composeFutures(serverChannel.close(), promise);
}
// TODO [AK]: shutdown boss/worker threads and listen for their futures
bootstrapFactory.shutdown();
}
@Nonnull
@Override
public final Observable<Message> listen() {
checkState(!stopped, "Transport is stopped");
return incomingMessagesSubject.onBackpressureBuffer().asObservable();
}
@Override
public void send(@CheckForNull Address address, @CheckForNull Message message) {
send(address, message, COMPLETED_PROMISE);
}
@Override
public void send(@CheckForNull Address address, @CheckForNull Message message,
@CheckForNull CompletableFuture<Void> promise) {
checkState(!stopped, "Transport is stopped");
checkArgument(address != null);
checkArgument(message != null);
checkArgument(promise != null);
message.setSender(this.address);
final ChannelFuture channelFuture = outgoingChannels.computeIfAbsent(address, this::connect);
if (channelFuture.isSuccess()) {
send(channelFuture.channel(), message, promise);
} else {
channelFuture.addListener((ChannelFuture chFuture) -> {
if (chFuture.isSuccess()) {
send(channelFuture.channel(), message, promise);
} else {
promise.completeExceptionally(chFuture.cause());
}
});
}
}
private void send(Channel channel, Message message, CompletableFuture<Void> promise) {
if (promise == COMPLETED_PROMISE) {
channel.writeAndFlush(message, channel.voidPromise());
} else {
composeFutures(channel.writeAndFlush(message), promise);
}
}
/**
* Converts netty {@link ChannelFuture} to the given {@link CompletableFuture}.
*
* @param channelFuture netty channel future
* @param promise guava future; can be null
*/
private void composeFutures(ChannelFuture channelFuture, @Nonnull final CompletableFuture<Void> promise) {
channelFuture.addListener((ChannelFuture future) -> {
if (channelFuture.isSuccess()) {
promise.complete(channelFuture.get());
} else {
promise.completeExceptionally(channelFuture.cause());
}
});
}
private ChannelFuture connect(Address address) {
OutgoingChannelInitializer channelInitializer = new OutgoingChannelInitializer(address);
Bootstrap client = bootstrapFactory.clientBootstrap().handler(channelInitializer);
ChannelFuture connectFuture = client.connect(address.host(), address.port());
// Register logger and cleanup listener
connectFuture.addListener((ChannelFutureListener) channelFuture -> {
if (channelFuture.isSuccess()) {
LOGGER.debug("Connected from {} to {}: {}", TransportImpl.this.address, address, channelFuture.channel());
} else {
LOGGER.warn("Failed to connect from {} to {}", TransportImpl.this.address, address);
outgoingChannels.remove(address);
}
});
return connectFuture;
}
@ChannelHandler.Sharable
private final class IncomingChannelInitializer extends ChannelInitializer {
@Override
protected void initChannel(Channel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(new ProtobufVarint32FrameDecoder());
pipeline.addLast(deserializerHandler);
pipeline.addLast(messageHandler);
pipeline.addLast(exceptionHandler);
}
}
@ChannelHandler.Sharable
private final class OutgoingChannelInitializer extends ChannelInitializer {
private final Address address;
public OutgoingChannelInitializer(Address address) {
this.address = address;
}
@Override
protected void initChannel(Channel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(new ChannelDuplexHandler() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
LOGGER.debug("Disconnected from: {} {}", address, ctx.channel());
outgoingChannels.remove(address);
super.channelInactive(ctx);
}
});
pipeline.addLast(new ProtobufVarint32LengthFieldPrepender());
pipeline.addLast(serializerHandler);
if (networkEmulatorHandler != null) {
pipeline.addLast(networkEmulatorHandler);
}
pipeline.addLast(exceptionHandler);
}
}
}