/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.client.http;
import java.net.InetSocketAddress;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.linecorp.armeria.client.ClientRequestContext;
import com.linecorp.armeria.client.WriteTimeoutException;
import com.linecorp.armeria.client.http.HttpResponseDecoder.HttpResponseWrapper;
import com.linecorp.armeria.common.AbstractRequestContext;
import com.linecorp.armeria.common.ClosedSessionException;
import com.linecorp.armeria.common.http.HttpData;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpObject;
import com.linecorp.armeria.common.http.HttpRequest;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
import com.linecorp.armeria.common.stream.ClosedPublisherException;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.internal.http.HttpObjectEncoder;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.http2.Http2Error;
final class HttpRequestSubscriber implements Subscriber<HttpObject>, ChannelFutureListener {
private static final Logger logger = LoggerFactory.getLogger(HttpRequestSubscriber.class);
enum State {
NEEDS_TO_WRITE_FIRST_HEADER,
NEEDS_DATA_OR_TRAILING_HEADERS,
DONE
}
private final ChannelHandlerContext ctx;
private final HttpObjectEncoder encoder;
private final int id;
private final HttpRequest request;
private final HttpResponseWrapper response;
private final ClientRequestContext reqCtx;
private final RequestLogBuilder logBuilder;
private final long timeoutMillis;
private Subscription subscription;
private ScheduledFuture<?> timeoutFuture;
private State state = State.NEEDS_TO_WRITE_FIRST_HEADER;
HttpRequestSubscriber(Channel ch, HttpObjectEncoder encoder,
int id, HttpRequest request, HttpResponseWrapper response,
ClientRequestContext reqCtx, long timeoutMillis) {
ctx = ch.pipeline().lastContext();
this.encoder = encoder;
this.id = id;
this.request = request;
this.response = response;
this.reqCtx = reqCtx;
logBuilder = reqCtx.logBuilder();
this.timeoutMillis = timeoutMillis;
}
/**
* Invoked on each write of an {@link HttpObject}.
*/
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
if (state == State.DONE) {
// Successfully sent the request; schedule the response timeout.
response.scheduleTimeout(ctx);
} else {
subscription.request(1);
}
return;
}
fail(future.cause());
final Throwable cause = future.cause();
if (!(cause instanceof ClosedPublisherException)) {
final Channel ch = future.channel();
Exceptions.logIfUnexpected(logger, ch, HttpSession.get(ch).protocol(), cause);
ch.close();
}
}
@Override
public void onSubscribe(Subscription subscription) {
assert this.subscription == null;
this.subscription = subscription;
final EventLoop eventLoop = ctx.channel().eventLoop();
if (timeoutMillis > 0) {
timeoutFuture = eventLoop.schedule(
() -> {
if (state == State.NEEDS_TO_WRITE_FIRST_HEADER) {
if (reqCtx instanceof AbstractRequestContext) {
((AbstractRequestContext) reqCtx).setTimedOut();
}
failAndRespond(WriteTimeoutException.get());
}
},
timeoutMillis, TimeUnit.MILLISECONDS);
}
// NB: This must be invoked at the end of this method because otherwise the callback methods in this
// class can be called before the member fields (subscription and timeoutFuture) are initialized.
// It is because the successful write of the first headers will trigger subscription.request(1).
eventLoop.execute(this::writeFirstHeader);
}
private void writeFirstHeader() {
final Channel ch = ctx.channel();
final HttpSession session = HttpSession.get(ch);
if (!session.isActive()) {
failAndRespond(ClosedSessionException.get());
return;
}
final HttpHeaders firstHeaders = request.headers();
String host = firstHeaders.authority();
if (host == null) {
host = ((InetSocketAddress) ch.remoteAddress()).getHostString();
} else {
final int colonIdx = host.lastIndexOf(':');
if (colonIdx > 0) {
host = host.substring(0, colonIdx);
}
}
logBuilder.startRequest(
ch, session.protocol(), host, firstHeaders.method().name(), firstHeaders.path());
logBuilder.requestEnvelope(firstHeaders);
if (request.isEmpty()) {
setDone();
write0(firstHeaders, true, true);
} else {
write0(firstHeaders, false, true);
}
state = State.NEEDS_DATA_OR_TRAILING_HEADERS;
cancelTimeout();
}
@Override
public void onNext(HttpObject o) {
if (!(o instanceof HttpData) && !(o instanceof HttpHeaders)) {
throw newIllegalStateException(
"published an HttpObject that's neither Http2Headers nor Http2Data: " + o);
}
boolean endOfStream = o.isEndOfStream();
switch (state) {
case NEEDS_DATA_OR_TRAILING_HEADERS: {
if (o instanceof HttpHeaders) {
final HttpHeaders trailingHeaders = (HttpHeaders) o;
if (trailingHeaders.status() != null) {
throw newIllegalStateException("published a trailing HttpHeaders with status: " + o);
}
// Trailing headers always end the stream even if not explicitly set.
endOfStream = true;
}
break;
}
case DONE:
return;
}
write(o, endOfStream, true);
}
@Override
public void onError(Throwable cause) {
failAndRespond(cause);
}
@Override
public void onComplete() {
if (!cancelTimeout()) {
return;
}
if (state != State.DONE) {
write(HttpData.EMPTY_DATA, true, true);
}
}
private void write(HttpObject o, boolean endOfStream, boolean flush) {
if (state == State.DONE) {
throw newIllegalStateException(
"a request publisher published an HttpObject after a trailing HttpHeaders: " + o);
}
final Channel ch = ctx.channel();
if (!ch.isActive()) {
fail(ClosedSessionException.get());
return;
}
if (endOfStream) {
setDone();
}
ch.eventLoop().execute(() -> write0(o, endOfStream, flush));
}
private void write0(HttpObject o, boolean endOfStream, boolean flush) {
final ChannelFuture future;
if (o instanceof HttpData) {
final HttpData data = (HttpData) o;
future = encoder.writeData(ctx, id, streamId(), data, endOfStream);
logBuilder.increaseRequestLength(data.length());
} else if (o instanceof HttpHeaders) {
future = encoder.writeHeaders(ctx, id, streamId(), (HttpHeaders) o, endOfStream);
} else {
// Should never reach here because we did validation in onNext().
throw new Error();
}
if (endOfStream) {
logBuilder.endRequest();
}
future.addListener(this);
if (flush) {
ctx.flush();
}
}
private int streamId() {
return (id << 1) + 1;
}
private void fail(Throwable cause) {
setDone();
logBuilder.endRequest(cause);
}
private void setDone() {
cancelTimeout();
state = State.DONE;
subscription.cancel();
}
private void failAndRespond(Throwable cause) {
fail(cause);
final Channel ch = ctx.channel();
final Http2Error error;
if (response.isOpen()) {
response.close(cause);
error = Http2Error.INTERNAL_ERROR;
} else if (cause instanceof WriteTimeoutException) {
error = Http2Error.CANCEL;
} else {
Exceptions.logIfUnexpected(logger, ch,
HttpSession.get(ch).protocol(),
"a request publisher raised an exception", cause);
error = Http2Error.INTERNAL_ERROR;
}
if (ch.isActive()) {
encoder.writeReset(ctx, id, streamId(), error);
ctx.flush();
}
}
private boolean cancelTimeout() {
final ScheduledFuture<?> timeoutFuture = this.timeoutFuture;
if (timeoutFuture == null) {
return true;
}
this.timeoutFuture = null;
return timeoutFuture.cancel(false);
}
private IllegalStateException newIllegalStateException(String msg) {
final IllegalStateException cause = new IllegalStateException(msg);
fail(cause);
return cause;
}
}