package com.faforever.client.connectivity;
import com.faforever.client.net.ConnectionState;
import com.faforever.client.relay.ConnectToPeerMessage;
import com.faforever.client.relay.CreatePermissionMessage;
import com.faforever.client.remote.FafService;
import com.faforever.client.test.AbstractPlainJavaFxTest;
import org.ice4j.ResponseCollector;
import org.ice4j.StunMessageEvent;
import org.ice4j.StunResponseEvent;
import org.ice4j.TransportAddress;
import org.ice4j.attribute.Attribute;
import org.ice4j.attribute.DataAttribute;
import org.ice4j.attribute.XorMappedAddressAttribute;
import org.ice4j.attribute.XorPeerAddressAttribute;
import org.ice4j.attribute.XorRelayedAddressAttribute;
import org.ice4j.message.ChannelData;
import org.ice4j.message.Message;
import org.ice4j.message.Response;
import org.ice4j.stack.MessageEventHandler;
import org.ice4j.stack.StunStack;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.stubbing.Stubber;
import org.springframework.context.ApplicationContext;
import org.testfx.util.WaitForAsyncUtils;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.ice4j.Transport.UDP;
import static org.ice4j.attribute.Attribute.XOR_MAPPED_ADDRESS;
import static org.ice4j.attribute.Attribute.XOR_RELAYED_ADDRESS;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
public class TurnServerAccessorImplTest extends AbstractPlainJavaFxTest {
private static final String TURN_HOST = "example.com";
private static final int TURN_PORT = 1337;
private TurnServerAccessorImpl instance;
@Mock
private FafService fafService;
@Mock
private ScheduledExecutorService scheduledExecutorService;
@Mock
private ConnectivityService connectivityService;
@Mock
private ApplicationContext applicationContext;
@Mock
private StunStack stunStack;
@Captor
private ArgumentCaptor<Consumer<CreatePermissionMessage>> createPermissionListenerCaptor;
@Captor
private ArgumentCaptor<Consumer<ConnectToPeerMessage>> connectToPeerMessageListenerCaptor;
@Captor
private ArgumentCaptor<MessageEventHandler> indicationListenerCaptor;
@Before
public void setUp() throws Exception {
instance = new TurnServerAccessorImpl();
instance.scheduledExecutorService = scheduledExecutorService;
instance.fafService = fafService;
instance.turnHost = TURN_HOST;
instance.turnPort = TURN_PORT;
instance.connectivityService = connectivityService;
instance.applicationContext = applicationContext;
when(connectivityService.getExternalSocketAddress()).thenReturn(InetSocketAddress.createUnresolved("foo", 123));
when(applicationContext.getBean(StunStack.class)).thenReturn(stunStack);
}
@After
public void tearDown() throws Exception {
instance.disconnect();
}
@Test
public void testConnect() throws Exception {
CountDownLatch refreshRequestLatch = new CountDownLatch(1);
doAnswer(invocation -> {
invocation.getArgumentAt(0, Runnable.class).run();
refreshRequestLatch.countDown();
return null;
}).when(scheduledExecutorService).scheduleWithFixedDelay(any(), anyLong(), anyLong(), any());
connect();
verify(scheduledExecutorService).scheduleWithFixedDelay(any(), anyLong(), anyLong(), any());
assertTrue(refreshRequestLatch.await(3, TimeUnit.SECONDS));
}
private void connect() throws IOException {
StunResponseEvent responseEvent = mock(StunResponseEvent.class);
Response response = mock(Response.class);
when(response.isSuccessResponse()).thenReturn(true);
when(responseEvent.getResponse()).thenReturn(response);
XorRelayedAddressAttribute xorRelayedAddressAttribute = mock(XorRelayedAddressAttribute.class);
XorMappedAddressAttribute xorMappedAddressAttribute = mock(XorMappedAddressAttribute.class);
when(response.getAttribute(XOR_RELAYED_ADDRESS)).thenReturn(xorRelayedAddressAttribute);
when(response.getAttribute(XOR_MAPPED_ADDRESS)).thenReturn(xorMappedAddressAttribute);
when(xorRelayedAddressAttribute.getAddress()).thenReturn(new TransportAddress("1.2.3.4", 1234, UDP));
when(xorMappedAddressAttribute.getAddress()).thenReturn(new TransportAddress("4.3.2.1", 1234, UDP));
respond(responseEvent).when(stunStack).sendRequest(any(), any(), any(TransportAddress.class), any());
assertThat(instance.getConnectionState(), is(ConnectionState.DISCONNECTED));
instance.connect();
verify(fafService).addOnMessageListener(eq(CreatePermissionMessage.class), createPermissionListenerCaptor.capture());
assertThat(instance.getConnectionState(), is(ConnectionState.CONNECTED));
}
private Stubber respond(StunResponseEvent response) {
return doAnswer(invocation -> {
ResponseCollector responseCollector = invocation.getArgumentAt(3, ResponseCollector.class);
responseCollector.processResponse(response);
return null;
});
}
@Test
public void testDisconnect() throws Exception {
connect();
InetSocketAddress localSocketAddress = instance.getLocalSocketAddress();
instance.disconnect();
verify(stunStack).removeSocket(eq(new TransportAddress(localSocketAddress, UDP)));
assertThat(instance.getConnectionState(), is(ConnectionState.DISCONNECTED));
}
@Test
public void testDisconnectNotDisconnected() throws Exception {
assertThat(instance.getConnectionState(), is(ConnectionState.DISCONNECTED));
instance.disconnect();
verifyZeroInteractions(stunStack);
assertThat(instance.getConnectionState(), is(ConnectionState.DISCONNECTED));
}
@Test
public void testSend() throws Exception {
connect();
InetSocketAddress remotePeerAddress = new InetSocketAddress("93.184.216.34", 1234);
CreatePermissionMessage createPermissionMessage = new CreatePermissionMessage();
createPermissionMessage.setAddress(remotePeerAddress);
createPermissionListenerCaptor.getValue().accept(createPermissionMessage);
byte[] bytes = new byte[]{0x00, 0x01, 0x02};
DatagramPacket datagramPacket = new DatagramPacket(bytes, bytes.length);
datagramPacket.setSocketAddress(remotePeerAddress);
bindToChannel(remotePeerAddress);
instance.send(datagramPacket);
ArgumentCaptor<ChannelData> channelDataCaptor = ArgumentCaptor.forClass(ChannelData.class);
ArgumentCaptor<TransportAddress> sendToCaptor = ArgumentCaptor.forClass(TransportAddress.class);
ArgumentCaptor<TransportAddress> sendThroughCaptor = ArgumentCaptor.forClass(TransportAddress.class);
verify(stunStack).sendChannelData(channelDataCaptor.capture(), sendToCaptor.capture(), sendThroughCaptor.capture());
assertThat(channelDataCaptor.getValue().getChannelNumber(), is(greaterThanOrEqualTo((char) 0x4000)));
assertThat(channelDataCaptor.getValue().getDataLength(), is((char) 3));
assertThat(sendToCaptor.getValue(), equalTo(new TransportAddress(TURN_HOST, TURN_PORT, UDP)));
assertThat(sendThroughCaptor.getValue(), equalTo(new TransportAddress(instance.getLocalSocketAddress(), UDP)));
}
private void bindToChannel(SocketAddress socketAddress) throws IOException {
// Mock the binding response
StunResponseEvent responseEvent = mock(StunResponseEvent.class);
Response response = mock(Response.class);
when(response.isSuccessResponse()).thenReturn(true);
when(responseEvent.getResponse()).thenReturn(response);
respond(responseEvent).when(stunStack).sendRequest(any(), any(), any(TransportAddress.class), any());
instance.bind((InetSocketAddress) socketAddress);
}
@Test
public void testReceiveChannelData() throws Exception {
doAnswer(invocation -> {
WaitForAsyncUtils.async(invocation.getArgumentAt(0, Runnable.class));
return null;
}).when(scheduledExecutorService).execute(any());
try (DatagramSocket socket = new DatagramSocket(new InetSocketAddress(InetAddress.getLocalHost(), 0))) {
instance.turnHost = socket.getLocalAddress().getHostAddress();
instance.turnPort = socket.getLocalPort();
connect();
CompletableFuture<DatagramPacket> packetFuture = new CompletableFuture<>();
instance.addOnPacketListener(packetFuture::complete);
InetSocketAddress localSocketAddress = new InetSocketAddress(InetAddress.getLocalHost(), instance.getLocalSocketAddress().getPort());
bindToChannel(socket.getLocalSocketAddress());
// First two bytes are channel number, second two bytes are message length, rest is data
byte[] channelData = new byte[]{0x40, 0x00, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04};
DatagramPacket datagramPacket = new DatagramPacket(channelData, channelData.length);
datagramPacket.setSocketAddress(localSocketAddress);
datagramPacket.setData(channelData);
CreatePermissionMessage createPermissionMessage = new CreatePermissionMessage();
createPermissionMessage.setAddress((InetSocketAddress) socket.getLocalSocketAddress());
createPermissionListenerCaptor.getValue().accept(createPermissionMessage);
socket.send(datagramPacket);
DatagramPacket packetReceivedOnChannel = packetFuture.get(2, TimeUnit.SECONDS);
assertArrayEquals(Arrays.copyOfRange(channelData, 4, channelData.length), packetReceivedOnChannel.getData());
}
}
@Test
public void testIndication() throws Exception {
connect();
verify(stunStack).addIndicationListener(any(), indicationListenerCaptor.capture());
CompletableFuture<DatagramPacket> packetFuture = new CompletableFuture<>();
instance.addOnPacketListener(packetFuture::complete);
StunMessageEvent event = mock(StunResponseEvent.class);
Message message = mock(Message.class);
when(event.getMessage()).thenReturn(message);
byte[] data = "\bBind 102144".getBytes(US_ASCII);
DataAttribute dataAttribute = mock(DataAttribute.class);
when(dataAttribute.getData()).thenReturn(data);
when(message.getAttribute(Attribute.DATA)).thenReturn(dataAttribute);
TransportAddress sender = new TransportAddress("1.2.3.4", 1234, UDP);
XorPeerAddressAttribute xorPeerAddressAttribute = mock(XorPeerAddressAttribute.class);
when(xorPeerAddressAttribute.getAddress(any())).thenReturn(sender);
when(message.getAttribute(Attribute.XOR_PEER_ADDRESS)).thenReturn(xorPeerAddressAttribute);
indicationListenerCaptor.getValue().handleMessageEvent(event);
DatagramPacket packet = packetFuture.get(2, TimeUnit.SECONDS);
assertThat(packet.getData(), is(data));
assertThat(packet.getSocketAddress(), is(sender));
}
}