/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.atomix.catalyst.transport.netty;
import io.atomix.catalyst.concurrent.Listener;
import io.atomix.catalyst.concurrent.Listeners;
import io.atomix.catalyst.concurrent.Scheduled;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.serializer.SerializationException;
import io.atomix.catalyst.transport.Connection;
import io.atomix.catalyst.util.Assert;
import io.atomix.catalyst.util.reference.ReferenceCounted;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import java.net.ConnectException;
import java.time.Duration;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
/**
* Netty connection.
*
* @author <a href="http://github.com/kuujo">Jordan Halterman</a>
*/
public class NettyConnection implements Connection {
static final byte REQUEST = 0x01;
static final byte RESPONSE = 0x02;
static final byte SUCCESS = 0x03;
static final byte FAILURE = 0x04;
private static final ThreadLocal<ByteBufInput> INPUT = new ThreadLocal<ByteBufInput>() {
@Override
protected ByteBufInput initialValue() {
return new ByteBufInput();
}
};
private static final ThreadLocal<ByteBufOutput> OUTPUT = new ThreadLocal<ByteBufOutput>() {
@Override
protected ByteBufOutput initialValue() {
return new ByteBufOutput();
}
};
private final Channel channel;
private final ThreadContext context;
private final Map<Class, HandlerHolder> handlers = new ConcurrentHashMap<>();
private final Listeners<Throwable> exceptionListeners = new Listeners<>();
private final Listeners<Connection> closeListeners = new Listeners<>();
private final long requestTimeout;
private volatile long requestId;
private volatile Throwable failure;
private volatile boolean closed;
private Scheduled timeout;
private final Map<Long, ContextualFuture> responseFutures = new ConcurrentSkipListMap<>();
private ChannelFuture writeFuture;
/**
* @throws NullPointerException if any argument is null
*/
public NettyConnection(Channel channel, ThreadContext context, NettyOptions options) {
this.channel = channel;
this.context = context;
this.requestTimeout = options.requestTimeout();
this.timeout = context.schedule(Duration.ofMillis(requestTimeout / 2), Duration.ofMillis(requestTimeout / 2), this::timeout);
}
/**
* Handles a request.
*/
void handleRequest(ByteBuf buffer) {
long requestId = buffer.readLong();
try {
Object request = readRequest(buffer);
HandlerHolder handler = handlers.get(request.getClass());
if (handler != null) {
handler.context.executor().execute(() -> handleRequest(requestId, request, handler));
} else {
handleRequestFailure(requestId, new SerializationException("unknown message type: " + request.getClass()), this.context);
}
} catch (SerializationException e) {
handleRequestFailure(requestId, e, this.context);
} finally {
buffer.release();
}
}
/**
* Handles a request.
*/
private void handleRequest(long requestId, Object request, HandlerHolder handler) {
@SuppressWarnings("unchecked")
CompletableFuture<Object> responseFuture = handler.handler.apply(request);
if (responseFuture != null) {
responseFuture.whenComplete((response, error) -> {
ThreadContext context = ThreadContext.currentContext();
if (context == null) {
this.context.executor().execute(() -> {
if (error == null) {
handleRequestSuccess(requestId, response, this.context);
} else {
handleRequestFailure(requestId, error, this.context);
}
});
} else {
if (error == null) {
handleRequestSuccess(requestId, response, context);
} else {
handleRequestFailure(requestId, error, context);
}
}
});
}
}
/**
* Handles a request response.
*/
private void handleRequestSuccess(long requestId, Object response, ThreadContext context) {
ByteBuf buffer = channel.alloc().buffer(10)
.writeByte(RESPONSE)
.writeLong(requestId)
.writeByte(SUCCESS);
try {
writeResponse(buffer, response, context);
} catch (SerializationException e) {
handleRequestFailure(requestId, e, context);
return;
}
channel.writeAndFlush(buffer, channel.voidPromise());
if (response instanceof ReferenceCounted) {
((ReferenceCounted) response).release();
}
}
/**
* Handles a request failure.
*/
private void handleRequestFailure(long requestId, Throwable error, ThreadContext context) {
ByteBuf buffer = channel.alloc().buffer(10)
.writeByte(RESPONSE)
.writeLong(requestId)
.writeByte(FAILURE);
try {
writeError(buffer, error, context);
} catch (SerializationException e) {
return;
}
channel.writeAndFlush(buffer, channel.voidPromise());
}
/**
* Handles response.
*/
void handleResponse(ByteBuf response) {
long requestId = response.readLong();
byte status = response.readByte();
switch (status) {
case SUCCESS:
try {
handleResponseSuccess(requestId, readResponse(response));
} catch (SerializationException e) {
handleResponseFailure(requestId, e);
}
break;
case FAILURE:
try {
handleResponseFailure(requestId, readError(response));
} catch (SerializationException e) {
handleResponseFailure(requestId, e);
}
break;
}
response.release();
}
/**
* Handles a successful response.
*/
@SuppressWarnings("unchecked")
private void handleResponseSuccess(long requestId, Object response) {
ContextualFuture future = responseFutures.remove(requestId);
if (future != null) {
future.context.executor().execute(() -> future.complete(response));
}
}
/**
* Handles a failure response.
*/
private void handleResponseFailure(long requestId, Throwable t) {
ContextualFuture future = responseFutures.remove(requestId);
if (future != null) {
future.context.executor().execute(() -> future.completeExceptionally(t));
}
}
/**
* Writes a request to the given buffer.
*/
private ByteBuf writeRequest(ByteBuf buffer, Object request, ThreadContext context) {
context.serializer().writeObject(request, OUTPUT.get().setByteBuf(buffer));
if (request instanceof ReferenceCounted) {
((ReferenceCounted) request).release();
}
return buffer;
}
/**
* Writes a response to the given buffer.
*/
private ByteBuf writeResponse(ByteBuf buffer, Object request, ThreadContext context) {
context.serializer().writeObject(request, OUTPUT.get().setByteBuf(buffer));
return buffer;
}
/**
* Writes an error to the given buffer.
*/
private ByteBuf writeError(ByteBuf buffer, Throwable t, ThreadContext context) {
context.serializer().writeObject(t, OUTPUT.get().setByteBuf(buffer));
return buffer;
}
/**
* Reads a request from the given buffer.
*/
private Object readRequest(ByteBuf buffer) {
return context.serializer().readObject(INPUT.get().setByteBuf(buffer));
}
/**
* Reads a response from the given buffer.
*/
private Object readResponse(ByteBuf buffer) {
return context.serializer().readObject(INPUT.get().setByteBuf(buffer));
}
/**
* Reads an error from the given buffer.
*/
private Throwable readError(ByteBuf buffer) {
return context.serializer().readObject(INPUT.get().setByteBuf(buffer));
}
/**
* Handles an exception.
*
* @param t The exception to handle.
*/
void handleException(Throwable t) {
if (failure == null) {
failure = t;
for (ContextualFuture<?> responseFuture : responseFutures.values()) {
responseFuture.context.executor().execute(() -> responseFuture.completeExceptionally(t));
}
responseFutures.clear();
for (Listener<Throwable> listener : exceptionListeners) {
listener.accept(t);
}
}
}
/**
* Handles the channel being closed.
*/
void handleClosed() {
if (!closed) {
closed = true;
for (ContextualFuture<?> responseFuture : responseFutures.values()) {
responseFuture.context.executor().execute(() -> responseFuture.completeExceptionally(new ConnectException("connection closed")));
}
responseFutures.clear();
for (Listener<Connection> listener : closeListeners) {
listener.accept(this);
}
timeout.cancel();
}
}
/**
* Times out requests.
*/
void timeout() {
long time = System.currentTimeMillis();
Iterator<Map.Entry<Long, ContextualFuture>> iterator = responseFutures.entrySet().iterator();
while (iterator.hasNext()) {
ContextualFuture future = iterator.next().getValue();
if (future.time + requestTimeout < time) {
iterator.remove();
future.context.executor().execute(() -> future.completeExceptionally(new TimeoutException("request timed out")));
} else {
break;
}
}
}
@Override
public CompletableFuture<Void> send(Object request) {
Assert.notNull(request, "request");
ThreadContext context = ThreadContext.currentContextOrThrow();
ContextualFuture<Void> future = new ContextualFuture<>(System.currentTimeMillis(), context);
long requestId = ++this.requestId;
ByteBuf buffer = this.channel.alloc().buffer(9)
.writeByte(REQUEST)
.writeLong(requestId);
try {
writeRequest(buffer, request, context);
} catch (SerializationException e) {
future.completeExceptionally(e);
return future;
}
responseFutures.put(requestId, future);
writeFuture = channel.writeAndFlush(buffer).addListener((channelFuture) -> {
if (channelFuture.isSuccess()) {
future.context.executor().execute(() -> future.complete(null));
} else {
future.context.executor().execute(() -> future.completeExceptionally(channelFuture.cause()));
}
});
return future;
}
@Override
public <T, U> CompletableFuture<U> sendAndReceive(T request) {
Assert.notNull(request, "request");
ThreadContext context = ThreadContext.currentContextOrThrow();
ContextualFuture<U> future = new ContextualFuture<>(System.currentTimeMillis(), context);
long requestId = ++this.requestId;
ByteBuf buffer = this.channel.alloc().buffer(9)
.writeByte(REQUEST)
.writeLong(requestId);
try {
writeRequest(buffer, request, context);
} catch (SerializationException e) {
future.completeExceptionally(e);
return future;
}
responseFutures.put(requestId, future);
writeFuture = channel.writeAndFlush(buffer).addListener((channelFuture) -> {
if (channelFuture.isSuccess()) {
if (closed) {
ContextualFuture responseFuture = responseFutures.remove(requestId);
if (responseFuture != null) {
responseFuture.context.executor().execute(() -> responseFuture.completeExceptionally(new ConnectException("connection closed")));
}
}
} else {
future.context.executor().execute(() -> future.completeExceptionally(channelFuture.cause()));
}
});
return future;
}
@Override
public <T, U> Connection handler(Class<T> type, Consumer<T> handler) {
return handler(type, r -> {
handler.accept(r);
return null;
});
}
@Override
public <T, U> Connection handler(Class<T> type, Function<T, CompletableFuture<U>> handler) {
Assert.notNull(type, "type");
handlers.put(type, new HandlerHolder(handler, ThreadContext.currentContextOrThrow()));
return null;
}
@Override
public Listener<Throwable> onException(Consumer<Throwable> listener) {
if (failure != null) {
listener.accept(failure);
}
return exceptionListeners.add(Assert.notNull(listener, "listener"));
}
@Override
public Listener<Connection> onClose(Consumer<Connection> listener) {
if (closed) {
listener.accept(this);
}
return closeListeners.add(Assert.notNull(listener, "listener"));
}
@Override
public CompletableFuture<Void> close() {
ThreadContext context = ThreadContext.currentContextOrThrow();
CompletableFuture<Void> future = new CompletableFuture<>();
if (writeFuture != null && !writeFuture.isDone()) {
writeFuture.addListener(channelFuture -> {
channel.close().addListener(closeFuture -> {
if (closeFuture.isSuccess()) {
context.executor().execute(() -> future.complete(null));
} else {
context.executor().execute(() -> future.completeExceptionally(closeFuture.cause()));
}
});
});
} else {
channel.close().addListener(closeFuture -> {
if (closeFuture.isSuccess()) {
context.executor().execute(() -> future.complete(null));
} else {
context.executor().execute(() -> future.completeExceptionally(closeFuture.cause()));
}
});
}
return future;
}
@Override
public int hashCode() {
return channel.hashCode();
}
@Override
public boolean equals(Object object) {
return object instanceof NettyConnection && ((NettyConnection) object).channel.equals(channel);
}
/**
* Holds message handler and thread context.
*/
protected static class HandlerHolder {
private final Function<Object, CompletableFuture<Object>> handler;
private final ThreadContext context;
@SuppressWarnings("unchecked")
private HandlerHolder(Function handler, ThreadContext context) {
this.handler = handler;
this.context = context;
}
}
/**
* Contextual future.
*/
private static class ContextualFuture<T> extends CompletableFuture<T> {
private final long time;
private final ThreadContext context;
private ContextualFuture(long time, ThreadContext context) {
this.time = time;
this.context = context;
}
}
}