/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.jooby.internal; import static java.util.Objects.requireNonNull; import java.nio.charset.Charset; import java.text.MessageFormat; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.Function; import java.util.stream.Collectors; import javax.inject.Inject; import javax.inject.Named; import javax.inject.Provider; import javax.inject.Singleton; import org.jooby.Deferred; import org.jooby.Err; import org.jooby.Err.Handler; import org.jooby.MediaType; import org.jooby.Renderer; import org.jooby.Request; import org.jooby.Response; import org.jooby.Route; import org.jooby.Session; import org.jooby.Sse; import org.jooby.Status; import org.jooby.WebSocket; import org.jooby.WebSocket.Definition; import org.jooby.internal.parser.ParserExecutor; import org.jooby.spi.HttpHandler; import org.jooby.spi.NativeRequest; import org.jooby.spi.NativeResponse; import org.jooby.spi.NativeWebSocket; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Strings; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.Sets; import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.name.Names; import com.typesafe.config.Config; import javaslang.control.Try; @Singleton public class HttpHandlerImpl implements HttpHandler { private static class RouteKey { private final String method; private final String path; private final MediaType consumes; private final List<MediaType> produces; private final String key; public RouteKey(final String method, final String path, final MediaType consumes, final List<MediaType> produces) { String c = consumes.name(); String p = produces.toString(); key = new StringBuilder(method.length() + path.length() + c.length() + p.length()) .append(method) .append(path) .append(c) .append(p) .toString(); this.method = method; this.path = path; this.consumes = consumes; this.produces = produces; } @Override public int hashCode() { return key.hashCode(); } @Override public boolean equals(final Object obj) { RouteKey that = (RouteKey) obj; return key.equals(that.key); } } private static final String NO_CACHE = "must-revalidate,no-cache,no-store"; private static final String WEB_SOCKET = "WebSocket"; private static final String UPGRADE = "Upgrade"; private static final String REFERER = "Referer"; private static final String PATH = "path"; private static final String CONTEXT_PATH = "contextPath"; private static final Key<Request> REQ = Key.get(Request.class); private static final Key<Response> RSP = Key.get(Response.class); private static final Key<Sse> SSE = Key.get(Sse.class); private static final Key<Session> SESS = Key.get(Session.class); private static final Key<String> DEF_EXEC = Key.get(String.class, Names.named("deferred")); private static final String BYTE_RANGE = "Range"; /** * The logging system. */ private Injector injector; private Set<Err.Handler> err; private String applicationPath; private RequestScope requestScope; private Set<Definition> socketDefs; private Config config; private int port; private String _method; private Charset charset; private List<Renderer> renderers; private ParserExecutor parserExecutor; private List<Locale> locales; private final LoadingCache<RouteKey, List<Route>> routeCache; private final String redirectHttps; private Function<String, String> rpath = null; private String contextPath; private boolean hasSockets; private final Map<String, Renderer> rendererMap; private StatusCodeProvider sc; /** * Global deferred executor. */ private Key<Executor> gexec; @Inject public HttpHandlerImpl(final Injector injector, final RequestScope requestScope, final Set<Route.Definition> routes, final Set<WebSocket.Definition> sockets, final @Named("application.path") String path, final ParserExecutor parserExecutor, final Set<Renderer> renderers, final Set<Err.Handler> err, final StatusCodeProvider sc, final Charset charset, final List<Locale> locale) { this.injector = requireNonNull(injector, "An injector is required."); this.requestScope = requireNonNull(requestScope, "A request scope is required."); this.socketDefs = requireNonNull(sockets, "Sockets are required."); this.hasSockets = socketDefs.size() > 0; this.applicationPath = normalizeURI(requireNonNull(path, "An application.path is required.")); this.err = requireNonNull(err, "An err handler is required."); this.sc = sc; this.config = injector.getInstance(Config.class); _method = Strings.emptyToNull(this.config.getString("server.http.Method").trim()); this.port = config.getInt("application.port"); this.charset = charset; this.locales = locale; this.parserExecutor = parserExecutor; this.renderers = new ArrayList<>(renderers); rendererMap = new HashMap<>(); this.renderers.forEach(r -> rendererMap.put(r.name(), r)); // route cache routeCache = routeCache(routes, config); // force https String redirectHttps = config.getString("application.redirect_https").trim(); this.redirectHttps = redirectHttps.length() > 0 ? redirectHttps : null; // custom path? if (applicationPath.equals("/")) { this.contextPath = ""; } else { this.contextPath = applicationPath; this.rpath = rootpath(applicationPath); } // global deferred executor this.gexec = Key.get(Executor.class, Names.named(injector.getInstance(DEF_EXEC))); } @Override public void handle(final NativeRequest request, final NativeResponse response) throws Exception { long start = System.currentTimeMillis(); Map<String, Object> locals = new HashMap<>(16); Map<Object, Object> scope = new HashMap<>(16); String method = _method == null ? request.method() : method(_method, request); String path = normalizeURI(request.path()); if (rpath != null) { path = rpath.apply(path); } // put request attributes first to make sure we don't override defaults Map<String, Object> nativeAttrs = request.attributes(); if (nativeAttrs.size() > 0) { locals.putAll(nativeAttrs); } // default locals locals.put(CONTEXT_PATH, contextPath); locals.put(PATH, path); Route notFound = RouteImpl.notFound(method, path); RequestImpl req = new RequestImpl(injector, request, contextPath, port, notFound, charset, locales, scope, locals, start); ResponseImpl rsp = new ResponseImpl(req, parserExecutor, response, notFound, renderers, rendererMap, locals, req.charset(), request.header(REFERER), request.header(BYTE_RANGE)); MediaType type = req.type(); // seed req & rsp scope.put(REQ, req); scope.put(RSP, rsp); // seed sse Provider<Sse> sse = () -> Try.of(() -> request.upgrade(Sse.class)) .getOrElseThrow(() -> new UnsupportedOperationException("Server-sent events")); scope.put(SSE, sse); // seed session Provider<Session> session = () -> req.session(); scope.put(SESS, session); boolean deferred = false; Throwable x = null; try { requestScope.enter(scope); // force https? if (redirectHttps != null) { if (!req.secure()) { rsp.redirect(MessageFormat.format(redirectHttps, path.substring(1))); return; } } // websocket? if (hasSockets) { if (upgrade(request)) { Optional<WebSocket> sockets = findSockets(socketDefs, path); if (sockets.isPresent()) { NativeWebSocket ws = request.upgrade(NativeWebSocket.class); ws.onConnect(() -> ((WebSocketImpl) sockets.get()).connect(injector, req, ws)); return; } } } // usual req/rsp List<Route> routes = routeCache .getUnchecked(new RouteKey(method, path, type, req.accept())); new RouteChain(req, rsp, routes).next(req, rsp); } catch (DeferredExecution ex) { deferred = true; onDeferred(scope, request, req, rsp, ex.deferred); } catch (Throwable ex) { x = ex; } finally { cleanup(req, rsp, true, x, !deferred); } } private boolean upgrade(final NativeRequest request) { Optional<String> upgrade = request.header(UPGRADE); return upgrade.isPresent() && upgrade.get().equalsIgnoreCase(WEB_SOCKET); } private void done(final RequestImpl req, final ResponseImpl rsp, final Throwable x, final boolean close) { // mark request/response as done. req.done(); if (close) { rsp.done(Optional.ofNullable(x)); } } private void onDeferred(final Map<Object, Object> scope, final NativeRequest request, final RequestImpl req, final ResponseImpl rsp, final Deferred deferred) { /** Deferred executor. */ Key<Executor> execKey = deferred.executor() .map(it -> Key.get(Executor.class, Names.named(it))) .orElse(gexec); /** Get executor. */ Executor executor = injector.getInstance(execKey); request.startAsync(executor, () -> { try { deferred.handler(req, (success, x) -> { boolean close = false; Optional<Throwable> failure = Optional.ofNullable(x); try { requestScope.enter(scope); if (success != null) { close = true; rsp.send(success); } } catch (Throwable exerr) { failure = Optional.of(failure.orElse(exerr)); } finally { Throwable cause = failure.orElse(null); if (cause != null) { close = true; } cleanup(req, rsp, close, cause, true); } }); } catch (Exception ex) { handleErr(req, rsp, ex); } }); } private void cleanup(final RequestImpl req, final ResponseImpl rsp, final boolean close, final Throwable x, final boolean done) { if (x != null) { handleErr(req, rsp, x); } if (done) { done(req, rsp, x, close); } requestScope.exit(); } private void handleErr(final RequestImpl req, final ResponseImpl rsp, final Throwable ex) { Logger log = LoggerFactory.getLogger(HttpHandler.class); try { log.debug("execution of: {}{} resulted in exception", req.method(), req.path(), ex); // execution failed, find status code Status status = sc.apply(ex); if (status == Status.REQUESTED_RANGE_NOT_SATISFIABLE) { String range = rsp.header("Content-Length").toOptional().map(it -> "bytes */" + it) .orElse("*"); rsp.reset(); rsp.header("Content-Range", range); } else { rsp.reset(); } rsp.header("Cache-Control", NO_CACHE); rsp.status(status); Err err = ex instanceof Err ? (Err) ex : new Err(status, ex); Iterator<Handler> it = this.err.iterator(); while (!rsp.committed() && it.hasNext()) { Err.Handler next = it.next(); log.debug("handling err with: {}", next); next.handle(req, rsp, err); } } catch (Throwable errex) { log.error("error handler resulted in exception: {}{}\nRoute:\n{}\n\nStacktrace:\n{}\nSource:", req.method(), req.path(), req.route().print(6), Throwables.getStackTraceAsString(errex), ex); } } private static String normalizeURI(final String uri) { int len = uri.length(); return len > 1 && uri.charAt(len - 1) == '/' ? uri.substring(0, len - 1) : uri; } private static List<Route> routes(final Set<Route.Definition> routeDefs, final String method, final String path, final MediaType type, final List<MediaType> accept) { List<Route> routes = findRoutes(routeDefs, method, path, type, accept); routes.add(RouteImpl.fallback((req, rsp, chain) -> { if (!rsp.status().isPresent()) { // 406 or 415 Err ex = handle406or415(routeDefs, method, path, type, accept); if (ex != null) { throw ex; } // 405 ex = handle405(routeDefs, method, path, type, accept); if (ex != null) { throw ex; } // favicon.ico if (path.equals("/favicon.ico")) { // default /favicon.ico handler: rsp.status(Status.NOT_FOUND).end(); } else { throw new Err(Status.NOT_FOUND, req.path(true)); } } }, method, path, "err", accept)); return routes; } private static List<Route> findRoutes(final Set<Route.Definition> routeDefs, final String method, final String path, final MediaType type, final List<MediaType> accept) { List<Route> routes = new ArrayList<>(); for (Route.Definition routeDef : routeDefs) { Optional<Route> route = routeDef.matches(method, path, type, accept); if (route.isPresent()) { routes.add(route.get()); } } return routes; } private static Optional<WebSocket> findSockets(final Set<WebSocket.Definition> sockets, final String path) { for (WebSocket.Definition socketDef : sockets) { Optional<WebSocket> match = socketDef.matches(path); if (match.isPresent()) { return match; } } return Optional.empty(); } private static Err handle405(final Set<Route.Definition> routeDefs, final String method, final String path, final MediaType type, final List<MediaType> accept) { if (alternative(routeDefs, method, path).size() > 0) { return new Err(Status.METHOD_NOT_ALLOWED, method); } return null; } private static List<Route> alternative(final Set<Route.Definition> routeDefs, final String verb, final String uri) { List<Route> routes = new LinkedList<>(); Set<String> verbs = Sets.newHashSet(Route.METHODS); verbs.remove(verb); for (String alt : verbs) { findRoutes(routeDefs, alt, uri, MediaType.all, MediaType.ALL) .stream() // skip glob pattern .filter(r -> !r.pattern().contains("*")) .forEach(routes::add); } return routes; } private static Err handle406or415(final Set<Route.Definition> routeDefs, final String method, final String path, final MediaType contentType, final List<MediaType> accept) { for (Route.Definition routeDef : routeDefs) { Optional<Route> route = routeDef.matches(method, path, MediaType.all, MediaType.ALL); if (route.isPresent() && !route.get().pattern().contains("*")) { if (!routeDef.canProduce(accept)) { return new Err(Status.NOT_ACCEPTABLE, accept.stream() .map(MediaType::name) .collect(Collectors.joining(", "))); } if (!contentType.isAny()) { return new Err(Status.UNSUPPORTED_MEDIA_TYPE, contentType.name()); } } } return null; } private static String method(final String methodParam, final NativeRequest request) throws Exception { Optional<String> header = request.header(methodParam); if (header.isPresent()) { return header.get(); } List<String> param = request.params(methodParam); return param.size() == 0 ? request.method() : param.get(0); } private static LoadingCache<RouteKey, List<Route>> routeCache(final Set<Route.Definition> routes, final Config conf) { return CacheBuilder.from(conf.getString("server.routes.Cache")) .build(new CacheLoader<RouteKey, List<Route>>() { @Override public List<Route> load(final RouteKey key) throws Exception { return routes(routes, key.method, key.path, key.consumes, key.produces); } }); } private static Function<String, String> rootpath(final String applicationPath) { return p -> { if (applicationPath.equals(p)) { return "/"; } else if (p.startsWith(applicationPath)) { return p.substring(applicationPath.length()); } else { // mark as failure return Route.errpath(p); } }; } }