package test.r2.integ;
import com.linkedin.common.callback.Callback;
import com.linkedin.common.callback.FutureCallback;
import com.linkedin.common.util.None;
import com.linkedin.r2.filter.R2Constants;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.RestRequestHandler;
import com.linkedin.r2.transport.common.Server;
import com.linkedin.r2.transport.common.StreamRequestHandler;
import com.linkedin.r2.transport.common.StreamRequestHandlerAdapter;
import com.linkedin.r2.transport.common.TransportClientFactory;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.client.TransportClientAdapter;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;
import com.linkedin.r2.transport.common.bridge.server.TransportCallbackAdapter;
import com.linkedin.r2.transport.common.bridge.server.TransportDispatcher;
import com.linkedin.r2.transport.http.client.HttpClientFactory;
import com.linkedin.r2.transport.http.common.HttpProtocolVersion;
import com.linkedin.r2.transport.http.server.HttpJettyServer;
import com.linkedin.r2.transport.http.server.HttpServerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Factory;
import org.testng.annotations.Test;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* @author Zhenkai Zhu
*/
public class TestQueryTunnel
{
private static int PORT = 9003;
private static int IS_TUNNELED_RESPONSE_CODE = 200;
private static int IS_NOT_TUNNELED_RESPONSE_CODE = 201;
private static int QUERY_TUNNEL_THRESHOLD = 8;
private Client _client;
private Server _server;
private TransportClientFactory _clientFactory;
private final boolean _clientROS;
private final boolean _serverROS;
private final HttpJettyServer.ServletType _servletType;
private final String _httpProtocolVersion;
private final int _port;
@Factory(dataProvider = "configs")
public TestQueryTunnel(boolean clientROS, boolean serverROS, String httpProtocolVersion, HttpJettyServer.ServletType servletType, int port)
{
_clientROS = clientROS;
_serverROS = serverROS;
_httpProtocolVersion = httpProtocolVersion;
_servletType = servletType;
_port = port;
}
@DataProvider
public static Object[][] configs()
{
return new Object[][] {
{true, true, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.RAP, PORT},
{true, false, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.RAP, PORT},
{false, true, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.RAP, PORT},
{false, false, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.RAP, PORT},
{true, true, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{true, false, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{false, true, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{false, false, HttpProtocolVersion.HTTP_1_1.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{true, true, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.RAP, PORT},
{true, false, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.RAP, PORT},
{false, true, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.RAP, PORT},
{false, false, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.RAP, PORT},
{true, true, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{true, false, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{false, true, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT},
{false, false, HttpProtocolVersion.HTTP_2.name(), HttpJettyServer.ServletType.ASYNC_EVENT, PORT}
};
}
@BeforeClass
protected void setUp() throws Exception
{
Map<String, String> clientProperties = new HashMap<String, String>();
clientProperties.put(HttpClientFactory.HTTP_QUERY_POST_THRESHOLD, String.valueOf(QUERY_TUNNEL_THRESHOLD));
clientProperties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, _httpProtocolVersion);
_clientFactory = new HttpClientFactory();
final TransportClient transportClient = _clientFactory
.getClient(clientProperties);
_client = new TransportClientAdapter(transportClient, _clientROS);
final RestRequestHandler restHandler = new CheckQueryTunnelHandler();
final StreamRequestHandler streamHandler = new StreamRequestHandlerAdapter(restHandler);
TransportDispatcher dispatcher = new TransportDispatcher()
{
@Override
public void handleRestRequest(RestRequest req, Map<String, String> wireAttrs, RequestContext requestContext,
TransportCallback<RestResponse> callback)
{
restHandler.handleRequest(req, requestContext, new TransportCallbackAdapter<RestResponse>(callback));
}
@Override
public void handleStreamRequest(StreamRequest req, Map<String, String> wireAttrs,
RequestContext requestContext, TransportCallback<StreamResponse> callback)
{
streamHandler.handleRequest(req, requestContext, new TransportCallbackAdapter<StreamResponse>(callback));
}
};
_server = new HttpServerFactory(_servletType).createH2cServer(_port, dispatcher, _serverROS);
_server.start();
}
@Test
public void testShouldNotQueryTunnel() throws Exception
{
String shortQuery = buildQuery(QUERY_TUNNEL_THRESHOLD - 1);
RestResponse response = getResponse(shortQuery, new RequestContext());
Assert.assertEquals(response.getStatus(), IS_NOT_TUNNELED_RESPONSE_CODE);
Assert.assertEquals(response.getEntity().copyBytes(), shortQuery.getBytes());
}
@Test
public void testShouldQueryTunnel() throws Exception
{
String longQuery = buildQuery(QUERY_TUNNEL_THRESHOLD);
RestResponse response = getResponse(longQuery, new RequestContext());
Assert.assertEquals(response.getStatus(), IS_TUNNELED_RESPONSE_CODE);
Assert.assertEquals(response.getEntity().copyBytes(), longQuery.getBytes());
}
@Test
public void testForceQueryTunnel() throws Exception
{
String shortQuery = buildQuery(QUERY_TUNNEL_THRESHOLD - 1);
RequestContext requestContext = new RequestContext();
requestContext.putLocalAttr(R2Constants.FORCE_QUERY_TUNNEL, true);
RestResponse response = getResponse(shortQuery, requestContext);
Assert.assertEquals(response.getStatus(), IS_TUNNELED_RESPONSE_CODE);
Assert.assertEquals(response.getEntity().copyBytes(), shortQuery.getBytes());
}
private String buildQuery(int len)
{
StringBuilder builder = new StringBuilder("id=");
for (int i = 0; i < len - 3; i++)
{
builder.append("a");
}
return builder.toString();
}
private RestResponse getResponse(String query, RequestContext requestContext) throws Exception
{
URI uri = URI.create("http://localhost:" + _port + "/checkQuery?" + query);
RestRequestBuilder builder = new RestRequestBuilder(uri);
return _client.restRequest(builder.build(), requestContext).get(5000, TimeUnit.MILLISECONDS);
}
@AfterClass
protected void tearDown() throws Exception
{
final FutureCallback<None> callback = new FutureCallback<None>();
_client.shutdown(callback);
callback.get();
final FutureCallback<None> factoryCallback = new FutureCallback<None>();
_clientFactory.shutdown(factoryCallback);
factoryCallback.get();
_server.stop();
_server.waitForStop();
}
private class CheckQueryTunnelHandler implements RestRequestHandler
{
@Override
public void handleRequest(RestRequest request, RequestContext requestContext, Callback<RestResponse> callback)
{
RestResponseBuilder builder = new RestResponseBuilder().setEntity(request.getURI().getRawQuery().getBytes());
Object isQueryTunnel = requestContext.getLocalAttr(R2Constants.IS_QUERY_TUNNELED);
if (isQueryTunnel != null && (Boolean) isQueryTunnel)
{
builder.setStatus(IS_TUNNELED_RESPONSE_CODE).build();
}
else
{
builder.setStatus(IS_NOT_TUNNELED_RESPONSE_CODE).build();
}
callback.onSuccess(builder.build());
}
}
}