/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.devtools.tunnel.client; import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; /** * Tests for {@link TunnelClient}. * * @author Phillip Webb */ public class TunnelClientTests { @Rule public ExpectedException thrown = ExpectedException.none(); private int listenPort = SocketUtils.findAvailableTcpPort(); private MockTunnelConnection tunnelConnection = new MockTunnelConnection(); @Test public void listenPortMustBePositive() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("ListenPort must be positive"); new TunnelClient(0, this.tunnelConnection); } @Test public void tunnelConnectionMustNotBeNull() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("TunnelConnection must not be null"); new TunnelClient(1, null); } @Test public void typicalTraffic() throws Exception { TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); client.start(); SocketChannel channel = SocketChannel .open(new InetSocketAddress(this.listenPort)); channel.write(ByteBuffer.wrap("hello".getBytes())); ByteBuffer buffer = ByteBuffer.allocate(5); channel.read(buffer); channel.close(); this.tunnelConnection.verifyWritten("hello"); assertThat(new String(buffer.array())).isEqualTo("olleh"); } @Test public void socketChannelClosedTriggersTunnelClose() throws Exception { TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); client.start(); SocketChannel channel = SocketChannel .open(new InetSocketAddress(this.listenPort)); Thread.sleep(200); channel.close(); client.getServerThread().stopAcceptingConnections(); client.getServerThread().join(2000); assertThat(this.tunnelConnection.getOpenedTimes()).isEqualTo(1); assertThat(this.tunnelConnection.isOpen()).isFalse(); } @Test public void stopTriggersTunnelClose() throws Exception { TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); client.start(); SocketChannel channel = SocketChannel .open(new InetSocketAddress(this.listenPort)); Thread.sleep(200); client.stop(); assertThat(this.tunnelConnection.getOpenedTimes()).isEqualTo(1); assertThat(this.tunnelConnection.isOpen()).isFalse(); assertThat(channel.read(ByteBuffer.allocate(1))).isEqualTo(-1); } @Test public void addListener() throws Exception { TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); TunnelClientListener listener = mock(TunnelClientListener.class); client.addListener(listener); client.start(); SocketChannel channel = SocketChannel .open(new InetSocketAddress(this.listenPort)); Thread.sleep(200); channel.close(); client.getServerThread().stopAcceptingConnections(); client.getServerThread().join(2000); verify(listener).onOpen(any(SocketChannel.class)); verify(listener).onClose(any(SocketChannel.class)); } private static class MockTunnelConnection implements TunnelConnection { private final ByteArrayOutputStream written = new ByteArrayOutputStream(); private boolean open; private int openedTimes; @Override public WritableByteChannel open(WritableByteChannel incomingChannel, Closeable closeable) throws Exception { this.openedTimes++; this.open = true; return new TunnelChannel(incomingChannel, closeable); } public void verifyWritten(String expected) { verifyWritten(expected.getBytes()); } public void verifyWritten(byte[] expected) { synchronized (this.written) { assertThat(this.written.toByteArray()).isEqualTo(expected); this.written.reset(); } } public boolean isOpen() { return this.open; } public int getOpenedTimes() { return this.openedTimes; } private class TunnelChannel implements WritableByteChannel { private final WritableByteChannel incomingChannel; private final Closeable closeable; TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) { this.incomingChannel = incomingChannel; this.closeable = closeable; } @Override public boolean isOpen() { return MockTunnelConnection.this.open; } @Override public void close() throws IOException { MockTunnelConnection.this.open = false; this.closeable.close(); } @Override public int write(ByteBuffer src) throws IOException { int remaining = src.remaining(); ByteArrayOutputStream stream = new ByteArrayOutputStream(); Channels.newChannel(stream).write(src); byte[] bytes = stream.toByteArray(); synchronized (MockTunnelConnection.this.written) { MockTunnelConnection.this.written.write(bytes); } byte[] reversed = new byte[bytes.length]; for (int i = 0; i < reversed.length; i++) { reversed[i] = bytes[bytes.length - 1 - i]; } this.incomingChannel.write(ByteBuffer.wrap(reversed)); return remaining; } } } }