/*
* Copyright 2015 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.thrift;
import static org.junit.Assert.assertTrue;
import java.net.URI;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import org.apache.thrift.transport.TMemoryBuffer;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import com.linecorp.armeria.internal.http.Http1ClientCodec;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientUpgradeHandler;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener;
import io.netty.handler.codec.http2.Http2ClientUpgradeCodec;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.HttpConversionUtil;
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandler;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder;
import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol;
import io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior;
import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.concurrent.Promise;
/**
* An extremely simple Thrift-over-HTTP/2 client which sends and receives a single Thrift request/response
* per connection.
*/
final class THttp2Client extends TTransport {
private final EventLoopGroup group = new NioEventLoopGroup(1);
private final SslContext sslCtx;
private final URI uri;
private final String host;
private final int port;
private final String path;
private final HttpHeaders defaultHeaders;
private TMemoryInputTransport in;
private final TMemoryBuffer out = new TMemoryBuffer(128);
THttp2Client(String uriStr, HttpHeaders defaultHeaders) throws TTransportException {
uri = URI.create(uriStr);
this.defaultHeaders = defaultHeaders;
int port;
switch (uri.getScheme()) {
case "http":
port = uri.getPort();
if (port < 0) {
port = 80;
}
sslCtx = null;
break;
case "https":
port = uri.getPort();
if (port < 0) {
port = 443;
}
try {
sslCtx = SslContextBuilder.forClient()
.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.applicationProtocolConfig(new ApplicationProtocolConfig(
Protocol.ALPN,
// NO_ADVERTISE is currently the only mode supported by both OpenSsl and
// JDK providers.
SelectorFailureBehavior.NO_ADVERTISE,
// ACCEPT is currently the only mode supported by both OpenSsl and
// JDK providers.
SelectedListenerFailureBehavior.ACCEPT,
ApplicationProtocolNames.HTTP_2))
.build();
} catch (SSLException e) {
throw new TTransportException(TTransportException.UNKNOWN, e);
}
break;
default:
throw new IllegalArgumentException("unknown scheme: " + uri.getScheme());
}
String host = uri.getHost();
if (host == null) {
throw new IllegalArgumentException("host not specified: " + uriStr);
}
String path = uri.getPath();
if (path == null) {
throw new IllegalArgumentException("path not specified: " + uriStr);
}
this.host = host;
this.port = port;
this.path = path;
}
@Override
public boolean isOpen() {
return true;
}
@Override
public void open() { }
@Override
public void close() {
group.shutdownGracefully();
}
@Override
public int read(byte[] buf, int off, int len) throws TTransportException {
return in.read(buf, off, len);
}
@Override
public int readAll(byte[] buf, int off, int len) throws TTransportException {
return in.readAll(buf, off, len);
}
@Override
public byte[] getBuffer() {
return in.getBuffer();
}
@Override
public int getBufferPosition() {
return in.getBufferPosition();
}
@Override
public int getBytesRemainingInBuffer() {
return in.getBytesRemainingInBuffer();
}
@Override
public void consumeBuffer(int len) {
in.consumeBuffer(len);
}
@Override
public void write(byte[] buf, int off, int len) {
out.write(buf, off, len);
}
@Override
public void flush() throws TTransportException {
THttp2ClientInitializer initHandler = new THttp2ClientInitializer();
Bootstrap b = new Bootstrap();
b.group(group);
b.channel(NioSocketChannel.class);
b.handler(initHandler);
Channel ch = null;
try {
ch = b.connect(host, port).syncUninterruptibly().channel();
THttp2ClientHandler handler = initHandler.clientHandler;
// Wait until HTTP/2 upgrade is finished.
assertTrue(handler.settingsPromise.await(5, TimeUnit.SECONDS));
handler.settingsPromise.get();
// Send a Thrift request.
FullHttpRequest request = new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, path,
Unpooled.wrappedBuffer(out.getArray(), 0, out.length()));
request.headers().add(HttpHeaderNames.HOST, host);
request.headers().set(ExtensionHeaderNames.SCHEME.text(), uri.getScheme());
request.headers().add(defaultHeaders);
ch.writeAndFlush(request).sync();
// Wait until the Thrift response is received.
assertTrue(handler.responsePromise.await(5, TimeUnit.SECONDS));
ByteBuf response = handler.responsePromise.get();
// Pass the received Thrift response to the Thrift client.
final byte[] array = new byte[response.readableBytes()];
response.readBytes(array);
in = new TMemoryInputTransport(array);
response.release();
} catch (Exception e) {
throw new TTransportException(TTransportException.UNKNOWN, e);
} finally {
if (ch != null) {
ch.close();
}
}
}
private final class THttp2ClientInitializer extends ChannelInitializer<SocketChannel> {
THttp2ClientHandler clientHandler;
@Override
public void initChannel(SocketChannel ch) throws Exception {
final ChannelPipeline p = ch.pipeline();
final Http2Connection conn = new DefaultHttp2Connection(false);
final HttpToHttp2ConnectionHandler connHandler = new HttpToHttp2ConnectionHandlerBuilder()
.connection(conn)
.frameListener(new DelegatingDecompressorFrameListener(
conn,
new InboundHttp2ToHttpAdapterBuilder(conn)
.maxContentLength(Integer.MAX_VALUE)
.propagateSettings(true).build()))
.build();
clientHandler = new THttp2ClientHandler(ch.eventLoop());
if (sslCtx != null) {
p.addLast(sslCtx.newHandler(p.channel().alloc()));
p.addLast(connHandler);
configureEndOfPipeline(p);
} else {
Http1ClientCodec sourceCodec = new Http1ClientCodec();
HttpClientUpgradeHandler upgradeHandler = new HttpClientUpgradeHandler(
sourceCodec, new Http2ClientUpgradeCodec(connHandler), 65536);
p.addLast(sourceCodec, upgradeHandler, new UpgradeRequestHandler());
}
}
private void configureEndOfPipeline(ChannelPipeline p) {
p.addLast(clientHandler);
}
/**
* A handler that triggers the cleartext upgrade to HTTP/2 by sending an initial HTTP request.
*/
private final class UpgradeRequestHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
DefaultFullHttpRequest upgradeRequest =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/");
ctx.writeAndFlush(upgradeRequest);
ctx.fireChannelActive();
// Done with this handler, remove it from the pipeline.
ctx.pipeline().remove(this);
configureEndOfPipeline(ctx.pipeline());
}
}
}
static final class THttp2ClientHandler extends SimpleChannelInboundHandler<Object> {
final Promise<Void> settingsPromise;
final Promise<ByteBuf> responsePromise;
THttp2ClientHandler(EventLoop loop) {
settingsPromise = loop.newPromise();
responsePromise = loop.newPromise();
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof Http2Settings) {
settingsPromise.setSuccess(null);
return;
}
if (msg instanceof FullHttpResponse) {
FullHttpResponse res = (FullHttpResponse) msg;
Integer streamId = res.headers().getInt(
HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text());
if (streamId == null) {
responsePromise.tryFailure(new AssertionError("message without stream ID: " + msg));
return;
}
if (streamId == 1) {
// Response to the upgrade request, which is OK to ignore.
return;
}
if (streamId != 3) {
responsePromise.tryFailure(new AssertionError("unexpected stream ID: " + msg));
return;
}
responsePromise.setSuccess(res.content().retain());
return;
}
throw new IllegalStateException("unexpected message type: " + msg.getClass().getName());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
responsePromise.tryFailure(cause);
}
}
}