/*
* Copyright 2010 Brian S O'Neill
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cojen.dirmi.io;
import java.io.InterruptedIOException;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Iterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import org.cojen.dirmi.ClosedException;
import org.cojen.dirmi.RejectedException;
import org.cojen.dirmi.RemoteTimeoutException;
import org.cojen.dirmi.util.ScheduledTask;
import org.cojen.dirmi.util.Timer;
/**
* Factory for TCP channel acceptors and connectors that use selectable
* sockets.
*
* @author Brian S O'Neill
*/
public class RecyclableSocketChannelSelector implements SocketChannelSelector {
private final IOExecutor mExecutor;
private final Selector mSelector;
private final ConcurrentLinkedQueue<Selectable> mQueue;
public RecyclableSocketChannelSelector(IOExecutor executor) throws IOException {
this(executor, Selector.open());
}
private RecyclableSocketChannelSelector(IOExecutor executor, Selector selector) {
if (executor == null || selector == null) {
throw new IllegalArgumentException();
}
mExecutor = executor;
mSelector = selector;
mQueue = new ConcurrentLinkedQueue<Selectable>();
}
/**
* Perform socket selection, returning normally only when selector is closed.
*/
public void selectLoop() throws IOException {
IOExecutor executor = mExecutor;
Selector selector = mSelector;
ConcurrentLinkedQueue<Selectable> queue = mQueue;
try {
while (true) {
int count = selector.select();
boolean didRegister = false;
Selectable selectable;
while ((selectable = queue.poll()) != null) {
didRegister = true;
selectable.register(selector);
}
if (count == 0) {
if (!selector.isOpen()) {
return;
}
if (!didRegister) {
// Workaround for unknown race condition in which
// closed channels are not removed from the selector.
// If they remain, select no longer blocks.
for (SelectionKey key : selector.keys()) {
if (key.isValid() && !key.channel().isOpen()) {
key.cancel();
}
}
}
} else {
Iterator<SelectionKey> it = selector.selectedKeys().iterator();
while (it.hasNext()) {
SelectionKey key = it.next();
Selectable selected = (Selectable) key.attachment();
try {
try {
executor.execute(selected);
} catch (RejectedException e) {
try {
executor.schedule(selected, 0, TimeUnit.SECONDS);
} catch (RejectedException e2) {
selected.rejected(e);
}
}
} finally {
key.cancel();
}
}
}
}
} catch (ClosedSelectorException e) {
// Ignore and return.
}
}
public void close() throws IOException {
mSelector.close();
}
public ChannelAcceptor newChannelAcceptor(SocketAddress localAddress) throws IOException {
return new NioChannelAcceptor(localAddress);
}
public ChannelConnector newChannelConnector(SocketAddress remoteAddress) {
return newChannelConnector(remoteAddress, null);
}
public ChannelConnector newChannelConnector(SocketAddress remoteAddress,
SocketAddress localAddress)
{
return new NioChannelConnector(remoteAddress, localAddress);
}
/**
* Register a listener which is asynchronously notified when channel has
* been connected. Listener is called at most once per registration.
*/
void connectNotify(SocketChannel channel, CloseableGroup<Channel> connected,
ChannelConnector.Listener listener)
{
mQueue.add(new ConnectNotify(channel, connected, listener));
mSelector.wakeup();
}
/**
* Register a listener which is asynchronously notified when channel has
* been accepted. Listener is called at most once per registration.
*/
void acceptNotify(AccessControlContext context, CloseableGroup<Channel> accepted,
ServerSocketChannel channel, ChannelAcceptor.Listener listener)
{
mQueue.add(new AcceptNotify(context, accepted, channel, listener));
mSelector.wakeup();
}
/**
* Register a listener which is asynchronously notified when channel can be
* read from. Listener is called at most once per registration.
*/
public void inputNotify(SocketChannel channel, Channel.Listener listener) {
mQueue.add(new ChannelNotify(channel, listener, SelectionKey.OP_READ));
mSelector.wakeup();
}
/**
* Register a listener which is asynchronously notified when channel can be
* written to. Listener is called at most once per registration.
*/
public void outputNotify(SocketChannel channel, Channel.Listener listener) {
mQueue.add(new ChannelNotify(channel, listener, SelectionKey.OP_WRITE));
mSelector.wakeup();
}
public IOExecutor executor() {
return mExecutor;
}
private static abstract class Selectable extends ScheduledTask<RuntimeException> {
abstract void register(Selector selector);
abstract void rejected(RejectedException cause);
}
private static class ChannelNotify extends Selectable {
private final SocketChannel mChannel;
private final Channel.Listener mListener;
private final int mOps;
ChannelNotify(SocketChannel channel, Channel.Listener listener, int ops) {
mChannel = channel;
mListener = listener;
mOps = ops;
}
void register(Selector selector) {
try {
mChannel.register(selector, mOps, this);
} catch (ClosedChannelException e) {
mListener.closed(e);
} catch (RuntimeException e) {
try {
mChannel.close();
} catch (IOException e2) {
// Ignore.
}
mListener.closed(new IOException(e));
}
}
void rejected(RejectedException cause) {
mListener.rejected(cause);
}
protected void doRun() {
mListener.ready();
}
}
private class ConnectNotify extends Selectable {
private final CloseableGroup<Channel> mConnected;
private final SocketChannel mChannel;
private final ChannelConnector.Listener mListener;
ConnectNotify(SocketChannel channel, CloseableGroup<Channel> connected,
ChannelConnector.Listener listener)
{
mConnected = connected;
mChannel = channel;
mListener = listener;
}
void register(Selector selector) {
try {
mChannel.register(selector, SelectionKey.OP_CONNECT, this);
} catch (ClosedChannelException e) {
mListener.failed(e);
} catch (RuntimeException e) {
try {
mChannel.close();
} catch (IOException e2) {
// Ignore.
}
mListener.failed(new IOException(e));
}
}
void rejected(RejectedException cause) {
mListener.rejected(cause);
}
protected void doRun() {
try {
mChannel.finishConnect();
NioSocketChannel nsc =
new NioSocketChannel(RecyclableSocketChannelSelector.this, mChannel);
NioRecyclableSocketChannel nrsc = new NioRecyclableSocketChannel(executor(), nsc);
nrsc.register(mConnected);
mListener.connected(nrsc);
} catch (IOException e) {
mListener.failed(e);
}
}
}
private class AcceptNotify extends Selectable {
private final AccessControlContext mContext;
private final CloseableGroup<Channel> mAccepted;
final ServerSocketChannel mChannel;
private final ChannelAcceptor.Listener mListener;
AcceptNotify(AccessControlContext context, CloseableGroup<Channel> accepted,
ServerSocketChannel channel, ChannelAcceptor.Listener listener)
{
mContext = context;
mAccepted = accepted;
mChannel = channel;
mListener = listener;
}
void register(Selector selector) {
try {
mChannel.register(selector, SelectionKey.OP_ACCEPT, this);
} catch (ClosedChannelException e) {
mListener.closed(e);
} catch (RuntimeException e) {
try {
mChannel.close();
} catch (IOException e2) {
// Ignore.
}
mListener.closed(new IOException(e));
}
}
void rejected(RejectedException cause) {
mListener.rejected(cause);
}
protected void doRun() {
SocketChannel channel;
try {
try {
channel = AccessController
.doPrivileged(new PrivilegedExceptionAction<SocketChannel>() {
public SocketChannel run() throws IOException {
return mChannel.accept();
}
}, mContext);
channel.configureBlocking(false);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
} catch (SecurityException e) {
mListener.failed(new IOException(e));
return;
} catch (Exception e) {
try {
mChannel.close();
} catch (IOException e2) {
// Ignore.
}
mListener.closed(e instanceof IOException ? (IOException) e : new IOException(e));
return;
}
try {
NioSocketChannel nsc =
new NioSocketChannel(RecyclableSocketChannelSelector.this, channel);
NioRecyclableSocketChannel nrsc = new NioRecyclableSocketChannel(executor(), nsc);
nrsc.register(mAccepted);
mListener.accepted(nrsc);
} catch (IOException e) {
mListener.failed(e);
}
}
}
static Timer toTimer(long timeout, TimeUnit unit) {
Timer timer;
if (timeout < 0) {
return null;
} else if (timeout == 0) {
return new Timer(0, TimeUnit.NANOSECONDS);
} else {
return new Timer(timeout, unit);
}
}
private class NioChannelAcceptor implements ChannelAcceptor {
private final SocketAddress mLocalAddress;
private final ServerSocketChannel mChannel;
private final AccessControlContext mContext;
private final CloseableGroup<Channel> mAccepted;
final ConcurrentLinkedQueue<Channel> mAcceptQueue;
NioChannelAcceptor(SocketAddress localAddress) throws IOException {
ServerSocketChannel ssc = ServerSocketChannel.open();
ssc.configureBlocking(false);
ssc.socket().setReuseAddress(true);
ssc.socket().bind(localAddress, SocketChannelAcceptor.LISTEN_BACKLOG);
mLocalAddress = ssc.socket().getLocalSocketAddress();
mChannel = ssc;
mContext = AccessController.getContext();
mAccepted = new CloseableGroup<Channel>();
mAcceptQueue = new ConcurrentLinkedQueue<Channel>();
}
@Override
public Channel accept() throws IOException {
return accept(-1, null);
}
@Override
public Channel accept(long timeout, TimeUnit unit) throws IOException {
return accept(toTimer(timeout, unit));
}
@Override
public Channel accept(Timer timer) throws IOException {
mAccepted.checkClosed();
class Listener implements ChannelAcceptor.Listener {
private Channel mChannel;
private IOException mException;
private boolean mAbandoned;
@Override
public synchronized void accepted(Channel channel) {
if (mAbandoned) {
mAcceptQueue.add(channel);
} else {
mChannel = channel;
notify();
}
}
@Override
public synchronized void rejected(RejectedException cause) {
mException = cause;
notify();
}
@Override
public synchronized void failed(IOException cause) {
mException = cause;
notify();
}
@Override
public synchronized void closed(IOException cause) {
mException = cause;
notify();
}
synchronized Channel waitForChannel(Timer timer) throws IOException {
while (true) {
if (mChannel != null) {
return mChannel;
}
if (mException != null) {
if (mException.getCause() instanceof SecurityException) {
throw (SecurityException) mException.getCause();
}
throw mException;
}
try {
try {
if (timer == null) {
wait();
} else {
long remaining = RemoteTimeoutException.checkRemaining(timer);
wait(timer.unit().toMillis(remaining));
}
} catch (InterruptedException e) {
throw new InterruptedIOException();
}
} catch (IOException e) {
mAbandoned = true;
throw e;
}
}
}
};
Listener listener = new Listener();
accept(listener);
return listener.waitForChannel(timer);
}
@Override
public void accept(final Listener listener) {
Channel channel = mAcceptQueue.poll();
if (channel != null) {
// FIXME: race conditions cause channel to be lost
// FIXME: separate thread for listener call
listener.accepted(channel);
}
acceptNotify(mContext, mAccepted, mChannel, new ChannelAcceptor.Listener() {
@Override
public void accepted(Channel channel) {
if (acceptedChannel(channel)) {
listener.accepted(channel);
} else {
listener.closed(new ClosedException());
}
}
@Override
public void rejected(RejectedException cause) {
listener.rejected(cause);
}
@Override
public void failed(IOException cause) {
listener.failed(cause);
}
@Override
public void closed(IOException cause) {
listener.closed(cause);
}
});
}
@Override
public void close() {
mAccepted.close();
try {
mChannel.close();
} catch (IOException e) {
// Ignore.
}
}
@Override
public Object getLocalAddress() {
return mLocalAddress;
}
@Override
public String toString() {
return "ChannelAcceptor {localAddress=" + mLocalAddress + '}';
}
boolean acceptedChannel(Channel channel) {
if (mAccepted.isClosed()) {
channel.disconnect();
return false;
}
return true;
}
}
private class NioChannelConnector implements ChannelConnector {
final SocketAddress mRemoteAddress;
final SocketAddress mLocalAddress;
private final AccessControlContext mContext;
private final CloseableGroup<Channel> mConnected;
NioChannelConnector(SocketAddress remoteAddress, SocketAddress localAddress) {
if (remoteAddress == null) {
throw new IllegalArgumentException("Must provide a remote address");
}
mRemoteAddress = remoteAddress;
mLocalAddress = localAddress;
mContext = AccessController.getContext();
mConnected = new CloseableGroup<Channel>();
}
@Override
public Object getRemoteAddress() {
return mRemoteAddress;
}
@Override
public Object getLocalAddress() {
return mLocalAddress;
}
@Override
public Channel connect() throws IOException {
return connect(-1, null);
}
@Override
public Channel connect(long timeout, TimeUnit unit) throws IOException {
mConnected.checkClosed();
ChannelConnectWaiter waiter = new ChannelConnectWaiter();
connect(waiter);
if (timeout < 0) {
return waiter.waitForChannel();
} else {
return waiter.waitForChannel(timeout, unit);
}
}
@Override
public Channel connect(Timer timer) throws IOException {
mConnected.checkClosed();
return connect(timer.duration(), timer.unit());
}
@Override
public void connect(Listener listener) {
SocketChannel sc;
try {
try {
sc = AccessController
.doPrivileged(new PrivilegedExceptionAction<SocketChannel>() {
public SocketChannel run() throws IOException {
SocketChannel sc = SocketChannel.open();
sc.configureBlocking(false);
if (mLocalAddress != null) {
sc.socket().bind(mLocalAddress);
}
sc.connect(mRemoteAddress);
return sc;
}
}, mContext);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
} catch (IOException e) {
listener.failed(e);
return;
}
connectNotify(sc, mConnected, listener);
}
@Override
public void close() {
mConnected.close();
}
@Override
public String toString() {
return "ChannelConnector {localAddress=" + mLocalAddress +
", remoteAddress=" + mRemoteAddress + '}';
}
}
}