/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server;
import static java.util.Objects.requireNonNull;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
import javax.annotation.Nullable;
import javax.net.ssl.SSLSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.linecorp.armeria.common.NonWrappingRequestContext;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.logging.DefaultRequestLog;
import com.linecorp.armeria.common.logging.RequestLog;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.EventLoop;
/**
* Default {@link ServiceRequestContext} implementation.
*/
public class DefaultServiceRequestContext extends NonWrappingRequestContext implements ServiceRequestContext {
private final Channel ch;
private final ServiceConfig cfg;
private final String mappedPath;
private final SSLSession sslSession;
private final DefaultRequestLog log;
private final Logger logger;
private ExecutorService blockingTaskExecutor;
private long requestTimeoutMillis;
private long maxRequestLength;
private volatile RequestTimeoutChangeListener requestTimeoutChangeListener;
private String strVal;
/**
* Creates a new instance.
*
* @param ch the {@link Channel} that handles the invocation
* @param sessionProtocol the {@link SessionProtocol} of the invocation
* @param request the request associated with this context
* @param sslSession the {@link SSLSession} for this invocation if it is over TLS
*/
public DefaultServiceRequestContext(
ServiceConfig cfg, Channel ch, SessionProtocol sessionProtocol,
String method, String path, String mappedPath, Object request,
@Nullable SSLSession sslSession) {
super(sessionProtocol, method, path, request);
this.ch = ch;
this.cfg = cfg;
this.mappedPath = mappedPath;
this.sslSession = sslSession;
log = new DefaultRequestLog(this);
log.startRequest(ch, sessionProtocol, cfg.virtualHost().defaultHostname(), method, path);
logger = newLogger(cfg);
final ServerConfig serverCfg = cfg.server().config();
requestTimeoutMillis = serverCfg.defaultRequestTimeoutMillis();
maxRequestLength = serverCfg.defaultMaxRequestLength();
}
private RequestContextAwareLogger newLogger(ServiceConfig cfg) {
String loggerName = cfg.loggerName().orElse(null);
if (loggerName == null) {
loggerName = cfg.pathMapping().loggerName();
}
return new RequestContextAwareLogger(this, LoggerFactory.getLogger(
cfg.server().config().serviceLoggerPrefix() + '.' + loggerName));
}
@Override
protected Channel channel() {
return ch;
}
@Override
public Server server() {
return cfg.server();
}
@Override
public VirtualHost virtualHost() {
return cfg.virtualHost();
}
@Override
public PathMapping pathMapping() {
return cfg.pathMapping();
}
@Override
public <T extends Service<?, ?>> T service() {
return cfg.service();
}
@Override
public ExecutorService blockingTaskExecutor() {
if (blockingTaskExecutor != null) {
return blockingTaskExecutor;
}
return blockingTaskExecutor = makeContextAware(server().config().blockingTaskExecutor());
}
@Override
public EventLoop eventLoop() {
return ch.eventLoop();
}
@Override
public String mappedPath() {
return mappedPath;
}
@Override
public Logger logger() {
return logger;
}
@Nullable
@Override
public SSLSession sslSession() {
return sslSession;
}
@Override
public long requestTimeoutMillis() {
return requestTimeoutMillis;
}
@Override
public void setRequestTimeoutMillis(long requestTimeoutMillis) {
if (requestTimeoutMillis < 0) {
throw new IllegalArgumentException(
"requestTimeoutMillis: " + requestTimeoutMillis + " (expected: >= 0)");
}
if (this.requestTimeoutMillis != requestTimeoutMillis) {
this.requestTimeoutMillis = requestTimeoutMillis;
final RequestTimeoutChangeListener listener = requestTimeoutChangeListener;
if (listener != null) {
if (ch.eventLoop().inEventLoop()) {
listener.onRequestTimeoutChange(requestTimeoutMillis);
} else {
ch.eventLoop().execute(() -> listener.onRequestTimeoutChange(requestTimeoutMillis));
}
}
}
}
@Override
public void setRequestTimeout(Duration requestTimeout) {
setRequestTimeoutMillis(requireNonNull(requestTimeout, "requestTimeout").toMillis());
}
@Override
public long maxRequestLength() {
return maxRequestLength;
}
@Override
public void setMaxRequestLength(long maxRequestLength) {
if (maxRequestLength < 0) {
throw new IllegalArgumentException(
"maxRequestLength: " + maxRequestLength + " (expected: >= 0)");
}
this.maxRequestLength = maxRequestLength;
}
@Override
public RequestLog log() {
return log;
}
@Override
public RequestLogBuilder logBuilder() {
return log;
}
@Override
public ByteBufAllocator alloc() {
return ch.alloc();
}
/**
* Sets the listener that is notified when the {@linkplain #requestTimeoutMillis()} request timeout} of
* the request is changed.
*
* <p>Note: This method is meant for internal use by server-side protocol implementation to reschedule
* a timeout task when a user updates the request timeout configuration.
*/
public void setRequestTimeoutChangeListener(RequestTimeoutChangeListener listener) {
requireNonNull(listener, "listener");
if (requestTimeoutChangeListener != null) {
throw new IllegalStateException("requestTimeoutChangeListener is set already.");
}
requestTimeoutChangeListener = listener;
}
@Override
public String toString() {
String strVal = this.strVal;
if (strVal != null) {
return strVal;
}
final StringBuilder buf = new StringBuilder(96);
// Prepend the current channel information if available.
final Channel ch = channel();
final boolean hasChannel = ch != null;
if (hasChannel) {
buf.append(ch);
}
buf.append('[')
.append(sessionProtocol().uriText())
.append("://")
.append(virtualHost().defaultHostname());
final InetSocketAddress raddr = remoteAddress();
if (raddr != null) {
buf.append(':').append(raddr.getPort());
} else {
buf.append(":-1"); // Port unknown.
}
buf.append(path())
.append('#')
.append(method())
.append(']');
strVal = buf.toString();
if (hasChannel) {
this.strVal = strVal;
}
return strVal;
}
}