/*
* JBoss, Home of Professional Open Source.
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership. Some portions may be licensed
* to Red Hat, Inc. under one or more contributor license agreements.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301 USA.
*/
/**
*
*/
package org.teiid.transport;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.net.InetAddress;
import java.net.SocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.teiid.logging.LogConstants;
import org.teiid.logging.LogManager;
import org.teiid.logging.MessageLevel;
import org.teiid.net.socket.ObjectChannel;
import org.teiid.runtime.RuntimePlugin;
/**
* Main class for creating Netty Nio Channels
*/
@Sharable
public class SSLAwareChannelHandler extends ChannelDuplexHandler {
public class ObjectChannelImpl implements ObjectChannel {
private final Channel channel;
public ObjectChannelImpl(Channel channel) {
this.channel = channel;
}
public void close() {
channel.close();
}
public boolean isOpen() {
return channel.isOpen();
}
public SocketAddress getRemoteAddress() {
return channel.remoteAddress();
}
@Override
public InetAddress getLocalAddress() {
throw new UnsupportedOperationException();
}
@Override
public Object read() throws IOException,
ClassNotFoundException {
throw new UnsupportedOperationException();
}
public Future<?> write(Object msg) {
final ChannelFuture future = channel.write(msg);
channel.flush();
future.addListener(completionListener);
return new Future<Void>() {
@Override
public boolean cancel(boolean arg0) {
return future.cancel(arg0);
}
@Override
public Void get() throws InterruptedException,
ExecutionException {
future.await();
if (!future.isSuccess()) {
throw new ExecutionException(future.cause());
}
return null;
}
@Override
public Void get(long arg0, TimeUnit arg1)
throws InterruptedException, ExecutionException,
TimeoutException {
if (future.await(arg0, arg1)) {
if (!future.isSuccess()) {
throw new ExecutionException(future.cause());
}
return null;
}
throw new TimeoutException();
}
@Override
public boolean isCancelled() {
return future.isCancelled();
}
@Override
public boolean isDone() {
return future.isDone();
}
};
}
}
private final ChannelListener.ChannelListenerFactory listenerFactory;
private Map<Channel, ChannelListener> listeners = new ConcurrentHashMap<Channel, ChannelListener>();
private AtomicLong objectsRead = new AtomicLong(0);
private AtomicLong objectsWritten = new AtomicLong(0);
private volatile int maxChannels;
private ChannelFutureListener completionListener = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture arg0)
throws Exception {
if (arg0.isSuccess()) {
objectsWritten.getAndIncrement();
} else if (arg0.cause() != null) {
writeExceptionCaught(arg0.channel(), arg0.cause());
}
}
};
public SSLAwareChannelHandler(ChannelListener.ChannelListenerFactory listenerFactory) {
this.listenerFactory = listenerFactory;
}
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
ChannelListener listener = this.listenerFactory.createChannelListener(new ObjectChannelImpl(ctx.channel()));
this.listeners.put(ctx.channel(), listener);
maxChannels = Math.max(maxChannels, this.listeners.size());
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler != null) {
sslHandler.handshakeFuture().addListener(new GenericFutureListener<DefaultPromise<Channel>>() {
@Override
public void operationComplete(DefaultPromise<Channel> future)
throws Exception {
onConnection(ctx.channel());
}
});
} else {
onConnection(ctx.channel());
}
}
private void onConnection(Channel channel) throws Exception {
ChannelListener listener = this.listeners.get(channel);
if (listener != null) {
listener.onConnection();
}
}
private void writeExceptionCaught(Channel channel,
Throwable cause) {
ChannelListener listener = this.listeners.get(channel);
if (listener != null) {
listener.exceptionOccurred(cause);
} else {
int level = SocketClientInstance.getLevel(cause);
LogManager.log(level, LogConstants.CTX_TRANSPORT, LogManager.isMessageToBeRecorded(LogConstants.CTX_TRANSPORT, MessageLevel.DETAIL)||level<MessageLevel.WARNING?cause:null, RuntimePlugin.Util.gs(RuntimePlugin.Event.TEIID40114, cause.getMessage()));
channel.close();
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
writeExceptionCaught(ctx.channel(), cause);
}
public void messageReceived(ChannelHandlerContext ctx,
Object msg) throws Exception {
objectsRead.getAndIncrement();
ChannelListener listener = this.listeners.get(ctx.channel());
if (listener != null) {
listener.receivedMessage(msg);
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
messageReceived(ctx, msg);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
ChannelListener listener = this.listeners.remove(ctx.channel());
if (listener != null) {
LogManager.logDetail(LogConstants.CTX_TRANSPORT,
RuntimePlugin.Util.getString("SSLAwareChannelHandler.channel_closed")); //$NON-NLS-1$
listener.disconnected();
}
}
public long getObjectsRead() {
return this.objectsRead.get();
}
public long getObjectsWritten() {
return this.objectsWritten.get();
}
public int getConnectedChannels() {
return this.listeners.size();
}
public int getMaxConnectedChannels() {
return this.maxChannels;
}
}