package com.faforever.client.connectivity;
import com.faforever.client.net.ConnectionState;
import com.faforever.client.relay.CreatePermissionMessage;
import com.faforever.client.remote.FafService;
import com.google.common.annotations.VisibleForTesting;
import javafx.beans.property.ObjectProperty;
import javafx.beans.property.ReadOnlyObjectProperty;
import javafx.beans.property.SimpleObjectProperty;
import org.apache.commons.compress.utils.IOUtils;
import org.ice4j.ChannelDataMessageEvent;
import org.ice4j.ResponseCollector;
import org.ice4j.StunException;
import org.ice4j.StunMessageEvent;
import org.ice4j.StunResponseEvent;
import org.ice4j.StunTimeoutEvent;
import org.ice4j.Transport;
import org.ice4j.TransportAddress;
import org.ice4j.attribute.Attribute;
import org.ice4j.attribute.DataAttribute;
import org.ice4j.attribute.ErrorCodeAttribute;
import org.ice4j.attribute.XorMappedAddressAttribute;
import org.ice4j.attribute.XorPeerAddressAttribute;
import org.ice4j.attribute.XorRelayedAddressAttribute;
import org.ice4j.message.ChannelData;
import org.ice4j.message.Message;
import org.ice4j.message.MessageFactory;
import org.ice4j.message.Request;
import org.ice4j.message.Response;
import org.ice4j.socket.IceUdpSocketWrapper;
import org.ice4j.socket.MultiplexedDatagramSocket;
import org.ice4j.socket.MultiplexingDatagramSocket;
import org.ice4j.socket.TurnDatagramPacketFilter;
import org.ice4j.stack.StunStack;
import org.ice4j.stack.TransactionID;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import javax.annotation.PreDestroy;
import javax.annotation.Resource;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import static com.github.nocatch.NoCatch.noCatch;
import static org.ice4j.attribute.Attribute.ERROR_CODE;
import static org.ice4j.attribute.Attribute.XOR_MAPPED_ADDRESS;
import static org.ice4j.attribute.Attribute.XOR_PEER_ADDRESS;
import static org.ice4j.attribute.Attribute.XOR_RELAYED_ADDRESS;
import static org.ice4j.attribute.RequestedTransportAttribute.UDP;
public class TurnServerAccessorImpl implements TurnServerAccessor {
private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final AtomicInteger CHANNEL_NUMBER = new AtomicInteger(0x4000);
/**
* The length in bytes of the Channel Number field of a TURN ChannelData message.
*/
private static final int CHANNELDATA_CHANNELNUMBER_LENGTH = 2;
/**
* The length in bytes of the Length field of a TURN ChannelData message.
*/
private static final int CHANNELDATA_LENGTH_LENGTH = 2;
private final Map<TransportAddress, Character> peerAddressToChannel;
private final Map<Character, TransportAddress> channelToPeerAddress;
@Resource
ScheduledExecutorService scheduledExecutorService;
@Resource
FafService fafService;
@Resource
ConnectivityService connectivityService;
@Resource
ApplicationContext applicationContext;
@Value("${turn.host}")
String turnHost;
@Value("${turn.port}")
int turnPort;
@Value("${turn.refreshInterval}")
int refreshInterval;
private TransportAddress relayedAddress;
private TransportAddress mappedAddress;
private ScheduledFuture<?> refreshTask;
private TransportAddress serverAddress;
private TransportAddress localAddress;
private MultiplexingDatagramSocket localSocket;
private MultiplexedDatagramSocket channelDataSocket;
private ObjectProperty<ConnectionState> connectionState;
private Collection<Consumer<DatagramPacket>> onPacketListeners;
private StunStack stunStack;
public TurnServerAccessorImpl() {
peerAddressToChannel = new HashMap<>();
channelToPeerAddress = new HashMap<>();
connectionState = new SimpleObjectProperty<>(ConnectionState.DISCONNECTED);
onPacketListeners = new LinkedHashSet<>();
}
@VisibleForTesting
public InetSocketAddress getLocalSocketAddress() {
return (InetSocketAddress) localSocket.getLocalSocketAddress();
}
/**
* Permits a peer to send data through the TURN.
*
* @param address the peer's publicly visible address
*/
private void permit(InetSocketAddress address) {
logger.info("Permitting sends from {}", address);
Request createPermissionRequest = MessageFactory.createCreatePermissionRequest(
new TransportAddress(address, Transport.UDP),
TransactionID.createNewTransactionID().getBytes()
);
sendBlockingStunRequest(createPermissionRequest);
sendStunRequest(createPermissionRequest, new ResponseCollector() {
@Override
public void processResponse(StunResponseEvent event) {
logger.debug("Permitted sends from {}", address);
}
@Override
public void processTimeout(StunTimeoutEvent event) {
logger.warn("Permission request for '{}' timed out.", address);
}
});
}
@Override
@PreDestroy
public void disconnect() {
connectionState.set(ConnectionState.DISCONNECTED);
releaseAllocation();
peerAddressToChannel.clear();
channelToPeerAddress.clear();
if (localAddress != null && stunStack != null) {
stunStack.removeSocket(localAddress);
stunStack.shutDown();
}
if (refreshTask != null) {
refreshTask.cancel(true);
}
IOUtils.closeQuietly(localSocket);
IOUtils.closeQuietly(channelDataSocket);
}
@Override
public InetSocketAddress getRelayAddress() {
return relayedAddress;
}
@Override
public void send(DatagramPacket packet) {
SocketAddress socketAddress = packet.getSocketAddress();
//noinspection SuspiciousMethodCalls
if (!peerAddressToChannel.containsKey(socketAddress)) {
logger.warn("Peer {} is not bound to a channel", socketAddress);
return;
}
@SuppressWarnings("SuspiciousMethodCalls")
Character channelNumber = peerAddressToChannel.get(socketAddress);
byte[] payload = new byte[packet.getLength()];
System.arraycopy(packet.getData(), packet.getOffset(), payload, packet.getOffset(), payload.length);
ChannelData channelData = new ChannelData();
channelData.setData(payload);
channelData.setChannelNumber(channelNumber);
if (logger.isTraceEnabled()) {
logger.trace("Writing {} bytes to channel {}: {}", packet.getLength(), (int) channelNumber,
new String(payload, 0, payload.length, StandardCharsets.US_ASCII));
}
try {
stunStack.sendChannelData(channelData, serverAddress, localAddress);
} catch (StunException e) {
throw new RuntimeException(e);
}
}
@Override
public ConnectionState getConnectionState() {
return connectionState.get();
}
@Override
public ReadOnlyObjectProperty<ConnectionState> connectionStateProperty() {
return connectionState;
}
@Override
public void connect() {
if (connectionState.get() == ConnectionState.CONNECTED) {
return;
}
stunStack = applicationContext.getBean(StunStack.class);
fafService.addOnMessageListener(CreatePermissionMessage.class, message -> permit(message.getAddress()));
serverAddress = new TransportAddress(turnHost, turnPort, Transport.UDP);
connectionState.set(ConnectionState.CONNECTING);
try {
localSocket = new MultiplexingDatagramSocket(0);
channelDataSocket = localSocket.getSocket(
new TurnDatagramPacketFilter(serverAddress) {
@Override
public boolean accept(DatagramPacket packet) {
return isChannelData(packet);
}
@Override
protected boolean acceptMethod(char method) {
return false;
}
});
localAddress = new TransportAddress((InetSocketAddress) localSocket.getLocalSocketAddress(), Transport.UDP);
stunStack.addSocket(new IceUdpSocketWrapper(localSocket), serverAddress);
stunStack.addIndicationListener(localAddress, this::onIndication);
releaseAllocation();
allocateAddress(serverAddress);
permit(connectivityService.getExternalSocketAddress());
connectionState.set(ConnectionState.CONNECTED);
scheduledExecutorService.execute(this::runInReceiveChannelDataThread);
} catch (StunException | IOException e) {
throw new RuntimeException(e);
}
}
@Override
@SuppressWarnings("SuspiciousMethodCalls")
public boolean isBound(InetSocketAddress socketAddress) {
return peerAddressToChannel.containsKey(socketAddress);
}
@Override
public void bind(InetSocketAddress socketAddress) {
bind(new TransportAddress(socketAddress, Transport.UDP), CHANNEL_NUMBER.getAndIncrement());
}
private void bind(TransportAddress address, int channelNumber) {
synchronized (peerAddressToChannel) {
if (peerAddressToChannel.containsKey(address)) {
return;
}
logger.info("Binding '{}' to channel '{}'", address, channelNumber);
sendStunRequest(MessageFactory.createChannelBindRequest(
(char) channelNumber, address, TransactionID.createNewTransactionID().getBytes()
), new ResponseCollector() {
@Override
public void processResponse(StunResponseEvent event) {
if (event.getResponse().isSuccessResponse()) {
logger.debug("Bound '{}' to channel '{}'", address, channelNumber);
peerAddressToChannel.put(address, (char) channelNumber);
channelToPeerAddress.put((char) channelNumber, address);
} else {
logger.warn("Binding for '{}' to channel '{}' failed", address, channelNumber);
}
}
@Override
public void processTimeout(StunTimeoutEvent event) {
logger.warn("Binding request for '{}' to channel '{}' timed out.", address, channelNumber);
}
});
}
}
private void releaseAllocation() {
if (refreshTask != null && !refreshTask.isCancelled()) {
logger.debug("Releasing previous allocation");
sendBlockingStunRequest(MessageFactory.createRefreshRequest(0));
refreshTask.cancel(true);
}
}
@NotNull
private Response sendBlockingStunRequest(Request request) {
return noCatch(() -> {
CompletableFuture<Response> responseFuture = new CompletableFuture<>();
sendStunRequest(request, responseCollector(responseFuture));
Response response = responseFuture.get();
if (response.isErrorResponse()) {
logger.warn("STUN error: {}", ((ErrorCodeAttribute) response.getAttribute(ERROR_CODE)).getReasonPhrase());
}
return response;
});
}
private void sendStunRequest(Request request, ResponseCollector responseCollector) {
noCatch(() -> stunStack.sendRequest(request, serverAddress, localAddress, responseCollector));
}
@NotNull
private ResponseCollector responseCollector(final CompletableFuture<Response> responseFuture) {
return new ResponseCollector() {
@Override
public void processResponse(StunResponseEvent event) {
responseFuture.complete(event.getResponse());
}
@Override
public void processTimeout(StunTimeoutEvent event) {
logger.warn("STUN request timed out: {}", event.getMessage());
responseFuture.completeExceptionally(new RuntimeException("STUN request " + event.getTransactionID() + " timed out"));
}
};
}
@Override
public void addOnPacketListener(Consumer<DatagramPacket> listener) {
onPacketListeners.add(listener);
}
@Override
public void removeOnPacketListener(Consumer<DatagramPacket> listener) {
onPacketListeners.remove(listener);
}
/**
* Determines whether a specific {@code DatagramPacket} is accepted by {@link #channelDataSocket} (i.e. whether {@code
* channelDataSocket} understands {@code packet} and {@code packet} is meant to be received by {@code
* channelDataSocket}).
*
* @param packet the {@code DatagramPacket} which is to be checked whether it is accepted by {@code
* channelDataSocket}
*
* @return {@code true} if {@code channelDataSocket} accepts {@code packet} (i.e. {@code channelDataSocket}
* understands {@code packet} and {@code p} is meant to be received by {@code channelDataSocket}); otherwise, {@code
* false}
*/
private boolean isChannelData(DatagramPacket packet) {
// Is it from our TURN server?
if (!serverAddress.equals(packet.getSocketAddress())) {
return false;
}
int packetLength = packet.getLength();
if (packetLength
< (CHANNELDATA_CHANNELNUMBER_LENGTH
+ CHANNELDATA_LENGTH_LENGTH)) {
return false;
}
byte[] pData = packet.getData();
int pOffset = packet.getOffset();
/*
* The first two bits should be 0b01 because of the current channel number range 0x4000 - 0x7FFE. But 0b10 and 0b11
* which are currently reserved and may be used in the future to extend the range of channel numbers.
*/
if ((pData[pOffset] & 0xC0) == 0) {
return false;
}
pOffset += CHANNELDATA_CHANNELNUMBER_LENGTH;
packetLength -= CHANNELDATA_CHANNELNUMBER_LENGTH;
int length = ((pData[pOffset++] << 8) | (pData[pOffset] & 0xFF));
int padding = ((length % 4) > 0) ? 4 - (length % 4) : 0;
/*
* The Length field specifies the length in bytes of the Application Data field. The Length field does not include
* the padding that is sometimes present in the data of the DatagramPacket.
*/
return length == packetLength - padding - CHANNELDATA_LENGTH_LENGTH
|| length == packetLength - CHANNELDATA_LENGTH_LENGTH;
}
private void allocateAddress(TransportAddress turnServerAddress) throws StunException, IOException {
logger.info("Requesting address allocation at {}", serverAddress);
Response response = sendBlockingStunRequest(MessageFactory.createAllocateRequest(UDP, false));
byte[] transactionID = response.getTransactionID();
relayedAddress = ((XorRelayedAddressAttribute) response.getAttribute(XOR_RELAYED_ADDRESS)).getAddress(transactionID);
mappedAddress = ((XorMappedAddressAttribute) response.getAttribute(XOR_MAPPED_ADDRESS)).getAddress(transactionID);
logger.info("Relayed address: {}, mapped address: {}", relayedAddress, mappedAddress);
refreshTask = scheduleRefresh(refreshInterval);
}
private ScheduledFuture<?> scheduleRefresh(int interval) {
return scheduledExecutorService.scheduleWithFixedDelay(() -> {
logger.debug("Refreshing TURN allocation");
sendBlockingStunRequest(MessageFactory.createRefreshRequest(interval));
for (Map.Entry<TransportAddress, Character> entry : peerAddressToChannel.entrySet()) {
bind(entry.getKey(), entry.getValue());
}
}, interval, interval, TimeUnit.MILLISECONDS);
}
private void onIndication(StunMessageEvent event) {
Message message = event.getMessage();
byte[] data = ((DataAttribute) message.getAttribute(Attribute.DATA)).getData();
TransportAddress sender = ((XorPeerAddressAttribute) message.getAttribute(XOR_PEER_ADDRESS)).getAddress(message.getTransactionID());
if (logger.isTraceEnabled()) {
logger.trace("Received {} bytes indication from '{}': {}", data.length, sender,
new String(data, 0, data.length, StandardCharsets.US_ASCII));
}
DatagramPacket datagramPacket = new DatagramPacket(data, data.length);
datagramPacket.setSocketAddress(sender);
onPacketReceived(datagramPacket);
}
private void onPacketReceived(DatagramPacket datagramPacket) {
onPacketListeners.forEach(consumer -> consumer.accept(datagramPacket));
}
private void runInReceiveChannelDataThread() {
int receiveBufferSize = 1500;
DatagramPacket packet = new DatagramPacket(new byte[receiveBufferSize], receiveBufferSize);
while (connectionState.get() == ConnectionState.CONNECTED) {
try {
channelDataSocket.receive(packet);
} catch (IOException e) {
if (channelDataSocket.isClosed()) {
logger.debug("Channel data socket has been closed");
return;
} else {
throw new RuntimeException(e);
}
}
int channelDataLength = packet.getLength();
if (channelDataLength < (CHANNELDATA_CHANNELNUMBER_LENGTH + CHANNELDATA_LENGTH_LENGTH)) {
continue;
}
byte[] receivedData = packet.getData();
int channelDataOffset = packet.getOffset();
char channelNumber = (char)
((receivedData[channelDataOffset++] << 8)
| (receivedData[channelDataOffset++] & 0xFF));
channelDataLength -= CHANNELDATA_CHANNELNUMBER_LENGTH;
char length = (char)
((receivedData[channelDataOffset++] << 8)
| (receivedData[channelDataOffset++] & 0xFF));
channelDataLength -= CHANNELDATA_LENGTH_LENGTH;
if (length > channelDataLength) {
continue;
}
if (!channelToPeerAddress.containsKey(channelNumber)) {
logger.warn("Received {} bytes on unbound channel '{}' from {}", channelDataLength, (int) channelNumber, packet.getSocketAddress());
continue;
}
byte[] payload = new byte[length];
System.arraycopy(receivedData, channelDataOffset, payload, 0, length);
if (logger.isTraceEnabled()) {
logger.trace("Received {} bytes on channel {}: {}", channelDataLength, (int) channelNumber,
new String(payload, 0, length, StandardCharsets.US_ASCII));
}
ChannelData channelData = new ChannelData();
channelData.setChannelNumber(channelNumber);
channelData.setData(payload);
try {
onChannelData(new ChannelDataMessageEvent(stunStack,
channelToPeerAddress.get(channelNumber),
localAddress,
channelData
));
} catch (Exception e) {
logger.warn("Error while handling channel data", e);
}
}
logger.info("Stopped reading channel data");
}
private void onChannelData(ChannelDataMessageEvent event) {
ChannelData channelData = event.getChannelDataMessage();
DatagramPacket datagramPacket = new DatagramPacket(channelData.getData(), channelData.getDataLength());
datagramPacket.setSocketAddress(event.getRemoteAddress());
onPacketReceived(datagramPacket);
}
}