/**
* Copyright (C) 2012 FuseSource, Inc.
* http://fusesource.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.fusesource.hawtdispatch.transport;
import org.fusesource.hawtdispatch.*;
import java.io.IOException;
import java.net.*;
import java.nio.channels.DatagramChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.WritableByteChannel;
import java.util.LinkedList;
import java.util.concurrent.Executor;
/**
* <p>
* </p>
*
* @author <a href="http://hiramchirino.com">Hiram Chirino</a>
*/
public class UdpTransport extends ServiceBase implements Transport {
public static final SocketAddress ANY_ADDRESS = new SocketAddress() {
@Override
public String toString() {
return "*:*";
}
};
abstract static class SocketState {
void onStop(Task onCompleted) {
}
void onCanceled() {
}
boolean is(Class<? extends SocketState> clazz) {
return getClass()==clazz;
}
}
static class DISCONNECTED extends SocketState{}
class CONNECTING extends SocketState{
void onStop(Task onCompleted) {
trace("CONNECTING.onStop");
CANCELING state = new CANCELING();
socketState = state;
state.onStop(onCompleted);
}
void onCanceled() {
trace("CONNECTING.onCanceled");
CANCELING state = new CANCELING();
socketState = state;
state.onCanceled();
}
}
class CONNECTED extends SocketState {
public CONNECTED() {
localAddress = channel.socket().getLocalSocketAddress();
remoteAddress = channel.socket().getRemoteSocketAddress();
if(remoteAddress == null ) {
remoteAddress = ANY_ADDRESS;
}
}
void onStop(Task onCompleted) {
trace("CONNECTED.onStop");
CANCELING state = new CANCELING();
socketState = state;
state.add(createDisconnectTask());
state.onStop(onCompleted);
}
void onCanceled() {
trace("CONNECTED.onCanceled");
CANCELING state = new CANCELING();
socketState = state;
state.add(createDisconnectTask());
state.onCanceled();
}
Task createDisconnectTask() {
return new Task(){
public void run() {
listener.onTransportDisconnected();
}
};
}
}
class CANCELING extends SocketState {
private LinkedList<Task> runnables = new LinkedList<Task>();
private int remaining;
private boolean dispose;
public CANCELING() {
if( readSource!=null ) {
remaining++;
readSource.cancel();
}
if( writeSource!=null ) {
remaining++;
writeSource.cancel();
}
}
void onStop(Task onCompleted) {
trace("CANCELING.onCompleted");
add(onCompleted);
dispose = true;
}
void add(Task onCompleted) {
if( onCompleted!=null ) {
runnables.add(onCompleted);
}
}
void onCanceled() {
trace("CANCELING.onCanceled");
remaining--;
if( remaining!=0 ) {
return;
}
try {
channel.close();
} catch (IOException ignore) {
}
socketState = new CANCELED(dispose);
for (Task runnable : runnables) {
runnable.run();
}
if (dispose) {
dispose();
}
}
}
class CANCELED extends SocketState {
private boolean disposed;
public CANCELED(boolean disposed) {
this.disposed=disposed;
}
void onStop(Task onCompleted) {
trace("CANCELED.onStop");
if( !disposed ) {
disposed = true;
dispose();
}
onCompleted.run();
}
}
protected URI remoteLocation;
protected URI localLocation;
protected TransportListener listener;
protected ProtocolCodec codec;
protected DatagramChannel channel;
protected SocketState socketState = new DISCONNECTED();
protected DispatchQueue dispatchQueue;
private DispatchSource readSource;
private DispatchSource writeSource;
protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
protected CustomDispatchSource<Integer, Integer> yieldSource;
protected boolean useLocalHost = true;
int receiveBufferSize = 1024*64;
int sendBufferSize = 1024*64;
public static final int IPTOS_LOWCOST = 0x02;
public static final int IPTOS_RELIABILITY = 0x04;
public static final int IPTOS_THROUGHPUT = 0x08;
public static final int IPTOS_LOWDELAY = 0x10;
int trafficClass = IPTOS_THROUGHPUT;
SocketAddress localAddress;
SocketAddress remoteAddress = ANY_ADDRESS;
Executor blockingExecutor;
private final Task CANCEL_HANDLER = new Task() {
public void run() {
socketState.onCanceled();
}
};
static final class OneWay {
final Object command;
final Retained retained;
public OneWay(Object command, Retained retained) {
this.command = command;
this.retained = retained;
}
}
public void connected(DatagramChannel channel) throws IOException, Exception {
this.channel = channel;
initializeChannel();
this.socketState = new CONNECTED();
}
protected void initializeChannel() throws Exception {
this.channel.configureBlocking(false);
DatagramSocket socket = channel.socket();
try {
socket.setReuseAddress(true);
} catch (SocketException e) {
}
try {
socket.setTrafficClass(trafficClass);
} catch (SocketException e) {
}
try {
socket.setReceiveBufferSize(receiveBufferSize);
} catch (SocketException e) {
}
try {
socket.setSendBufferSize(sendBufferSize);
} catch (SocketException e) {
}
if( channel!=null && codec!=null ) {
initializeCodec();
}
}
protected void initializeCodec() throws Exception {
codec.setTransport(this);
}
public void connecting(final URI remoteLocation, final URI localLocation) throws Exception {
this.channel = DatagramChannel.open();
initializeChannel();
this.remoteLocation = remoteLocation;
this.localLocation = localLocation;
socketState = new CONNECTING();
}
public DispatchQueue getDispatchQueue() {
return dispatchQueue;
}
public void setDispatchQueue(DispatchQueue queue) {
this.dispatchQueue = queue;
if(readSource!=null) readSource.setTargetQueue(queue);
if(writeSource!=null) writeSource.setTargetQueue(queue);
if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
if(yieldSource!=null) yieldSource.setTargetQueue(queue);
}
public void _start(Task onCompleted) {
try {
if ( socketState.is(CONNECTING.class) ) {
// Resolving host names might block.. so do it on the blocking executor.
this.blockingExecutor.execute(new Runnable() {
public void run() {
// No need to complete if we have been canceled.
if( ! socketState.is(CONNECTING.class) ) {
return;
}
try {
final InetSocketAddress localAddress = (localLocation != null) ?
new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort())
: null;
String host = resolveHostName(remoteLocation.getHost());
final InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
// Done resolving.. switch back to the dispatch queue.
dispatchQueue.execute(new Task() {
@Override
public void run() {
try {
if(localAddress!=null) {
channel.socket().bind(localAddress);
}
channel.connect(remoteAddress);
} catch (IOException e) {
try {
channel.close();
} catch (IOException ignore) {
}
socketState = new CANCELED(true);
listener.onTransportFailure(e);
}
}
});
} catch (IOException e) {
try {
channel.close();
} catch (IOException ignore) {
}
socketState = new CANCELED(true);
listener.onTransportFailure(e);
}
}
});
} else if (socketState.is(CONNECTED.class) ) {
dispatchQueue.execute(new Task() {
public void run() {
try {
trace("was connected.");
onConnected();
} catch (IOException e) {
onTransportFailure(e);
}
}
});
} else {
System.err.println("cannot be started. socket state is: "+socketState);
}
} finally {
if( onCompleted!=null ) {
onCompleted.run();
}
}
}
public void _stop(final Task onCompleted) {
trace("stopping.. at state: "+socketState);
socketState.onStop(onCompleted);
}
protected String resolveHostName(String host) throws UnknownHostException {
String localName = InetAddress.getLocalHost().getHostName();
if (localName != null && isUseLocalHost()) {
if (localName.equals(host)) {
return "localhost";
}
}
return host;
}
protected void onConnected() throws IOException {
yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
yieldSource.setEventHandler(new Task() {
public void run() {
drainInbound();
}
});
yieldSource.resume();
drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
drainOutboundSource.setEventHandler(new Task() {
public void run() {
flush();
}
});
drainOutboundSource.resume();
readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
readSource.setCancelHandler(CANCEL_HANDLER);
writeSource.setCancelHandler(CANCEL_HANDLER);
readSource.setEventHandler(new Task() {
public void run() {
drainInbound();
}
});
writeSource.setEventHandler(new Task() {
public void run() {
flush();
}
});
listener.onTransportConnected();
}
Task onDispose;
private void dispose() {
if( readSource!=null ) {
readSource.cancel();
readSource=null;
}
if( writeSource!=null ) {
writeSource.cancel();
writeSource=null;
}
this.codec = null;
if(onDispose!=null) {
onDispose.run();
onDispose = null;
}
}
public void onTransportFailure(IOException error) {
listener.onTransportFailure(error);
socketState.onCanceled();
}
public boolean full() {
return codec==null || codec.full();
}
boolean rejectingOffers;
public boolean offer(Object command) {
dispatchQueue.assertExecuting();
try {
if (!socketState.is(CONNECTED.class)) {
throw new IOException("Not connected.");
}
if (getServiceState() != STARTED) {
throw new IOException("Not running.");
}
ProtocolCodec.BufferState rc = codec.write(command);
rejectingOffers = codec.full();
switch (rc ) {
case FULL:
return false;
default:
drainOutboundSource.merge(1);
return true;
}
} catch (IOException e) {
onTransportFailure(e);
return false;
}
}
boolean writeResumedForCodecFlush = false;
/**
*
*/
public void flush() {
dispatchQueue.assertExecuting();
if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
return;
}
try {
if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
if( writeResumedForCodecFlush) {
writeResumedForCodecFlush = false;
suspendWrite();
}
rejectingOffers = false;
listener.onRefill();
} else {
if(!writeResumedForCodecFlush) {
writeResumedForCodecFlush = true;
resumeWrite();
}
}
} catch (IOException e) {
onTransportFailure(e);
}
}
protected boolean transportFlush() throws IOException {
return true;
}
public void drainInbound() {
if (!getServiceState().isStarted() || readSource.isSuspended()) {
return;
}
try {
long initial = codec.getReadCounter();
// Only process upto 2 x the read buffer worth of data at a time so we can give
// other connections a chance to process their requests.
while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
Object command = codec.read();
if ( command!=null ) {
try {
listener.onTransportCommand(command);
} catch (Throwable e) {
e.printStackTrace();
onTransportFailure(new IOException("Transport listener failure."));
}
// the transport may be suspended after processing a command.
if (getServiceState() == STOPPED || readSource.isSuspended()) {
return;
}
} else {
return;
}
}
yieldSource.merge(1);
} catch (IOException e) {
onTransportFailure(e);
}
}
public SocketAddress getLocalAddress() {
return localAddress;
}
public SocketAddress getRemoteAddress() {
return remoteAddress;
}
private boolean assertConnected() {
try {
if ( !isConnected() ) {
throw new IOException("Not connected.");
}
return true;
} catch (IOException e) {
onTransportFailure(e);
}
return false;
}
public void suspendRead() {
if( isConnected() && readSource!=null ) {
readSource.suspend();
}
}
public void resumeRead() {
if( isConnected() && readSource!=null ) {
_resumeRead();
}
}
private void _resumeRead() {
readSource.resume();
dispatchQueue.execute(new Task(){
public void run() {
drainInbound();
}
});
}
protected void suspendWrite() {
if( isConnected() && writeSource!=null ) {
writeSource.suspend();
}
}
protected void resumeWrite() {
if( isConnected() && writeSource!=null ) {
writeSource.resume();
}
}
public TransportListener getTransportListener() {
return listener;
}
public void setTransportListener(TransportListener transportListener) {
this.listener = transportListener;
}
public ProtocolCodec getProtocolCodec() {
return codec;
}
public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
this.codec = protocolCodec;
if( channel!=null && codec!=null ) {
initializeCodec();
}
}
public boolean isConnected() {
return socketState.is(CONNECTED.class);
}
public boolean isClosed() {
return getServiceState() == STOPPED;
}
public boolean isUseLocalHost() {
return useLocalHost;
}
/**
* Sets whether 'localhost' or the actual local host name should be used to
* make local connections. On some operating systems such as Macs its not
* possible to connect as the local host name so localhost is better.
*/
public void setUseLocalHost(boolean useLocalHost) {
this.useLocalHost = useLocalHost;
}
private void trace(String message) {
// TODO:
}
public DatagramChannel getDatagramChannel() {
return channel;
}
public ReadableByteChannel getReadChannel() {
return channel;
}
public WritableByteChannel getWriteChannel() {
return channel;
}
public int getTrafficClass() {
return trafficClass;
}
public void setTrafficClass(int trafficClass) {
this.trafficClass = trafficClass;
}
public int getReceiveBufferSize() {
return receiveBufferSize;
}
public void setReceiveBufferSize(int receiveBufferSize) {
this.receiveBufferSize = receiveBufferSize;
}
public int getSendBufferSize() {
return sendBufferSize;
}
public void setSendBufferSize(int sendBufferSize) {
this.sendBufferSize = sendBufferSize;
}
public Executor getBlockingExecutor() {
return blockingExecutor;
}
public void setBlockingExecutor(Executor blockingExecutor) {
this.blockingExecutor = blockingExecutor;
}
}