/*
* Copyright 2013 Atteo.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.atteo.moonshine.websocket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.inject.Provider;
import javax.servlet.ServletContext;
import javax.servlet.ServletContextEvent;
import javax.servlet.ServletContextListener;
import javax.websocket.Decoder;
import javax.websocket.DeploymentException;
import javax.websocket.Encoder;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlIDREF;
import org.atteo.moonshine.TopLevelService;
import org.atteo.moonshine.services.ImportService;
import org.atteo.moonshine.webserver.ServletContainer;
import com.google.inject.Module;
import com.google.inject.PrivateModule;
/**
* WebSocket container.
*/
public abstract class WebSocketContainerService extends TopLevelService {
@ImportService
@XmlIDREF
@XmlElement
protected ServletContainer servletContainer;
private final List<EndpointDefinition<?>> endpoints = new ArrayList<>();
protected <T> EndpointDefinition<T> createEndpointDefinition(Class<T> klass) {
return new EndpointDefinition<>(klass);
}
/**
* Adds ordinary endpoint.
*/
public <T extends Endpoint> EndpointBuilder<T> addEndpoint(Class<T> klass) {
EndpointDefinition<T> definition = createEndpointDefinition(klass);
endpoints.add(definition);
return definition;
}
/**
* Adds annotated endpoint.
* @param endpoint
*/
public void addAnnotatedEndpoint(Class<?> endpoint) {
EndpointDefinition<?> definition = createEndpointDefinition(endpoint);
endpoints.add(definition);
}
/**
* Adds annotated endpoint.
* @param endpoint
*/
public <T> void addAnnotatedEndpoint(Class<T> endpoint, Provider<? extends T> provider) {
addAnnotatedEndpointInternal(endpoint, provider);
}
private <T> void addAnnotatedEndpointInternal(Class<T> endpoint, Provider<? extends T> provider) {
@SuppressWarnings("unchecked")
EndpointDefinition<T> definition = createEndpointDefinition(endpoint);
definition.provider(provider);
endpoints.add(definition);
}
@Override
public Module configure() {
return new PrivateModule() {
@Override
protected void configure() {
servletContainer.addListener(new Listener());
}
};
}
public interface EndpointBuilder<T> {
EndpointBuilder<T> pattern(String pattern);
EndpointBuilder<T> provider(Provider<? extends T> provider);
EndpointBuilder<T> addEncoder(Class<? extends Encoder> encoder);
EndpointBuilder<T> addDecoder(Class<? extends Decoder> encoder);
EndpointBuilder<T> addUserProperty(String key, Object value);
}
private class Listener implements ServletContextListener {
@Override
public void contextInitialized(ServletContextEvent contextEvent) {
ServletContext context = contextEvent.getServletContext();
ServerContainer container = (ServerContainer) context.getAttribute(ServerContainer.class.getName());
for (EndpointDefinition<?> endpointDefinition : endpoints) {
try {
container.addEndpoint(endpointDefinition);
} catch (DeploymentException ex) {
throw new RuntimeException(ex);
}
}
}
@Override
public void contextDestroyed(ServletContextEvent sce) {
}
}
public static class EndpointDefinition<T> implements ServerEndpointConfig, EndpointBuilder<T> {
private final Class<T> endpointClass;
private Provider<? extends T> provider;
private String pattern;
private final List<Class<? extends Encoder>> encoders = new ArrayList<>();
private final List<Class<? extends Decoder>> decoders = new ArrayList<>();
protected final Map<String, Object> userProperties = new HashMap<>();
public EndpointDefinition(Class<T> endpointClass) {
this.endpointClass = endpointClass;
}
@Override
public EndpointBuilder<T> pattern(String pattern) {
this.pattern = pattern;
return this;
}
@Override
public EndpointBuilder<T> provider(Provider<? extends T> provider) {
this.provider = provider;
return this;
}
@Override
public EndpointBuilder<T> addEncoder(Class<? extends Encoder> encoder) {
encoders.add(encoder);
return this;
}
@Override
public EndpointBuilder<T> addDecoder(Class<? extends Decoder> decoder) {
decoders.add(decoder);
return this;
}
@Override
public EndpointBuilder<T> addUserProperty(String key, Object value) {
userProperties.put(key, value);
return this;
}
@Override
public Class<T> getEndpointClass() {
return endpointClass;
}
@Override
public String getPath() {
return pattern;
}
@Override
public List<String> getSubprotocols() {
return Collections.emptyList();
}
@Override
public List<Extension> getExtensions() {
return Collections.emptyList();
}
@Override
public ServerEndpointConfig.Configurator getConfigurator() {
return new Configurator() {
@Override
public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
if (provider == null) {
return super.getEndpointInstance(endpointClass);
}
return endpointClass.cast(provider.get());
}
};
}
@Override
public List<Class<? extends Encoder>> getEncoders() {
return encoders;
}
@Override
public List<Class<? extends Decoder>> getDecoders() {
return decoders;
}
@Override
public Map<String, Object> getUserProperties() {
return userProperties;
}
}
}