/**
* Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.wso2.carbon.websocket.transport;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import org.apache.axiom.om.OMAbstractFactory;
import org.apache.axiom.om.OMElement;
import org.apache.axiom.soap.SOAPEnvelope;
import org.apache.axiom.soap.SOAPFactory;
import org.apache.axiom.util.UIDGenerator;
import org.apache.axis2.AxisFault;
import org.apache.axis2.context.ConfigurationContext;
import org.apache.axis2.context.OperationContext;
import org.apache.axis2.context.ServiceContext;
import org.apache.axis2.builder.Builder;
import org.apache.axis2.builder.BuilderUtil;
import org.apache.axis2.builder.SOAPBuilder;
import org.apache.axis2.description.InOutAxisOperation;
import org.apache.axis2.transport.TransportUtils;
import org.apache.commons.io.input.AutoCloseInputStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.synapse.SynapseConstants;
import org.apache.synapse.core.axis2.MessageContextCreatorForAxis2;
import org.apache.synapse.inbound.InboundEndpointConstants;
import org.apache.synapse.inbound.InboundResponseSender;
import org.apache.synapse.mediators.MediatorFaultHandler;
import org.apache.synapse.mediators.base.SequenceMediator;
import org.wso2.carbon.core.multitenancy.utils.TenantAxisUtils;
import org.wso2.carbon.utils.multitenancy.MultitenantConstants;
import org.wso2.carbon.websocket.transport.service.ServiceReferenceHolder;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
public class WebSocketClientHandler extends SimpleChannelInboundHandler<Object> {
private final WebSocketClientHandshaker handshaker;
private ChannelPromise handshakeFuture;
private static final Log log = LogFactory.getLog(WebSocketClientHandler.class);
private String dispatchSequence;
private String dispatchErrorSequence;
private ChannelHandlerContext ctx;
private InboundResponseSender responseSender;
private String tenantDomain;
public void setTenantDomain(String tenantDomain) {
this.tenantDomain = tenantDomain;
}
public WebSocketClientHandler(WebSocketClientHandshaker handshaker) {
this.handshaker = handshaker;
}
public void setDispatchSequence(String dispatchSequence) {
this.dispatchSequence = dispatchSequence;
}
public void setDispatchErrorSequence(String dispatchErrorSequence) {
this.dispatchErrorSequence = dispatchErrorSequence;
}
public ChannelFuture handshakeFuture() {
return handshakeFuture;
}
public ChannelHandlerContext getChannelHandlerContext() {
return this.ctx;
}
public void registerWebsocketResponseSender(InboundResponseSender responseSender) {
this.responseSender = responseSender;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
handshakeFuture = ctx.newPromise();
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
this.ctx = ctx;
this.handshaker.handshake(ctx.channel());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
if (log.isDebugEnabled()) {
log.debug("WebSocket client disconnected on context id : " + ctx.channel().toString());
}
}
public void handleHandshake(ChannelHandlerContext ctx, FullHttpResponse msg) {
if (!handshaker.isHandshakeComplete()) {
handshaker.finishHandshake(ctx.channel(), (FullHttpResponse) msg);
if (log.isDebugEnabled()) {
log.debug("WebSocket client connected to remote WS endpoint on context id : " + ctx.channel().toString());
}
handshakeFuture.setSuccess();
return;
}
}
public void acknowledgeHandshake() {
try {
if (handshaker.isHandshakeComplete()) {
if (responseSender != null) {
org.apache.synapse.MessageContext synCtx = getSynapseMessageContext(tenantDomain);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_TARGET_HANDSHAKE_PRESENT, true);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_TARGET_HANDLER_CONTEXT, ctx);
injectToSequence(synCtx, dispatchSequence, dispatchErrorSequence);
}
}
} catch (Exception e) {
log.error("Exception occured while injecting websocket frames to the Synapse engine", e);
}
}
public void handleTargetWebsocketChannelTermination(WebSocketFrame frame) throws AxisFault {
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain()).addListener(ChannelFutureListener.CLOSE);
}
public void handleWebsocketBinaryFrame(WebSocketFrame frame) throws AxisFault {
org.apache.synapse.MessageContext synCtx = getSynapseMessageContext(tenantDomain);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_BINARY_FRAME_PRESENT, true);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_BINARY_FRAME, frame);
injectToSequence(synCtx, dispatchSequence, dispatchErrorSequence);
}
public void handlePassthroughTextFrame(WebSocketFrame frame) throws AxisFault {
org.apache.synapse.MessageContext synCtx = getSynapseMessageContext(tenantDomain);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_TEXT_FRAME_PRESENT, true);
synCtx.setProperty(WebsocketConstants.WEBSOCKET_TEXT_FRAME, frame);
injectToSequence(synCtx, dispatchSequence, dispatchErrorSequence);
}
public void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) throws AxisFault {
try {
if (handshaker.isHandshakeComplete()) {
if (frame instanceof CloseWebSocketFrame) {
handleTargetWebsocketChannelTermination(frame);
return;
} else if ((frame instanceof BinaryWebSocketFrame) && ((handshaker.actualSubprotocol() == null) ||
((handshaker.actualSubprotocol() != null) &&
!handshaker.actualSubprotocol().contains(WebsocketConstants.SYNAPSE_SUBPROTOCOL_PREFIX)))) {
handleWebsocketBinaryFrame(frame);
return;
} else if ((frame instanceof TextWebSocketFrame) && ((handshaker.actualSubprotocol() == null) ||
((handshaker.actualSubprotocol() != null) &&
!handshaker.actualSubprotocol().contains(WebsocketConstants.SYNAPSE_SUBPROTOCOL_PREFIX)))) {
handlePassthroughTextFrame(frame);
return;
} else if ((frame instanceof TextWebSocketFrame) &&
((handshaker.actualSubprotocol() != null) &&
handshaker.actualSubprotocol().contains(WebsocketConstants.SYNAPSE_SUBPROTOCOL_PREFIX))) {
org.apache.synapse.MessageContext synCtx = getSynapseMessageContext(tenantDomain);
String message = ((TextWebSocketFrame) frame).text();
String contentType = SubprotocolBuilderUtil.syanapeSubprotocolToContentType(handshaker.actualSubprotocol());
org.apache.axis2.context.MessageContext axis2MsgCtx =
((org.apache.synapse.core.axis2.Axis2MessageContext) synCtx)
.getAxis2MessageContext();
Builder builder = null;
if (contentType == null) {
log.debug("No content type specified. Using SOAP builder.");
builder = new SOAPBuilder();
} else {
int index = contentType.indexOf(';');
String type = index > 0 ? contentType.substring(0, index)
: contentType;
try {
builder = BuilderUtil.getBuilderFromSelector(type, axis2MsgCtx);
} catch (AxisFault axisFault) {
log.error("Error while creating message builder :: "
+ axisFault.getMessage());
}
if (builder == null) {
if (log.isDebugEnabled()) {
log.debug("No message builder found for type '" + type
+ "'. Falling back to SOAP.");
}
builder = new SOAPBuilder();
}
}
OMElement documentElement = null;
InputStream in = new AutoCloseInputStream(new ByteArrayInputStream(message.getBytes()));
documentElement = builder.processDocument(in, contentType, axis2MsgCtx);
synCtx.setEnvelope(TransportUtils.createSOAPEnvelope(documentElement));
injectToSequence(synCtx, dispatchSequence, dispatchErrorSequence);
}
} else {
log.error("Handshake incomplete at target handler. Failed to inject websocket frames to Synapse engine");
}
} catch (Exception e) {
log.error("Exception occured while injecting websocket frames to the Synapse engine", e);
}
}
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpResponse) {
handleHandshake(ctx, (FullHttpResponse) msg);
} else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("Error encountered while processing the response", cause);
if (!handshakeFuture.isDone()) {
handshakeFuture.setFailure(cause);
}
ctx.close();
}
private org.apache.synapse.MessageContext getSynapseMessageContext(String tenantDomain) throws AxisFault {
org.apache.synapse.MessageContext synCtx = createSynapseMessageContext(tenantDomain);
if (responseSender != null) {
synCtx.setProperty(SynapseConstants.IS_INBOUND, true);
synCtx.setProperty(InboundEndpointConstants.INBOUND_ENDPOINT_RESPONSE_WORKER, responseSender);
}
synCtx.setProperty(WebsocketConstants.WEBSOCKET_SUBSCRIBER_PATH, handshaker.uri().toString());
return synCtx;
}
private static org.apache.synapse.MessageContext createSynapseMessageContext(String tenantDomain) throws AxisFault {
org.apache.axis2.context.MessageContext axis2MsgCtx = createAxis2MessageContext();
ServiceContext svcCtx = new ServiceContext();
OperationContext opCtx = new OperationContext(new InOutAxisOperation(), svcCtx);
axis2MsgCtx.setServiceContext(svcCtx);
axis2MsgCtx.setOperationContext(opCtx);
if (!tenantDomain.equals(MultitenantConstants.SUPER_TENANT_DOMAIN_NAME)) {
ConfigurationContext tenantConfigCtx =
TenantAxisUtils.getTenantConfigurationContext(tenantDomain,
axis2MsgCtx.getConfigurationContext());
axis2MsgCtx.setConfigurationContext(tenantConfigCtx);
axis2MsgCtx.setProperty(MultitenantConstants.TENANT_DOMAIN, tenantDomain);
} else {
axis2MsgCtx.setProperty(MultitenantConstants.TENANT_DOMAIN,
MultitenantConstants.SUPER_TENANT_DOMAIN_NAME);
}
SOAPFactory fac = OMAbstractFactory.getSOAP11Factory();
SOAPEnvelope envelope = fac.getDefaultEnvelope();
axis2MsgCtx.setEnvelope(envelope);
return MessageContextCreatorForAxis2.getSynapseMessageContext(axis2MsgCtx);
}
private static org.apache.axis2.context.MessageContext createAxis2MessageContext() {
org.apache.axis2.context.MessageContext axis2MsgCtx = new org.apache.axis2.context.MessageContext();
axis2MsgCtx.setMessageID(UIDGenerator.generateURNString());
axis2MsgCtx.setConfigurationContext(ServiceReferenceHolder.getInstance().getConfigurationContextService()
.getServerConfigContext());
axis2MsgCtx.setProperty(org.apache.axis2.context.MessageContext.CLIENT_API_NON_BLOCKING,
Boolean.FALSE);
axis2MsgCtx.setServerSide(true);
return axis2MsgCtx;
}
private void injectToSequence(org.apache.synapse.MessageContext synCtx,
String dispatchSequence, String dispatchErrorSequence) {
SequenceMediator injectingSequence = null;
if (dispatchSequence != null) {
injectingSequence = (SequenceMediator) synCtx.getSequence(dispatchSequence);
}
if (injectingSequence == null) {
injectingSequence = (SequenceMediator) synCtx.getMainSequence();
}
SequenceMediator faultSequence = getFaultSequence(synCtx, dispatchErrorSequence);
MediatorFaultHandler mediatorFaultHandler = new MediatorFaultHandler(faultSequence);
synCtx.pushFaultHandler(mediatorFaultHandler);
if (log.isDebugEnabled()) {
log.debug("injecting message to sequence : " + dispatchSequence);
}
synCtx.getEnvironment().injectMessage(synCtx, injectingSequence);
}
private SequenceMediator getFaultSequence(org.apache.synapse.MessageContext synCtx,
String dispatchErrorSequence) {
SequenceMediator faultSequence = null;
if (dispatchErrorSequence != null) {
faultSequence = (SequenceMediator) synCtx.getSequence(dispatchErrorSequence);
}
if (faultSequence == null) {
faultSequence = (SequenceMediator) synCtx.getFaultSequence();
}
return faultSequence;
}
}