/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.it.thrift;
import static com.linecorp.armeria.common.http.HttpHeaderNames.AUTHORIZATION;
import static com.linecorp.armeria.common.thrift.ThriftSerializationFormats.BINARY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import org.apache.thrift.TException;
import org.apache.thrift.async.AsyncMethodCallback;
import org.junit.ClassRule;
import org.junit.Test;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpRequest;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.thrift.THttpService;
import com.linecorp.armeria.service.test.thrift.main.HelloService;
import com.linecorp.armeria.service.test.thrift.main.HelloService.Iface;
import com.linecorp.armeria.testing.server.ServerRule;
/**
* Tests if Armeria decorators can alter the request/response timeout specified in Thrift call parameters.
*/
public class ThriftThreadLocalHttpHeaderTest {
private static final String SECRET = "QWxhZGRpbjpPcGVuU2VzYW1l";
private static final HelloService.AsyncIface helloService = (name, resultHandler) -> {
final HttpRequest httpReq = RequestContext.current().request();
final HttpHeaders headers = httpReq.headers();
if (headers.contains(AUTHORIZATION, SECRET)) {
resultHandler.onComplete("Hello, " + name + '!');
} else {
final String errorMessage;
if (headers.contains(AUTHORIZATION)) {
errorMessage = "not authorized: " + headers.get(AUTHORIZATION);
} else {
errorMessage = "not authorized due to missing credential";
}
resultHandler.onError(new Exception(errorMessage));
}
};
@ClassRule
public static final ServerRule server = new ServerRule() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
sb.serviceAt("/hello", THttpService.of(helloService));
}
};
@Test
public void testSimpleManipulation() throws Exception {
final HelloService.Iface client = newClient();
try (SafeCloseable ignored = Clients.withHttpHeader(AUTHORIZATION, SECRET)) {
assertThat(client.hello("trustin")).isEqualTo("Hello, trustin!");
}
// Ensure that the header manipulator set in the thread-local variable has been cleared.
assertAuthorizationFailure(client, null);
}
@Test
public void testNestedManipulation() throws Exception {
// Split the secret into two pieces.
final String secretA = SECRET.substring(0, SECRET.length() >>> 1);
final String secretB = SECRET.substring(secretA.length());
final HelloService.Iface client = newClient();
try (SafeCloseable ignored = Clients.withHttpHeader(AUTHORIZATION, secretA)) {
// Should fail with the first half of the secret.
assertAuthorizationFailure(client, secretA);
try (SafeCloseable ignored2 = Clients.withHttpHeaders(
h -> h.set(AUTHORIZATION, h.get(AUTHORIZATION) + secretB))) {
// Should pass if both manipulators worked.
assertThat(client.hello("foobar")).isEqualTo("Hello, foobar!");
}
// Should fail again with the first half of the secret.
assertAuthorizationFailure(client, secretA);
}
// Ensure that the header manipulator set in the thread-local variable has been cleared.
assertAuthorizationFailure(client, null);
}
@Test
public void testSimpleManipulationAsync() throws Exception {
final HelloService.AsyncIface client = Clients.newClient(
server.uri(BINARY, "/hello"), HelloService.AsyncIface.class);
final BlockingQueue<Object> result = new ArrayBlockingQueue<>(1);
final Callback callback = new Callback(result);
try (SafeCloseable ignored = Clients.withHttpHeader(AUTHORIZATION, SECRET)) {
client.hello("armeria", callback);
}
assertThat(result.poll(10, TimeUnit.SECONDS)).isEqualTo("Hello, armeria!");
// Ensure that the header manipulator set in the thread-local variable has been cleared.
client.hello("bar", callback);
assertThat(result.poll(10, TimeUnit.SECONDS))
.isInstanceOf(TException.class)
.matches(o -> ((Throwable) o).getMessage().contains("not authorized"),
"must fail with authorization failure");
}
@Test
public void testFailedAuthorization() throws Exception {
assertAuthorizationFailure(newClient(), null);
}
private static Iface newClient() {
return Clients.newClient(server.uri(BINARY, "/hello"), HelloService.Iface.class);
}
private static void assertAuthorizationFailure(Iface client, String expectedSecret) {
final String expectedMessage;
if (expectedSecret != null) {
expectedMessage = "not authorized: " + expectedSecret;
} else {
expectedMessage = "not authorized due to missing credential";
}
assertThatThrownBy(() -> client.hello("foo"))
.isInstanceOf(TException.class)
.hasMessageContaining(expectedMessage);
}
private static final class Callback implements AsyncMethodCallback<String> {
private final BlockingQueue<Object> result;
Callback(BlockingQueue<Object> result) {
this.result = result;
}
@Override
public void onComplete(String response) {
result.add(response);
}
@Override
public void onError(Exception exception) {
result.add(exception);
}
}
}