package org.ifsoft.websockets;
import org.jivesoftware.util.JiveGlobals;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.security.*;
import java.util.*;
import java.text.*;
import java.net.*;
import java.security.cert.Certificate;
import java.util.concurrent.ConcurrentHashMap;
import java.math.BigInteger;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.api.annotations.*;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.jivesoftware.util.ParamUtils;
import org.jivesoftware.openfire.SessionManager;
import org.jivesoftware.openfire.StreamID;
import org.jivesoftware.openfire.session.ClientSession;
import org.jivesoftware.openfire.session.LocalClientSession;
import org.jivesoftware.openfire.net.VirtualConnection;
import org.jivesoftware.openfire.auth.UnauthorizedException;
import org.jivesoftware.openfire.auth.AuthToken;
import org.jivesoftware.openfire.auth.AuthFactory;
import org.jivesoftware.openfire.user.User;
import org.jivesoftware.openfire.user.UserAlreadyExistsException;
import org.jivesoftware.openfire.user.UserManager;
import org.jivesoftware.openfire.user.UserNotFoundException;
import org.jivesoftware.openfire.SessionPacketRouter;
import org.jivesoftware.openfire.XMPPServer;
import org.jivesoftware.database.SequenceManager;
import org.jivesoftware.openfire.component.InternalComponentManager;
import org.jivesoftware.openfire.plugin.ofmeet.OfMeetPlugin;
import org.jivesoftware.openfire.plugin.ofmeet.OpenfireLoginService;
import org.xmpp.packet.*;
import org.dom4j.*;
public final class XMPPServlet extends WebSocketServlet
{
private static Logger Log = LoggerFactory.getLogger( "XMPPServlet" );
private ConcurrentHashMap<String, XMPPServlet.XMPPWebSocket> sockets;
private String remoteAddr;
private OfMeetPlugin plugin;
public XMPPServlet()
{
plugin = (OfMeetPlugin) XMPPServer.getInstance().getPluginManager().getPlugin("ofmeet");
sockets = plugin.getSockets();
}
@Override public void configure(WebSocketServletFactory factory)
{
factory.getPolicy().setMaxTextMessageSize(64000000);
factory.setCreator(new WSocketCreator());
}
public class WSocketCreator implements WebSocketCreator
{
@Override public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp)
{
for (String subprotocol : req.getSubProtocols())
{
if ("xmpp".equals(subprotocol))
{
XMPPWebSocket socket = new XMPPWebSocket();
if (doWebSocketConnect(req.getHttpServletRequest(), socket))
{
resp.setAcceptedSubProtocol(subprotocol);
return socket;
} else return null;
}
}
return null;
}
private boolean doWebSocketConnect(HttpServletRequest request, XMPPWebSocket socket)
{
try {
boolean isExistingSession = false;
String username = URLDecoder.decode( ParamUtils.getParameter(request, "username"), "UTF-8");
String password = URLDecoder.decode( ParamUtils.getParameter(request, "password"), "UTF-8");
String resource = URLDecoder.decode( ParamUtils.getParameter(request, "resource"), "UTF-8");
String register = ParamUtils.getParameter(request, "register");
username = JID.escapeNode( username );
String user = username.equals("null") ? resource : username;
String digest = getMD5(user + password + resource );
JID userJid = XMPPServer.getInstance().createJID(user, resource);
Log.debug( digest + " : doWebSocketConnect : Digest created for " + userJid + " : " + register );
LocalClientSession session = (LocalClientSession) SessionManager.getInstance().getSession(userJid);
if (session != null)
{
isExistingSession = true;
int conflictLimit = SessionManager.getInstance().getConflictKickLimit();
if (conflictLimit == SessionManager.NEVER_KICK) {
return false;
}
int conflictCount = session.incrementConflictCount();
if (conflictCount > conflictLimit) {
session.close();
SessionManager.getInstance().removeSession(session);
}
else {
return false;
}
}
// get remote addr
String remoteAddr = request.getRemoteAddr();
if ( JiveGlobals.getProperty("websockets.header.remoteaddr") != null && request.getHeader( JiveGlobals.getProperty("websockets.header.remoteaddr") ) != null) {
remoteAddr = request.getHeader( JiveGlobals.getProperty("websockets.header.remoteaddr") );
}
try {
WSConnection wsConnection = new WSConnection( remoteAddr, request.getRemoteHost() );
socket.setWSConnection(digest, wsConnection);
AuthToken authToken;
try {
if (username.equals("null") == false && OpenfireLoginService.authTokens.containsKey(username))
{
authToken = OpenfireLoginService.authTokens.get(username);
} else {
if (username.equals("null") && password.equals("null")) // anonymous user
{
authToken = new AuthToken(resource, true);
} else {
if (isExistingSession && (password.equals("dummy") || password.equals("reuse")))
{
authToken = new AuthToken(username);
} else {
try {
String userName = JID.unescapeNode(username);
UserManager userManager = XMPPServer.getInstance().getUserManager();
if (register != null && register.equals("true") && XMPPServer.getInstance().getIQRegisterHandler().isInbandRegEnabled()) // if register, create new user
{
try {
userManager.getUser(userName);
}
catch (UserNotFoundException e) {
userManager.createUser(userName, password, null, null);
}
} else {
try {
userManager.getUser(userName);
}
catch (UserNotFoundException e) {
Log.error( "user not found " + userName, e );
return false;
}
}
authToken = AuthFactory.authenticate( userName, password );
} catch ( UnauthorizedException e ) {
Log.error( "An error occurred while attempting to create a web socket (USERNAME: " + username + " RESOURCE: " + resource + " ) : ", e );
return false;
} catch ( Exception e ) {
Log.error( "An error occurred while attempting to create a web socket : ", e );
return false;
}
}
}
}
session = SessionManager.getInstance().createClientSession( wsConnection, (Locale) null );
wsConnection.setRouter( new SessionPacketRouter( session ) );
session.setAuthToken(authToken, resource);
socket.setSession( session );
}
catch (Exception e1) {
Log.error( "An error occurred while attempting to create a new socket " + e1);
return false;
}
Log.debug( "Created new socket for digest " + digest );
Log.debug( "Total websockets created : " + sockets.size() );
} catch ( Exception e ) {
Log.error( "An error occurred while attempting to create a new socket " + e);
return false;
}
} catch ( Exception e ) {
if (socket.getSession() != null) SessionManager.getInstance().removeSession(socket.getSession());
return false;
}
return true;
}
private String getMD5(String input)
{
try {
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] messageDigest = md.digest(input.getBytes());
BigInteger number = new BigInteger(1, messageDigest);
String hashtext = number.toString(16);
// Now we need to zero pad it if you actually want the full 32 chars.
while (hashtext.length() < 32) {
hashtext = "0" + hashtext;
}
return hashtext;
}
catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
}
@WebSocket public class XMPPWebSocket {
private Session wsSession;
private WSConnection wsConnection;
private String digest;
private LocalClientSession xmppSession;
public void setWSConnection(String digest, WSConnection wsConnection) {
this.digest = digest;
this.wsConnection = wsConnection;
wsConnection.setSocket(this);
sockets.put(digest, this);
Log.debug(digest + " : setWSConnection");
}
public String getDigest() {
return digest;
}
public void setSession( LocalClientSession xmppSession ) {
this.xmppSession = xmppSession;
}
public LocalClientSession getSession() {
return xmppSession;
}
public boolean isOpen() {
return wsSession.isOpen();
}
@OnWebSocketConnect public void onConnect(Session wsSession)
{
this.wsSession = wsSession;
wsConnection.setSecure(wsSession.isSecure());
Log.debug(digest + " : onConnect");
}
@OnWebSocketClose public void onClose(int statusCode, String reason)
{
try {
sockets.remove(digest);
if (xmppSession != null) xmppSession.close();
xmppSession = null;
} catch ( Exception e ) {
Log.error( "An error occurred while attempting to remove the socket and xmppSession", e );
}
Log.debug( digest + " : onClose : " + statusCode + " " + reason);
}
@OnWebSocketError public void onError(Throwable error)
{
Log.error("XMPPWebSocket onError", error);
}
@OnWebSocketMessage public void onTextMethod(String data)
{
if ( !"".equals( data.trim()))
{
try {
Log.debug( digest + " : onMessage : Received : " + data );
wsConnection.getRouter().route(DocumentHelper.parseText(data).getRootElement());
} catch ( Exception e ) {
Log.error( "An error occurred while attempting to route the packet : ", e );
}
}
}
@OnWebSocketMessage public void onBinaryMethod(byte data[], int offset, int length)
{
// simple BINARY message received
}
public void deliver(String packet)
{
if (wsSession != null && wsSession.isOpen() && !"".equals( packet.trim() ) )
{
try {
Log.debug( digest + " : Delivered : \n" + packet );
wsSession.getRemote().sendStringByFuture(packet);
} catch (Exception e) {
Log.error("XMPPWebSocket deliver " + e);
Log.warn( digest + " : Could not deliver : \n" + packet );
}
}
}
public void disconnect()
{
Log.debug( digest + " : disconnect : XMPPWebSocket disconnect");
Log.debug( "Total websockets created : " + sockets.size() );
try {
if (wsSession != null && wsSession.isOpen())
{
wsSession.close();
}
} catch ( Exception e ) {
try {
wsSession.disconnect();
} catch ( Exception e1 ) {
}
}
try {
sockets.remove( digest );
SessionManager.getInstance().removeSession( xmppSession );
} catch ( Exception e ) {
Log.error( "An error has occurred", e );
}
xmppSession = null;
}
}
}