package org.talend.esb.mep.requestcallback.impl; import java.util.List; import java.util.Map; import javax.xml.bind.JAXBException; import javax.xml.namespace.QName; import org.apache.cxf.binding.soap.SoapMessage; import org.apache.cxf.headers.Header; import org.apache.cxf.interceptor.Fault; import org.apache.cxf.jaxb.JAXBDataBinding; import org.apache.cxf.jaxws.EndpointImpl; import org.apache.cxf.message.Exchange; import org.apache.cxf.message.Message; import org.apache.cxf.phase.AbstractPhaseInterceptor; import org.apache.cxf.phase.Phase; import org.apache.cxf.ws.addressing.AddressingProperties; import org.apache.cxf.ws.addressing.AttributedURIType; import org.apache.cxf.ws.addressing.ContextUtils; import org.apache.cxf.ws.addressing.EndpointReferenceType; import org.apache.cxf.ws.addressing.JAXWSAConstants; import org.apache.cxf.ws.addressing.MAPAggregator; import org.apache.cxf.ws.addressing.RelatesToType; //import org.apache.cxf.ws.addressing.impl.AddressingProperties; import org.apache.cxf.ws.security.SecurityConstants; import org.talend.esb.mep.requestcallback.feature.CallContext; import org.talend.esb.mep.requestcallback.feature.RequestCallbackFeature; import org.talend.esb.sam.agent.message.FlowIdHelper; /** * The Class CompressionOutInterceptor. */ public class RequestCallbackOutInterceptor extends AbstractPhaseInterceptor<SoapMessage> { public RequestCallbackOutInterceptor() { super(Phase.PRE_LOGICAL); addBefore(MAPAggregator.class.getName()); } /** * {@inheritDoc} */ @Override public void handleMessage(SoapMessage message) throws Fault { final Exchange e = message.getExchange(); if (!e.isOneWay()) { return; } doHandleSoapMessage(message); } private void doHandleSoapMessage(SoapMessage message) throws Fault { final Object callbackEndpoint = message.getContextualProperty( RequestCallbackFeature.CALLBACK_ENDPOINT_PROPERTY_NAME); if (callbackEndpoint != null) { final String callbackEndpointAddress; if (callbackEndpoint instanceof String) { callbackEndpointAddress = (String) callbackEndpoint; } else if (callbackEndpoint instanceof EndpointImpl) { callbackEndpointAddress = ((EndpointImpl) callbackEndpoint).getAddress(); } else { throw new IllegalArgumentException("Unsupported type of endpoint. "); } doHandleRequestSoapMessage(message, callbackEndpointAddress); return; } final CallContext ctx = (CallContext) message.getContextualProperty( RequestCallbackFeature.CALLCONTEXT_PROPERTY_NAME); if (ctx != null) { doHandleCallbackSoapMessage(message, ctx); return; } } private void doHandleRequestSoapMessage( SoapMessage message, String callbackEndpoint) throws Fault { final String callId = ContextUtils.generateUUID(); message.getHeaders().add(createHeader( RequestCallbackFeature.CALL_ID_HEADER_NAME, callId)); Map<String, Object> requestInfo = getCallInfo(message); if (requestInfo != null) { requestInfo.put(RequestCallbackFeature.CALL_ID_NAME, callId); } aggregateAddressing(message, callbackEndpoint, null); } private void doHandleCallbackSoapMessage( SoapMessage message, CallContext callContext) throws Fault { final String callId = callContext.getCallId(); final String correlationID = callContext.getCorrelationId(); final String callbackId = ContextUtils.generateUUID(); List<Header> headers = message.getHeaders(); if(correlationID!=null){ message.getHeaders().add(createHeader( RequestCallbackFeature.CORRELATION_ID_HEADER_NAME, correlationID)); } message.getHeaders().add(createHeader( RequestCallbackFeature.CALL_ID_HEADER_NAME, callId)); headers.add(createHeader( RequestCallbackFeature.CALLBACK_ID_HEADER_NAME, callbackId)); Map<String, Object> requestInfo = getCallInfo(message); if (requestInfo != null) { requestInfo.put(RequestCallbackFeature.CALL_ID_NAME, callId); requestInfo.put(RequestCallbackFeature.CALLBACK_ID_NAME, callbackId); } String flowId = callContext.getFlowId(); if (flowId != null && !flowId.isEmpty()) { FlowIdHelper.setFlowId(message, flowId); } aggregateAddressing(message, null, callContext.getRequestId()); // In case of encryption propagate stored requestor certificate to callback response propagateRequestorCertificate(message, callContext); } private void aggregateAddressing( SoapMessage message, String callbackEndpoint, String relatesTo) { final AddressingProperties maps = initAddressingProperties(message); if (callbackEndpoint != null) { EndpointReferenceType replyTo= maps.getReplyTo(); if (replyTo == null || ContextUtils.isGenericAddress(replyTo)) { EndpointReferenceType replyToRef = new EndpointReferenceType(); AttributedURIType address = new AttributedURIType(); address.setValue(callbackEndpoint); replyToRef.setAddress(address); maps.setReplyTo(replyToRef); } } if (maps.getRelatesTo() == null) { RelatesToType relatesToAttr = new RelatesToType(); relatesToAttr.setRelationshipType("message"); relatesToAttr.setValue(relatesTo); maps.setRelatesTo(relatesToAttr); } } private static Header createHeader(QName headerName, String value) throws Fault { try { return new Header(headerName, value, new JAXBDataBinding(String.class)); } catch (JAXBException e) { throw new Fault(e); } } private static AddressingProperties initAddressingProperties(SoapMessage message) { AddressingProperties maps = (AddressingProperties) message.getContextualProperty( JAXWSAConstants.CLIENT_ADDRESSING_PROPERTIES); if (maps == null) { maps = new AddressingProperties(); message.put(JAXWSAConstants.CLIENT_ADDRESSING_PROPERTIES, maps); } return maps; } @SuppressWarnings("unchecked") private static Map<String, Object> getCallInfo(SoapMessage message) { return (Map<String, Object>) message.getContextualProperty( RequestCallbackFeature.CALL_INFO_PROPERTY_NAME); } private static void propagateRequestorCertificate(Message message, CallContext callContext) { if (callContext.getRequestorSignatureCertificate() != null) { message.put(SecurityConstants.ENCRYPT_CERT, callContext.getRequestorSignatureCertificate()); } } }