/* * Copyright 2016 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package io.reactivex.netty.protocol.http.server; import io.netty.channel.Channel; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import io.reactivex.netty.channel.AllocatingTransformer; import io.reactivex.netty.channel.ChannelOperations; import io.reactivex.netty.channel.Connection; import io.reactivex.netty.channel.MarkAwarePipeline; import io.reactivex.netty.protocol.http.HttpHandlerNames; import io.reactivex.netty.protocol.http.TrailingHeaders; import io.reactivex.netty.protocol.http.sse.ServerSentEvent; import io.reactivex.netty.protocol.http.sse.server.ServerSentEventEncoder; import io.reactivex.netty.protocol.http.ws.server.WebSocketHandler; import io.reactivex.netty.protocol.http.ws.server.WebSocketHandshaker; import rx.Observable; import rx.Subscriber; import rx.functions.Action0; import rx.functions.Func0; import rx.functions.Func1; import rx.functions.Func2; import java.util.Date; import java.util.List; import java.util.Set; public final class HttpServerResponseImpl<C> extends HttpServerResponse<C> { private final State<C> state; private HttpServerResponseImpl(final State<C> state) { super(new OnSubscribe<Void>() { @Override public void call(Subscriber<? super Void> subscriber) { state.sendHeaders().unsafeSubscribe(subscriber); } }); this.state = state; } @Override public HttpResponseStatus getStatus() { return state.headers.status(); } @Override public boolean containsHeader(CharSequence name) { return state.headers.headers().contains(name); } @Override public boolean containsHeader(CharSequence name, CharSequence value, boolean ignoreCaseValue) { return state.headers.headers().contains(name, value, ignoreCaseValue); } @Override public String getHeader(CharSequence name) { return state.headers.headers().get(name); } @Override public String getHeader(CharSequence name, String defaultValue) { return state.headers.headers().get(name, defaultValue); } @Override public List<String> getAllHeaderValues(CharSequence name) { return state.headers.headers().getAll(name); } @Override public long getDateHeader(CharSequence name) { return state.headers.headers().getTimeMillis(name); } @Override public long getDateHeader(CharSequence name, long defaultValue) { return state.headers.headers().getTimeMillis(name, defaultValue); } @Override public int getIntHeader(CharSequence name) { return state.headers.headers().getInt(name); } @Override public int getIntHeader(CharSequence name, int defaultValue) { return state.headers.headers().getInt(name, defaultValue); } @Override public Set<String> getHeaderNames() { return state.headers.headers().names(); } @Override public HttpServerResponse<C> addHeader(CharSequence name, Object value) { if (state.allowUpdate()) { state.headers.headers().add(name, value); } return this; } @Override public HttpServerResponse<C> addCookie(Cookie cookie) { if (state.allowUpdate()) { state.headers.headers().add(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookie)); } return this; } @Override public HttpServerResponse<C> addDateHeader(CharSequence name, Date value) { if (state.allowUpdate()) { state.headers.headers().add(name, value); } return this; } @Override public HttpServerResponse<C> addDateHeader(CharSequence name, Iterable<Date> values) { if (state.allowUpdate()) { for (Date value : values) { state.headers.headers().add(name, value); } } return this; } @Override public HttpServerResponse<C> addHeader(CharSequence name, Iterable<Object> values) { if (state.allowUpdate()) { state.headers.headers().add(name, values); } return this; } @Override public HttpServerResponse<C> setDateHeader(CharSequence name, Date value) { if (state.allowUpdate()) { state.headers.headers().set(name, value); } return this; } @Override public HttpServerResponse<C> setHeader(CharSequence name, Object value) { if (state.allowUpdate()) { state.headers.headers().set(name, value); } return this; } @Override public HttpServerResponse<C> setDateHeader(CharSequence name, Iterable<Date> values) { if (state.allowUpdate()) { for (Date value : values) { state.headers.headers().set(name, value); } } return this; } @Override public HttpServerResponse<C> setHeader(CharSequence name, Iterable<Object> values) { if (state.allowUpdate()) { state.headers.headers().set(name, values); } return this; } @Override public HttpServerResponse<C> removeHeader(CharSequence name) { if (state.allowUpdate()) { state.headers.headers().remove(name); } return this; } @Override public HttpServerResponse<C> setStatus(HttpResponseStatus status) { if (state.allowUpdate()) { state.headers.setStatus(status); } return this; } @Override public HttpServerResponse<C> setTransferEncodingChunked() { if (state.allowUpdate()) { HttpUtil.setTransferEncodingChunked(state.headers, true); } return this; } @Override public HttpServerResponse<C> flushOnlyOnReadComplete() { // Does not need to be guarded by allowUpdate() as flush semantics can be changed anytime. state.connection.unsafeNettyChannel().attr(ChannelOperations.FLUSH_ONLY_ON_READ_COMPLETE).set(true); return this; } @Override public ResponseContentWriter<C> sendHeaders() { return state.sendHeaders(); } @Override public HttpServerResponse<ServerSentEvent> transformToServerSentEvents() { markAwarePipeline().addAfter(HttpHandlerNames.HttpServerEncoder.getName(), HttpHandlerNames.SseServerCodec.getName(), new ServerSentEventEncoder()); return _cast(); } @Override public <CC> HttpServerResponse<CC> transformContent(AllocatingTransformer<CC, C> transformer) { @SuppressWarnings("unchecked") Connection transformedC = state.connection.transformWrite(transformer); return new HttpServerResponseImpl<>(new State<CC>(state, transformedC)); } @Override public WebSocketHandshaker acceptWebSocketUpgrade(WebSocketHandler handler) { return WebSocketHandshaker.isUpgradeRequested(state.request) ? WebSocketHandshaker.newHandshaker(state.request, this, handler) : WebSocketHandshaker.newErrorHandshaker(new IllegalStateException("WebSocket upgrade was not requested.")); } @Override public Observable<Void> dispose() { return Observable.defer(new Func0<Observable<Void>>() { @Override public Observable<Void> call() { return (state.allowUpdate() ? write(Observable.<C>empty()) : Observable.<Void>empty()) .doOnSubscribe(new Action0() { @Override public void call() { state.connection .getResettableChannelPipeline() .reset(); } }); } }); } @Override public Channel unsafeNettyChannel() { return state.connection.unsafeNettyChannel(); } @Override public Connection<?, ?> unsafeConnection() { return state.connection; } @Override public ResponseContentWriter<C> write(Observable<C> msgs) { return state.sendHeaders().write(msgs); } @Override public <T extends TrailingHeaders> Observable<Void> write(Observable<C> contentSource, Func0<T> trailerFactory, Func2<T, C, T> trailerMutator) { return state.sendHeaders().write(contentSource, trailerFactory, trailerMutator); } @Override public <T extends TrailingHeaders> Observable<Void> write(Observable<C> contentSource, Func0<T> trailerFactory, Func2<T, C, T> trailerMutator, Func1<C, Boolean> flushSelector) { return state.sendHeaders().write(contentSource, trailerFactory, trailerMutator, flushSelector); } @Override public ResponseContentWriter<C> write(Observable<C> msgs, Func1<C, Boolean> flushSelector) { return state.sendHeaders().write(msgs, flushSelector); } @Override public ResponseContentWriter<C> writeAndFlushOnEach(Observable<C> msgs) { return state.sendHeaders().writeAndFlushOnEach(msgs); } @Override public ResponseContentWriter<C> writeString(Observable<String> msgs) { return state.sendHeaders().writeString(msgs); } @Override public <T extends TrailingHeaders> Observable<Void> writeString(Observable<String> contentSource, Func0<T> trailerFactory, Func2<T, String, T> trailerMutator) { return state.sendHeaders().writeString(contentSource, trailerFactory, trailerMutator); } @Override public <T extends TrailingHeaders> Observable<Void> writeString(Observable<String> contentSource, Func0<T> trailerFactory, Func2<T, String, T> trailerMutator, Func1<String, Boolean> flushSelector) { return state.sendHeaders().writeString(contentSource, trailerFactory, trailerMutator, flushSelector); } @Override public ResponseContentWriter<C> writeString(Observable<String> msgs, Func1<String, Boolean> flushSelector) { return state.sendHeaders().writeString(msgs, flushSelector); } @Override public ResponseContentWriter<C> writeStringAndFlushOnEach(Observable<String> msgs) { return state.sendHeaders().writeStringAndFlushOnEach(msgs); } @Override public ResponseContentWriter<C> writeBytes(Observable<byte[]> msgs) { return state.sendHeaders().writeBytes(msgs); } @Override public <T extends TrailingHeaders> Observable<Void> writeBytes(Observable<byte[]> contentSource, Func0<T> trailerFactory, Func2<T, byte[], T> trailerMutator) { return state.sendHeaders().writeBytes(contentSource, trailerFactory, trailerMutator); } @Override public <T extends TrailingHeaders> Observable<Void> writeBytes(Observable<byte[]> contentSource, Func0<T> trailerFactory, Func2<T, byte[], T> trailerMutator, Func1<byte[], Boolean> flushSelector) { return state.sendHeaders().writeBytes(contentSource, trailerFactory, trailerMutator, flushSelector); } @Override public ResponseContentWriter<C> writeBytes(Observable<byte[]> msgs, Func1<byte[], Boolean> flushSelector) { return state.sendHeaders().writeBytes(msgs, flushSelector); } @Override public ResponseContentWriter<C> writeBytesAndFlushOnEach(Observable<byte[]> msgs) { return state.sendHeaders().writeBytesAndFlushOnEach(msgs); } public static <T> HttpServerResponse<T> create(HttpServerRequest<?> request, @SuppressWarnings("rawtypes") Connection connection, HttpResponse headers) { final State<T> newState = new State<>(headers, connection, request); return new HttpServerResponseImpl<>(newState); } @SuppressWarnings("unchecked") private <CC> HttpServerResponse<CC> _cast() { return (HttpServerResponse<CC>) this; } private MarkAwarePipeline markAwarePipeline() { return state.connection.getResettableChannelPipeline().markIfNotYetMarked(); } private static class State<T> { private final HttpResponse headers; @SuppressWarnings("rawtypes") private final Connection connection; private final HttpServerRequest<?> request; /*This links the headers sent dynamic state from one response to a child response (created via a mutation method). If it is a simple boolean, then a copy of state will just lead to a copy by value and not reference.*/ private final HeaderSentStateHolder sentStateHolder; private State(HttpResponse headers, @SuppressWarnings("rawtypes") Connection connection, HttpServerRequest<?> request) { this.headers = headers; this.connection = connection; this.request = request; this.sentStateHolder = new HeaderSentStateHolder(); } public State(State<?> state, Connection connection) { this.headers = state.headers; this.request = state.request; this.sentStateHolder = state.sentStateHolder; this.connection = connection; } private boolean allowUpdate() { return !sentStateHolder.headersSent; } public ResponseContentWriter<T> sendHeaders() { if (allowUpdate()) { sentStateHolder.headersSent = true; return new ContentWriterImpl<>(connection, headers); } return new FailedContentWriter<>(); } } private static final class HeaderSentStateHolder implements Func0 { private boolean headersSent = false; @Override public Object call() { return headersSent; } } }