package org.talend.esb.policy.correlation.impl; import java.io.IOException; import java.util.Arrays; import java.util.Collection; import javax.xml.parsers.ParserConfigurationException; import org.apache.cxf.binding.soap.SoapBinding; import org.apache.cxf.binding.soap.saaj.SAAJInInterceptor; import org.apache.cxf.binding.soap.saaj.SAAJStreamWriter; import org.apache.cxf.interceptor.Fault; import org.apache.cxf.message.Exchange; import org.apache.cxf.message.Message; import org.apache.cxf.message.MessageUtils; import org.apache.cxf.phase.AbstractPhaseInterceptor; import org.apache.cxf.phase.Phase; import org.apache.cxf.ws.addressing.ContextUtils; import org.apache.cxf.ws.policy.AbstractPolicyInterceptorProvider; import org.apache.cxf.ws.policy.AssertionInfo; import org.apache.cxf.ws.policy.AssertionInfoMap; import org.talend.esb.policy.correlation.CorrelationIDCallbackHandler; import org.talend.esb.policy.correlation.feature.CorrelationIDFeature; import org.talend.esb.policy.correlation.impl.CorrelationIDAssertion.MethodType; import org.w3c.dom.NodeList; import org.xml.sax.SAXException; public class CorrelationIDInterceptorProvider extends AbstractPolicyInterceptorProvider { private static final long serialVersionUID = 5698743589425687361L; public CorrelationIDInterceptorProvider() { super(Arrays.asList(CorrelationIDPolicyBuilder.CORRELATION_ID)); this.getOutInterceptors().add(new CorrelationIDPolicyOutInterceptor()); this.getOutFaultInterceptors().add(new CorrelationIDPolicyOutInterceptor()); this.getInInterceptors().add(new CorrelationIDPolicyInInterceptor()); this.getInFaultInterceptors().add(new CorrelationIDPolicyInInterceptor()); // Selector registers SAAJ interceptor for Soap messages only CorrelationIDFeatureSelectorInterceptor selector = new CorrelationIDFeatureSelectorInterceptor(); this.getInInterceptors().add(selector); this.getInFaultInterceptors().add(selector); } static class CorrelationIDPolicyOutInterceptor extends AbstractPhaseInterceptor<Message> { public CorrelationIDPolicyOutInterceptor() { super(Phase.PRE_STREAM); } @Override public void handleMessage(Message message) throws Fault { try { process(message); } catch (SAXException e) { throw new Fault(e); } catch (IOException e) { throw new Fault(e); } catch (ParserConfigurationException e) { throw new Fault(e); } } } static class CorrelationIDPolicyInInterceptor extends AbstractPhaseInterceptor<Message> { public CorrelationIDPolicyInInterceptor() { super(Phase.PRE_PROTOCOL); addAfter(SAAJInInterceptor.class.getName()); } @Override public void handleMessage(Message message) throws Fault { try { process(message); } catch (SAXException e) { throw new Fault(e); } catch (IOException e) { throw new Fault(e); } catch (ParserConfigurationException e) { throw new Fault(e); } } } static void process(Message message) throws SAXException, IOException, ParserConfigurationException { AssertionInfoMap aim = message.get(AssertionInfoMap.class); if (aim != null) { Collection<AssertionInfo> ais = aim.get(CorrelationIDPolicyBuilder.CORRELATION_ID); if (ais == null) { return; } for (AssertionInfo ai : ais) { if (ai.getAssertion() instanceof CorrelationIDAssertion) { CorrelationIDAssertion cAssertion = (CorrelationIDAssertion) ai.getAssertion(); MethodType mType = cAssertion.getMethodType(); // String value = cAssetrion.getValue(); String correlationId = null; // get ID from Http header correlationId = CorrelationIdProtocolHeaderCodec.readCorrelationId(message); // get ID from SOAP header if (null == correlationId) { correlationId = CorrelationIdSoapCodec.readCorrelationId(message); } // get from message if (null == correlationId) { // Get ID from Message correlationId = (String) message.get(CorrelationIDFeature.MESSAGE_CORRELATION_ID); } if ((message.getContent(javax.xml.stream.XMLStreamWriter.class) != null) && (message.getContent(javax.xml.stream.XMLStreamWriter.class) instanceof SAAJStreamWriter)) { NodeList nodeList = ((SAAJStreamWriter) message .getContent(javax.xml.stream.XMLStreamWriter.class)) .getDocument() .getElementsByTagNameNS("http://www.talend.com/esb/sam/correlationId/v1", "correlationId"); if(nodeList.getLength()>0) { correlationId = nodeList.item(0).getTextContent(); } } // get from message exchange if (null == correlationId) { // Get ID from Message exchange Exchange ex = message.getExchange(); if (null != ex) { Message reqMsg = null; if (MessageUtils.isOutbound(message)) { reqMsg = ex.getInMessage(); } else { reqMsg = ex.getOutMessage(); } if (null != reqMsg) { correlationId = (String) reqMsg.get(CorrelationIDFeature.MESSAGE_CORRELATION_ID); } } } // If correlationId is null we should add it to headers if (null == correlationId) { if (MethodType.XPATH.equals(mType)) { XPathProcessor proc = new XPathProcessor(cAssertion, message); correlationId = proc.getCorrelationID(); } else if (MethodType.CALLBACK.equals(mType)){ CorrelationIDCallbackHandler handler = (CorrelationIDCallbackHandler) message .get(CorrelationIDFeature.CORRELATION_ID_CALLBACK_HANDLER); if (null == handler) { handler = (CorrelationIDCallbackHandler) message .getContextualProperty(CorrelationIDFeature.CORRELATION_ID_CALLBACK_HANDLER); } if (handler != null) correlationId = handler.getCorrelationId(); } // Generate new ID if it was not set in callback or // request if (null == correlationId) { correlationId = ContextUtils.generateUUID(); } } message.put(CorrelationIDFeature.MESSAGE_CORRELATION_ID, correlationId); // if (!MessageUtils.isRequestor(message) && // MessageUtils.isOutbound(message)) {// RESP_OUT if (isRestMessage(message)) { // Add correlationId to http header if (null == CorrelationIdProtocolHeaderCodec.readCorrelationId(message)) { CorrelationIdProtocolHeaderCodec.writeCorrelationId(message, correlationId); } } else { // Add correlationId to soap header if (null == CorrelationIdSoapCodec.readCorrelationId(message)) { CorrelationIdSoapCodec.writeCorrelationId(message, correlationId); } } // } ai.setAsserted(true); } } } } private static boolean isRestMessage(Message message) { return !(message.getExchange().getBinding() instanceof SoapBinding); } }