/*
* Copyright 2015 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.thrift;
import static com.linecorp.armeria.common.http.HttpSessionProtocols.HTTP;
import static com.linecorp.armeria.common.http.HttpSessionProtocols.HTTPS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.async.AsyncMethodCallback;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import com.google.common.base.Strings;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.RpcResponse;
import com.linecorp.armeria.common.http.HttpHeaderNames;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpRequest;
import com.linecorp.armeria.common.http.HttpResponse;
import com.linecorp.armeria.common.logging.RequestLog;
import com.linecorp.armeria.common.logging.RequestLogAvailability;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.SimpleDecoratingService;
import com.linecorp.armeria.server.logging.LoggingService;
import com.linecorp.armeria.service.test.thrift.main.HelloService;
import com.linecorp.armeria.service.test.thrift.main.HelloService.AsyncIface;
import com.linecorp.armeria.service.test.thrift.main.SleepService;
import io.netty.handler.ssl.util.SelfSignedCertificate;
public abstract class AbstractThriftOverHttpTest {
private static final String LARGER_THAN_TLS = Strings.repeat("A", 16384);
private static final Server server;
private static int httpPort;
private static int httpsPort;
private static volatile boolean recordMessageLogs;
private static final BlockingQueue<RequestLog> requestLogs = new LinkedBlockingQueue<>();
abstract static class HelloServiceBase implements AsyncIface {
@Override
public void hello(String name, AsyncMethodCallback resultHandler) throws TException {
resultHandler.onComplete(getResponse(name));
}
protected String getResponse(String name) {
return "Hello, " + name + '!';
}
}
static class HelloServiceChild extends HelloServiceBase {
@Override
protected String getResponse(String name) {
return "Goodbye, " + name + '!';
}
}
static {
final SelfSignedCertificate ssc;
final ServerBuilder sb = new ServerBuilder();
try {
sb.port(0, HTTP);
sb.port(0, HTTPS);
ssc = new SelfSignedCertificate("127.0.0.1");
sb.sslContext(HTTPS, ssc.certificate(), ssc.privateKey());
sb.serviceAt("/hello", THttpService.of(
(AsyncIface) (name, resultHandler) -> resultHandler.onComplete("Hello, " + name + '!')));
sb.serviceAt("/hellochild", THttpService.of(new HelloServiceChild()));
sb.serviceAt("/exception", THttpService.of(
(AsyncIface) (name, resultHandler) ->
resultHandler.onError(Exceptions.clearTrace(new Exception(name)))));
sb.serviceAt("/hellochild", THttpService.of(new HelloServiceChild()));
sb.serviceAt("/sleep", THttpService.of(
(SleepService.AsyncIface) (milliseconds, resultHandler) ->
RequestContext.current().eventLoop().schedule(
() -> resultHandler.onComplete(milliseconds),
milliseconds, TimeUnit.MILLISECONDS)));
// Response larger than a h1 TLS record
sb.serviceAt("/large", THttpService.of(
(AsyncIface) (name, resultHandler) -> resultHandler.onComplete(LARGER_THAN_TLS)));
sb.decorator(LoggingService::new);
final Function<Service<HttpRequest, HttpResponse>,
Service<HttpRequest, HttpResponse>> logCollectingDecorator =
s -> new SimpleDecoratingService<HttpRequest, HttpResponse>(s) {
@Override
public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception {
if (recordMessageLogs) {
ctx.log().addListener(requestLogs::add,
RequestLogAvailability.COMPLETE);
}
return delegate().serve(ctx, req);
}
};
sb.decorator(logCollectingDecorator);
} catch (Exception e) {
throw new Error(e);
}
server = sb.build();
}
@BeforeClass
public static void init() throws Exception {
server.start().get();
httpPort = server.activePorts().values().stream()
.filter(p -> p.protocol() == HTTP).findAny().get()
.localAddress().getPort();
httpsPort = server.activePorts().values().stream()
.filter(p -> p.protocol() == HTTPS).findAny().get()
.localAddress().getPort();
}
@AfterClass
public static void destroy() throws Exception {
server.stop();
}
@Before
public void beforeTest() {
recordMessageLogs = false;
requestLogs.clear();
}
@Test
public void testHttpInvocation() throws Exception {
try (TTransport transport = newTransport("http", "/hello")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test
public void testInheritedThriftService() throws Exception {
try (TTransport transport = newTransport("http", "/hellochild")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Goodbye, Trustin!");
}
}
@Test
public void testHttpsInvocation() throws Exception {
try (TTransport transport = newTransport("https", "/hello")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test
public void testLargeHttpsInvocation() throws Exception {
try (TTransport transport = newTransport("https", "/large")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo(LARGER_THAN_TLS);
}
}
@Test
public void testAcceptHeaderWithCommaSeparatedMediaTypes() throws Exception {
try (TTransport transport = newTransport("http", "/hello",
HttpHeaders.of(HttpHeaderNames.ACCEPT, "text/plain, */*"))) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test
public void testAcceptHeaderWithQValues() throws Exception {
// Server should choose TBINARY because it has higher q-value (0.5) than that of TTEXT (0.2)
try (TTransport transport = newTransport(
"http", "/hello",
HttpHeaders.of(HttpHeaderNames.ACCEPT,
"application/x-thrift; protocol=TTEXT; q=0.2, " +
"application/x-thrift; protocol=TBINARY; q=0.5"))) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test
public void testAcceptHeaderWithDefaultQValues() throws Exception {
// Server should choose TBINARY because it has higher q-value (default 1.0) than that of TTEXT (0.2)
try (TTransport transport = newTransport(
"http", "/hello",
HttpHeaders.of(HttpHeaderNames.ACCEPT,
"application/x-thrift; protocol=TTEXT; q=0.2, " +
"application/x-thrift; protocol=TBINARY"))) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test
public void testAcceptHeaderWithUnsupportedMediaTypes() throws Exception {
// Server should choose TBINARY because it does not support the media type
// with the highest preference (text/plain).
try (TTransport transport = newTransport(
"http", "/hello",
HttpHeaders.of(HttpHeaderNames.ACCEPT,
"application/x-thrift; protocol=TBINARY; q=0.2, text/plain"))) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
assertThat(client.hello("Trustin")).isEqualTo("Hello, Trustin!");
}
}
@Test(timeout = 10000)
public void testMessageLogsForCall() throws Exception {
try (TTransport transport = newTransport("http", "/hello")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
recordMessageLogs = true;
client.hello("Trustin");
}
final RequestLog log = requestLogs.take();
assertThat(log.requestEnvelope()).isInstanceOf(HttpHeaders.class);
assertThat(log.requestContent()).isInstanceOf(RpcRequest.class);
assertThat(log.rawRequestContent()).isInstanceOf(ThriftCall.class);
final RpcRequest request = (RpcRequest) log.requestContent();
assertThat(request.serviceType()).isEqualTo(HelloService.AsyncIface.class);
assertThat(request.method()).isEqualTo("hello");
assertThat(request.params()).containsExactly("Trustin");
final ThriftCall rawRequest = (ThriftCall) log.rawRequestContent();
assertThat(rawRequest.header().type).isEqualTo(TMessageType.CALL);
assertThat(rawRequest.header().name).isEqualTo("hello");
assertThat(rawRequest.args()).isInstanceOf(HelloService.hello_args.class);
assertThat(((HelloService.hello_args) rawRequest.args()).getName()).isEqualTo("Trustin");
assertThat(log.responseEnvelope()).isInstanceOf(HttpHeaders.class);
assertThat(log.responseContent()).isInstanceOf(RpcResponse.class);
assertThat(log.rawResponseContent()).isInstanceOf(ThriftReply.class);
final RpcResponse response = (RpcResponse) log.responseContent();
assertThat(response.get()).isEqualTo("Hello, Trustin!");
final ThriftReply rawResponse = (ThriftReply) log.rawResponseContent();
assertThat(rawResponse.header().type).isEqualTo(TMessageType.REPLY);
assertThat(rawResponse.header().name).isEqualTo("hello");
assertThat(rawResponse.result()).isInstanceOf(HelloService.hello_result.class);
assertThat(((HelloService.hello_result) rawResponse.result()).getSuccess())
.isEqualTo("Hello, Trustin!");
}
@Test(timeout = 10000)
public void testMessageLogsForException() throws Exception {
try (TTransport transport = newTransport("http", "/exception")) {
HelloService.Client client =
new HelloService.Client.Factory().getClient(
ThriftProtocolFactories.BINARY.getProtocol(transport));
recordMessageLogs = true;
assertThatThrownBy(() -> client.hello("Trustin")).isInstanceOf(TApplicationException.class);
}
final RequestLog log = requestLogs.take();
assertThat(log.requestEnvelope()).isInstanceOf(HttpHeaders.class);
assertThat(log.requestContent()).isInstanceOf(RpcRequest.class);
assertThat(log.rawRequestContent()).isInstanceOf(ThriftCall.class);
final RpcRequest request = (RpcRequest) log.requestContent();
assertThat(request.serviceType()).isEqualTo(HelloService.AsyncIface.class);
assertThat(request.method()).isEqualTo("hello");
assertThat(request.params()).containsExactly("Trustin");
final ThriftCall rawRequest = (ThriftCall) log.rawRequestContent();
assertThat(rawRequest.header().type).isEqualTo(TMessageType.CALL);
assertThat(rawRequest.header().name).isEqualTo("hello");
assertThat(rawRequest.args()).isInstanceOf(HelloService.hello_args.class);
assertThat(((HelloService.hello_args) rawRequest.args()).getName()).isEqualTo("Trustin");
assertThat(log.responseEnvelope()).isInstanceOf(HttpHeaders.class);
assertThat(log.responseContent()).isInstanceOf(RpcResponse.class);
assertThat(log.rawResponseContent()).isInstanceOf(ThriftReply.class);
final RpcResponse response = (RpcResponse) log.responseContent();
assertThat(response.cause()).isNotNull();
final ThriftReply rawResponse = (ThriftReply) log.rawResponseContent();
assertThat(rawResponse.header().type).isEqualTo(TMessageType.EXCEPTION);
assertThat(rawResponse.header().name).isEqualTo("hello");
assertThat(rawResponse.exception()).isNotNull();
}
protected final TTransport newTransport(String scheme, String path) throws TTransportException {
return newTransport(scheme, path, HttpHeaders.EMPTY_HEADERS);
}
protected final TTransport newTransport(String scheme, String path,
HttpHeaders headers) throws TTransportException {
return newTransport(newUri(scheme, path), headers);
}
protected abstract TTransport newTransport(String uri, HttpHeaders headers) throws TTransportException;
protected static String newUri(String scheme, String path) {
switch (scheme) {
case "http":
return scheme + "://127.0.0.1:" + httpPort + path;
case "https":
return scheme + "://127.0.0.1:" + httpsPort + path;
}
throw new Error();
}
}