package com.workshare.msnos.core.protocols.ip.udp;
import static java.lang.Math.min;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.workshare.msnos.core.Cloud;
import com.workshare.msnos.core.Gateway;
import com.workshare.msnos.core.Identifiable;
import com.workshare.msnos.core.Message;
import com.workshare.msnos.core.Message.Payload;
import com.workshare.msnos.core.Message.Status;
import com.workshare.msnos.core.Receipt;
import com.workshare.msnos.core.protocols.ip.BaseEndpoint;
import com.workshare.msnos.core.protocols.ip.Endpoint.Type;
import com.workshare.msnos.core.protocols.ip.Endpoints;
import com.workshare.msnos.core.protocols.ip.MulticastSocketFactory;
import com.workshare.msnos.core.protocols.ip.Network;
import com.workshare.msnos.core.receipts.SingleReceipt;
import com.workshare.msnos.core.serializers.WireSerializer;
import com.workshare.msnos.soup.threading.Multicaster;
import com.workshare.msnos.soup.time.SystemTime;
public class UDPGateway implements Gateway {
private static Logger logger = LoggerFactory.getLogger(UDPGateway.class);
public static final String SYSP_PORT_NUM = "com.ws.nsnos.udp.port.number";
public static final String SYSP_PORT_WIDTH = "com.ws.nsnos.udp.port.width";
public static final String SYSP_UDP_GROUP = "com.ws.nsnos.udp.group";
public static final String SYSP_UDP_PACKET_SIZE = "com.ws.nsnos.udp.packet.size";
public static final String SYSP_RETRY_TIMES = "com.ws.nsnos.udp.group.retry.times";
public static final String SYSP_NET_IPV6ALSO = "com.ws.msnos.network.ipv6also";
public static final String SYSP_NET_VIRTUAL = "com.ws.msnos.network.includevirtual";
private MulticastSocket socket;
private InetAddress group;
private int ports[];
private final Multicaster<Listener, Message> caster;
private final WireSerializer sz;
private final int packetSize;
private final int retries;
private final Endpoints endpoints;
private final UDPServer server;
public UDPGateway(MulticastSocketFactory sockets, UDPServer aServer, Multicaster<Listener, Message> caster) throws IOException {
this.caster = caster;
this.sz = aServer.serializer();
this.retries = Integer.getInteger(SYSP_RETRY_TIMES, 3);
this.packetSize = Integer.getInteger(SYSP_UDP_PACKET_SIZE, 512);
this.endpoints = createEndpoints();
this.server = aServer;
loadPorts();
openSocket(sockets);
startServer(aServer);
}
@Override
public String name() {
return "UDP";
}
@Override
public void close() throws IOException {
server.stop();
socket.close();
}
private void startServer(UDPServer server) {
server.start(socket, packetSize);
server.addListener(new Listener() {
@Override
public void onMessage(Message message) {
caster.dispatch(message);
}
});
}
private void openSocket(MulticastSocketFactory sockets) throws IOException {
for (int port : ports) {
try {
MulticastSocket msock = sockets.create();
msock.setReuseAddress(true);
msock.bind(new InetSocketAddress(port));
socket = msock;
logger.info("Socket opened on port: {} ", port);
break;
} catch (IOException ex) {
logger.warn("Unable to open multicast socket on port: {} ", port);
}
}
if (socket == null)
throw new IOException("Unable to open socket, I tried to bind on ports " + Arrays.asList(ports));
String groupAddressName = loadUDPGroup();
group = InetAddress.getByName(groupAddressName);
socket.joinGroup(group);
logger.info("Joined group " + group);
}
@Override
public void addListener(Cloud cloud, Listener listener) {
caster.addListener(listener);
}
@Override
public Endpoints endpoints() {
return endpoints;
}
@Override
public Receipt send(Cloud cloud, Message message, Identifiable to) throws IOException {
logger.debug("send message {} ", message);
List<Payload> payloads;
int fullMsgLength = sz.toBytes(message).length;
int lengthWithoutPayload = fullMsgLength - sz.toBytes(message.getData()).length;
if (fullMsgLength > packetSize) {
payloads = getSplitPayloads(new ArrayList<Payload>(), message.getData(), lengthWithoutPayload);
} else {
payloads = Arrays.asList(message.getData());
}
for (Payload load : payloads) {
Message msg = message.data(load);
byte[] payload = sz.toBytes(msg);
for (int port : ports) {
DatagramPacket packet = new DatagramPacket(
payload,
payload.length,
group,
port);
doSend(packet);
}
}
return new SingleReceipt(this, Status.PENDING, message);
}
private void doSend(DatagramPacket packet) throws IOException {
int count = retries;
long wait = 0;
while(count-- > 1) {
try {
socket.send(packet);
return;
}
catch (IOException ex) {
logger.debug("Temporary unable to send the packet trough UDP - retrying...");
}
wait = min(5, wait + (retries-count));
logger.debug("Enforcing transport pacing, wait for {} milliseconds before next try");
sleep(min(5, wait));
}
socket.send(packet);
}
private void sleep(long waitInMillis) {
try { SystemTime.sleep(waitInMillis);}
catch (InterruptedException ex) {Thread.interrupted();}
}
private List<Payload> getSplitPayloads(List<Payload> payloads, Payload payload, int msgLength) throws IOException {
Payload[] loads = payload.split();
if (loads == null)
throw new IOException("Unable to send message: the payload is too big and unsplittable");
for (Payload load : loads) {
if (sz.toBytes(load).length + msgLength > packetSize) {
getSplitPayloads(payloads, load, msgLength);
} else {
payloads.add(load);
}
}
return payloads;
}
private void loadPorts() {
int port = loadBasePort();
int width = loadPortWidth();
ports = new int[width];
for (int i = 0; i < width; i++) {
ports[i] = port + i;
}
logger.debug("UDP mounted on ports: {} ", Arrays.toString(ports));
}
private Integer loadBasePort() {
return Integer.getInteger(SYSP_PORT_NUM, 3728);
}
private Integer loadPortWidth() {
return Integer.getInteger(SYSP_PORT_WIDTH, 3);
}
private String loadUDPGroup() {
return System.getProperty(SYSP_UDP_GROUP, "230.31.32.33");
}
public WireSerializer serializer() {
return sz;
}
private Endpoints createEndpoints() {
boolean ipv6Also= Boolean.getBoolean(SYSP_NET_IPV6ALSO);
boolean includeVirtual = Boolean.getBoolean(SYSP_NET_VIRTUAL);
logger.debug("Collecting endpoints: ipv6 {}, virtual {}", ipv6Also, includeVirtual);
Set<Network> nets = Network.listAll(!ipv6Also, includeVirtual);
Set<BaseEndpoint> ends = new HashSet<BaseEndpoint>();
for (Network net : nets) {
ends.add(new BaseEndpoint(Type.UDP, net));
}
logger.debug("Loaded endpoints: {}", ends);
return BaseEndpoint.create(ends);
}
}