/* * Copyright 2010 Brian S O'Neill * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.cojen.dirmi.io; import java.io.IOException; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketTimeoutException; import java.util.concurrent.TimeUnit; import java.security.AccessControlContext; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import javax.net.SocketFactory; import org.cojen.dirmi.ClosedException; import org.cojen.dirmi.RejectedException; import org.cojen.dirmi.RemoteTimeoutException; import org.cojen.dirmi.util.Timer; /** * Implements a connector using TCP/IP. * * @author Brian S O'Neill */ abstract class SocketChannelConnector implements ChannelConnector { private final IOExecutor mExecutor; private final SocketAddress mRemoteAddress; private final SocketAddress mLocalAddress; private final SocketFactory mFactory; private final AccessControlContext mContext; private final CloseableGroup<Channel> mConnected; /** * @param remoteAddress address to connect to */ public SocketChannelConnector(IOExecutor executor, SocketAddress remoteAddress) { this(executor, remoteAddress, null); } /** * @param remoteAddress address to connect to * @param localAddress local address to bind to; pass null for any */ public SocketChannelConnector(IOExecutor executor, SocketAddress remoteAddress, SocketAddress localAddress) { this(executor, remoteAddress, localAddress, SocketFactory.getDefault()); } /** * @param remoteAddress address to connect to * @param localAddress local address to bind to; pass null for any */ public SocketChannelConnector(IOExecutor executor, SocketAddress remoteAddress, SocketAddress localAddress, SocketFactory factory) { if (executor == null) { throw new IllegalArgumentException("Must provide an executor"); } if (remoteAddress == null) { throw new IllegalArgumentException("Must provide a remote address"); } if (factory == null) { throw new IllegalArgumentException("Must provide a SocketFactory"); } mExecutor = executor; mRemoteAddress = remoteAddress; mLocalAddress = localAddress; mFactory = factory; mContext = AccessController.getContext(); mConnected = new CloseableGroup<Channel>(); } @Override public Channel connect() throws IOException { return connect(-1, null); } @Override public Channel connect(final long timeout, final TimeUnit unit) throws IOException { mConnected.checkClosed(); if (timeout == 0) { throw new RemoteTimeoutException(timeout, unit); } Socket socket; try { socket = AccessController.doPrivileged(new PrivilegedExceptionAction<Socket>() { public Socket run() throws IOException { return connectSocket(timeout, unit); } }, mContext); } catch (PrivilegedActionException e) { mConnected.checkClosed(); throw (IOException) e.getCause(); } Channel channel = createChannel(SocketChannel.toSimpleSocket(socket)); channel.register(mConnected); return channel; } @Override public Channel connect(Timer timer) throws IOException { mConnected.checkClosed(); return connect(RemoteTimeoutException.checkRemaining(timer), timer.unit()); } @Override public void connect(final Listener listener) { try { mExecutor.execute(new Runnable() { public void run() { if (mConnected.isClosed()) { listener.closed(new ClosedException()); return; } Channel channel; try { channel = connect(); } catch (IOException e) { if (mConnected.isClosed()) { listener.closed(e); } else { listener.failed(e); } return; } listener.connected(channel); } }); } catch (RejectedException e) { listener.rejected(e); } } Socket connectSocket(long timeout, TimeUnit unit) throws IOException { Socket socket = mFactory.createSocket(); try { if (mLocalAddress != null) { socket.bind(mLocalAddress); } if (timeout < 0) { socket.connect(mRemoteAddress); } else { long millis = unit.toMillis(timeout); if (millis <= 0) { throw new RemoteTimeoutException(timeout, unit); } else if (millis > Integer.MAX_VALUE) { socket.connect(mRemoteAddress); } else { try { socket.connect(mRemoteAddress, (int) millis); } catch (SocketTimeoutException e) { throw new RemoteTimeoutException(timeout, unit); } } } socket.setTcpNoDelay(true); return socket; } catch (SecurityException e) { disconnect(socket); throw e; } catch (IOException e) { disconnect(socket); throw e; } } @Override public void close() { mConnected.close(); } @Override public String toString() { return "ChannelConnector {localAddress=" + mLocalAddress + ", remoteAddress=" + mRemoteAddress + '}'; } @Override public final SocketAddress getRemoteAddress() { return mRemoteAddress; } @Override public final SocketAddress getLocalAddress() { return mLocalAddress; } protected IOExecutor executor() { return mExecutor; } abstract Channel createChannel(SimpleSocket socket) throws IOException; private static void disconnect(Socket socket) { try { socket.close(); } catch (IOException e2) { // Ignore. } } }