// Copyright 2016 Twitter. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.twitter.heron.spi.utils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.InetSocketAddress; import java.net.URL; import java.time.Duration; import java.util.logging.Logger; import com.sun.net.httpserver.Headers; import com.sun.net.httpserver.HttpExchange; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Matchers; import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import com.twitter.heron.common.basics.Pair; import com.twitter.heron.common.basics.SysUtils; @RunWith(PowerMockRunner.class) @PrepareForTest({ SysUtils.class, NetworkUtils.class, ShellUtils.class}) public class NetworkUtilsTest { private static final Logger LOG = Logger.getLogger(NetworkUtilsTest.class.getName()); @Test public void testSendHttpResponse() throws Exception { HttpExchange exchange = Mockito.mock(HttpExchange.class); Mockito.doNothing().when(exchange).sendResponseHeaders(Matchers.anyInt(), Matchers.anyLong()); OutputStream os = Mockito.mock(OutputStream.class); Mockito.doReturn(os).when(exchange).getResponseBody(); Mockito.doNothing().when(os).write(Matchers.any(byte[].class)); Mockito.doNothing().when(os).close(); Assert.assertTrue(NetworkUtils.sendHttpResponse(exchange, new byte[0])); Mockito.verify(exchange).getResponseBody(); Mockito.verify(os, Mockito.atLeastOnce()).write(Matchers.any(byte[].class)); Mockito.verify(os, Mockito.atLeastOnce()).close(); } @Test public void testSendHttpResponseFail() throws Exception { HttpExchange exchange = Mockito.mock(HttpExchange.class); Mockito.doThrow(new IOException("Designed IO exception for testing")). when(exchange).sendResponseHeaders(Matchers.anyInt(), Matchers.anyLong()); Assert.assertFalse(NetworkUtils.sendHttpResponse(exchange, new byte[0])); Mockito.verify(exchange, Mockito.never()).getResponseBody(); Mockito.doNothing(). when(exchange).sendResponseHeaders(Matchers.anyInt(), Matchers.anyLong()); OutputStream os = Mockito.mock(OutputStream.class); Mockito.doReturn(os).when(exchange).getResponseBody(); Mockito.doThrow(new IOException("Designed IO exception for testing")). when(os).write(Matchers.any(byte[].class)); Assert.assertFalse(NetworkUtils.sendHttpResponse(exchange, new byte[0])); Mockito.verify(os, Mockito.atLeastOnce()).close(); Mockito.doNothing().when(os).write(Matchers.any(byte[].class)); Mockito.doThrow(new IOException("Designed IO exception for testing")) .when(os).close(); Assert.assertFalse(NetworkUtils.sendHttpResponse(exchange, new byte[0])); } @Test public void testReadHttpRequestBody() throws Exception { byte[] expectedBytes = "TO READ".getBytes(); InputStream is = Mockito.spy(new ByteArrayInputStream(expectedBytes)); HttpExchange exchange = Mockito.mock(HttpExchange.class); Headers headers = Mockito.mock(Headers.class); Mockito.doReturn(Integer.toString(expectedBytes.length)). when(headers).getFirst(Matchers.anyString()); Mockito.doReturn(headers).when(exchange).getRequestHeaders(); Mockito.doReturn(is).when(exchange).getRequestBody(); Assert.assertArrayEquals(expectedBytes, NetworkUtils.readHttpRequestBody(exchange)); Mockito.verify(is, Mockito.atLeastOnce()).close(); } @Test public void testReadHttpRequestBodyFail() throws Exception { HttpExchange exchange = Mockito.mock(HttpExchange.class); Headers headers = Mockito.mock(Headers.class); Mockito.doReturn(headers).when(exchange).getRequestHeaders(); Mockito.doReturn("-1"). when(headers).getFirst(Matchers.anyString()); Assert.assertArrayEquals(new byte[0], NetworkUtils.readHttpRequestBody(exchange)); Mockito.doReturn("10"). when(headers).getFirst(Matchers.anyString()); InputStream inputStream = Mockito.mock(InputStream.class); Mockito.doReturn(inputStream).when(exchange).getRequestBody(); Mockito.doThrow(new IOException("Designed IO exception for testing")) .when(inputStream).read(Matchers.any(byte[].class), Matchers.anyInt(), Matchers.anyInt()); Assert.assertArrayEquals(new byte[0], NetworkUtils.readHttpRequestBody(exchange)); Mockito.verify(inputStream, Mockito.atLeastOnce()).close(); } @Test public void testSendHttpPostRequest() throws Exception { URL url = new URL("http://"); int dataLength = 100; HttpURLConnection connection = Mockito.spy((HttpURLConnection) url.openConnection()); OutputStream os = Mockito.mock(OutputStream.class); Mockito.doReturn(os).when(connection).getOutputStream(); byte[] data = new byte[dataLength]; Assert.assertTrue(NetworkUtils.sendHttpPostRequest(connection, NetworkUtils.URL_ENCODE_TYPE, data)); Assert.assertEquals("POST", connection.getRequestMethod()); Assert.assertEquals("application/x-www-form-urlencoded", connection.getRequestProperty("Content-Type")); Assert.assertEquals(false, connection.getUseCaches()); Assert.assertEquals(true, connection.getDoOutput()); connection.disconnect(); } @Test public void testSendHttpPostRequestFail() throws Exception { URL url = new URL("http://"); HttpURLConnection connection = Mockito.spy((HttpURLConnection) url.openConnection()); Mockito.doThrow(new IOException("Designed IO exception for testing")). when(connection).getOutputStream(); Assert.assertFalse(NetworkUtils.sendHttpPostRequest(connection, NetworkUtils.URL_ENCODE_TYPE, new byte[1])); connection.disconnect(); } @Test public void testReadHttpResponseFail() throws Exception { HttpURLConnection connection = Mockito.mock(HttpURLConnection.class); // Unable to read response due to wrong response code Mockito.doReturn(HttpURLConnection.HTTP_NOT_FOUND).when(connection).getResponseCode(); Assert.assertArrayEquals(new byte[0], NetworkUtils.readHttpResponse(connection)); // Unable to read response due to wrong response content length Mockito.doReturn(HttpURLConnection.HTTP_OK).when(connection).getResponseCode(); Mockito.doReturn(-1).when(connection).getContentLength(); Assert.assertArrayEquals(new byte[0], NetworkUtils.readHttpResponse(connection)); Mockito.doThrow(new IOException("Designed IO exception for testing")). when(connection).getResponseCode(); Assert.assertArrayEquals(new byte[0], NetworkUtils.readHttpResponse(connection)); } @Test public void testReadHttpResponse() throws Exception { String expectedResponseString = "Hello World!"; byte[] expectedBytes = expectedResponseString.getBytes(); HttpURLConnection connection = Mockito.mock(HttpURLConnection.class); Mockito.doReturn(HttpURLConnection.HTTP_OK).when(connection).getResponseCode(); Mockito.doReturn(expectedBytes.length).when(connection).getContentLength(); InputStream is = new ByteArrayInputStream(expectedBytes); Mockito.doReturn(is).when(connection).getInputStream(); Assert.assertArrayEquals(expectedBytes, NetworkUtils.readHttpResponse(connection)); } /** * Test establishSSHTunnelIfNeeded() */ @Test public void testEstablishSSHTunnelIfNeeded() throws Exception { // Mock host to verified String mockHost = "host0"; int mockPort = 9049; String mockEndpoint = String.format("%s:%d", mockHost, mockPort); InetSocketAddress mockAddr = NetworkUtils.getInetSocketAddress(mockEndpoint); int mockFreePort = 9519; String tunnelHost = "tunnelHost"; Duration timeout = Duration.ofMillis(-1); int retryCount = -1; Duration retryInterval = Duration.ofMillis(-1); int verifyCount = -1; NetworkUtils.TunnelConfig tunnelConfig = new NetworkUtils.TunnelConfig( true, tunnelHost, timeout, retryCount, retryInterval, verifyCount); // Can reach directly, no need to ssh tunnel PowerMockito.spy(NetworkUtils.class); PowerMockito.doReturn(true).when(NetworkUtils.class, "isLocationReachable", Mockito.eq(mockAddr), Mockito.eq(timeout), Mockito.anyInt(), Mockito.eq(retryInterval)); Pair<InetSocketAddress, Process> ret = NetworkUtils.establishSSHTunnelIfNeeded(NetworkUtils.getInetSocketAddress(mockEndpoint), tunnelConfig, NetworkUtils.TunnelType.PORT_FORWARD); Assert.assertEquals(mockHost, ret.first.getHostName()); Assert.assertEquals(mockPort, ret.first.getPort()); Assert.assertEquals(mockEndpoint, ret.first.toString()); Assert.assertNull(ret.second); // Can not reach directly, basic setup PowerMockito.doReturn(false).when(NetworkUtils.class, "isLocationReachable", Mockito.eq(mockAddr), Mockito.eq(timeout), Mockito.anyInt(), Mockito.eq(retryInterval)); PowerMockito.spy(SysUtils.class); PowerMockito.doReturn(mockFreePort).when(SysUtils.class, "getFreePort"); Process process = Mockito.mock(Process.class); Mockito.doReturn(true).when(process).isAlive(); // Can not reach directly, failed to establish ssh tunnel either PowerMockito.spy(ShellUtils.class); PowerMockito.doReturn(process).when(ShellUtils.class, "establishSSHTunnelProcess", Mockito.anyString(), Mockito.anyInt(), Mockito.anyString(), Mockito.anyInt()); InetSocketAddress newAddress = NetworkUtils.getInetSocketAddress( String.format("%s:%d", NetworkUtils.LOCAL_HOST, mockFreePort)); PowerMockito.doReturn(false).when(NetworkUtils.class, "isLocationReachable", Mockito.eq(newAddress), Mockito.eq(timeout), Mockito.anyInt(), Mockito.eq(retryInterval)); ret = NetworkUtils.establishSSHTunnelIfNeeded(NetworkUtils.getInetSocketAddress(mockEndpoint), tunnelConfig, NetworkUtils.TunnelType.PORT_FORWARD); Assert.assertNull(ret.first); Assert.assertNull(ret.second); // Can not reach directly, but can establish ssh tunnel to reach the destination PowerMockito.doReturn(true).when(NetworkUtils.class, "isLocationReachable", Mockito.eq(newAddress), Mockito.eq(timeout), Mockito.anyInt(), Mockito.eq(retryInterval)); ret = NetworkUtils.establishSSHTunnelIfNeeded(NetworkUtils.getInetSocketAddress(mockEndpoint), tunnelConfig, NetworkUtils.TunnelType.PORT_FORWARD); Assert.assertEquals(NetworkUtils.LOCAL_HOST, ret.first.getHostName()); Assert.assertEquals(mockFreePort, ret.first.getPort()); Assert.assertEquals(process, ret.second); } @Test public void testGetInetSocketAddress() throws Exception { String host = "host"; int port = 999; String endpoint = String.format("%s:%d", host, port); InetSocketAddress address = NetworkUtils.getInetSocketAddress(endpoint); Assert.assertEquals(host, address.getHostString()); Assert.assertEquals(port, address.getPort()); Assert.assertEquals(endpoint, address.toString()); } }