/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.usergrid.websocket; import java.security.MessageDigest; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelFutureListener; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.handler.codec.http.DefaultHttpResponse; import org.jboss.netty.handler.codec.http.HttpHeaders; import org.jboss.netty.handler.codec.http.HttpHeaders.Names; import org.jboss.netty.handler.codec.http.HttpHeaders.Values; import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.QueryStringDecoder; import org.jboss.netty.handler.codec.http.websocket.DefaultWebSocketFrame; import org.jboss.netty.handler.codec.http.websocket.WebSocketFrame; import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameDecoder; import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder; import org.jboss.netty.util.CharsetUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.usergrid.management.ManagementService; import org.apache.usergrid.persistence.EntityManagerFactory; import org.apache.usergrid.services.ServiceManagerFactory; import org.apache.shiro.mgt.SessionsSecurityManager; import org.apache.shiro.subject.Subject; import static org.apache.commons.lang.StringUtils.isEmpty; import static org.apache.commons.lang.StringUtils.removeEnd; import static org.apache.commons.lang.StringUtils.split; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONNECTION; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ORIGIN; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_KEY1; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_KEY2; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_LOCATION; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_ORIGIN; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_LOCATION; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_ORIGIN; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_PROTOCOL; import static org.jboss.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET; import static org.jboss.netty.handler.codec.http.HttpHeaders.isKeepAlive; import static org.jboss.netty.handler.codec.http.HttpHeaders.setContentLength; import static org.jboss.netty.handler.codec.http.HttpMethod.GET; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; public class WebSocketChannelHandler extends SimpleChannelUpstreamHandler { private static final Logger logger = LoggerFactory.getLogger( WebSocketChannelHandler.class ); private final EntityManagerFactory emf; private final ServiceManagerFactory smf; private final ManagementService management; private final SessionsSecurityManager securityManager; private final boolean ssl; boolean websocket = false; Subject subject = null; private static ConcurrentHashMap<String, ChannelGroup> subscribers = new ConcurrentHashMap<String, ChannelGroup>(); List<String> subscriptions; public WebSocketChannelHandler( EntityManagerFactory emf, ServiceManagerFactory smf, ManagementService management, SessionsSecurityManager securityManager, boolean ssl ) { super(); this.emf = emf; this.smf = smf; this.management = management; this.securityManager = securityManager; this.ssl = ssl; if ( securityManager != null ) { subject = new Subject.Builder( securityManager ).buildSubject(); } } public EntityManagerFactory getEmf() { return emf; } public ServiceManagerFactory getSmf() { return smf; } public ManagementService getOrganizations() { return management; } public SessionsSecurityManager getSecurityManager() { return securityManager; } private String getWebSocketLocation( HttpRequest req ) { String path = req.getUri(); if ( path.equals( "/" ) ) { path = null; } if ( path != null ) { path = removeEnd( path, "/" ); } String location = ( ssl ? "wss://" : "ws://" ) + req.getHeader( HttpHeaders.Names.HOST ) + ( path != null ? path : "" ); logger.info( location ); return location; } private void sendHttpResponse( ChannelHandlerContext ctx, HttpRequest req, HttpResponse res ) { // Generate an error page if response status code is not OK (200). if ( res.getStatus().getCode() != 200 ) { res.setContent( ChannelBuffers.copiedBuffer( res.getStatus().toString(), CharsetUtil.UTF_8 ) ); setContentLength( res, res.getContent().readableBytes() ); } // Send the response and close the connection if necessary. ChannelFuture f = ctx.getChannel().write( res ); if ( !isKeepAlive( req ) || ( res.getStatus().getCode() != 200 ) ) { f.addListener( ChannelFutureListener.CLOSE ); } } private void sendHttpResponse( ChannelHandlerContext ctx, HttpRequest req, HttpResponseStatus status ) { sendHttpResponse( ctx, req, new DefaultHttpResponse( HTTP_1_1, status ) ); } @Override public void exceptionCaught( ChannelHandlerContext ctx, ExceptionEvent e ) { logger.warn( "Unexpected exception from downstream.", e.getCause() ); e.getChannel().close(); } @Override public void channelDisconnected( ChannelHandlerContext ctx, ChannelStateEvent e ) throws Exception { super.channelDisconnected( ctx, e ); if ( websocket ) { logger.info( "Websocket disconnected" ); } } @Override public void messageReceived( ChannelHandlerContext ctx, MessageEvent e ) throws Exception { Object msg = e.getMessage(); if ( msg instanceof HttpRequest ) { handleHttpRequest( ctx, ( HttpRequest ) msg ); } else if ( msg instanceof WebSocketFrame ) { handleWebSocketFrame( ctx, ( WebSocketFrame ) msg ); } } private void handleHttpRequest( ChannelHandlerContext ctx, HttpRequest req ) throws Exception { // Allow only GET methods. if ( req.getMethod() != GET ) { sendHttpResponse( ctx, req, FORBIDDEN ); return; } boolean is_ws_request = Values.UPGRADE.equalsIgnoreCase( req.getHeader( CONNECTION ) ) && WEBSOCKET .equalsIgnoreCase( req.getHeader( Names.UPGRADE ) ); // Send the demo page. if ( !is_ws_request && req.getUri().equals( "/" ) ) { HttpResponse res = new DefaultHttpResponse( HTTP_1_1, OK ); ChannelBuffer content = WebSocketServerIndexPage.getContent( getWebSocketLocation( req ) ); res.setHeader( CONTENT_TYPE, "text/html; charset=UTF-8" ); setContentLength( res, content.readableBytes() ); res.setContent( content ); sendHttpResponse( ctx, req, res ); return; } else if ( is_ws_request ) { // Serve the WebSocket handshake request. logger.info( "Starting new websocket connection..." ); websocket = true; // Create the WebSocket handshake response. HttpResponse res = new DefaultHttpResponse( HTTP_1_1, new HttpResponseStatus( 101, "Web Socket Protocol Handshake" ) ); res.addHeader( Names.UPGRADE, WEBSOCKET ); res.addHeader( CONNECTION, Values.UPGRADE ); QueryStringDecoder qs = new QueryStringDecoder( req.getUri() ); String path = qs.getPath(); logger.info( path ); // Fill in the headers and contents depending on handshake method. if ( req.containsHeader( SEC_WEBSOCKET_KEY1 ) && req.containsHeader( SEC_WEBSOCKET_KEY2 ) ) { String[] segments = split( path, '/' ); if ( segments.length != 3 ) { logger.info( "Wrong number of path segments, expected 3, found " + segments.length ); sendHttpResponse( ctx, req, FORBIDDEN ); return; } String nsStr = segments[0]; String collStr = segments[1]; String idStr = segments[2]; logger.info( nsStr + "/" + collStr + "/" + idStr ); if ( isEmpty( nsStr ) || isEmpty( collStr ) || isEmpty( idStr ) ) { sendHttpResponse( ctx, req, FORBIDDEN ); return; } // New handshake method with a challenge: res.addHeader( SEC_WEBSOCKET_ORIGIN, req.getHeader( ORIGIN ) ); res.addHeader( SEC_WEBSOCKET_LOCATION, getWebSocketLocation( req ) ); String protocol = req.getHeader( SEC_WEBSOCKET_PROTOCOL ); if ( protocol != null ) { res.addHeader( SEC_WEBSOCKET_PROTOCOL, protocol ); } // Calculate the answer of the challenge. String key1 = req.getHeader( SEC_WEBSOCKET_KEY1 ); String key2 = req.getHeader( SEC_WEBSOCKET_KEY2 ); int a = ( int ) ( Long.parseLong( key1.replaceAll( "[^0-9]", "" ) ) / key1.replaceAll( "[^ ]", "" ) .length() ); int b = ( int ) ( Long.parseLong( key2.replaceAll( "[^0-9]", "" ) ) / key2.replaceAll( "[^ ]", "" ) .length() ); long c = req.getContent().readLong(); ChannelBuffer input = ChannelBuffers.buffer( 16 ); input.writeInt( a ); input.writeInt( b ); input.writeLong( c ); ChannelBuffer output = ChannelBuffers.wrappedBuffer( MessageDigest.getInstance( "MD5" ).digest( input.array() ) ); res.setContent( output ); } else { // Old handshake method with no challenge: res.addHeader( WEBSOCKET_ORIGIN, req.getHeader( ORIGIN ) ); res.addHeader( WEBSOCKET_LOCATION, getWebSocketLocation( req ) ); String protocol = req.getHeader( WEBSOCKET_PROTOCOL ); if ( protocol != null ) { res.addHeader( WEBSOCKET_PROTOCOL, protocol ); } } // Upgrade the connection and send the handshake response. ChannelPipeline p = ctx.getChannel().getPipeline(); p.remove( "aggregator" ); p.replace( "decoder", "wsdecoder", new WebSocketFrameDecoder() ); ctx.getChannel().write( res ); p.replace( "encoder", "wsencoder", new WebSocketFrameEncoder() ); return; } // Send an error page otherwise. sendHttpResponse( ctx, req, FORBIDDEN ); } private void handleWebSocketFrame( ChannelHandlerContext ctx, WebSocketFrame frame ) { // Send the uppercased string back. ctx.getChannel().write( new DefaultWebSocketFrame( frame.getTextData().toUpperCase() ) ); } // TODO Review this for concurrency safety // Note: subscriptions are added and removed relatively infrequently // during the lifecycle of a connection i.e. typical minimum lifespan // would be 10 seconds when someone opens an app, it connects, then // they close the app. private ChannelGroup getChannelGroupWithDefault( String path ) { ChannelGroup group = subscribers.get( path ); if ( group == null ) { group = subscribers.putIfAbsent( path, new DefaultChannelGroup() ); } return group; } public void addSubscription( String path, Channel channel ) { ChannelGroup group = subscribers.get( path ); synchronized ( group ) { group.add( channel ); } } public void removeSubscription( String path, Channel channel ) { ChannelGroup group = subscribers.get( path ); synchronized ( group ) { group.remove( channel ); if ( group.isEmpty() ) { subscribers.remove( path, group ); } } } public ChannelGroup getSubscriptionGroup( String path ) { return subscribers.get( path ); } }