/******************************************************************************* * Copyright (c) 2012-2016 Codenvy, S.A. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Codenvy, S.A. - initial API and implementation *******************************************************************************/ package org.everrest.websockets; import org.everrest.core.impl.EverrestConfiguration; import org.everrest.core.impl.EverrestProcessor; import org.everrest.core.impl.provider.json.JsonException; import org.everrest.core.tools.SimplePrincipal; import org.everrest.core.tools.SimpleSecurityContext; import org.everrest.core.tools.WebApplicationDeclaredRoles; import org.everrest.websockets.message.BaseTextDecoder; import org.everrest.websockets.message.BaseTextEncoder; import org.everrest.websockets.message.JsonMessageConverter; import org.everrest.websockets.message.OutputMessage; import org.everrest.websockets.message.RestInputMessage; import javax.servlet.ServletContext; import javax.servlet.ServletContextEvent; import javax.servlet.ServletContextListener; import javax.servlet.http.HttpSession; import javax.websocket.DecodeException; import javax.websocket.Decoder; import javax.websocket.DeploymentException; import javax.websocket.EncodeException; import javax.websocket.Encoder; import javax.websocket.HandshakeResponse; import javax.websocket.server.HandshakeRequest; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerEndpointConfig; import javax.ws.rs.core.SecurityContext; import java.security.Principal; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicLong; import static javax.websocket.server.ServerEndpointConfig.Builder.create; import static javax.websocket.server.ServerEndpointConfig.Configurator; /** * @author andrew00x */ public class ServerContainerInitializeListener implements ServletContextListener { public static final String EVERREST_PROCESSOR_ATTRIBUTE = EverrestProcessor.class.getName(); public static final String HTTP_SESSION_ATTRIBUTE = HttpSession.class.getName(); public static final String EVERREST_CONFIG_ATTRIBUTE = EverrestConfiguration.class.getName(); public static final String EXECUTOR_ATTRIBUTE = "everrest.Executor"; public static final String SECURITY_CONTEXT = SecurityContext.class.getName(); private static final AtomicLong sequence = new AtomicLong(1); private WebApplicationDeclaredRoles webApplicationDeclaredRoles; private EverrestConfiguration everrestConfiguration; private ServerEndpointConfig serverEndpointConfig; @Override public final void contextInitialized(ServletContextEvent sce) { final ServletContext servletContext = sce.getServletContext(); webApplicationDeclaredRoles = new WebApplicationDeclaredRoles(servletContext); everrestConfiguration = (EverrestConfiguration)servletContext.getAttribute(EVERREST_CONFIG_ATTRIBUTE); if (everrestConfiguration == null) { everrestConfiguration = new EverrestConfiguration(); } final ServerContainer serverContainer = (ServerContainer)servletContext.getAttribute("javax.websocket.server.ServerContainer"); try { serverEndpointConfig = createServerEndpointConfig(servletContext); serverContainer.addEndpoint(serverEndpointConfig); } catch (DeploymentException e) { throw new IllegalStateException(e.getMessage(), e); } } @Override public void contextDestroyed(ServletContextEvent sce) { if (serverEndpointConfig != null) { ExecutorService executor = (ExecutorService)serverEndpointConfig.getUserProperties().get(EXECUTOR_ATTRIBUTE); if (executor != null) { executor.shutdownNow(); } } } protected ServerEndpointConfig createServerEndpointConfig(ServletContext servletContext) { final List<Class<? extends Encoder>> encoders = new LinkedList<>(); final List<Class<? extends Decoder>> decoders = new LinkedList<>(); encoders.add(OutputMessageEncoder.class); decoders.add(InputMessageDecoder.class); final ServerEndpointConfig endpointConfig = create(WSConnectionImpl.class, "/ws") .configurator(createConfigurator()).encoders(encoders).decoders(decoders).build(); endpointConfig.getUserProperties().put(EVERREST_PROCESSOR_ATTRIBUTE, getEverrestProcessor(servletContext)); endpointConfig.getUserProperties().put(EVERREST_CONFIG_ATTRIBUTE, getEverrestConfiguration(servletContext)); endpointConfig.getUserProperties().put(EXECUTOR_ATTRIBUTE, createExecutor(servletContext)); return endpointConfig; } private Configurator createConfigurator() { return new Configurator() { @Override public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { super.modifyHandshake(sec, request, response); final HttpSession httpSession = (HttpSession)request.getHttpSession(); if (httpSession != null) { sec.getUserProperties().put(HTTP_SESSION_ATTRIBUTE, httpSession); } final SecurityContext securityContext = createSecurityContext(request); sec.getUserProperties().put(SECURITY_CONTEXT, securityContext); } }; } protected EverrestProcessor getEverrestProcessor(ServletContext servletContext) { return (EverrestProcessor)servletContext.getAttribute(EVERREST_PROCESSOR_ATTRIBUTE); } protected EverrestConfiguration getEverrestConfiguration(ServletContext servletContext) { return everrestConfiguration; } protected ExecutorService createExecutor(ServletContext servletContext) { final EverrestConfiguration everrestConfiguration = getEverrestConfiguration(servletContext); return Executors.newFixedThreadPool(everrestConfiguration.getAsynchronousPoolSize(), new ThreadFactory() { @Override public Thread newThread(Runnable r) { final Thread t = new Thread(r, "everrest.WSConnection" + sequence.getAndIncrement()); t.setDaemon(true); return t; } }); } protected SecurityContext createSecurityContext(HandshakeRequest req) { final boolean isSecure = false; //todo: get somehow from request final Principal principal = req.getUserPrincipal(); if (principal == null) { return new SimpleSecurityContext(isSecure); } final String authenticationScheme = "BASIC"; //todo: get somehow from request final Set<String> userRoles = new LinkedHashSet<>(); for (String declaredRole : webApplicationDeclaredRoles.getDeclaredRoles()) { if (req.isUserInRole(declaredRole)) { userRoles.add(declaredRole); } } return new SimpleSecurityContext(new SimplePrincipal(principal.getName()), userRoles, authenticationScheme, isSecure); } public static class InputMessageDecoder extends BaseTextDecoder<RestInputMessage> { private final JsonMessageConverter jsonMessageConverter = new JsonMessageConverter(); @Override public RestInputMessage decode(String s) throws DecodeException { try { return jsonMessageConverter.fromString(s, RestInputMessage.class); } catch (JsonException e) { throw new DecodeException(s, e.getMessage(), e); } } @Override public boolean willDecode(String s) { return true; } } public static class OutputMessageEncoder extends BaseTextEncoder<OutputMessage> { private final JsonMessageConverter jsonMessageConverter = new JsonMessageConverter(); @Override public String encode(OutputMessage output) throws EncodeException { try { return jsonMessageConverter.toString(output); } catch (JsonException e) { throw new EncodeException(output, e.getMessage(), e); } } } }