/*
* Copyright WSO2, Inc. (http://wso2.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wso2.carbon.cloud.gateway.agent.transport;
import org.apache.axiom.om.OMAbstractFactory;
import org.apache.axiom.om.OMDocument;
import org.apache.axiom.soap.*;
import org.apache.axis2.AxisFault;
import org.apache.axis2.Constants;
import org.apache.axis2.addressing.AddressingConstants;
import org.apache.axis2.context.MessageContext;
import org.apache.axis2.engine.AxisEngine;
import org.apache.axis2.transport.TransportUtils;
import org.apache.axis2.transport.base.threads.WorkerPool;
import org.apache.axis2.transport.http.HTTPConstants;
import org.apache.axis2.transport.http.HTTPTransportUtils;
import org.apache.axis2.transport.http.util.RESTUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.thrift.TException;
import org.wso2.carbon.cloud.gateway.agent.CGAgentPollingTaskFlags;
import org.wso2.carbon.cloud.gateway.agent.heartbeat.CGAgentHeartBeatTask;
import org.wso2.carbon.cloud.gateway.agent.heartbeat.CGAgentHeartBeatTaskList;
import org.wso2.carbon.cloud.gateway.agent.observer.CGAgentObserver;
import org.wso2.carbon.cloud.gateway.agent.observer.CGAgentObserverImpl;
import org.wso2.carbon.cloud.gateway.agent.observer.CGAgentSubject;
import org.wso2.carbon.cloud.gateway.common.CGConstant;
import org.wso2.carbon.cloud.gateway.common.CGUtils;
import org.wso2.carbon.cloud.gateway.common.thrift.CGThriftClient;
import org.wso2.carbon.cloud.gateway.common.thrift.gen.Message;
import org.wso2.carbon.cloud.gateway.common.thrift.gen.NotAuthorizedException;
import org.wso2.carbon.context.CarbonContext;
import javax.xml.stream.XMLStreamException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* The polling task implementation for transport receiver, there will be a task per deployed
* service
*/
public class CGPollingTransportTaskManager {
private static final Log log = LogFactory.getLog(CGPollingTransportTaskManager.class);
private static final String WSA_TO = "To";
public enum STATE {STOPPED, STARTED, FAILURE}
private int concurrentClients = 1;
private String serviceName;
private CGPollingTransportEndpoint endpoint;
private WorkerPool workerPool = null;
private CGAgentSubject subject;
/**
* The token for secure communication
*/
private String token;
/**
* The size of a request message block.i.e. thirft client will request a message block of size
* requestBlockSize from the thrift server
*/
private int requestBlockSize;
/**
* The size of the response message block size that the thrift client should send server. i.e.
* thrift client will send a response message block of size responseBlockSize to the thirft
* server
*/
private int responseBlockSize;
/**
* The thirft server host name
*/
private String hostName;
/**
* The thirft server port the client should connect to
*/
private int port;
/**
* The client timeout when connecting to thrift server
*/
private int timeout;
private String trustStoreLocation;
private String trustStorePassWord;
/**
* Initial duration to suspend the polling tasks
*/
private int initialReconnectDuration = 10000;
/**
* progression factor for the heart beat task
*/
private double reconnectionProgressionFactor = 2.0;
/**
* Response message processing block size
*/
private int messageProcessingBlockSize;
/**
* The list of active polling tasks managed by this instance
*/
private final List<MessageExchangeTask> pollingTasks =
Collections.synchronizedList(new ArrayList<MessageExchangeTask>());
/**
* The number of worker thread per task for processing
*/
private int noOfDispatchingTask = 2;
private CGPollingTransportBuffers taskBuffers;
private long pollingTaskSuspendDuration = 15;
public void setConcurrentClients(int concurrentClients) {
this.concurrentClients = concurrentClients;
}
public String getServiceName() {
return serviceName;
}
public void setInitialReconnectDuration(int initialReconnectDuration) {
this.initialReconnectDuration = initialReconnectDuration;
}
public void setReconnectionProgressionFactor(double reconnectionProgressionFactor) {
this.reconnectionProgressionFactor = reconnectionProgressionFactor;
}
public void setHostName(String hostName) {
this.hostName = hostName;
}
public void setPort(int port) {
this.port = port;
}
public void setTimeout(int timeout) {
this.timeout = timeout;
}
public void setServiceName(String serviceName) {
this.serviceName = serviceName;
}
public void setEndpoint(CGPollingTransportEndpoint endpoint) {
this.endpoint = endpoint;
}
public void setWorkerPool(WorkerPool workerPool) {
this.workerPool = workerPool;
}
public void setRequestBlockSize(int requestBlockSize) {
this.requestBlockSize = requestBlockSize;
}
public void setResponseBlockSize(int responseBlockSize) {
this.responseBlockSize = responseBlockSize;
}
public void setToken(String token) {
this.token = token;
}
public void setTrustStoreLocation(String trustStoreLocation) {
this.trustStoreLocation = trustStoreLocation;
}
public void setTrustStorePassWord(String trustStorePassWord) {
this.trustStorePassWord = trustStorePassWord;
}
public void setNoOfDispatchingTask(int noOfDispatchingTask) {
this.noOfDispatchingTask = noOfDispatchingTask;
}
public void setSubject(CGAgentSubject subject) {
this.subject = subject;
}
public void setMessageProcessingBlockSize(int messageProcessingBlockSize) {
this.messageProcessingBlockSize = messageProcessingBlockSize;
}
public void setTaskBuffers(CGPollingTransportBuffers taskBuffers) {
this.taskBuffers = taskBuffers;
}
public void setPollingTaskSuspendDuration(long pollingTaskSuspendDuration) {
this.pollingTaskSuspendDuration = pollingTaskSuspendDuration;
}
public synchronized void start() {
// start the worker task for message dispatching from transport queue to actual
// processing pool
for (int i = 0; i < noOfDispatchingTask; i++) {
workerPool.execute(new MessageDispatchTask(taskBuffers));
}
// start receiving the messages on different N clients
for (int i = 0; i < concurrentClients; i++) {
CGThriftClient client = new CGThriftClient(
CGUtils.getCGThriftClient(
hostName,
port,
timeout,
trustStoreLocation,
trustStorePassWord));
workerPool.execute(new MessageExchangeTask(
client,
requestBlockSize,
responseBlockSize,
taskBuffers));
}
}
public synchronized void stop() {
synchronized (pollingTasks) {
for (MessageExchangeTask exchangeTask : pollingTasks) {
exchangeTask.requestShutDown();
}
}
log.info("Task manager for service '" + serviceName + "', shutdown");
}
/**
* A periodic task to poll remote Thrift server buffers and submitting messages for processing
*/
private final class MessageExchangeTask implements Runnable {
private CGThriftClient client;
private volatile STATE workerState = STATE.STOPPED;
private int responseBlockSize;
private int requestBlockSize;
private CGPollingTransportBuffers buffers;
private CarbonContext carbonContext;
private MessageExchangeTask(CGThriftClient client,
int requestBlockSize,
int responseBlockSize,
CGPollingTransportBuffers buffers) {
this.client = client;
this.requestBlockSize = requestBlockSize;
this.responseBlockSize = responseBlockSize;
this.buffers = buffers;
this.carbonContext = CarbonContext.getThreadLocalCarbonContext();
// add the created task to the task store
synchronized (pollingTasks) {
pollingTasks.add(this);
}
}
public void run() {
workerState = STATE.STARTED;
//if this service failed earlier make sure we start from fresh
String taskKey = hostName + ":" + port;
if (CGAgentHeartBeatTaskList.isScheduledHeartBeatTaskAvailable(taskKey)) {
CGAgentHeartBeatTaskList.removeScheduledHeartBeatTask(taskKey);
}
List<Message> requestMsgList, responseMsgList;
int responseMessageListSize;
// the busy loop which polls the thrift server for messages
try {
while (workerState == STATE.STARTED &&
!CGAgentPollingTaskFlags.isFlaggedForShutDown(serviceName)) {
try {
responseMsgList = buffers.getResponseMessageList(responseBlockSize);
responseMessageListSize = responseMsgList.size();
// submit the transport response buffer to server also process any requests
requestMsgList = client.exchange(responseMsgList, requestBlockSize, token);
if (requestMsgList != null && requestMsgList.size() > 0) {
buffers.getRequestMessageBuffer().addAll(requestMsgList);
}
// if there is no request messages AND response messages there is no point of polling in a busy
// loop, just wait some time and try again
if ((requestMsgList != null && requestMsgList.size() == 0) && responseMessageListSize == 0) {
try {
Thread.sleep(pollingTaskSuspendDuration);
} catch (InterruptedException ignore) {
// ignore the interrupted exception and make this thread sleep in next iteration
}
}
} catch (TException e) {
log.error("Polling Task Manager encountered an error..", e);
// should be a connection error with the remote server
// schedule a heart beat task and end this loop
registerObserver(hostName, serviceName, port);
scheduleHeartBeatTaskIfRequired(hostName, port);
return;
} catch (NotAuthorizedException e) {
// just logged the error and re-try in the next attempt
if (log.isDebugEnabled()) {
log.debug(e);
}
} catch (AxisFault e) {
// just log and re-try in the next attempt
if (log.isDebugEnabled()) {
log.debug(e);
}
}
}
} finally {
workerState = STATE.STOPPED;
synchronized (pollingTasks) {
pollingTasks.remove(this);
}
}
}
protected void requestShutDown() {
workerState = STATE.STOPPED;
}
private void registerObserver(String hostName, String serviceName, int port) {
CGAgentObserver o = new CGAgentObserverImpl(hostName, serviceName, port);
subject.addObserver(o);
}
private void scheduleHeartBeatTaskIfRequired(String host, int port) {
// scheduled a heat beat task for this host, if not already done
String heartBeatTaskKey = host + ":" + port;
if (!CGAgentHeartBeatTaskList.isScheduledHeartBeatTaskAvailable(heartBeatTaskKey)) {
CGAgentHeartBeatTaskList.addScheduledHeartBeatTask(heartBeatTaskKey);
workerPool.execute(new CGAgentHeartBeatTask(
subject,
reconnectionProgressionFactor,
initialReconnectDuration,
host,
port, carbonContext));
}
}
}
/**
* The message dispatch task which dispatch messages from the source buffers to actual
* processing logic
*/
private final class MessageDispatchTask implements Runnable {
private CGPollingTransportBuffers buffers;
private MessageDispatchTask(CGPollingTransportBuffers buffers) {
this.buffers = buffers;
}
public void run() {
while (true) {
Message msg = buffers.getRequestMessage();
if (msg != null) {
workerPool.execute(new MessageProcessingTask(msg, buffers));
}
}
}
}
/**
* Process any request messages
*/
private final class MessageProcessingTask implements Runnable {
private Message message;
private boolean isSOAP11;
private CGPollingTransportBuffers buffers;
private MessageProcessingTask(Message message, CGPollingTransportBuffers buffers) {
this.message = message;
this.buffers = buffers;
}
public void run() {
try {
handleIncomingMessage(message, buffers);
} catch (AxisFault axisFault) {
// there has been a fault while trying to execute the back end service
// send that fault to the client
try {
handleFaultMessage(message, buffers, axisFault);
} catch (Exception e) {
// do not let the task die!
log.error("Error while sending the fault message to the client. Client will not" +
" receive any errors!", e);
}
}
}
private void handleIncomingMessage(Message message, CGPollingTransportBuffers buffers) throws AxisFault {
if (message == null) {
log.warn("A null Message received!");
} else {
try {
MessageContext msgContext = endpoint.createMessageContext();
String msgId = message.getMessageId();
msgContext.setMessageID(msgId);
msgContext.setProperty(CGConstant.CG_CORRELATION_KEY, msgId);
msgContext.setProperty(CGConstant.CG_POLLING_TRANSPORT_BUF_KEY, buffers);
Map<String, String> trpHeaders = message.getTransportHeaders();
String contentType = message.getContentType();
HTTPTransportUtils.initializeMessageContext(
msgContext,
message.getSoapAction(),
message.getRequestURI(),
contentType);
msgContext.setProperty(Constants.OUT_TRANSPORT_INFO,
new CGPollingTransportOutTransportInfo(contentType));
if (message.isIsDoingREST()) {
msgContext.setAxisService(null); // fix the service dispatching
msgContext.setProperty(HTTPConstants.HTTP_METHOD, message.getHttpMethod());
RESTUtil.processXMLRequest(
msgContext,
new ByteArrayInputStream(message.getMessage()),
new ByteArrayOutputStream(),
contentType);
} else {
ByteArrayInputStream inputStream = new ByteArrayInputStream(message.getMessage());
msgContext.setProperty(Constants.Configuration.CONTENT_TYPE, contentType);
msgContext.setProperty(MessageContext.TRANSPORT_HEADERS, trpHeaders);
if (message.isIsDoingMTOM()) {
msgContext.setDoingMTOM(message.isIsDoingMTOM());
msgContext.setProperty(
org.apache.axis2.Constants.Configuration.ENABLE_MTOM,
org.apache.axis2.Constants.VALUE_TRUE);
} else if (message.isIsDoingREST()) {
msgContext.setDoingSwA(message.isIsDoingSwA());
msgContext.setProperty(
org.apache.axis2.Constants.Configuration.ENABLE_SWA,
org.apache.axis2.Constants.VALUE_TRUE);
}
InputStream gzipInputStream =
HTTPTransportUtils.handleGZip(msgContext, inputStream);
msgContext.setEnvelope(
TransportUtils.createSOAPMessage(
msgContext,
gzipInputStream,
contentType));
isSOAP11 = msgContext.isSOAP11();
populateIncomingTransporterName(msgContext);
AxisEngine.receive(msgContext);
}
} catch (XMLStreamException e) {
throw new AxisFault(e.getMessage(), e);
} catch (IOException e) {
throw new AxisFault(e.getMessage(), e);
}
}
}
private void handleFaultMessage(Message originalMsg,
CGPollingTransportBuffers buffers,
AxisFault axisFault) throws Exception {
Message thriftMsg = new Message();
thriftMsg.setMessageId(originalMsg.getMessageId());
SOAPFactory factory = (isSOAP11 ?
OMAbstractFactory.getSOAP11Factory() : OMAbstractFactory.getSOAP12Factory());
OMDocument soapFaultDocument = factory.createOMDocument();
SOAPEnvelope faultEnvelope = factory.getDefaultFaultEnvelope();
soapFaultDocument.addChild(faultEnvelope);
// create the fault element if it is needed
SOAPFault fault = faultEnvelope.getBody().getFault();
if (fault == null) {
fault = factory.createSOAPFault();
}
SOAPFaultCode code = factory.createSOAPFaultCode();
code.setText(axisFault.getMessage());
fault.setCode(code);
SOAPFaultReason reason = factory.createSOAPFaultReason();
reason.setText(axisFault.getMessage());
fault.setReason(reason);
ByteArrayOutputStream out = new ByteArrayOutputStream();
faultEnvelope.serialize(out);
thriftMsg.setMessage(out.toByteArray());
buffers.addResponseMessage(thriftMsg);
}
}
private void populateIncomingTransporterName(MessageContext messageContext) {
SOAPHeader header = messageContext.getEnvelope().getHeader();
if (header != null) {
ArrayList<SOAPHeaderBlock> addressingHeaders = header.getHeaderBlocksWithNSURI(AddressingConstants.Final.WSA_NAMESPACE);
if (addressingHeaders != null && addressingHeaders.size() > 0) {
for (SOAPHeaderBlock addressingHeader : addressingHeaders) {
if (WSA_TO.equals(addressingHeader.getLocalName())) {
String toAddress = addressingHeader.getText();
String[] address = toAddress.split(":");
if (address.length > 0) {
messageContext.setIncomingTransportName(address[0]);
}
break;
}
}
}
}
}
}