/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.devtools.tunnel.server; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.Channels; import java.util.concurrent.BlockingDeque; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.springframework.boot.devtools.tunnel.payload.HttpTunnelPayload; import org.springframework.boot.devtools.tunnel.server.HttpTunnelServer.HttpConnection; import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; /** * Tests for {@link HttpTunnelServer}. * * @author Phillip Webb */ public class HttpTunnelServerTests { private static final int DEFAULT_LONG_POLL_TIMEOUT = 10000; private static final byte[] NO_DATA = {}; private static final String SEQ_HEADER = "x-seq"; @Rule public ExpectedException thrown = ExpectedException.none(); private HttpTunnelServer server; @Mock private TargetServerConnection serverConnection; private MockHttpServletRequest servletRequest; private MockHttpServletResponse servletResponse; private ServerHttpRequest request; private ServerHttpResponse response; private MockServerChannel serverChannel; @Before public void setup() throws Exception { MockitoAnnotations.initMocks(this); this.server = new HttpTunnelServer(this.serverConnection); given(this.serverConnection.open(anyInt())).willAnswer(new Answer<ByteChannel>() { @Override public ByteChannel answer(InvocationOnMock invocation) throws Throwable { MockServerChannel channel = HttpTunnelServerTests.this.serverChannel; channel.setTimeout((Integer) invocation.getArguments()[0]); return channel; } }); this.servletRequest = new MockHttpServletRequest(); this.servletRequest.setAsyncSupported(true); this.servletResponse = new MockHttpServletResponse(); this.request = new ServletServerHttpRequest(this.servletRequest); this.response = new ServletServerHttpResponse(this.servletResponse); this.serverChannel = new MockServerChannel(); } @Test public void serverConnectionIsRequired() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("ServerConnection must not be null"); new HttpTunnelServer(null); } @Test public void serverConnectedOnFirstRequest() throws Exception { verify(this.serverConnection, never()).open(anyInt()); this.server.handle(this.request, this.response); verify(this.serverConnection, times(1)).open(DEFAULT_LONG_POLL_TIMEOUT); } @Test public void longPollTimeout() throws Exception { this.server.setLongPollTimeout(800); this.server.handle(this.request, this.response); verify(this.serverConnection, times(1)).open(800); } @Test public void longPollTimeoutMustBePositiveValue() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("LongPollTimeout must be a positive value"); this.server.setLongPollTimeout(0); } @Test public void initialRequestIsSentToServer() throws Exception { this.servletRequest.addHeader(SEQ_HEADER, "1"); this.servletRequest.setContent("hello".getBytes()); this.server.handle(this.request, this.response); this.serverChannel.disconnect(); this.server.getServerThread().join(); this.serverChannel.verifyReceived("hello"); } @Test public void initialRequestIsUsedForFirstServerResponse() throws Exception { this.servletRequest.addHeader(SEQ_HEADER, "1"); this.servletRequest.setContent("hello".getBytes()); this.server.handle(this.request, this.response); System.out.println("sending"); this.serverChannel.send("hello"); this.serverChannel.disconnect(); this.server.getServerThread().join(); assertThat(this.servletResponse.getContentAsString()).isEqualTo("hello"); this.serverChannel.verifyReceived("hello"); } @Test public void initialRequestHasNoPayload() throws Exception { this.server.handle(this.request, this.response); this.serverChannel.disconnect(); this.server.getServerThread().join(); this.serverChannel.verifyReceived(NO_DATA); } @Test public void typicalRequestResponseTraffic() throws Exception { MockHttpConnection h1 = new MockHttpConnection(); this.server.handle(h1); MockHttpConnection h2 = new MockHttpConnection("hello server", 1); this.server.handle(h2); this.serverChannel.verifyReceived("hello server"); this.serverChannel.send("hello client"); h1.verifyReceived("hello client", 1); MockHttpConnection h3 = new MockHttpConnection("1+1", 2); this.server.handle(h3); this.serverChannel.send("=2"); h2.verifyReceived("=2", 2); MockHttpConnection h4 = new MockHttpConnection("1+2", 3); this.server.handle(h4); this.serverChannel.send("=3"); h3.verifyReceived("=3", 3); this.serverChannel.disconnect(); this.server.getServerThread().join(); } @Test public void clientIsAwareOfServerClose() throws Exception { MockHttpConnection h1 = new MockHttpConnection("1", 1); this.server.handle(h1); this.serverChannel.disconnect(); this.server.getServerThread().join(); assertThat(h1.getServletResponse().getStatus()).isEqualTo(410); } @Test public void clientCanCloseServer() throws Exception { MockHttpConnection h1 = new MockHttpConnection(); this.server.handle(h1); MockHttpConnection h2 = new MockHttpConnection("DISCONNECT", 1); h2.getServletRequest().addHeader("Content-Type", "application/x-disconnect"); this.server.handle(h2); this.server.getServerThread().join(); assertThat(h1.getServletResponse().getStatus()).isEqualTo(410); assertThat(this.serverChannel.isOpen()).isFalse(); } @Test public void neverMoreThanTwoHttpConnections() throws Exception { MockHttpConnection h1 = new MockHttpConnection(); this.server.handle(h1); MockHttpConnection h2 = new MockHttpConnection("1", 2); this.server.handle(h2); MockHttpConnection h3 = new MockHttpConnection("2", 3); this.server.handle(h3); h1.waitForResponse(); assertThat(h1.getServletResponse().getStatus()).isEqualTo(429); this.serverChannel.disconnect(); this.server.getServerThread().join(); } @Test public void requestReceivedOutOfOrder() throws Exception { MockHttpConnection h1 = new MockHttpConnection(); MockHttpConnection h2 = new MockHttpConnection("1+2", 1); MockHttpConnection h3 = new MockHttpConnection("+3", 2); this.server.handle(h1); this.server.handle(h3); this.server.handle(h2); this.serverChannel.verifyReceived("1+2+3"); this.serverChannel.disconnect(); this.server.getServerThread().join(); } @Test public void httpConnectionsAreClosedAfterLongPollTimeout() throws Exception { this.server.setDisconnectTimeout(1000); this.server.setLongPollTimeout(100); MockHttpConnection h1 = new MockHttpConnection(); this.server.handle(h1); MockHttpConnection h2 = new MockHttpConnection(); this.server.handle(h2); Thread.sleep(400); this.serverChannel.disconnect(); this.server.getServerThread().join(); assertThat(h1.getServletResponse().getStatus()).isEqualTo(204); assertThat(h2.getServletResponse().getStatus()).isEqualTo(204); } @Test public void disconnectTimeout() throws Exception { this.server.setDisconnectTimeout(100); this.server.setLongPollTimeout(100); MockHttpConnection h1 = new MockHttpConnection(); this.server.handle(h1); this.serverChannel.send("hello"); this.server.getServerThread().join(); assertThat(this.serverChannel.isOpen()).isFalse(); } @Test public void disconnectTimeoutMustBePositive() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("DisconnectTimeout must be a positive value"); this.server.setDisconnectTimeout(0); } @Test public void httpConnectionRespondWithPayload() throws Exception { HttpConnection connection = new HttpConnection(this.request, this.response); connection.waitForResponse(); connection.respond(new HttpTunnelPayload(1, ByteBuffer.wrap("hello".getBytes()))); assertThat(this.servletResponse.getStatus()).isEqualTo(200); assertThat(this.servletResponse.getContentAsString()).isEqualTo("hello"); assertThat(this.servletResponse.getHeader(SEQ_HEADER)).isEqualTo("1"); } @Test public void httpConnectionRespondWithStatus() throws Exception { HttpConnection connection = new HttpConnection(this.request, this.response); connection.waitForResponse(); connection.respond(HttpStatus.I_AM_A_TEAPOT); assertThat(this.servletResponse.getStatus()).isEqualTo(418); assertThat(this.servletResponse.getContentLength()).isEqualTo(0); } @Test public void httpConnectionAsync() throws Exception { ServerHttpAsyncRequestControl async = mock(ServerHttpAsyncRequestControl.class); ServerHttpRequest request = mock(ServerHttpRequest.class); given(request.getAsyncRequestControl(this.response)).willReturn(async); HttpConnection connection = new HttpConnection(request, this.response); connection.waitForResponse(); verify(async).start(); connection.respond(HttpStatus.NO_CONTENT); verify(async).complete(); } @Test public void httpConnectionNonAsync() throws Exception { testHttpConnectionNonAsync(0); testHttpConnectionNonAsync(100); } private void testHttpConnectionNonAsync(long sleepBeforeResponse) throws IOException, InterruptedException { ServerHttpRequest request = mock(ServerHttpRequest.class); given(request.getAsyncRequestControl(this.response)) .willThrow(new IllegalArgumentException()); final HttpConnection connection = new HttpConnection(request, this.response); final AtomicBoolean responded = new AtomicBoolean(); Thread connectionThread = new Thread() { @Override public void run() { connection.waitForResponse(); responded.set(true); } }; connectionThread.start(); assertThat(responded.get()).isFalse(); Thread.sleep(sleepBeforeResponse); connection.respond(HttpStatus.NO_CONTENT); connectionThread.join(); assertThat(responded.get()).isTrue(); } @Test public void httpConnectionRunning() throws Exception { HttpConnection connection = new HttpConnection(this.request, this.response); assertThat(connection.isOlderThan(100)).isFalse(); Thread.sleep(200); assertThat(connection.isOlderThan(100)).isTrue(); } /** * Mock {@link ByteChannel} used to simulate the server connection. */ private static class MockServerChannel implements ByteChannel { private static final ByteBuffer DISCONNECT = ByteBuffer.wrap(NO_DATA); private int timeout; private BlockingDeque<ByteBuffer> outgoing = new LinkedBlockingDeque<>(); private ByteArrayOutputStream written = new ByteArrayOutputStream(); private AtomicBoolean open = new AtomicBoolean(true); public void setTimeout(int timeout) { this.timeout = timeout; } public void send(String content) { send(content.getBytes()); } public void send(byte[] bytes) { this.outgoing.addLast(ByteBuffer.wrap(bytes)); } public void disconnect() { this.outgoing.addLast(DISCONNECT); } public void verifyReceived(String expected) { verifyReceived(expected.getBytes()); } public void verifyReceived(byte[] expected) { synchronized (this.written) { assertThat(this.written.toByteArray()).isEqualTo(expected); this.written.reset(); } } @Override public int read(ByteBuffer dst) throws IOException { try { ByteBuffer bytes = this.outgoing.pollFirst(this.timeout, TimeUnit.MILLISECONDS); if (bytes == null) { throw new SocketTimeoutException(); } if (bytes == DISCONNECT) { this.open.set(false); return -1; } int initialRemaining = dst.remaining(); bytes.limit(Math.min(bytes.limit(), initialRemaining)); dst.put(bytes); bytes.limit(bytes.capacity()); return initialRemaining - dst.remaining(); } catch (InterruptedException ex) { throw new IllegalStateException(ex); } } @Override public int write(ByteBuffer src) throws IOException { int remaining = src.remaining(); synchronized (this.written) { Channels.newChannel(this.written).write(src); } return remaining; } @Override public boolean isOpen() { return this.open.get(); } @Override public void close() throws IOException { this.open.set(false); } } /** * Mock {@link HttpConnection}. */ private static class MockHttpConnection extends HttpConnection { MockHttpConnection() { super(new ServletServerHttpRequest(new MockHttpServletRequest()), new ServletServerHttpResponse(new MockHttpServletResponse())); } MockHttpConnection(String content, int seq) { this(); MockHttpServletRequest request = getServletRequest(); request.setContent(content.getBytes()); request.addHeader(SEQ_HEADER, String.valueOf(seq)); } @Override protected ServerHttpAsyncRequestControl startAsync() { getServletRequest().setAsyncSupported(true); return super.startAsync(); } @Override protected void complete() { super.complete(); getServletResponse().setCommitted(true); } public MockHttpServletRequest getServletRequest() { return (MockHttpServletRequest) ((ServletServerHttpRequest) getRequest()) .getServletRequest(); } public MockHttpServletResponse getServletResponse() { return (MockHttpServletResponse) ((ServletServerHttpResponse) getResponse()) .getServletResponse(); } public void verifyReceived(String expectedContent, int expectedSeq) throws Exception { waitForServletResponse(); MockHttpServletResponse resp = getServletResponse(); assertThat(resp.getContentAsString()).isEqualTo(expectedContent); assertThat(resp.getHeader(SEQ_HEADER)).isEqualTo(String.valueOf(expectedSeq)); } public void waitForServletResponse() throws InterruptedException { while (!getServletResponse().isCommitted()) { Thread.sleep(10); } } } }