/* * Copyright 2012-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.devtools.tunnel.server; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import org.junit.Before; import org.junit.Test; import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.fail; /** * Tests for {@link SocketTargetServerConnection}. * * @author Phillip Webb */ public class SocketTargetServerConnectionTests { private static final int DEFAULT_TIMEOUT = 1000; private int port; private MockServer server; private SocketTargetServerConnection connection; @Before public void setup() throws IOException { this.port = SocketUtils.findAvailableTcpPort(); this.server = new MockServer(this.port); StaticPortProvider portProvider = new StaticPortProvider(this.port); this.connection = new SocketTargetServerConnection(portProvider); } @Test public void readData() throws Exception { this.server.willSend("hello".getBytes()); this.server.start(); ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT); ByteBuffer buffer = ByteBuffer.allocate(5); channel.read(buffer); assertThat(buffer.array()).isEqualTo("hello".getBytes()); } @Test public void writeData() throws Exception { this.server.expect("hello".getBytes()); this.server.start(); ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT); ByteBuffer buffer = ByteBuffer.wrap("hello".getBytes()); channel.write(buffer); this.server.closeAndVerify(); } @Test public void timeout() throws Exception { this.server.delay(1000); this.server.start(); ByteChannel channel = this.connection.open(10); long startTime = System.currentTimeMillis(); try { channel.read(ByteBuffer.allocate(5)); fail("No socket timeout thrown"); } catch (SocketTimeoutException ex) { // Expected long runTime = System.currentTimeMillis() - startTime; assertThat(runTime).isGreaterThanOrEqualTo(10L); assertThat(runTime).isLessThan(10000L); } } private static class MockServer { private ServerSocketChannel serverSocket; private byte[] send; private byte[] expect; private int delay; private ByteBuffer actualRead; private ServerThread thread; MockServer(int port) throws IOException { this.serverSocket = ServerSocketChannel.open(); this.serverSocket.bind(new InetSocketAddress(port)); } public void delay(int delay) { this.delay = delay; } public void willSend(byte[] send) { this.send = send; } public void expect(byte[] expect) { this.expect = expect; } public void start() { this.thread = new ServerThread(); this.thread.start(); } public void closeAndVerify() throws InterruptedException { close(); assertThat(this.actualRead.array()).isEqualTo(this.expect); } public void close() throws InterruptedException { while (this.thread.isAlive()) { Thread.sleep(10); } } private class ServerThread extends Thread { @Override public void run() { try { SocketChannel channel = MockServer.this.serverSocket.accept(); Thread.sleep(MockServer.this.delay); if (MockServer.this.send != null) { ByteBuffer buffer = ByteBuffer.wrap(MockServer.this.send); while (buffer.hasRemaining()) { channel.write(buffer); } } if (MockServer.this.expect != null) { ByteBuffer buffer = ByteBuffer .allocate(MockServer.this.expect.length); while (buffer.hasRemaining()) { channel.read(buffer); } MockServer.this.actualRead = buffer; } channel.close(); } catch (Exception ex) { ex.printStackTrace(); fail(); } } } } }