/**
* Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.wso2.carbon.websocket.transport;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import org.apache.axiom.om.OMElement;
import org.apache.axis2.description.Parameter;
import org.apache.axis2.description.TransportOutDescription;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.wso2.carbon.websocket.transport.utils.SSLUtil;
import javax.net.ssl.SSLException;
import javax.xml.namespace.QName;
import java.net.URI;
import java.util.concurrent.ConcurrentHashMap;
public class WebsocketConnectionFactory {
private static final Log log = LogFactory.getLog(WebsocketConnectionFactory.class);
private static WebsocketConnectionFactory instance = null;
private final TransportOutDescription transportOut;
private ConcurrentHashMap<String, ConcurrentHashMap<String, WebSocketClientHandler>>
channelHandlerPool = new ConcurrentHashMap<String, ConcurrentHashMap<String, WebSocketClientHandler>>();
public WebsocketConnectionFactory(TransportOutDescription transportOut) {
this.transportOut = transportOut;
}
public static WebsocketConnectionFactory getInstance(TransportOutDescription transportOut) {
if (instance == null) {
instance = new WebsocketConnectionFactory(transportOut);
}
return instance;
}
public WebSocketClientHandler getChannelHandler(final URI uri,
final String sourceIdentifier,
final boolean handshakePresent,
final String dispatchSequence,
final String dispatchErrorSequence,
final String contentType) throws InterruptedException {
WebSocketClientHandler channelHandler;
if (handshakePresent) {
channelHandler = cacheNewConnection(uri, sourceIdentifier, dispatchSequence, dispatchErrorSequence, contentType);
} else {
channelHandler = getChannelHandlerFromPool(sourceIdentifier, getClientHandlerIdentifier(uri));
if (channelHandler == null) {
channelHandler = cacheNewConnection(uri, sourceIdentifier, dispatchSequence, dispatchErrorSequence, contentType);
}
}
channelHandler.handshakeFuture().sync();
return channelHandler;
}
public String getClientHandlerIdentifier(final URI uri) {
final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost();
final int port = uri.getPort();
final String subscriberPath = uri.getPath();
return host.concat(String.valueOf(port)).concat(subscriberPath);
}
public WebSocketClientHandler cacheNewConnection(final URI uri,
final String sourceIdentifier,
String dispatchSequence,
String dispatchErrorSequence,
String contentType) {
if (log.isDebugEnabled()) {
log.debug("Creating a Connection for the specified WS endpoint.");
}
final WebSocketClientHandler handler;
try {
String scheme = uri.getScheme() == null ? WebsocketConstants.WS : uri.getScheme();
final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost();
final int port = uri.getPort();
if (!WebsocketConstants.WS.equalsIgnoreCase(scheme) && !WebsocketConstants.WSS.equalsIgnoreCase(scheme)) {
return null;
}
final boolean ssl = WebsocketConstants.WSS.equalsIgnoreCase(scheme);
final SslContext sslCtx;
if (ssl) {
Parameter trustParam = transportOut.getParameter(WebsocketConstants.TRUST_STORE_CONFIG_ELEMENT);
OMElement tsEle = null;
if (trustParam != null) {
tsEle = trustParam.getParameterElement().getFirstElement();
}
final String location =
tsEle.getFirstChildWithName(new QName(WebsocketConstants.TRUST_STORE_LOCATION))
.getText();
final String storePassword =
tsEle.getFirstChildWithName(new QName(WebsocketConstants.TRUST_STORE_PASSWORD))
.getText();
sslCtx = SslContextBuilder.forClient()
.trustManager(SSLUtil.createTrustmanager(location,
storePassword))
.build();
} else {
sslCtx = null;
}
if (sourceIdentifier.equals(WebsocketConstants.UNIVERSAL_SOURCE_IDENTIFIER)) {
Parameter dispatchParam = transportOut.getParameter(WebsocketConstants.WEBSOCKET_OUTFLOW_DISPATCH_SEQUENCE);
if (dispatchParam != null) {
dispatchSequence = dispatchParam.getParameterElement().getText();
}
Parameter errorParam = transportOut.getParameter(WebsocketConstants.WEBSOCKET_OUTFLOW_DISPATCH_FAULT_SEQUENCE);
if (errorParam != null) {
dispatchErrorSequence = errorParam.getParameterElement().getText();
}
}
final EventLoopGroup group = new NioEventLoopGroup();
handler = new WebSocketClientHandler(WebSocketClientHandshakerFactory.newHandshaker(uri,
WebSocketVersion.V13,
contentType != null ? SubprotocolBuilderUtil.contentTypeToSyanapeSubprotocol(contentType) : null,
false,
new DefaultHttpHeaders()));
Bootstrap b = new Bootstrap();
b.group(group).channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline p = ch.pipeline();
if (sslCtx != null) {
p.addLast(sslCtx.newHandler(ch.alloc(), host, port));
}
p.addLast(new HttpClientCodec(), new HttpObjectAggregator(8192),
new WebSocketFrameAggregator(Integer.MAX_VALUE), handler);
}
});
Channel ch = b.connect(uri.getHost(), port).sync().channel();
ch.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture channelFuture) throws Exception {
group.shutdownGracefully();
removeChannelHandler(sourceIdentifier, getClientHandlerIdentifier(uri));
}
});
handler.setDispatchSequence(dispatchSequence);
handler.setDispatchErrorSequence(dispatchErrorSequence);
addChannelHandler(sourceIdentifier, getClientHandlerIdentifier(uri), handler);
return handler;
} catch (InterruptedException e) {
log.error("Interruption occured while connecting to the remote WS endpoint", e);
} catch (SSLException e) {
log.error("Error occurred while building the SSL context fo WSS endpoint", e);
}
return null;
}
public void addChannelHandler(String sourceIdentifier,
String clientIdentifier,
WebSocketClientHandler clientHandler) {
ConcurrentHashMap<String, WebSocketClientHandler> handlerMap =
channelHandlerPool.get(sourceIdentifier);
if (handlerMap == null) {
handlerMap = new ConcurrentHashMap<String, WebSocketClientHandler>();
handlerMap.put(clientIdentifier, clientHandler);
channelHandlerPool.put(sourceIdentifier, handlerMap);
} else {
handlerMap.put(clientIdentifier, clientHandler);
}
}
public WebSocketClientHandler getChannelHandlerFromPool(String sourceIdentifier,
String clientIdentifier) {
ConcurrentHashMap<String, WebSocketClientHandler> handlerMap =
channelHandlerPool.get(sourceIdentifier);
if (handlerMap == null) {
return null;
} else {
return handlerMap.get(clientIdentifier);
}
}
public void removeChannelHandler(String sourceIdentifier,
String clientIdentifier) {
ConcurrentHashMap<String, WebSocketClientHandler> handlerMap =
channelHandlerPool.get(sourceIdentifier);
handlerMap.remove(clientIdentifier);
}
}