package org.talend.esb.policy.correlation.impl;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.activation.DataSource;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.soap.SOAPMessage;
import javax.xml.stream.XMLStreamReader;
import javax.xml.stream.XMLStreamWriter;
import javax.xml.transform.Source;
import javax.xml.transform.stream.StreamSource;
import org.apache.commons.jxpath.JXPathContext;
import org.apache.commons.jxpath.JXPathException;
import org.apache.cxf.databinding.DataWriter;
import org.apache.cxf.io.CachedOutputStream;
import org.apache.cxf.message.Exchange;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageContentsList;
import org.apache.cxf.message.MessageUtils;
import org.apache.cxf.service.Service;
import org.apache.cxf.service.model.BindingOperationInfo;
import org.apache.cxf.service.model.MessagePartInfo;
import org.apache.cxf.staxutils.StaxSource;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.cxf.wsdl.interceptors.BareOutInterceptor;
import org.apache.neethi.Assertion;
import org.talend.esb.policy.correlation.impl.xpath.XpathNamespace;
import org.talend.esb.policy.correlation.impl.xpath.XpathPart;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
public class XPathProcessor extends BareOutInterceptor {
public static final String CORRELATION_NAME_SEPARATOR = "#";
public static final String CORRELATION_PART_SEPARATOR = ";";
public static final String CORRELATION_PART_NAME_VALUE_SEPARATOR = "=";
public static String TEMP_CORRELATION_ID = "org.talend.esb.temp.correlation.id";
public static String CORRELATION_ID_XPATH_ASSERTION = "org.talend.esb.correlation-id.xpath.assertion";
public static String ORIGINAL_OUT_STREAM_CTX_PROPERTY_NAME =
"org.talend.correlation.id.original.out.stream";
private ByteArrayOutputStream buffer;
private XMLStreamWriter xmlWriter;
private Message message;
private Assertion assertion;
public XPathProcessor(Assertion assertion, Message message) {
super();
this.message = message;
this.assertion=assertion;
buffer = new ByteArrayOutputStream();
xmlWriter = StaxUtils.createXMLStreamWriter(buffer,
getEncoding(message));
}
@Override
protected void writeParts(Message message, Exchange exchange,
BindingOperationInfo operation, MessageContentsList objs,
List<MessagePartInfo> parts) {
Service service = exchange.getService();
DataWriter<XMLStreamWriter> dataWriter = getDataWriter(message,
service, XMLStreamWriter.class);
for (MessagePartInfo part : parts) {
if (objs.hasValue(part)) {
Object o = objs.get(part);
try {
if (o instanceof Source) {
XMLStreamReader reader = null;
if(o instanceof DataSource){
DataSource s = (DataSource)o;
reader = StaxUtils.createXMLStreamReader(s.getInputStream());
}else if(o instanceof StreamSource){
StreamSource s = (StreamSource)o;
reader = StaxUtils.createXMLStreamReader(s.getInputStream());
} else if(o instanceof StaxSource){
StaxSource s = (StaxSource)o;
reader = s.getXMLStreamReader();
}
if(reader!=null){
// Read original Stream data to buffer
CachedOutputStream cos = new CachedOutputStream();
StaxUtils.copy(reader, cos);
reader.close();
StaxUtils.copy(StaxUtils.createXMLStreamReader(cos.getInputStream()), xmlWriter);
// Replace original source by cached one
StaxSource source = new StaxSource(StaxUtils.createXMLStreamReader(cos.getInputStream()));
objs.put(part, source);
}else{
dataWriter.write(o, part, xmlWriter);
}
} else {
dataWriter.write(o, part, xmlWriter);
}
} catch (Exception e) {
throw new RuntimeException("Can not read part of SOAP body", e);
}
}
}
try {
xmlWriter.flush();
} catch (Exception e) {
}
}
@Override
protected <T> DataWriter<T> getDataWriter(Message message, Service service,
Class<T> output) {
DataWriter<T> writer = service.getDataBinding().createWriter(output);
writer.setProperty(DataWriter.ENDPOINT, message.getExchange()
.getEndpoint());
writer.setProperty(Message.class.getName(), message);
return writer;
}
private String getEncoding(Message message) {
Exchange ex = message.getExchange();
String encoding = (String) message.get(Message.ENCODING);
if (encoding == null && ex.getInMessage() != null) {
encoding = (String) ex.getInMessage().get(Message.ENCODING);
message.put(Message.ENCODING, encoding);
}
if (encoding == null) {
encoding = "UTF-8";
message.put(Message.ENCODING, encoding);
}
return encoding;
}
public String getCorrelationID() {
CorrelationIDAssertion cAssertion = null;
if(!(assertion instanceof CorrelationIDAssertion)){
throw new RuntimeException(
"Can not find correlation assertion");
}
cAssertion = (CorrelationIDAssertion)assertion;
Node body = getSoapBody(message);
if (body == null) {
throw new RuntimeException(
"SoapBody elements are not found in soap message");
}
List<XpathPart> parts = cAssertion.getCorrelationParts();
if(parts==null || parts.isEmpty()) return null;
List<XpathNamespace> namespaces = cAssertion.getCorrelationNamespaces();
Map<String, String> res = processJXpathParts(parts, namespaces, body);
return buildCorrelationIdFromXpathParts(parts,
cAssertion.getCorrelationName(), res);
}
private Node getSoapBody(Message message) {
if(!MessageUtils.isOutbound(message)){
//processing of incoming message
try{
if(message.getContent(SOAPMessage.class) != null){
SOAPMessage soap = (SOAPMessage)message.getContent(SOAPMessage.class);
return soap.getSOAPBody();
}else{
throw new RuntimeException("Can not find SOAP message in context");
}
}catch(Exception ex){
throw new RuntimeException("Can not read SOAP body: " + ex);
}
}else{
// processing of outgoing message
// try to build SoapBody
loadSoapBodyToBuffer(message);
try {
DocumentBuilderFactory builderFactory =
DocumentBuilderFactory.newInstance();
builderFactory.setNamespaceAware(true);
DocumentBuilder builder = builderFactory.newDocumentBuilder();
Document doc = builder.parse(
new ByteArrayInputStream(buffer.toByteArray()));
return (Node)doc;
} catch (Exception e) {
throw new RuntimeException("Can not read SOAP body: " + e);
}
}
}
private void loadSoapBodyToBuffer(Message message){
handleMessage(message);
}
private String buildCorrelationIdFromXpathParts(
final List<XpathPart> parts, final String cName, final Map<String, String> partsValues) {
StringBuilder builder = new StringBuilder();
if (cName != null) {
builder.append(cName);
builder.append(CORRELATION_NAME_SEPARATOR);
}
boolean firstPart = true;
for (XpathPart part : parts) {
String partName = part.getName();
String partValue = partsValues.get(part.getXpath());
if(partValue!=null){
if(!firstPart){
//Do not add part separator for first part
builder.append(CORRELATION_PART_SEPARATOR);
}else{
firstPart = false;
}
if (partName != null) {
builder.append(partName);
builder.append(CORRELATION_PART_NAME_VALUE_SEPARATOR);
}
builder.append(partValue);
}
}
return builder.toString();
}
private Map<String, String> processJXpathParts(List<XpathPart> parts,
List<XpathNamespace> namespaces, Node body){
Map<String, String> resultMap = new HashMap<String, String>();
JXPathContext messageContext = JXPathContext.newContext(body);
if(namespaces!=null){
for (XpathNamespace namespace : namespaces) {
String prefix = namespace.getPrefix();
String uri = namespace.getUri();
if(null != uri && null != prefix){
messageContext.registerNamespace(prefix, uri);
}
}
}
for (XpathPart part : parts) {
try {
JXPathContext.compile(part.getXpath());
} catch (JXPathException ex) {
throw new RuntimeException("Validation of XPATH expression"
+ "{ name: " + part.getName() + "; xpath: "
+ part.getXpath() + " } failed", ex);
}
try {
Object val = messageContext.getValue(part.getXpath());
String result = (val==null)?null:val.toString();
resultMap.put(part.getXpath(), val.toString());
if((result==null || result.isEmpty()) && !part.isOptional()){
throw new RuntimeException(
"Can not evaluate Xpath expression" + "{ name: "
+ part.getName() + "; xpath: "
+ part.getXpath() + " }");
}
} catch (RuntimeException ex) {
if (!part.isOptional()) {
throw new RuntimeException(
"Evaluation of XPATH expression" + "{ name: "
+ part.getName() + "; xpath: "
+ part.getXpath() + " } failed", ex);
}
}
}
return resultMap;
}
}