// Copyright 2012 Citrix Systems, Inc. Licensed under the
// Apache License, Version 2.0 (the "License"); you may not use this
// file except in compliance with the License. Citrix Systems, Inc.
// reserves all rights not expressly granted by the License.
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Automatically generated by addcopyright.py at 04/03/2012
package com.cloud.utils.nio;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.apache.log4j.Logger;
import com.cloud.utils.concurrency.NamedThreadFactory;
/**
* NioConnection abstracts the NIO socket operations. The Java implementation
* provides that.
*/
public abstract class NioConnection implements Runnable {
private static final Logger s_logger = Logger.getLogger(NioConnection.class);;
protected Selector _selector;
protected Thread _thread;
protected boolean _isRunning;
protected boolean _isStartup;
protected int _port;
protected List<ChangeRequest> _todos;
protected HandlerFactory _factory;
protected String _name;
protected ExecutorService _executor;
public NioConnection(String name, int port, int workers, HandlerFactory factory) {
_name = name;
_isRunning = false;
_thread = null;
_selector = null;
_port = port;
_factory = factory;
_executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue<Runnable>(), new NamedThreadFactory(name + "-Handler"));
}
public void start() {
_todos = new ArrayList<ChangeRequest>();
_thread = new Thread(this, _name + "-Selector");
_isRunning = true;
_thread.start();
// Wait until we got init() done
synchronized(_thread) {
try {
_thread.wait();
} catch (InterruptedException e) {
s_logger.warn("Interrupted start thread ", e);
}
}
}
public void stop() {
_executor.shutdown();
_isRunning = false;
if (_thread != null) {
_thread.interrupt();
}
}
public boolean isRunning() {
return _thread.isAlive();
}
public boolean isStartup() {
return _isStartup;
}
@Override
public void run() {
synchronized(_thread) {
try {
init();
} catch (ConnectException e) {
s_logger.error("Unable to connect to remote");
return;
} catch (IOException e) {
s_logger.error("Unable to initialize the threads.", e);
return;
} catch (Exception e) {
s_logger.error("Unable to initialize the threads due to unknown exception.", e);
return;
}
_isStartup = true;
_thread.notifyAll();
}
while (_isRunning) {
try {
_selector.select();
// Someone is ready for I/O, get the ready keys
Set<SelectionKey> readyKeys = _selector.selectedKeys();
Iterator<SelectionKey> i = readyKeys.iterator();
if (s_logger.isTraceEnabled()) {
s_logger.trace("Keys Processing: " + readyKeys.size());
}
// Walk through the ready keys collection.
while (i.hasNext()) {
SelectionKey sk = i.next();
i.remove();
if (!sk.isValid()) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Selection Key is invalid: " + sk.toString());
}
Link link = (Link)sk.attachment();
if (link != null) {
link.terminated();
} else {
closeConnection(sk);
}
} else if (sk.isReadable()) {
read(sk);
} else if (sk.isWritable()) {
write(sk);
} else if (sk.isAcceptable()) {
accept(sk);
} else if (sk.isConnectable()) {
connect(sk);
}
}
s_logger.trace("Keys Done Processing.");
processTodos();
} catch (Throwable e) {
s_logger.warn("Caught an exception but continuing on.", e);
}
}
synchronized(_thread) {
_isStartup = false;
}
}
abstract void init() throws IOException;
abstract void registerLink(InetSocketAddress saddr, Link link);
abstract void unregisterLink(InetSocketAddress saddr);
protected void accept(SelectionKey key) throws IOException {
ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
SocketChannel socketChannel = serverSocketChannel.accept();
Socket socket = socketChannel.socket();
socket.setKeepAlive(true);
if (s_logger.isTraceEnabled()) {
s_logger.trace("Connection accepted for " + socket);
}
// Begin SSL handshake in BLOCKING mode
socketChannel.configureBlocking(true);
SSLEngine sslEngine = null;
try {
SSLContext sslContext = Link.initSSLContext(false);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
sslEngine.setNeedClientAuth(false);
Link.doHandshake(socketChannel, sslEngine, false);
} catch (Exception e) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
}
try {
socketChannel.close();
socket.close();
} catch (IOException ignore) {
}
return;
}
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Handshake done");
}
socketChannel.configureBlocking(false);
InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
Link link = new Link(saddr, this);
link.setSSLEngine(sslEngine);
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
Task task = _factory.create(Task.Type.CONNECT, link, null);
registerLink(saddr, link);
_executor.execute(task);
}
protected void terminate(SelectionKey key) {
Link link = (Link)key.attachment();
closeConnection(key);
if (link != null) {
link.terminated();
Task task = _factory.create(Task.Type.DISCONNECT, link, null);
unregisterLink(link.getSocketAddress());
_executor.execute(task);
}
}
protected void read(SelectionKey key) throws IOException {
Link link = (Link)key.attachment();
try {
SocketChannel socketChannel = (SocketChannel)key.channel();
if (s_logger.isTraceEnabled()) {
s_logger.trace("Reading from: " + socketChannel.socket().toString());
}
byte[] data = link.read(socketChannel);
if (data == null) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Packet is incomplete. Waiting for more.");
}
return;
}
Task task = _factory.create(Task.Type.DATA, link, data);
_executor.execute(task);
} catch (Exception e) {
logDebug(e, key, 1);
terminate(key);
}
}
protected void logTrace(Exception e, SelectionKey key, int loc) {
if (s_logger.isTraceEnabled()) {
Socket socket = null;
if (key != null) {
SocketChannel ch = (SocketChannel)key.channel();
if (ch != null) {
socket = ch.socket();
}
}
s_logger.trace("Location " + loc + ": Socket " + socket + " closed on read. Probably -1 returned.");
}
}
protected void logDebug(Exception e, SelectionKey key, int loc) {
if (s_logger.isDebugEnabled()) {
Socket socket = null;
if (key != null) {
SocketChannel ch = (SocketChannel)key.channel();
if (ch != null) {
socket = ch.socket();
}
}
s_logger.debug("Location " + loc + ": Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
}
}
protected void processTodos() {
List<ChangeRequest> todos;
if (_todos.size() == 0) {
return; // Nothing to do.
}
synchronized (this) {
todos = _todos;
_todos = new ArrayList<ChangeRequest>();
}
if (s_logger.isTraceEnabled()) {
s_logger.trace("Todos Processing: " + todos.size());
}
SelectionKey key;
for (ChangeRequest todo : todos) {
switch (todo.type) {
case ChangeRequest.CHANGEOPS :
try {
key = (SelectionKey)todo.key;
if (key != null && key.isValid()) {
if (todo.att != null) {
key.attach(todo.att);
Link link = (Link)todo.att;
link.setKey(key);
}
key.interestOps(todo.ops);
}
} catch (CancelledKeyException e) {
s_logger.debug("key has been cancelled");
}
break;
case ChangeRequest.REGISTER :
try {
key = ((SocketChannel)(todo.key)).register(_selector, todo.ops, todo.att);
if (todo.att != null) {
Link link = (Link)todo.att;
link.setKey(key);
}
} catch (ClosedChannelException e) {
s_logger.warn("Couldn't register socket: " + todo.key);
try {
((SocketChannel)(todo.key)).close();
} catch (IOException ignore) {
} finally {
Link link = (Link)todo.att;
link.terminated();
}
}
break;
case ChangeRequest.CLOSE :
if (s_logger.isTraceEnabled()) {
s_logger.trace("Trying to close " + todo.key);
}
key = (SelectionKey)todo.key;
closeConnection(key);
if (key != null) {
Link link = (Link)key.attachment();
if (link != null) {
link.terminated();
}
}
break;
default :
s_logger.warn("Shouldn't be here");
throw new RuntimeException("Shouldn't be here");
}
}
s_logger.trace("Todos Done processing");
}
protected void connect(SelectionKey key) throws IOException {
SocketChannel socketChannel = (SocketChannel)key.channel();
try {
socketChannel.finishConnect();
key.interestOps(SelectionKey.OP_READ);
Socket socket = socketChannel.socket();
if (!socket.getKeepAlive()) {
socket.setKeepAlive(true);
}
if (s_logger.isDebugEnabled()) {
s_logger.debug("Connected to " + socket);
}
Link link = new Link((InetSocketAddress)socket.getRemoteSocketAddress(), this);
link.setKey(key);
key.attach(link);
Task task = _factory.create(Task.Type.CONNECT, link, null);
_executor.execute(task);
} catch (IOException e) {
logTrace(e, key, 2);
terminate(key);
}
}
protected void scheduleTask(Task task) {
_executor.execute(task);
}
protected void write(SelectionKey key) throws IOException {
Link link = (Link)key.attachment();
try {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing to " + link.getSocketAddress().toString());
}
boolean close = link.write((SocketChannel)key.channel());
if (close) {
closeConnection(key);
link.terminated();
} else {
key.interestOps(SelectionKey.OP_READ);
}
} catch (Exception e) {
logDebug(e, key, 3);
terminate(key);
}
}
protected void closeConnection(SelectionKey key) {
if (key != null) {
SocketChannel channel = (SocketChannel)key.channel();
key.cancel();
try {
if (channel != null) {
if (s_logger.isDebugEnabled()) {
s_logger.debug("Closing socket " + channel.socket());
}
channel.close();
}
} catch (IOException ignore) {
}
}
}
public void register(int ops, SocketChannel key, Object att) {
ChangeRequest todo = new ChangeRequest(key, ChangeRequest.REGISTER, ops, att);
synchronized (this) {
_todos.add(todo);
}
_selector.wakeup();
}
public void change(int ops, SelectionKey key, Object att) {
ChangeRequest todo = new ChangeRequest(key, ChangeRequest.CHANGEOPS, ops, att);
synchronized (this) {
_todos.add(todo);
}
_selector.wakeup();
}
public void close(SelectionKey key) {
ChangeRequest todo = new ChangeRequest(key, ChangeRequest.CLOSE, 0, null);
synchronized (this) {
_todos.add(todo);
}
_selector.wakeup();
}
/* Release the resource used by the instance */
public void cleanUp() throws IOException {
if (_selector != null) {
_selector.close();
}
}
public class ChangeRequest {
public static final int REGISTER = 1;
public static final int CHANGEOPS = 2;
public static final int CLOSE = 3;
public Object key;
public int type;
public int ops;
public Object att;
public ChangeRequest(Object key, int type, int ops, Object att) {
this.key = key;
this.type = type;
this.ops = ops;
this.att = att;
}
}
}